diff --git a/src/MaxKB-1.7.2/.dockerignore b/src/MaxKB-1.7.2/.dockerignore new file mode 100644 index 0000000..305b3d8 --- /dev/null +++ b/src/MaxKB-1.7.2/.dockerignore @@ -0,0 +1,2 @@ +.git* +.idea* diff --git a/src/MaxKB-1.7.2/.github/ISSUE_TEMPLATE/bug.yml b/src/MaxKB-1.7.2/.github/ISSUE_TEMPLATE/bug.yml new file mode 100644 index 0000000..77d9aec --- /dev/null +++ b/src/MaxKB-1.7.2/.github/ISSUE_TEMPLATE/bug.yml @@ -0,0 +1,61 @@ +name: BUG 提交 +description: 提交产品缺陷帮助我们更好的改进 +title: "[BUG]" +labels: "类型: 缺陷" +assignees: zyyfit +body: + - type: markdown + id: contacts_title + attributes: + value: "## 联系方式" + - type: input + id: contacts + validations: + required: false + attributes: + label: "联系方式" + description: "可以快速联系到您的方式:交流群号及昵称、邮箱等" + - type: markdown + id: environment + attributes: + value: "## 环境信息" + - type: input + id: version + validations: + required: true + attributes: + label: "MaxKB 版本" + description: "登录 MaxKB Web 控制台,在右上角关于页面查看当前版本。" + - type: markdown + id: details + attributes: + value: "## 详细信息" + - type: textarea + id: what-happened + attributes: + label: "问题描述" + description: "简要描述您碰到的问题" + validations: + required: true + - type: textarea + id: how-happened + attributes: + label: "重现步骤" + description: "如果操作可以重现该问题" + validations: + required: true + - type: textarea + id: expect + attributes: + label: "期待的正确结果" + - type: textarea + id: logs + attributes: + label: "相关日志输出" + description: "请复制并粘贴任何相关的日志输出。 这将自动格式化为代码,因此无需反引号。" + render: shell + - type: textarea + id: additional-information + attributes: + label: "附加信息" + description: "如果你还有其他需要提供的信息,可以在这里填写(可以提供截图、视频等)。" diff --git a/src/MaxKB-1.7.2/.github/ISSUE_TEMPLATE/config.yml b/src/MaxKB-1.7.2/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000..cd5a52f --- /dev/null +++ b/src/MaxKB-1.7.2/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: false +contact_links: + - name: 对 MaxKB 项目有其他问题 + url: https://bbs.fit2cloud.com/c/mk/11 + about: 如果你对 MaxKB 有其他想要提问的,我们欢迎到我们的官方社区进行提问。 \ No newline at end of file diff --git a/src/MaxKB-1.7.2/.github/ISSUE_TEMPLATE/feature.yml b/src/MaxKB-1.7.2/.github/ISSUE_TEMPLATE/feature.yml new file mode 100644 index 0000000..3c015c4 --- /dev/null +++ b/src/MaxKB-1.7.2/.github/ISSUE_TEMPLATE/feature.yml @@ -0,0 +1,36 @@ +name: 需求建议 +description: 提出针对本项目的想法和建议 +title: "[FEATURE]" +labels: enhancement +assignees: baixin513 +body: + - type: markdown + id: environment + attributes: + value: "## 环境信息" + - type: input + id: version + validations: + required: true + attributes: + label: "MaxKB 版本" + description: "登录 MaxKB Web 控制台,在右上角关于页面查看当前版本。" + - type: markdown + id: details + attributes: + value: "## 详细信息" + - type: textarea + id: description + attributes: + label: "请描述您的需求或者改进建议" + validations: + required: true + - type: textarea + id: solution + attributes: + label: "请描述你建议的实现方案" + - type: textarea + id: additional-information + attributes: + label: "附加信息" + description: "如果你还有其他需要提供的信息,可以在这里填写(可以提供截图、视频等)。" \ No newline at end of file diff --git a/src/MaxKB-1.7.2/.github/PULL_REQUEST_TEMPLATE.md b/src/MaxKB-1.7.2/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..1106f42 --- /dev/null +++ b/src/MaxKB-1.7.2/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,9 @@ +#### What this PR does / why we need it? + +#### Summary of your change + +#### Please indicate you've done the following: + +- [ ] Made sure tests are passing and test coverage is added if needed. +- [ ] Made sure commit message follow the rule of [Conventional Commits specification](https://www.conventionalcommits.org/). +- [ ] Considered the docs impact and opened a new docs issue or PR with docs changes if needed. \ No newline at end of file diff --git a/src/MaxKB-1.7.2/.github/workflows/build-and-push-python-pg.yml b/src/MaxKB-1.7.2/.github/workflows/build-and-push-python-pg.yml new file mode 100644 index 0000000..6b79916 --- /dev/null +++ b/src/MaxKB-1.7.2/.github/workflows/build-and-push-python-pg.yml @@ -0,0 +1,63 @@ +name: build-and-push-python-pg + +on: + workflow_dispatch: + inputs: + architecture: + description: 'Architecture' + required: true + default: 'linux/amd64' + type: choice + options: + - linux/amd64 + - linux/arm64 + - linux/amd64,linux/arm64 +jobs: + build-and-push-python-pg-to-ghcr: + runs-on: ubuntu-latest + steps: + - name: Check Disk Space + run: df -h + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + tool-cache: true + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: true + swap-storage: true + - name: Check Disk Space + run: df -h + - name: Checkout + uses: actions/checkout@v4 + with: + ref: main + - name: Prepare + id: prepare + run: | + DOCKER_IMAGE=ghcr.io/1panel-dev/maxkb-python-pg + DOCKER_PLATFORMS=${{ github.event.inputs.architecture }} + TAG_NAME=python3.11-pg15.8 + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" + echo ::set-output name=docker_image::${DOCKER_IMAGE} + echo ::set-output name=version::${TAG_NAME} + echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --no-cache \ + --build-arg VERSION=${TAG_NAME} \ + --build-arg BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ') \ + --build-arg VCS_REF=${GITHUB_SHA::8} \ + ${DOCKER_IMAGE_TAGS} . + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GH_TOKEN }} + - name: Docker Buildx (build-and-push) + run: | + docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile-python-pg \ No newline at end of file diff --git a/src/MaxKB-1.7.2/.github/workflows/build-and-push-vector-model.yml b/src/MaxKB-1.7.2/.github/workflows/build-and-push-vector-model.yml new file mode 100644 index 0000000..a054bc7 --- /dev/null +++ b/src/MaxKB-1.7.2/.github/workflows/build-and-push-vector-model.yml @@ -0,0 +1,68 @@ +name: build-and-push-vector-model + +on: + workflow_dispatch: + inputs: + dockerImageTag: + description: 'Docker Image Tag' + default: 'v1.0.1' + required: true + architecture: + description: 'Architecture' + required: true + default: 'linux/amd64' + type: choice + options: + - linux/amd64 + - linux/arm64 + - linux/amd64,linux/arm64 + +jobs: + build-and-push-vector-model-to-ghcr: + runs-on: ubuntu-latest + steps: + - name: Check Disk Space + run: df -h + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + tool-cache: true + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: true + swap-storage: true + - name: Check Disk Space + run: df -h + - name: Checkout + uses: actions/checkout@v4 + with: + ref: main + - name: Prepare + id: prepare + run: | + DOCKER_IMAGE=ghcr.io/1panel-dev/maxkb-vector-model + DOCKER_PLATFORMS=${{ github.event.inputs.architecture }} + TAG_NAME=${{ github.event.inputs.dockerImageTag }} + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" + echo ::set-output name=docker_image::${DOCKER_IMAGE} + echo ::set-output name=version::${TAG_NAME} + echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} --no-cache \ + --build-arg VERSION=${TAG_NAME} \ + --build-arg BUILD_DATE=$(date -u +'%Y-%m-%dT%H:%M:%SZ') \ + --build-arg VCS_REF=${GITHUB_SHA::8} \ + ${DOCKER_IMAGE_TAGS} . + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GH_TOKEN }} + - name: Docker Buildx (build-and-push) + run: | + docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile-vector-model \ No newline at end of file diff --git a/src/MaxKB-1.7.2/.github/workflows/build-and-push.yml b/src/MaxKB-1.7.2/.github/workflows/build-and-push.yml new file mode 100644 index 0000000..04e49a1 --- /dev/null +++ b/src/MaxKB-1.7.2/.github/workflows/build-and-push.yml @@ -0,0 +1,141 @@ +name: build-and-push + +run-name: 构建镜像并推送仓库 ${{ github.event.inputs.dockerImageTag }} (${{ github.event.inputs.registry }}) + +on: + workflow_dispatch: + inputs: + dockerImageTag: + description: 'Docker Image Tag' + default: 'v1.6.0-dev' + required: true + architecture: + description: 'Architecture' + required: true + default: 'linux/amd64' + type: choice + options: + - linux/amd64 + - linux/arm64 + - linux/amd64,linux/arm64 + registry: + description: 'Push To Registry' + required: true + default: 'fit2cloud-registry' + type: choice + options: + - fit2cloud-registry + - dockerhub + - dockerhub, fit2cloud-registry + +jobs: + build-and-push-to-fit2cloud-registry: + if: ${{ contains(github.event.inputs.registry, 'fit2cloud') }} + runs-on: ubuntu-latest + steps: + - name: Check Disk Space + run: df -h + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + tool-cache: true + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: true + swap-storage: true + - name: Check Disk Space + run: df -h + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.ref_name }} + - name: Prepare + id: prepare + run: | + DOCKER_IMAGE=${{ secrets.FIT2CLOUD_REGISTRY_HOST }}/maxkb/maxkb + DOCKER_PLATFORMS=${{ github.event.inputs.architecture }} + TAG_NAME=${{ github.event.inputs.dockerImageTag }} + if [[ ${TAG_NAME} == *dev* ]]; then + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}" + else + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" + fi + echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} \ + --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=${GITHUB_SHA::8} --no-cache \ + ${DOCKER_IMAGE_TAGS} . + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GH_TOKEN }} + - name: Login to FIT2CLOUD Registry + uses: docker/login-action@v3 + with: + registry: ${{ secrets.FIT2CLOUD_REGISTRY_HOST }} + username: ${{ secrets.FIT2CLOUD_REGISTRY_USERNAME }} + password: ${{ secrets.FIT2CLOUD_REGISTRY_PASSWORD }} + - name: Docker Buildx (build-and-push) + run: | + docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile + + build-and-push-to-dockerhub: + if: ${{ contains(github.event.inputs.registry, 'dockerhub') }} + runs-on: ubuntu-latest + steps: + - name: Check Disk Space + run: df -h + - name: Free Disk Space (Ubuntu) + uses: jlumbroso/free-disk-space@main + with: + tool-cache: true + android: true + dotnet: true + haskell: true + large-packages: true + docker-images: true + swap-storage: true + - name: Check Disk Space + run: df -h + - name: Checkout + uses: actions/checkout@v4 + with: + ref: ${{ github.ref_name }} + - name: Prepare + id: prepare + run: | + DOCKER_IMAGE=1panel/maxkb + DOCKER_PLATFORMS=${{ github.event.inputs.architecture }} + TAG_NAME=${{ github.event.inputs.dockerImageTag }} + if [[ ${TAG_NAME} == *dev* ]]; then + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME}" + else + DOCKER_IMAGE_TAGS="--tag ${DOCKER_IMAGE}:${TAG_NAME} --tag ${DOCKER_IMAGE}:latest" + fi + echo ::set-output name=buildx_args::--platform ${DOCKER_PLATFORMS} \ + --build-arg DOCKER_IMAGE_TAG=${{ github.event.inputs.dockerImageTag }} --build-arg BUILD_AT=$(TZ=Asia/Shanghai date +'%Y-%m-%dT%H:%M') --build-arg GITHUB_COMMIT=${GITHUB_SHA::8} --no-cache \ + ${DOCKER_IMAGE_TAGS} . + - name: Set up QEMU + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + - name: Login to GitHub Container Registry + uses: docker/login-action@v3 + with: + registry: ghcr.io + username: ${{ github.actor }} + password: ${{ secrets.GH_TOKEN }} + - name: Login to Docker Hub + uses: docker/login-action@v3 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + - name: Docker Buildx (build-and-push) + run: | + docker buildx build --output "type=image,push=true" ${{ steps.prepare.outputs.buildx_args }} -f installer/Dockerfile diff --git a/src/MaxKB-1.7.2/.github/workflows/create-pr-from-push.yml b/src/MaxKB-1.7.2/.github/workflows/create-pr-from-push.yml new file mode 100644 index 0000000..3e5ed91 --- /dev/null +++ b/src/MaxKB-1.7.2/.github/workflows/create-pr-from-push.yml @@ -0,0 +1,17 @@ +on: + push: + branches: + - 'pr@**' + - 'repr@**' + +name: 针对特定分支名自动创建 PR + +jobs: + generic_handler: + name: 自动创建 PR + runs-on: ubuntu-latest + steps: + - name: Create pull request + uses: jumpserver/action-generic-handler@master + env: + GITHUB_TOKEN: ${{ secrets.GH_TOKEN }} diff --git a/src/MaxKB-1.7.2/.github/workflows/sync2gitee.yml b/src/MaxKB-1.7.2/.github/workflows/sync2gitee.yml new file mode 100644 index 0000000..186cf15 --- /dev/null +++ b/src/MaxKB-1.7.2/.github/workflows/sync2gitee.yml @@ -0,0 +1,16 @@ +name: sync2gitee +on: [push] + +jobs: + repo-sync: + runs-on: ubuntu-latest + steps: + - name: Mirror the Github organization repos to Gitee. + uses: Yikun/hub-mirror-action@master + with: + src: 'github/1Panel-dev' + dst: 'gitee/fit2cloud-feizhiyun' + dst_key: ${{ secrets.GITEE_PRIVATE_KEY }} + dst_token: ${{ secrets.GITEE_TOKEN }} + static_list: "MaxKB" + force_update: true \ No newline at end of file diff --git a/src/MaxKB-1.7.2/.github/workflows/typos_check.yml b/src/MaxKB-1.7.2/.github/workflows/typos_check.yml new file mode 100644 index 0000000..0acbca9 --- /dev/null +++ b/src/MaxKB-1.7.2/.github/workflows/typos_check.yml @@ -0,0 +1,18 @@ +name: Typos Check +on: + push: + branches: + - main + pull_request: + types: [opened, synchronize, reopened] + +jobs: + run: + name: Spell Check with Typos + runs-on: ubuntu-latest + steps: + - name: Checkout Actions Repository + uses: actions/checkout@v2 + + - name: Check spelling + uses: crate-ci/typos@master diff --git a/src/MaxKB-1.7.2/.gitignore b/src/MaxKB-1.7.2/.gitignore new file mode 100644 index 0000000..af87ce4 --- /dev/null +++ b/src/MaxKB-1.7.2/.gitignore @@ -0,0 +1,185 @@ +# Mac +.DS_Store +*/.DS_Store + +# VS Code +.vscode +*.project +*.factorypath + +# IntelliJ IDEA +.idea/* +!.idea/icon.png +*.iws +*.iml +*.ipr + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script forms a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ +ui/package-lock.json +ui/node_modules +ui/dist +apps/static +models/ +!apps/**/models/ +data +.dev +poetry.lock +apps/setting/models_provider/impl/*/icon/ \ No newline at end of file diff --git a/src/MaxKB-1.7.2/.idea/icon.png b/src/MaxKB-1.7.2/.idea/icon.png new file mode 100644 index 0000000..7d9781e Binary files /dev/null and b/src/MaxKB-1.7.2/.idea/icon.png differ diff --git a/src/MaxKB-1.7.2/.typos.toml b/src/MaxKB-1.7.2/.typos.toml new file mode 100644 index 0000000..08f67e5 --- /dev/null +++ b/src/MaxKB-1.7.2/.typos.toml @@ -0,0 +1,4 @@ +[files] +extend-exclude = [ + 'apps/setting/models_provider/impl/*/icon/*' +] \ No newline at end of file diff --git a/src/MaxKB-1.7.2/CODE_OF_CONDUCT.md b/src/MaxKB-1.7.2/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..67b063a --- /dev/null +++ b/src/MaxKB-1.7.2/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +support@fit2cloud.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. \ No newline at end of file diff --git a/src/MaxKB-1.7.2/CONTRIBUTING.md b/src/MaxKB-1.7.2/CONTRIBUTING.md new file mode 100644 index 0000000..d7663b4 --- /dev/null +++ b/src/MaxKB-1.7.2/CONTRIBUTING.md @@ -0,0 +1,30 @@ +# Contributing + +As a contributor, you should agree that: + +- The producer can adjust the open-source agreement to be more strict or relaxed as deemed necessary. +- Your contributed code may be used for commercial purposes, including but not limited to its cloud business operations. + +## Create pull request +PR are always welcome, even if they only contain small fixes like typos or a few lines of code. If there will be a significant effort, please document it as an issue and get a discussion going before starting to work on it. + +Please submit a PR broken down into small changes bit by bit. A PR consisting of a lot of features and code changes may be hard to review. It is recommended to submit PRs in an incremental fashion. + +This [development guideline](https://github.com/1Panel-dev/MaxKB/wiki/3-%E5%BC%80%E5%8F%91%E7%8E%AF%E5%A2%83%E6%90%AD%E5%BB%BA) contains information about repository structure, how to set up development environment, how to run it, and more. + +Note: If you split your pull request to small changes, please make sure any of the changes goes to master will not break anything. Otherwise, it can not be merged until this feature complete. + +## Report issues +It is a great way to contribute by reporting an issue. Well-written and complete bug reports are always welcome! Please open an issue and follow the template to fill in required information. + +Before opening any issue, please look up the existing issues to avoid submitting a duplication. +If you find a match, you can "subscribe" to it to get notified on updates. If you have additional helpful information about the issue, please leave a comment. + +When reporting issues, always include: + +* Which version you are using. +* Steps to reproduce the issue. +* Snapshots or log files if needed + +Because the issues are open to the public, when submitting files, be sure to remove any sensitive information, e.g. user name, password, IP address, and company name. You can +replace those parts with "REDACTED" or other strings like "****". diff --git a/src/MaxKB-1.7.2/LICENSE b/src/MaxKB-1.7.2/LICENSE new file mode 100644 index 0000000..f288702 --- /dev/null +++ b/src/MaxKB-1.7.2/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/src/MaxKB-1.7.2/README.md b/src/MaxKB-1.7.2/README.md new file mode 100644 index 0000000..ae75baf --- /dev/null +++ b/src/MaxKB-1.7.2/README.md @@ -0,0 +1,86 @@ +[English](README_EN.md) | [中文](README.md) + +

MaxKB

+

基于大语言模型和 RAG 的知识库问答系统

+

+ 1Panel-dev%2FMaxKB | Trendshift + 1Panel-dev%2FMaxKB | Aliyun +

+

+ License: GPL v3 + Codacy + Latest release + Stars + Download +

+
+ +MaxKB = Max Knowledge Base,是一款基于大语言模型和 RAG 的开源知识库问答系统,广泛应用于企业内部知识库、客户服务、学术研究与教育等场景。 + +- **开箱即用**:支持直接上传文档 / 自动爬取在线文档,支持文本自动拆分、向量化和 RAG(检索增强生成),有效减少大模型幻觉,智能问答交互体验好; +- **模型中立**:支持对接各种大模型,包括本地私有大模型(Llama 3 / Qwen 2 等)、国内公共大模型(通义千问 / 腾讯混元 / 字节豆包 / 百度千帆 / 智谱 AI / Kimi 等)和国外公共大模型(OpenAI / Claude / Gemini 等); +- **灵活编排**:内置强大的工作流引擎和函数库,支持编排 AI 工作过程,满足复杂业务场景下的需求; +- **无缝嵌入**:支持零编码快速嵌入到第三方业务系统,让已有系统快速拥有智能问答能力,提高用户满意度。 + +三分钟视频介绍:https://www.bilibili.com/video/BV18JypYeEkj/ + +## 快速开始 + +``` +docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages cr2.fit2cloud.com/1panel/maxkb + +# 用户名: admin +# 密码: MaxKB@123.. +``` + +- 你也可以通过 [1Panel 应用商店](https://apps.fit2cloud.com/1panel) 快速部署 MaxKB + Ollama + Llama 3 / Qwen 2,快速上线基于本地大模型的 AI 知识库问答系统; +- 如果是内网环境,推荐使用 [离线安装包](https://community.fit2cloud.com/#/products/maxkb/downloads) 进行安装部署; +- 你也可以在线体验:[DataEase 小助手](https://dataease.io/docs/v2/),它是基于 MaxKB 搭建的智能 AI 问答系统,已经嵌入到 DataEase 产品及在线文档中; +- MaxKB 产品版本分为社区版和专业版,详情请参见:[MaxKB 产品版本对比](https://maxkb.cn/pricing.html)。 + +如你有更多问题,可以查看使用手册,或者通过论坛与我们交流。 + +- [使用手册](https://maxkb.cn/docs/) +- [论坛求助](https://bbs.fit2cloud.com/c/mk/11) +- 技术交流群 + + + +## UI 展示 + + + + + + + + + + +
MaxKB Demo1MaxKB Demo2
MaxKB Demo3MaxKB Demo4
+ +## 技术栈 + +- 前端:[Vue.js](https://cn.vuejs.org/) +- 后端:[Python / Django](https://www.djangoproject.com/) +- LangChain:[LangChain](https://www.langchain.com/) +- 向量数据库:[PostgreSQL / pgvector](https://www.postgresql.org/) +- 大模型:各种本地私有或者公共大模型 + +## 飞致云的其他明星项目 + +- [1Panel](https://github.com/1panel-dev/1panel/) - 现代化、开源的 Linux 服务器运维管理面板 +- [JumpServer](https://github.com/jumpserver/jumpserver/) - 广受欢迎的开源堡垒机 +- [DataEase](https://github.com/dataease/dataease/) - 人人可用的开源数据可视化分析工具 +- [MeterSphere](https://github.com/metersphere/metersphere/) - 新一代的开源持续测试工具 +- [Halo](https://github.com/halo-dev/halo/) - 强大易用的开源建站工具 + +## License + +Copyright (c) 2014-2024 飞致云 FIT2CLOUD, All rights reserved. + +Licensed under The GNU General Public License version 3 (GPLv3) (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + + + +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. diff --git a/src/MaxKB-1.7.2/README_EN.md b/src/MaxKB-1.7.2/README_EN.md new file mode 100644 index 0000000..8722c5c --- /dev/null +++ b/src/MaxKB-1.7.2/README_EN.md @@ -0,0 +1,78 @@ +

MaxKB

+

Knowledge base, question answering system, based on LLM large language models

+

1Panel-dev%2FMaxKB | Trendshift

+

+ License: GPL v3 + Codacy + Latest release + Stars + Download +

+
+ +MaxKB = Max Knowledge Base,It is an open source knowledge base question and answer system based on the LLM large language model. It is widely used in enterprise internal knowledge bases, customer services, academic research and education and other scenarios. + +- **Out-of-the-box**: Supports direct uploading of documents, automatic crawling of online documents, automatic text splitting, vectorization, RAG (retrieval enhancement generation), and a good interactive experience in intelligent question and answer; +- **Model neutral**: Supports docking with various large language models, including local private large models (Llama 3/Qwen 2, etc.), domestic public large models (Tongyi Qianwen/Zhipu AI/Baidu Qianfan/Kimi/DeepSeek, etc.) and foreign public models Large models (OpenAI / Azure OpenAI / Gemini, etc.); +- **Flexible Orchestration**: Built-in powerful workflow engine supports the orchestration of AI work processes to meet the needs of complex business scenarios; +- **Seamless Embedding**: Supports rapid embedding into third-party business systems with zero coding, allowing existing systems to quickly have intelligent question and answer capabilities and improve user satisfaction +## Quick start + +``` +docker run -d --name=maxkb --restart=always -p 8080:8080 -v ~/.maxkb:/var/lib/postgresql/data -v ~/.python-packages:/opt/maxkb/app/sandbox/python-packages cr2.fit2cloud.com/1panel/maxkb + +# username: admin +# pass: MaxKB@123.. +``` + +- You can also quickly deploy MaxKB + Ollama + Llama 3 through [1Panel App Store](https://apps.fit2cloud.com/1panel). A knowledge base question and answer system based on a local large model can be launched within 30 minutes and embedded into In third-party business systems; +- If it is an intranet environment, it is recommended to use [offline installation package](https://community.fit2cloud.com/#/products/maxkb/downloads) for installation and deployment; +- You can also experience it online: [DataEase Assistant](https://dataease.io/docs/v2/), which is an intelligent question and answer system based on MaxKB and has been embedded in DataEase products and online documents.; +- MaxKB's product version is divided into community version and professional version. For details, please see: [MaxKB product version comparison](https://maxkb.cn/pricing.html). + +If you have more questions, you can check the user manual or communicate with us through the forum. If you need to build a technical blog or knowledge base, it is recommended to use [Halo open source website building tool](https://github.com/halo-dev/halo/). You can experience Feizhiyun’s official [Technical Blog](https://blog.fit2cloud.com/) and [Knowledge Base](https://kb.fit2cloud.com) cases. +- [Docs](https://maxkb.cn/docs/) +- [Demo Vid](https://www.bilibili.com/video/BV1BE421M7YM/) +- [Forum](https://bbs.fit2cloud.com/c/mk/11) +- Technical exchange group + + + +## UI Screenshots + + + + + + + + + + +
MaxKB Demo1MaxKB Demo2
MaxKB Demo3MaxKB Demo4
+ +## Stack Used + +- Frontend:[Vue.js](https://cn.vuejs.org/) +- Backend:[Python / Django](https://www.djangoproject.com/) +- LangChain:[LangChain](https://www.langchain.com/) +- Vector DB:[PostgreSQL / pgvector](https://www.postgresql.org/) +- Large models: various local private or public large models + +## Other Projects From Feizhiyun + +- [1Panel](https://github.com/1panel-dev/1panel/) - Modern, open source Linux server operation and maintenance management panel +- [JumpServer](https://github.com/jumpserver/jumpserver/) - Popular open source bastion host +- [DataEase](https://github.com/dataease/dataease/) - Open source data visualization analysis tools available to everyone +- [MeterSphere](https://github.com/metersphere/metersphere/) - New generation of open-source test tools +- [Halo](https://github.com/halo-dev/halo/) - Powerful and easy-to-use open source website building tool + +## License + +Copyright (c) 2014-2024 Feizhiyun FIT2CLOUD, All rights reserved. + +Licensed under The GNU General Public License version 3 (GPLv3) (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + + + +Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. diff --git a/src/MaxKB-1.7.2/SECURITY.md b/src/MaxKB-1.7.2/SECURITY.md new file mode 100644 index 0000000..22be037 --- /dev/null +++ b/src/MaxKB-1.7.2/SECURITY.md @@ -0,0 +1,17 @@ +# 安全说明 + +如果您发现安全问题,请直接联系我们: + +- support@fit2cloud.com +- 400-052-0755 + +感谢您的支持! + +# Security Policy + +All security bugs should be reported to the contact as below: + +- support@fit2cloud.com +- 400-052-0755 + +Thanks for your support! \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/__init__.py b/src/MaxKB-1.7.2/apps/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/application/__init__.py b/src/MaxKB-1.7.2/apps/application/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/application/admin.py b/src/MaxKB-1.7.2/apps/application/admin.py new file mode 100644 index 0000000..8c38f3f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/src/MaxKB-1.7.2/apps/application/apps.py b/src/MaxKB-1.7.2/apps/application/apps.py new file mode 100644 index 0000000..30c0916 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class ApplicationConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'application' diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/I_base_chat_pipeline.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/I_base_chat_pipeline.py new file mode 100644 index 0000000..4c894dd --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/I_base_chat_pipeline.py @@ -0,0 +1,149 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: I_base_chat_pipeline.py + @date:2024/1/9 17:25 + @desc: +""" +import time +from abc import abstractmethod +from typing import Type + +from rest_framework import serializers + +from dataset.models import Paragraph + + +class ParagraphPipelineModel: + + def __init__(self, _id: str, document_id: str, dataset_id: str, content: str, title: str, status: str, + is_active: bool, comprehensive_score: float, similarity: float, dataset_name: str, document_name: str, + hit_handling_method: str, directly_return_similarity: float): + self.id = _id + self.document_id = document_id + self.dataset_id = dataset_id + self.content = content + self.title = title + self.status = status, + self.is_active = is_active + self.comprehensive_score = comprehensive_score + self.similarity = similarity + self.dataset_name = dataset_name + self.document_name = document_name + self.hit_handling_method = hit_handling_method + self.directly_return_similarity = directly_return_similarity + + def to_dict(self): + return { + 'id': self.id, + 'document_id': self.document_id, + 'dataset_id': self.dataset_id, + 'content': self.content, + 'title': self.title, + 'status': self.status, + 'is_active': self.is_active, + 'comprehensive_score': self.comprehensive_score, + 'similarity': self.similarity, + 'dataset_name': self.dataset_name, + 'document_name': self.document_name + } + + class builder: + def __init__(self): + self.similarity = None + self.paragraph = {} + self.comprehensive_score = None + self.document_name = None + self.dataset_name = None + self.hit_handling_method = None + self.directly_return_similarity = 0.9 + + def add_paragraph(self, paragraph): + if isinstance(paragraph, Paragraph): + self.paragraph = {'id': paragraph.id, + 'document_id': paragraph.document_id, + 'dataset_id': paragraph.dataset_id, + 'content': paragraph.content, + 'title': paragraph.title, + 'status': paragraph.status, + 'is_active': paragraph.is_active, + } + else: + self.paragraph = paragraph + return self + + def add_dataset_name(self, dataset_name): + self.dataset_name = dataset_name + return self + + def add_document_name(self, document_name): + self.document_name = document_name + return self + + def add_hit_handling_method(self, hit_handling_method): + self.hit_handling_method = hit_handling_method + return self + + def add_directly_return_similarity(self, directly_return_similarity): + self.directly_return_similarity = directly_return_similarity + return self + + def add_comprehensive_score(self, comprehensive_score: float): + self.comprehensive_score = comprehensive_score + return self + + def add_similarity(self, similarity: float): + self.similarity = similarity + return self + + def build(self): + return ParagraphPipelineModel(str(self.paragraph.get('id')), str(self.paragraph.get('document_id')), + str(self.paragraph.get('dataset_id')), + self.paragraph.get('content'), self.paragraph.get('title'), + self.paragraph.get('status'), + self.paragraph.get('is_active'), + self.comprehensive_score, self.similarity, self.dataset_name, + self.document_name, self.hit_handling_method, self.directly_return_similarity) + + +class IBaseChatPipelineStep: + def __init__(self): + # 当前步骤上下文,用于存储当前步骤信息 + self.context = {} + + @abstractmethod + def get_step_serializer(self, manage) -> Type[serializers.Serializer]: + pass + + def valid_args(self, manage): + step_serializer_clazz = self.get_step_serializer(manage) + step_serializer = step_serializer_clazz(data=manage.context) + step_serializer.is_valid(raise_exception=True) + self.context['step_args'] = step_serializer.data + + def run(self, manage): + """ + + :param manage: 步骤管理器 + :return: 执行结果 + """ + start_time = time.time() + self.context['start_time'] = start_time + # 校验参数, + self.valid_args(manage) + self._run(manage) + self.context['run_time'] = time.time() - start_time + + def _run(self, manage): + pass + + def execute(self, **kwargs): + pass + + def get_details(self, manage, **kwargs): + """ + 运行详情 + :return: 步骤详情 + """ + return None diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/__init__.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/__init__.py new file mode 100644 index 0000000..719a7e2 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 17:23 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/pipeline_manage.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/pipeline_manage.py new file mode 100644 index 0000000..7c4acb3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/pipeline_manage.py @@ -0,0 +1,57 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: pipeline_manage.py + @date:2024/1/9 17:40 + @desc: +""" +import time +from functools import reduce +from typing import List, Type, Dict + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.system_to_response import SystemToResponse + + +class PipelineManage: + def __init__(self, step_list: List[Type[IBaseChatPipelineStep]], + base_to_response: BaseToResponse = SystemToResponse()): + # 步骤执行器 + self.step_list = [step() for step in step_list] + # 上下文 + self.context = {'message_tokens': 0, 'answer_tokens': 0} + self.base_to_response = base_to_response + + def run(self, context: Dict = None): + self.context['start_time'] = time.time() + if context is not None: + for key, value in context.items(): + self.context[key] = value + for step in self.step_list: + step.run(self) + + def get_details(self): + return reduce(lambda x, y: {**x, **y}, [{item.get('step_type'): item} for item in + filter(lambda r: r is not None, + [row.get_details(self) for row in self.step_list])], {}) + + def get_base_to_response(self): + return self.base_to_response + + class builder: + def __init__(self): + self.step_list: List[Type[IBaseChatPipelineStep]] = [] + self.base_to_response = SystemToResponse() + + def append_step(self, step: Type[IBaseChatPipelineStep]): + self.step_list.append(step) + return self + + def add_base_to_response(self, base_to_response: BaseToResponse): + self.base_to_response = base_to_response + return self + + def build(self): + return PipelineManage(step_list=self.step_list, base_to_response=self.base_to_response) diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/__init__.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/__init__.py new file mode 100644 index 0000000..5d9549c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/chat_step/__init__.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/chat_step/__init__.py new file mode 100644 index 0000000..5d9549c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/chat_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/chat_step/i_chat_step.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/chat_step/i_chat_step.py new file mode 100644 index 0000000..e1a860d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/chat_step/i_chat_step.py @@ -0,0 +1,104 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_step.py + @date:2024/1/9 18:17 + @desc: 对话 +""" +from abc import abstractmethod +from typing import Type, List + +from langchain.chat_models.base import BaseChatModel +from langchain.schema import BaseMessage +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.serializers.application_serializers import NoReferencesSetting +from common.field.common import InstanceField +from common.util.field_message import ErrMessage + + +class ModelField(serializers.Field): + def to_internal_value(self, data): + if not isinstance(data, BaseChatModel): + self.fail('模型类型错误', value=data) + return data + + def to_representation(self, value): + return value + + +class MessageField(serializers.Field): + def to_internal_value(self, data): + if not isinstance(data, BaseMessage): + self.fail('message类型错误', value=data) + return data + + def to_representation(self, value): + return value + + +class PostResponseHandler: + @abstractmethod + def handler(self, chat_id, chat_record_id, paragraph_list: List[ParagraphPipelineModel], problem_text: str, + answer_text, + manage, step, padding_problem_text: str = None, client_id=None, **kwargs): + pass + + +class IChatStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 对话列表 + message_list = serializers.ListField(required=True, child=MessageField(required=True), + error_messages=ErrMessage.list("对话列表")) + model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("模型id")) + # 段落列表 + paragraph_list = serializers.ListField(error_messages=ErrMessage.list("段落列表")) + # 对话id + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) + # 用户问题 + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户问题")) + # 后置处理器 + post_response_handler = InstanceField(model_type=PostResponseHandler, + error_messages=ErrMessage.base("用户问题")) + # 补全问题 + padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.base("补全问题")) + # 是否使用流的形式输出 + stream = serializers.BooleanField(required=False, error_messages=ErrMessage.base("流式输出")) + client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) + client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) + # 未查询到引用分段 + no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置")) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + model_params_setting = serializers.DictField(required=False, allow_null=True, + error_messages=ErrMessage.dict("模型参数设置")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + message_list: List = self.initial_data.get('message_list') + for message in message_list: + if not isinstance(message, BaseMessage): + raise Exception("message 类型错误") + + def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: + return self.InstanceSerializer + + def _run(self, manage: PipelineManage): + chat_result = self.execute(**self.context['step_args'], manage=manage) + manage.context['chat_result'] = chat_result + + @abstractmethod + def execute(self, message_list: List[BaseMessage], + chat_id, problem_text, + post_response_handler: PostResponseHandler, + model_id: str = None, + user_id: str = None, + paragraph_list=None, + manage: PipelineManage = None, + padding_problem_text: str = None, stream: bool = True, client_id=None, client_type=None, + no_references_setting=None, model_params_setting=None, **kwargs): + pass diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py new file mode 100644 index 0000000..4cad179 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/chat_step/impl/base_chat_step.py @@ -0,0 +1,245 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_chat_step.py + @date:2024/1/9 18:25 + @desc: 对话step Base实现 +""" +import logging +import time +import traceback +import uuid +from typing import List + +from django.db.models import QuerySet +from django.http import StreamingHttpResponse +from langchain.chat_models.base import BaseChatModel +from langchain.schema import BaseMessage +from langchain.schema.messages import HumanMessage, AIMessage +from langchain_core.messages import AIMessageChunk +from rest_framework import status + +from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.chat_pipeline.step.chat_step.i_chat_step import IChatStep, PostResponseHandler +from application.models.api_key_model import ApplicationPublicAccessClient +from common.constants.authentication_type import AuthenticationType +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def add_access_num(client_id=None, client_type=None): + if client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: + application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=client_id).first() + if application_public_access_client is not None: + application_public_access_client.access_num = application_public_access_client.access_num + 1 + application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 + application_public_access_client.save() + + +def write_context(step, manage, request_token, response_token, all_text): + step.context['message_tokens'] = request_token + step.context['answer_tokens'] = response_token + current_time = time.time() + step.context['answer_text'] = all_text + step.context['run_time'] = current_time - step.context['start_time'] + manage.context['run_time'] = current_time - manage.context['start_time'] + manage.context['message_tokens'] = manage.context['message_tokens'] + request_token + manage.context['answer_tokens'] = manage.context['answer_tokens'] + response_token + + +def event_content(response, + chat_id, + chat_record_id, + paragraph_list: List[ParagraphPipelineModel], + post_response_handler: PostResponseHandler, + manage, + step, + chat_model, + message_list: List[BaseMessage], + problem_text: str, + padding_problem_text: str = None, + client_id=None, client_type=None, + is_ai_chat: bool = None): + all_text = '' + try: + for chunk in response: + all_text += chunk.content + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), chunk.content, + False, + 0, 0) + # 获取token + if is_ai_chat: + try: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(all_text) + except Exception as e: + request_token = 0 + response_token = 0 + else: + request_token = 0 + response_token = 0 + write_context(step, manage, request_token, response_token, all_text) + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + all_text, manage, step, padding_problem_text, client_id) + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), '', True, + request_token, response_token) + add_access_num(client_id, client_type) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + all_text = '异常' + str(e) + write_context(step, manage, 0, 0, all_text) + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + all_text, manage, step, padding_problem_text, client_id) + add_access_num(client_id, client_type) + yield manage.get_base_to_response().to_stream_chunk_response(chat_id, str(chat_record_id), all_text, True, 0, 0) + + +class BaseChatStep(IChatStep): + def execute(self, message_list: List[BaseMessage], + chat_id, + problem_text, + post_response_handler: PostResponseHandler, + model_id: str = None, + user_id: str = None, + paragraph_list=None, + manage: PipelineManage = None, + padding_problem_text: str = None, + stream: bool = True, + client_id=None, client_type=None, + no_references_setting=None, + model_params_setting=None, + **kwargs): + chat_model = get_model_instance_by_model_user_id(model_id, user_id, + **model_params_setting) if model_id is not None else None + if stream: + return self.execute_stream(message_list, chat_id, problem_text, post_response_handler, chat_model, + paragraph_list, + manage, padding_problem_text, client_id, client_type, no_references_setting) + else: + return self.execute_block(message_list, chat_id, problem_text, post_response_handler, chat_model, + paragraph_list, + manage, padding_problem_text, client_id, client_type, no_references_setting) + + def get_details(self, manage, **kwargs): + return { + 'step_type': 'chat_step', + 'run_time': self.context['run_time'], + 'model_id': str(manage.context['model_id']), + 'message_list': self.reset_message_list(self.context['step_args'].get('message_list'), + self.context['answer_text']), + 'message_tokens': self.context['message_tokens'], + 'answer_tokens': self.context['answer_tokens'], + 'cost': 0, + } + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + @staticmethod + def get_stream_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None, + problem_text=None): + if paragraph_list is None: + paragraph_list = [] + directly_return_chunk_list = [AIMessageChunk(content=paragraph.content) + for paragraph in paragraph_list if ( + paragraph.hit_handling_method == 'directly_return' and paragraph.similarity >= paragraph.directly_return_similarity)] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return iter(directly_return_chunk_list), False + elif len(paragraph_list) == 0 and no_references_setting.get( + 'status') == 'designated_answer': + return iter( + [AIMessageChunk(content=no_references_setting.get('value').replace('{question}', problem_text))]), False + if chat_model is None: + return iter([AIMessageChunk('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。')]), False + else: + return chat_model.stream(message_list), True + + def execute_stream(self, message_list: List[BaseMessage], + chat_id, + problem_text, + post_response_handler: PostResponseHandler, + chat_model: BaseChatModel = None, + paragraph_list=None, + manage: PipelineManage = None, + padding_problem_text: str = None, + client_id=None, client_type=None, + no_references_setting=None): + chat_result, is_ai_chat = self.get_stream_result(message_list, chat_model, paragraph_list, + no_references_setting, problem_text) + chat_record_id = uuid.uuid1() + r = StreamingHttpResponse( + streaming_content=event_content(chat_result, chat_id, chat_record_id, paragraph_list, + post_response_handler, manage, self, chat_model, message_list, problem_text, + padding_problem_text, client_id, client_type, is_ai_chat), + content_type='text/event-stream;charset=utf-8') + + r['Cache-Control'] = 'no-cache' + return r + + @staticmethod + def get_block_result(message_list: List[BaseMessage], + chat_model: BaseChatModel = None, + paragraph_list=None, + no_references_setting=None, + problem_text=None): + if paragraph_list is None: + paragraph_list = [] + + directly_return_chunk_list = [AIMessage(content=paragraph.content) + for paragraph in paragraph_list if + paragraph.hit_handling_method == 'directly_return'] + if directly_return_chunk_list is not None and len(directly_return_chunk_list) > 0: + return directly_return_chunk_list[0], False + elif len(paragraph_list) == 0 and no_references_setting.get( + 'status') == 'designated_answer': + return AIMessage(no_references_setting.get('value').replace('{question}', problem_text)), False + if chat_model is None: + return AIMessage('抱歉,没有配置 AI 模型,无法优化引用分段,请先去应用中设置 AI 模型。'), False + else: + return chat_model.invoke(message_list), True + + def execute_block(self, message_list: List[BaseMessage], + chat_id, + problem_text, + post_response_handler: PostResponseHandler, + chat_model: BaseChatModel = None, + paragraph_list=None, + manage: PipelineManage = None, + padding_problem_text: str = None, + client_id=None, client_type=None, no_references_setting=None): + chat_record_id = uuid.uuid1() + # 调用模型 + try: + chat_result, is_ai_chat = self.get_block_result(message_list, chat_model, paragraph_list, + no_references_setting, problem_text) + if is_ai_chat: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(chat_result.content) + else: + request_token = 0 + response_token = 0 + write_context(self, manage, request_token, response_token, chat_result.content) + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + chat_result.content, manage, self, padding_problem_text, client_id) + add_access_num(client_id, client_type) + return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), + chat_result.content, True, + request_token, response_token) + except Exception as e: + all_text = '异常' + str(e) + write_context(self, manage, 0, 0, all_text) + post_response_handler.handler(chat_id, chat_record_id, paragraph_list, problem_text, + all_text, manage, self, padding_problem_text, client_id) + add_access_num(client_id, client_type) + return manage.get_base_to_response().to_block_response(str(chat_id), str(chat_record_id), all_text, True, 0, + 0, _status=status.HTTP_500_INTERNAL_SERVER_ERROR) diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py new file mode 100644 index 0000000..5d9549c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/generate_human_message_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py new file mode 100644 index 0000000..fe6be7d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/generate_human_message_step/i_generate_human_message_step.py @@ -0,0 +1,80 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_generate_human_message_step.py + @date:2024/1/9 18:15 + @desc: 生成对话模板 +""" +from abc import abstractmethod +from typing import Type, List + +from langchain.schema import BaseMessage +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.models import ChatRecord +from application.serializers.application_serializers import NoReferencesSetting +from common.field.common import InstanceField +from common.util.field_message import ErrMessage + + +class IGenerateHumanMessageStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 问题 + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char("问题")) + # 段落列表 + paragraph_list = serializers.ListField(child=InstanceField(model_type=ParagraphPipelineModel, required=True), + error_messages=ErrMessage.list("段落列表")) + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list("历史对答")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + # 最大携带知识库段落长度 + max_paragraph_char_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer( + "最大携带知识库段落长度")) + # 模板 + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + system = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("系统提示词(角色)")) + # 补齐问题 + padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("补齐问题")) + # 未查询到引用分段 + no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("无引用分段设置")) + + def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: + return self.InstanceSerializer + + def _run(self, manage: PipelineManage): + message_list = self.execute(**self.context['step_args']) + manage.context['message_list'] = message_list + + @abstractmethod + def execute(self, + problem_text: str, + paragraph_list: List[ParagraphPipelineModel], + history_chat_record: List[ChatRecord], + dialogue_number: int, + max_paragraph_char_number: int, + prompt: str, + padding_problem_text: str = None, + no_references_setting=None, + system=None, + **kwargs) -> List[BaseMessage]: + """ + + :param problem_text: 原始问题文本 + :param paragraph_list: 段落列表 + :param history_chat_record: 历史对话记录 + :param dialogue_number: 多轮对话数量 + :param max_paragraph_char_number: 最大段落长度 + :param prompt: 模板 + :param padding_problem_text 用户修改文本 + :param kwargs: 其他参数 + :param no_references_setting: 无引用分段设置 + :param system 系统提示称 + :return: + """ + pass diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py new file mode 100644 index 0000000..68cfbbc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/generate_human_message_step/impl/base_generate_human_message_step.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_generate_human_message_step.py.py + @date:2024/1/10 17:50 + @desc: +""" +from typing import List, Dict + +from langchain.schema import BaseMessage, HumanMessage +from langchain_core.messages import SystemMessage + +from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel +from application.chat_pipeline.step.generate_human_message_step.i_generate_human_message_step import \ + IGenerateHumanMessageStep +from application.models import ChatRecord +from common.util.split_model import flat_map + + +class BaseGenerateHumanMessageStep(IGenerateHumanMessageStep): + + def execute(self, problem_text: str, + paragraph_list: List[ParagraphPipelineModel], + history_chat_record: List[ChatRecord], + dialogue_number: int, + max_paragraph_char_number: int, + prompt: str, + padding_problem_text: str = None, + no_references_setting=None, + system=None, + **kwargs) -> List[BaseMessage]: + prompt = prompt if (paragraph_list is not None and len(paragraph_list) > 0) else no_references_setting.get( + 'value') + exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text + start_index = len(history_chat_record) - dialogue_number + history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))] + if system is not None and len(system) > 0: + return [SystemMessage(system), *flat_map(history_message), + self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list, + no_references_setting)] + + return [*flat_map(history_message), + self.to_human_message(prompt, exec_problem_text, max_paragraph_char_number, paragraph_list, + no_references_setting)] + + @staticmethod + def to_human_message(prompt: str, + problem: str, + max_paragraph_char_number: int, + paragraph_list: List[ParagraphPipelineModel], + no_references_setting: Dict): + if paragraph_list is None or len(paragraph_list) == 0: + if no_references_setting.get('status') == 'ai_questioning': + return HumanMessage( + content=no_references_setting.get('value').replace('{question}', problem)) + else: + return HumanMessage(content=prompt.replace('{data}', "").replace('{question}', problem)) + temp_data = "" + data_list = [] + for p in paragraph_list: + content = f"{p.title}:{p.content}" + temp_data += content + if len(temp_data) > max_paragraph_char_number: + row_data = content[0:max_paragraph_char_number - len(temp_data)] + data_list.append(f"{row_data}") + break + else: + data_list.append(f"{content}") + data = "\n".join(data_list) + return HumanMessage(content=prompt.replace('{data}', data).replace('{question}', problem)) diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/reset_problem_step/__init__.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/reset_problem_step/__init__.py new file mode 100644 index 0000000..5d9549c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/reset_problem_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:23 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py new file mode 100644 index 0000000..e12fd08 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/reset_problem_step/i_reset_problem_step.py @@ -0,0 +1,56 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_reset_problem_step.py + @date:2024/1/9 18:12 + @desc: 重写处理问题 +""" +from abc import abstractmethod +from typing import Type, List + +from langchain.chat_models.base import BaseChatModel +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.chat_pipeline.step.chat_step.i_chat_step import ModelField +from application.models import ChatRecord +from common.field.common import InstanceField +from common.util.field_message import ErrMessage + + +class IResetProblemStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 问题文本 + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.float("问题文本")) + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list("历史对答")) + # 大语言模型 + model_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("模型id")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + problem_optimization_prompt = serializers.CharField(required=False, max_length=102400, + error_messages=ErrMessage.char("问题补全提示词")) + + def get_step_serializer(self, manage: PipelineManage) -> Type[serializers.Serializer]: + return self.InstanceSerializer + + def _run(self, manage: PipelineManage): + padding_problem = self.execute(**self.context.get('step_args')) + # 用户输入问题 + source_problem_text = self.context.get('step_args').get('problem_text') + self.context['problem_text'] = source_problem_text + self.context['padding_problem_text'] = padding_problem + manage.context['problem_text'] = source_problem_text + manage.context['padding_problem_text'] = padding_problem + # 累加tokens + manage.context['message_tokens'] = manage.context['message_tokens'] + self.context.get('message_tokens') + manage.context['answer_tokens'] = manage.context['answer_tokens'] + self.context.get('answer_tokens') + + @abstractmethod + def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, + problem_optimization_prompt=None, + user_id=None, + **kwargs): + pass diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py new file mode 100644 index 0000000..2d631e0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/reset_problem_step/impl/base_reset_problem_step.py @@ -0,0 +1,65 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_reset_problem_step.py + @date:2024/1/10 14:35 + @desc: +""" +from typing import List + +from langchain.schema import HumanMessage + +from application.chat_pipeline.step.reset_problem_step.i_reset_problem_step import IResetProblemStep +from application.models import ChatRecord +from common.util.split_model import flat_map +from setting.models_provider.tools import get_model_instance_by_model_user_id + +prompt = ( + '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中') + + +class BaseResetProblemStep(IResetProblemStep): + def execute(self, problem_text: str, history_chat_record: List[ChatRecord] = None, model_id: str = None, + problem_optimization_prompt=None, + user_id=None, + **kwargs) -> str: + chat_model = get_model_instance_by_model_user_id(model_id, user_id) if model_id is not None else None + start_index = len(history_chat_record) - 3 + history_message = [[history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))] + reset_prompt = problem_optimization_prompt if problem_optimization_prompt else prompt + message_list = [*flat_map(history_message), + HumanMessage(content=reset_prompt.replace('{question}', problem_text))] + response = chat_model.invoke(message_list) + padding_problem = problem_text + if response.content.__contains__("") and response.content.__contains__(''): + padding_problem_data = response.content[ + response.content.index('') + 6:response.content.index('')] + if padding_problem_data is not None and len(padding_problem_data.strip()) > 0: + padding_problem = padding_problem_data + elif len(response.content) > 0: + padding_problem = response.content + + try: + request_token = chat_model.get_num_tokens_from_messages(message_list) + response_token = chat_model.get_num_tokens(padding_problem) + except Exception as e: + request_token = 0 + response_token = 0 + self.context['message_tokens'] = request_token + self.context['answer_tokens'] = response_token + return padding_problem + + def get_details(self, manage, **kwargs): + return { + 'step_type': 'problem_padding', + 'run_time': self.context['run_time'], + 'model_id': str(manage.context['model_id']) if 'model_id' in manage.context else None, + 'message_tokens': self.context['message_tokens'], + 'answer_tokens': self.context['answer_tokens'], + 'cost': 0, + 'padding_problem_text': self.context.get('padding_problem_text'), + 'problem_text': self.context.get("step_args").get('problem_text'), + } diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/search_dataset_step/__init__.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/search_dataset_step/__init__.py new file mode 100644 index 0000000..023c4bc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/search_dataset_step/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/1/9 18:24 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py new file mode 100644 index 0000000..97da296 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/search_dataset_step/i_search_dataset_step.py @@ -0,0 +1,75 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_search_dataset_step.py + @date:2024/1/9 18:10 + @desc: 检索知识库 +""" +import re +from abc import abstractmethod +from typing import List, Type + +from django.core import validators +from rest_framework import serializers + +from application.chat_pipeline.I_base_chat_pipeline import IBaseChatPipelineStep, ParagraphPipelineModel +from application.chat_pipeline.pipeline_manage import PipelineManage +from common.util.field_message import ErrMessage + + +class ISearchDatasetStep(IBaseChatPipelineStep): + class InstanceSerializer(serializers.Serializer): + # 原始问题文本 + problem_text = serializers.CharField(required=True, error_messages=ErrMessage.char("问题")) + # 系统补全问题文本 + padding_problem_text = serializers.CharField(required=False, error_messages=ErrMessage.char("系统补全问题文本")) + # 需要查询的数据集id列表 + dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("数据集id列表")) + # 需要排除的文档id + exclude_document_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("排除的文档id列表")) + # 需要排除向量id + exclude_paragraph_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("排除向量id列表")) + # 需要查询的条数 + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer("引用分段数")) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=1, min_value=0, + error_messages=ErrMessage.float("引用分段数")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def get_step_serializer(self, manage: PipelineManage) -> Type[InstanceSerializer]: + return self.InstanceSerializer + + def _run(self, manage: PipelineManage): + paragraph_list = self.execute(**self.context['step_args']) + manage.context['paragraph_list'] = paragraph_list + self.context['paragraph_list'] = paragraph_list + + @abstractmethod + def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + search_mode: str = None, + user_id=None, + **kwargs) -> List[ParagraphPipelineModel]: + """ + 关于 用户和补全问题 说明: 补全问题如果有就使用补全问题去查询 反之就用用户原始问题查询 + :param similarity: 相关性 + :param top_n: 查询多少条 + :param problem_text: 用户问题 + :param dataset_id_list: 需要查询的数据集id列表 + :param exclude_document_id_list: 需要排除的文档id + :param exclude_paragraph_id_list: 需要排除段落id + :param padding_problem_text 补全问题 + :param search_mode 检索模式 + :param user_id 用户id + :return: 段落列表 + """ + pass diff --git a/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py new file mode 100644 index 0000000..c13b414 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/chat_pipeline/step/search_dataset_step/impl/base_search_dataset_step.py @@ -0,0 +1,134 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_search_dataset_step.py + @date:2024/1/10 10:33 + @desc: +""" +import os +from typing import List, Dict + +from django.db.models import QuerySet + +from application.chat_pipeline.I_base_chat_pipeline import ParagraphPipelineModel +from application.chat_pipeline.step.search_dataset_step.i_search_dataset_step import ISearchDatasetStep +from common.config.embedding_config import VectorStore, ModelManage +from common.db.search import native_search +from common.util.file_util import get_file_content +from dataset.models import Paragraph, DataSet +from embedding.models import SearchMode +from setting.models import Model +from setting.models_provider import get_model +from smartdoc.conf import PROJECT_DIR + + +def get_model_by_id(_id, user_id): + model = QuerySet(Model).filter(id=_id).first() + if model is None: + raise Exception("模型不存在") + if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): + raise Exception(f"无权限使用此模型:{model.name}") + return model + + +def get_embedding_id(dataset_id_list): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("关联知识库的向量模型不一致,无法召回分段。") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return dataset_list[0].embedding_mode_id + + +class BaseSearchDatasetStep(ISearchDatasetStep): + + def execute(self, problem_text: str, dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_id_list: list[str], top_n: int, similarity: float, padding_problem_text: str = None, + search_mode: str = None, + user_id=None, + **kwargs) -> List[ParagraphPipelineModel]: + if len(dataset_id_list) == 0: + return [] + exec_problem_text = padding_problem_text if padding_problem_text is not None else problem_text + model_id = get_embedding_id(dataset_id_list) + model = get_model_by_id(model_id, user_id) + self.context['model_name'] = model.name + embedding_model = ModelManage.get_model(model_id, lambda _id: get_model(model)) + embedding_value = embedding_model.embed_query(exec_problem_text) + vector = VectorStore.get_embedding_vector() + embedding_list = vector.query(exec_problem_text, embedding_value, dataset_id_list, exclude_document_id_list, + exclude_paragraph_id_list, True, top_n, similarity, SearchMode(search_mode)) + if embedding_list is None: + return [] + paragraph_list = self.list_paragraph(embedding_list, vector) + result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] + return result + + @staticmethod + def reset_paragraph(paragraph: Dict, embedding_list: List) -> ParagraphPipelineModel: + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return (ParagraphPipelineModel.builder() + .add_paragraph(paragraph) + .add_similarity(find_embedding.get('similarity')) + .add_comprehensive_score(find_embedding.get('comprehensive_score')) + .add_dataset_name(paragraph.get('dataset_name')) + .add_document_name(paragraph.get('document_name')) + .add_hit_handling_method(paragraph.get('hit_handling_method')) + .add_directly_return_similarity(paragraph.get('directly_return_similarity')) + .build()) + + @staticmethod + def get_similarity(paragraph, embedding_list: List): + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return find_embedding.get('comprehensive_score') + return 0 + + @staticmethod + def list_paragraph(embedding_list: List, vector): + paragraph_id_list = [row.get('paragraph_id') for row in embedding_list] + if paragraph_id_list is None or len(paragraph_id_list) == 0: + return [] + paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'list_dataset_paragraph_by_paragraph_id.sql')), + with_table_name=True) + # 如果向量库中存在脏数据 直接删除 + if len(paragraph_list) != len(paragraph_id_list): + exist_paragraph_list = [row.get('id') for row in paragraph_list] + for paragraph_id in paragraph_id_list: + if not exist_paragraph_list.__contains__(paragraph_id): + vector.delete_by_paragraph_id(paragraph_id) + # 如果存在直接返回的则取直接返回段落 + hit_handling_method_paragraph = [paragraph for paragraph in paragraph_list if + (paragraph.get( + 'hit_handling_method') == 'directly_return' and BaseSearchDatasetStep.get_similarity( + paragraph, embedding_list) >= paragraph.get( + 'directly_return_similarity'))] + if len(hit_handling_method_paragraph) > 0: + # 找到评分最高的 + return [sorted(hit_handling_method_paragraph, + key=lambda p: BaseSearchDatasetStep.get_similarity(p, embedding_list))[-1]] + return paragraph_list + + def get_details(self, manage, **kwargs): + step_args = self.context['step_args'] + + return { + 'step_type': 'search_step', + 'paragraph_list': [row.to_dict() for row in self.context['paragraph_list']], + 'run_time': self.context['run_time'], + 'problem_text': step_args.get( + 'padding_problem_text') if 'padding_problem_text' in step_args else step_args.get('problem_text'), + 'model_name': self.context.get('model_name'), + 'message_tokens': 0, + 'answer_tokens': 0, + 'cost': 0 + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/__init__.py new file mode 100644 index 0000000..328e8f8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/application/flow/default_workflow.json b/src/MaxKB-1.7.2/apps/application/flow/default_workflow.json new file mode 100644 index 0000000..48ac23c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/default_workflow.json @@ -0,0 +1,451 @@ +{ + "nodes": [ + { + "id": "base-node", + "type": "base-node", + "x": 360, + "y": 2810, + "properties": { + "config": { + + }, + "height": 825.6, + "stepName": "基本信息", + "node_data": { + "desc": "", + "name": "maxkbapplication", + "prologue": "您好,我是 MaxKB 小助手,您可以向我提出 MaxKB 使用问题。\n- MaxKB 主要功能有什么?\n- MaxKB 支持哪些大语言模型?\n- MaxKB 支持哪些文档类型?" + }, + "input_field_list": [ + + ] + } + }, + { + "id": "start-node", + "type": "start-node", + "x": 430, + "y": 3660, + "properties": { + "config": { + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + }, + "fields": [ + { + "label": "用户问题", + "value": "question" + } + ], + "height": 276, + "stepName": "开始", + "globalFields": [ + { + "label": "当前时间", + "value": "time" + } + ] + } + }, + { + "id": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "type": "search-dataset-node", + "x": 840, + "y": 3210, + "properties": { + "config": { + "fields": [ + { + "label": "检索结果的分段列表", + "value": "paragraph_list" + }, + { + "label": "满足直接回答的分段列表", + "value": "is_hit_handling_method_list" + }, + { + "label": "检索结果", + "value": "data" + }, + { + "label": "满足直接回答的分段内容", + "value": "directly_return" + } + ] + }, + "height": 794, + "stepName": "知识库检索", + "node_data": { + "dataset_id_list": [ + + ], + "dataset_setting": { + "top_n": 3, + "similarity": 0.6, + "search_mode": "embedding", + "max_paragraph_char_number": 5000 + }, + "question_reference_address": [ + "start-node", + "question" + ], + "source_dataset_id_list": [ + + ] + } + } + }, + { + "id": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "type": "condition-node", + "x": 1490, + "y": 3210, + "properties": { + "width": 600, + "config": { + "fields": [ + { + "label": "分支名称", + "value": "branch_name" + } + ] + }, + "height": 543.675, + "stepName": "判断器", + "node_data": { + "branch": [ + { + "id": "1009", + "type": "IF", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "is_hit_handling_method_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "4908", + "type": "ELSE IF 1", + "condition": "and", + "conditions": [ + { + "field": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "paragraph_list" + ], + "value": "1", + "compare": "len_ge" + } + ] + }, + { + "id": "161", + "type": "ELSE", + "condition": "and", + "conditions": [ + + ] + } + ] + }, + "branch_condition_list": [ + { + "index": 0, + "height": 121.225, + "id": "1009" + }, + { + "index": 1, + "height": 121.225, + "id": "4908" + }, + { + "index": 2, + "height": 44, + "id": "161" + } + ] + } + }, + { + "id": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "type": "reply-node", + "x": 2170, + "y": 2480, + "properties": { + "config": { + "fields": [ + { + "label": "内容", + "value": "answer" + } + ] + }, + "height": 378, + "stepName": "指定回复", + "node_data": { + "fields": [ + "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "directly_return" + ], + "content": "", + "reply_type": "referencing", + "is_result": true + } + } + }, + { + "id": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "type": "ai-chat-node", + "x": 2160, + "y": 3200, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI 对话", + "node_data": { + "prompt": "已知信息:\n{{知识库检索.data}}\n问题:\n{{开始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + }, + { + "id": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "type": "ai-chat-node", + "x": 2160, + "y": 3970, + "properties": { + "config": { + "fields": [ + { + "label": "AI 回答内容", + "value": "answer" + } + ] + }, + "height": 763, + "stepName": "AI 对话1", + "node_data": { + "prompt": "{{开始.question}}", + "system": "", + "model_id": "", + "dialogue_number": 0, + "is_result": true + } + } + } + ], + "edges": [ + { + "id": "7d0f166f-c472-41b2-b9a2-c294f4c83d73", + "type": "app-edge", + "sourceNodeId": "start-node", + "targetNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "startPoint": { + "x": 590, + "y": 3660 + }, + "endPoint": { + "x": 680, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 590, + "y": 3660 + }, + { + "x": 700, + "y": 3660 + }, + { + "x": 570, + "y": 3210 + }, + { + "x": 680, + "y": 3210 + } + ], + "sourceAnchorId": "start-node_right", + "targetAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_left" + }, + { + "id": "35cb86dd-f328-429e-a973-12fd7218b696", + "type": "app-edge", + "sourceNodeId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5", + "targetNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "startPoint": { + "x": 1000, + "y": 3210 + }, + "endPoint": { + "x": 1200, + "y": 3210 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1000, + "y": 3210 + }, + { + "x": 1110, + "y": 3210 + }, + { + "x": 1090, + "y": 3210 + }, + { + "x": 1200, + "y": 3210 + } + ], + "sourceAnchorId": "b931efe5-5b66-46e0-ae3b-0160cb18eeb5_right", + "targetAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_left" + }, + { + "id": "e8f6cfe6-7e48-41cd-abd3-abfb5304d0d8", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "4ffe1086-25df-4c85-b168-979b5bbf0a26", + "startPoint": { + "x": 1780, + "y": 3073.775 + }, + "endPoint": { + "x": 2010, + "y": 2480 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3073.775 + }, + { + "x": 1890, + "y": 3073.775 + }, + { + "x": 1900, + "y": 2480 + }, + { + "x": 2010, + "y": 2480 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_1009_right", + "targetAnchorId": "4ffe1086-25df-4c85-b168-979b5bbf0a26_left" + }, + { + "id": "994ff325-6f7a-4ebc-b61b-10e15519d6d2", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb", + "startPoint": { + "x": 1780, + "y": 3203 + }, + "endPoint": { + "x": 2000, + "y": 3200 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3203 + }, + { + "x": 1890, + "y": 3203 + }, + { + "x": 1890, + "y": 3200 + }, + { + "x": 2000, + "y": 3200 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_4908_right", + "targetAnchorId": "f1f1ee18-5a02-46f6-b4e6-226253cdffbb_left" + }, + { + "id": "19270caf-bb9f-4ba7-9bf8-200aa70fecd5", + "type": "app-edge", + "sourceNodeId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b", + "targetNodeId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7", + "startPoint": { + "x": 1780, + "y": 3293.6124999999997 + }, + "endPoint": { + "x": 2000, + "y": 3970 + }, + "properties": { + + }, + "pointsList": [ + { + "x": 1780, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3293.6124999999997 + }, + { + "x": 1890, + "y": 3970 + }, + { + "x": 2000, + "y": 3970 + } + ], + "sourceAnchorId": "fc60863a-dec2-4854-9e5a-7a44b7187a2b_161_right", + "targetAnchorId": "309d0eef-c597-46b5-8d51-b9a28aaef4c7_left" + } + ] +} \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/flow/i_step_node.py b/src/MaxKB-1.7.2/apps/application/flow/i_step_node.py new file mode 100644 index 0000000..48c3733 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/i_step_node.py @@ -0,0 +1,198 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_step_node.py + @date:2024/6/3 14:57 + @desc: +""" +import time +from abc import abstractmethod +from typing import Type, Dict, List + +from django.core import cache +from django.db.models import QuerySet +from rest_framework import serializers +from rest_framework.exceptions import ValidationError, ErrorDetail + +from application.models import ChatRecord +from application.models.api_key_model import ApplicationPublicAccessClient +from common.constants.authentication_type import AuthenticationType +from common.field.common import InstanceField +from common.util.field_message import ErrMessage + +chat_cache = cache.caches['chat_cache'] + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'answer' in step_variable: + answer = step_variable['answer'] + yield answer + workflow.answer += answer + if global_variable is not None: + for key in global_variable: + workflow.context[key] = global_variable[key] + node.context['run_time'] = time.time() - node.context['start_time'] + + +class WorkFlowPostHandler: + def __init__(self, chat_info, client_id, client_type): + self.chat_info = chat_info + self.client_id = client_id + self.client_type = client_type + + def handler(self, chat_id, + chat_record_id, + answer, + workflow): + question = workflow.params['question'] + details = workflow.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + chat_record = ChatRecord(id=chat_record_id, + chat_id=chat_id, + problem_text=question, + answer_text=answer, + details=details, + message_tokens=message_tokens, + answer_tokens=answer_tokens, + run_time=time.time() - workflow.context['start_time'], + index=0) + self.chat_info.append_chat_record(chat_record, self.client_id) + # 重新设置缓存 + chat_cache.set(chat_id, + self.chat_info, timeout=60 * 30) + if self.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: + application_public_access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.client_id).first() + if application_public_access_client is not None: + application_public_access_client.access_num = application_public_access_client.access_num + 1 + application_public_access_client.intraday_access_num = application_public_access_client.intraday_access_num + 1 + application_public_access_client.save() + + +class NodeResult: + def __init__(self, node_variable: Dict, workflow_variable: Dict, + _write_context=write_context): + self._write_context = _write_context + self.node_variable = node_variable + self.workflow_variable = workflow_variable + + def write_context(self, node, workflow): + return self._write_context(self.node_variable, self.workflow_variable, node, workflow) + + def is_assertion_result(self): + return 'branch_id' in self.node_variable + + +class ReferenceAddressSerializer(serializers.Serializer): + node_id = serializers.CharField(required=True, error_messages=ErrMessage.char("节点id")) + fields = serializers.ListField( + child=serializers.CharField(required=True, error_messages=ErrMessage.char("节点字段")), required=True, + error_messages=ErrMessage.list("节点字段数组")) + + +class FlowParamsSerializer(serializers.Serializer): + # 历史对答 + history_chat_record = serializers.ListField(child=InstanceField(model_type=ChatRecord, required=True), + error_messages=ErrMessage.list("历史对答")) + + question = serializers.CharField(required=True, error_messages=ErrMessage.list("用户问题")) + + chat_id = serializers.CharField(required=True, error_messages=ErrMessage.list("对话id")) + + chat_record_id = serializers.CharField(required=True, error_messages=ErrMessage.char("对话记录id")) + + stream = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("流式输出")) + + client_id = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端id")) + + client_type = serializers.CharField(required=False, error_messages=ErrMessage.char("客户端类型")) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("换个答案")) + + +class INode: + def __init__(self, node, workflow_params, workflow_manage): + # 当前步骤上下文,用于存储当前步骤信息 + self.status = 200 + self.err_message = '' + self.node = node + self.node_params = node.properties.get('node_data') + self.workflow_params = workflow_params + self.workflow_manage = workflow_manage + self.node_params_serializer = None + self.flow_params_serializer = None + self.context = {} + self.id = node.id + + def valid_args(self, node_params, flow_params): + flow_params_serializer_class = self.get_flow_params_serializer_class() + node_params_serializer_class = self.get_node_params_serializer_class() + if flow_params_serializer_class is not None and flow_params is not None: + self.flow_params_serializer = flow_params_serializer_class(data=flow_params) + self.flow_params_serializer.is_valid(raise_exception=True) + if node_params_serializer_class is not None: + self.node_params_serializer = node_params_serializer_class(data=node_params) + self.node_params_serializer.is_valid(raise_exception=True) + if self.node.properties.get('status', 200) != 200: + raise ValidationError(ErrorDetail(f'节点{self.node.properties.get("stepName")} 不可用')) + + def get_reference_field(self, fields: List[str]): + return self.get_field(self.context, fields) + + @staticmethod + def get_field(obj, fields: List[str]): + for field in fields: + value = obj.get(field) + if value is None: + return None + else: + obj = value + return obj + + @abstractmethod + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + pass + + def get_flow_params_serializer_class(self) -> Type[serializers.Serializer]: + return FlowParamsSerializer + + def get_write_error_context(self, e): + self.status = 500 + self.err_message = str(e) + self.context['run_time'] = time.time() - self.context['start_time'] + + def write_error_context(answer, status=200): + pass + + return write_error_context + + def run(self) -> NodeResult: + """ + :return: 执行结果 + """ + start_time = time.time() + self.context['start_time'] = start_time + result = self._run() + self.context['run_time'] = time.time() - start_time + return result + + def _run(self): + result = self.execute() + return result + + def execute(self, **kwargs) -> NodeResult: + pass + + def get_details(self, index: int, **kwargs): + """ + 运行详情 + :return: 步骤详情 + """ + return {} diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/__init__.py new file mode 100644 index 0000000..6227381 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/__init__.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" +from .ai_chat_step_node import * +from .condition_node import * +from .question_node import * +from .search_dataset_node import * +from .start_node import * +from .direct_reply_node import * +from .function_lib_node import * +from .function_node import * +from .reranker_node import * + +node_list = [BaseStartStepNode, BaseChatNode, BaseSearchDatasetNode, BaseQuestionNode, BaseConditionNode, BaseReplyNode, + BaseFunctionNodeNode, BaseFunctionLibNodeNode, BaseRerankerNode] + + +def get_node(node_type): + find_list = [node for node in node_list if node.type == node_type] + if len(find_list) > 0: + return find_list[0] + return None diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/__init__.py new file mode 100644 index 0000000..1929ae2 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:29 + @desc: +""" +from .impl import * diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py new file mode 100644 index 0000000..b7dfecf --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/i_chat_node.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_node.py + @date:2024/6/4 13:58 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class ChatNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("角色设定")) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置")) + + +class IChatNode(INode): + type = 'ai-chat-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ChatNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, + chat_record_id, + model_params_setting=None, + **kwargs) -> NodeResult: + pass diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py new file mode 100644 index 0000000..79051a9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:34 + @desc: +""" +from .base_chat_node import BaseChatNode diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py new file mode 100644 index 0000000..daa7b45 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/ai_chat_step_node/impl/base_chat_node.py @@ -0,0 +1,144 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_question_node.py + @date:2024/6/4 14:30 + @desc: +""" +import time +from functools import reduce +from typing import List, Dict + +from django.db.models import QuerySet +from langchain.schema import HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage + +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.ai_chat_step_node.i_chat_node import IChatNode +from setting.models import Model +from setting.models_provider import get_model_credential +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + workflow.answer += answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + for chunk in response: + answer += chunk.content + yield chunk.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = response.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def get_default_model_params_setting(model_id): + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form( + model.model_name).get_default_form_data() + return model_params_setting + + +class BaseChatNode(IChatNode): + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting=None, + **kwargs) -> NodeResult: + if model_params_setting is None: + model_params_setting = get_default_model_params_setting(model_id) + chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), + **model_params_setting) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question.content + message_list = self.generate_message_list(system, prompt, history_message) + self.context['message_list'] = message_list + if stream: + r = chat_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + else: + r = chat_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context) + + @staticmethod + def get_history_message(history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, system: str, prompt: str, history_message): + if system is not None and len(system) > 0: + return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, + HumanMessage(self.workflow_manage.generate_prompt(prompt))] + else: + return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.node_params.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/__init__.py new file mode 100644 index 0000000..5763850 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" +from .impl import * diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/__init__.py new file mode 100644 index 0000000..02d42a2 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/__init__.py @@ -0,0 +1,28 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/6/7 14:43 + @desc: +""" + +from .contain_compare import * +from .equal_compare import * +from .gt_compare import * +from .ge_compare import * +from .le_compare import * +from .lt_compare import * +from .len_ge_compare import * +from .len_gt_compare import * +from .len_le_compare import * +from .len_lt_compare import * +from .len_equal_compare import * +from .is_not_null_compare import * +from .is_null_compare import * +from .not_contain_compare import * + +compare_handle_list = [GECompare(), GTCompare(), ContainCompare(), EqualCompare(), LTCompare(), LECompare(), + LenLECompare(), LenGECompare(), LenEqualCompare(), LenGTCompare(), LenLTCompare(), + IsNullCompare(), + IsNotNullCompare(), NotContainCompare()] diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/compare.py new file mode 100644 index 0000000..6cbb4af --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/compare.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: compare.py + @date:2024/6/7 14:37 + @desc: +""" +from abc import abstractmethod +from typing import List + + +class Compare: + @abstractmethod + def support(self, node_id, fields: List[str], source_value, compare, target_value): + pass + + @abstractmethod + def compare(self, source_value, compare, target_value): + pass diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/contain_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/contain_compare.py new file mode 100644 index 0000000..6073131 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/contain_compare.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: contain_compare.py + @date:2024/6/11 10:02 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class ContainCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'contain': + return True + + def compare(self, source_value, compare, target_value): + if isinstance(source_value, str): + return str(target_value) in source_value + return any([str(item) == str(target_value) for item in source_value]) diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/equal_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/equal_compare.py new file mode 100644 index 0000000..0061a82 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/equal_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: equal_compare.py + @date:2024/6/7 14:44 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class EqualCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'eq': + return True + + def compare(self, source_value, compare, target_value): + return str(source_value) == str(target_value) diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/ge_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/ge_compare.py new file mode 100644 index 0000000..d4e22cb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/ge_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class GECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'ge': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) >= float(target_value) + except Exception as e: + return False diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/gt_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/gt_compare.py new file mode 100644 index 0000000..80942ab --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/gt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class GTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'gt': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) > float(target_value) + except Exception as e: + return False diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py new file mode 100644 index 0000000..5dec267 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/is_not_null_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: is_not_null_compare.py + @date:2024/6/28 10:45 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare import Compare + + +class IsNotNullCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'is_not_null': + return True + + def compare(self, source_value, compare, target_value): + return source_value is not None and len(source_value) > 0 diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/is_null_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/is_null_compare.py new file mode 100644 index 0000000..c463f3f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/is_null_compare.py @@ -0,0 +1,21 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: is_null_compare.py + @date:2024/6/28 10:45 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare import Compare + + +class IsNullCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'is_null': + return True + + def compare(self, source_value, compare, target_value): + return source_value is None or len(source_value) == 0 diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/le_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/le_compare.py new file mode 100644 index 0000000..77a0bca --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/le_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'le': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) <= float(target_value) + except Exception as e: + return False diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py new file mode 100644 index 0000000..f2b0764 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_equal_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: equal_compare.py + @date:2024/6/7 14:44 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenEqualCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_eq': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) == int(target_value) + except Exception as e: + return False diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py new file mode 100644 index 0000000..87f11eb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_ge_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenGECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_ge': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) >= int(target_value) + except Exception as e: + return False diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py new file mode 100644 index 0000000..0532d35 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_gt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 大于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenGTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_gt': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) > int(target_value) + except Exception as e: + return False diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_le_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_le_compare.py new file mode 100644 index 0000000..d315a75 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_le_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenLECompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_le': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) <= int(target_value) + except Exception as e: + return False diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py new file mode 100644 index 0000000..c89638c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/len_lt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LenLTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'len_lt': + return True + + def compare(self, source_value, compare, target_value): + try: + return len(source_value) < int(target_value) + except Exception as e: + return False diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/lt_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/lt_compare.py new file mode 100644 index 0000000..d2d5be7 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/lt_compare.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: lt_compare.py + @date:2024/6/11 9:52 + @desc: 小于比较器 +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class LTCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'lt': + return True + + def compare(self, source_value, compare, target_value): + try: + return float(source_value) < float(target_value) + except Exception as e: + return False diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py new file mode 100644 index 0000000..f95b237 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/compare/not_contain_compare.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: contain_compare.py + @date:2024/6/11 10:02 + @desc: +""" +from typing import List + +from application.flow.step_node.condition_node.compare.compare import Compare + + +class NotContainCompare(Compare): + + def support(self, node_id, fields: List[str], source_value, compare, target_value): + if compare == 'not_contain': + return True + + def compare(self, source_value, compare, target_value): + if isinstance(source_value, str): + return str(target_value) not in source_value + return not any([str(item) == str(target_value) for item in source_value]) diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/i_condition_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/i_condition_node.py new file mode 100644 index 0000000..ffb975a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/i_condition_node.py @@ -0,0 +1,39 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_condition_node.py + @date:2024/6/7 9:54 + @desc: +""" +import json +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode +from common.util.field_message import ErrMessage + + +class ConditionSerializer(serializers.Serializer): + compare = serializers.CharField(required=True, error_messages=ErrMessage.char("比较器")) + value = serializers.CharField(required=True, error_messages=ErrMessage.char("")) + field = serializers.ListField(required=True, error_messages=ErrMessage.char("字段")) + + +class ConditionBranchSerializer(serializers.Serializer): + id = serializers.CharField(required=True, error_messages=ErrMessage.char("分支id")) + type = serializers.CharField(required=True, error_messages=ErrMessage.char("分支类型")) + condition = serializers.CharField(required=True, error_messages=ErrMessage.char("条件or|and")) + conditions = ConditionSerializer(many=True) + + +class ConditionNodeParamsSerializer(serializers.Serializer): + branch = ConditionBranchSerializer(many=True) + + +class IConditionNode(INode): + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ConditionNodeParamsSerializer + + type = 'condition-node' diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/impl/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/impl/__init__.py new file mode 100644 index 0000000..c21cd3e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_condition_node import BaseConditionNode diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/impl/base_condition_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/impl/base_condition_node.py new file mode 100644 index 0000000..3164bb9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/condition_node/impl/base_condition_node.py @@ -0,0 +1,50 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_condition_node.py + @date:2024/6/7 11:29 + @desc: +""" +from typing import List + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.condition_node.compare import compare_handle_list +from application.flow.step_node.condition_node.i_condition_node import IConditionNode + + +class BaseConditionNode(IConditionNode): + def execute(self, **kwargs) -> NodeResult: + branch_list = self.node_params_serializer.data['branch'] + branch = self._execute(branch_list) + r = NodeResult({'branch_id': branch.get('id'), 'branch_name': branch.get('type')}, {}) + return r + + def _execute(self, branch_list: List): + for branch in branch_list: + if self.branch_assertion(branch): + return branch + + def branch_assertion(self, branch): + condition_list = [self.assertion(row.get('field'), row.get('compare'), row.get('value')) for row in + branch.get('conditions')] + condition = branch.get('condition') + return all(condition_list) if condition == 'and' else any(condition_list) + + def assertion(self, field_list: List[str], compare: str, value): + field_value = self.workflow_manage.get_reference_field(field_list[0], field_list[1:]) + for compare_handler in compare_handle_list: + if compare_handler.support(field_list[0], field_list[1:], field_value, compare, value): + return compare_handler.compare(field_value, compare, value) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'branch_id': self.context.get('branch_id'), + 'branch_name': self.context.get('branch_name'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/__init__.py new file mode 100644 index 0000000..cf360f9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 17:50 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/i_reply_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/i_reply_node.py new file mode 100644 index 0000000..3c0f358 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/i_reply_node.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_reply_node.py + @date:2024/6/11 16:25 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.exception.app_exception import AppApiException +from common.util.field_message import ErrMessage + + +class ReplyNodeParamsSerializer(serializers.Serializer): + reply_type = serializers.CharField(required=True, error_messages=ErrMessage.char("回复类型")) + fields = serializers.ListField(required=False, error_messages=ErrMessage.list("引用字段")) + content = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("直接回答内容")) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if self.data.get('reply_type') == 'referencing': + if 'fields' not in self.data: + raise AppApiException(500, "引用字段不能为空") + if len(self.data.get('fields')) < 2: + raise AppApiException(500, "引用字段错误") + else: + if 'content' not in self.data or self.data.get('content') is None: + raise AppApiException(500, "内容不能为空") + + +class IReplyNode(INode): + type = 'reply-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return ReplyNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: + pass diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/impl/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/impl/__init__.py new file mode 100644 index 0000000..3307e90 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 17:49 + @desc: +""" +from .base_reply_node import * \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py new file mode 100644 index 0000000..de79279 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/direct_reply_node/impl/base_reply_node.py @@ -0,0 +1,40 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_reply_node.py + @date:2024/6/11 17:25 + @desc: +""" +from typing import List + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.direct_reply_node.i_reply_node import IReplyNode + + +class BaseReplyNode(IReplyNode): + def execute(self, reply_type, stream, fields=None, content=None, **kwargs) -> NodeResult: + if reply_type == 'referencing': + result = self.get_reference_content(fields) + else: + result = self.generate_reply_content(content) + return NodeResult({'answer': result}, {}) + + def generate_reply_content(self, prompt): + return self.workflow_manage.generate_prompt(prompt) + + def get_reference_content(self, fields: List[str]): + return str(self.workflow_manage.get_reference_field( + fields[0], + fields[1:])) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'answer': self.context.get('answer'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/__init__.py new file mode 100644 index 0000000..7422965 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/8/8 17:45 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/i_function_lib_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/i_function_lib_node.py new file mode 100644 index 0000000..e69d40a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/i_function_lib_node.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_function_lib_node.py + @date:2024/8/8 16:21 + @desc: +""" +from typing import Type + +from django.db.models import QuerySet +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.field.common import ObjectField +from common.util.field_message import ErrMessage +from function_lib.models.function import FunctionLib + + +class InputField(serializers.Serializer): + name = serializers.CharField(required=True, error_messages=ErrMessage.char('变量名')) + value = ObjectField(required=True, error_messages=ErrMessage.char("变量值"), model_type_list=[str, list]) + + +class FunctionLibNodeParamsSerializer(serializers.Serializer): + function_lib_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid('函数库id')) + input_field_list = InputField(required=True, many=True) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + f_lib = QuerySet(FunctionLib).filter(id=self.data.get('function_lib_id')).first() + if f_lib is None: + raise Exception('函数库已被删除') + + +class IFunctionLibNode(INode): + type = 'function-lib-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return FunctionLibNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult: + pass diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/impl/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/impl/__init__.py new file mode 100644 index 0000000..9668147 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/8/8 17:48 + @desc: +""" +from .base_function_lib_node import BaseFunctionLibNodeNode diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py new file mode 100644 index 0000000..64e1c55 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_lib_node/impl/base_function_lib_node.py @@ -0,0 +1,120 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_function_lib_node.py + @date:2024/8/8 17:49 + @desc: +""" +import json +import time +from typing import Dict + +from django.db.models import QuerySet + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.function_lib_node.i_function_lib_node import IFunctionLibNode +from common.exception.app_exception import AppApiException +from common.util.function_code import FunctionExecutor +from function_lib.models.function import FunctionLib +from smartdoc.const import CONFIG + +function_executor = FunctionExecutor(CONFIG.get('SANDBOX')) + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable: + result = str(step_variable['result']) + '\n' + yield result + workflow.answer += result + node.context['run_time'] = time.time() - node.context['start_time'] + + +def get_field_value(debug_field_list, name, is_required): + result = [field for field in debug_field_list if field.get('name') == name] + if len(result) > 0: + return result[-1]['value'] + if is_required: + raise AppApiException(500, f"{name}字段未设置值") + return None + + +def valid_reference_value(_type, value, name): + if _type == 'int': + instance_type = int + elif _type == 'float': + instance_type = float + elif _type == 'dict': + instance_type = dict + elif _type == 'array': + instance_type = list + elif _type == 'string': + instance_type = str + else: + raise Exception(500, f'字段:{name}类型:{_type} 不支持的类型') + if not isinstance(value, instance_type): + raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误') + + +def convert_value(name: str, value, _type, is_required, source, node): + if not is_required and value is None: + return None + if not is_required and source == 'reference' and (value is None or len(value) == 0): + return None + if source == 'reference': + value = node.workflow_manage.get_reference_field( + value[0], + value[1:]) + valid_reference_value(_type, value, name) + return value + try: + if _type == 'int': + return int(value) + if _type == 'float': + return float(value) + if _type == 'dict': + v = json.loads(value) + if isinstance(v, dict): + return v + raise Exception("类型错误") + if _type == 'array': + v = json.loads(value) + if isinstance(v, list): + return v + raise Exception("类型错误") + return value + except Exception as e: + raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误') + + +class BaseFunctionLibNodeNode(IFunctionLibNode): + def execute(self, function_lib_id, input_field_list, **kwargs) -> NodeResult: + function_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first() + if not function_lib.is_active: + raise Exception(f'函数:{function_lib.name} 不可用') + params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'), + field.get('is_required'), + field.get('source'), self) + for field in + [{'value': get_field_value(input_field_list, field.get('name'), field.get('is_required'), + ), **field} + for field in + function_lib.input_field_list]} + self.context['params'] = params + result = function_executor.exec_code(function_lib.code, params) + return NodeResult({'result': result}, {}, _write_context=write_context) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "result": self.context.get('result'), + "params": self.context.get('params'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/__init__.py new file mode 100644 index 0000000..ebfbe8d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/13 10:43 + @desc: +""" +from .impl import * \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/i_function_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/i_function_node.py new file mode 100644 index 0000000..30e6c96 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/i_function_node.py @@ -0,0 +1,60 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_function_lib_node.py + @date:2024/8/8 16:21 + @desc: +""" +import re +from typing import Type + +from django.core import validators +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.exception.app_exception import AppApiException +from common.field.common import ObjectField +from common.util.field_message import ErrMessage + + +class InputField(serializers.Serializer): + name = serializers.CharField(required=True, error_messages=ErrMessage.char('变量名')) + is_required = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("是否必填")) + type = serializers.CharField(required=True, error_messages=ErrMessage.char("类型"), validators=[ + validators.RegexValidator(regex=re.compile("^string|int|dict|array|float$"), + message="字段只支持string|int|dict|array|float", code=500) + ]) + source = serializers.CharField(required=True, error_messages=ErrMessage.char("来源"), validators=[ + validators.RegexValidator(regex=re.compile("^custom|reference$"), + message="字段只支持custom|reference", code=500) + ]) + value = ObjectField(required=True, error_messages=ErrMessage.char("变量值"), model_type_list=[str, list]) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + is_required = self.data.get('is_required') + if is_required and self.data.get('value') is None: + raise AppApiException(500, f'{self.data.get("name")}必填') + + +class FunctionNodeParamsSerializer(serializers.Serializer): + input_field_list = InputField(required=True, many=True) + code = serializers.CharField(required=True, error_messages=ErrMessage.char("函数")) + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class IFunctionNode(INode): + type = 'function-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return FunctionNodeParamsSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, input_field_list, code, **kwargs) -> NodeResult: + pass diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/impl/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/impl/__init__.py new file mode 100644 index 0000000..1a09636 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/13 11:19 + @desc: +""" +from .base_function_node import BaseFunctionNodeNode diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/impl/base_function_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/impl/base_function_node.py new file mode 100644 index 0000000..f2aead8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/function_node/impl/base_function_node.py @@ -0,0 +1,99 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_function_lib_node.py + @date:2024/8/8 17:49 + @desc: +""" +import json +import time + +from typing import Dict + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.function_node.i_function_node import IFunctionNode +from common.exception.app_exception import AppApiException +from common.util.function_code import FunctionExecutor +from smartdoc.const import CONFIG + +function_executor = FunctionExecutor(CONFIG.get('SANDBOX')) + + +def write_context(step_variable: Dict, global_variable: Dict, node, workflow): + if step_variable is not None: + for key in step_variable: + node.context[key] = step_variable[key] + if workflow.is_result(node, NodeResult(step_variable, global_variable)) and 'result' in step_variable: + result = str(step_variable['result']) + '\n' + yield result + workflow.answer += result + node.context['run_time'] = time.time() - node.context['start_time'] + + +def valid_reference_value(_type, value, name): + if _type == 'int': + instance_type = int + elif _type == 'float': + instance_type = float + elif _type == 'dict': + instance_type = dict + elif _type == 'array': + instance_type = list + elif _type == 'string': + instance_type = str + else: + raise Exception(500, f'字段:{name}类型:{_type} 不支持的类型') + if not isinstance(value, instance_type): + raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误') + + +def convert_value(name: str, value, _type, is_required, source, node): + if not is_required and value is None: + return None + if source == 'reference': + value = node.workflow_manage.get_reference_field( + value[0], + value[1:]) + valid_reference_value(_type, value, name) + return value + try: + if _type == 'int': + return int(value) + if _type == 'float': + return float(value) + if _type == 'dict': + v = json.loads(value) + if isinstance(v, dict): + return v + raise Exception("类型错误") + if _type == 'array': + v = json.loads(value) + if isinstance(v, list): + return v + raise Exception("类型错误") + return value + except Exception as e: + raise Exception(f'字段:{name}类型:{_type}值:{value}类型错误') + + +class BaseFunctionNodeNode(IFunctionNode): + def execute(self, input_field_list, code, **kwargs) -> NodeResult: + params = {field.get('name'): convert_value(field.get('name'), field.get('value'), field.get('type'), + field.get('is_required'), field.get('source'), self) + for field in input_field_list} + result = function_executor.exec_code(code, params) + self.context['params'] = params + return NodeResult({'result': result}, {}, _write_context=write_context) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "result": self.context.get('result'), + "params": self.context.get('params'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/__init__.py new file mode 100644 index 0000000..98a1afc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/i_question_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/i_question_node.py new file mode 100644 index 0000000..054a7ad --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/i_question_node.py @@ -0,0 +1,41 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_chat_node.py + @date:2024/6/4 13:58 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class QuestionNodeSerializer(serializers.Serializer): + model_id = serializers.CharField(required=True, error_messages=ErrMessage.char("模型id")) + system = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("角色设定")) + prompt = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + # 多轮对话数量 + dialogue_number = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("多轮对话数量")) + + is_result = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean('是否返回内容')) + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.integer("模型参数相关设置")) + + +class IQuestionNode(INode): + type = 'question-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return QuestionNodeSerializer + + def _run(self): + return self.execute(**self.node_params_serializer.data, **self.flow_params_serializer.data) + + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting=None, + **kwargs) -> NodeResult: + pass diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/impl/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/impl/__init__.py new file mode 100644 index 0000000..d85aa87 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_question_node import BaseQuestionNode diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/impl/base_question_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/impl/base_question_node.py new file mode 100644 index 0000000..8e43a9b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/question_node/impl/base_question_node.py @@ -0,0 +1,144 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_question_node.py + @date:2024/6/4 14:30 + @desc: +""" +import time +from functools import reduce +from typing import List, Dict + +from django.db.models import QuerySet +from langchain.schema import HumanMessage, SystemMessage +from langchain_core.messages import BaseMessage + +from application.flow.i_step_node import NodeResult, INode +from application.flow.step_node.question_node.i_question_node import IQuestionNode +from setting.models import Model +from setting.models_provider import get_model_credential +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def _write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow, answer: str): + chat_model = node_variable.get('chat_model') + message_tokens = chat_model.get_num_tokens_from_messages(node_variable.get('message_list')) + answer_tokens = chat_model.get_num_tokens(answer) + node.context['message_tokens'] = message_tokens + node.context['answer_tokens'] = answer_tokens + node.context['answer'] = answer + node.context['history_message'] = node_variable['history_message'] + node.context['question'] = node_variable['question'] + node.context['run_time'] = time.time() - node.context['start_time'] + if workflow.is_result(node, NodeResult(node_variable, workflow_variable)): + workflow.answer += answer + + +def write_context_stream(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 (流式) + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = '' + for chunk in response: + answer += chunk.content + yield chunk.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def write_context(node_variable: Dict, workflow_variable: Dict, node: INode, workflow): + """ + 写入上下文数据 + @param node_variable: 节点数据 + @param workflow_variable: 全局数据 + @param node: 节点实例对象 + @param workflow: 工作流管理器 + """ + response = node_variable.get('result') + answer = response.content + _write_context(node_variable, workflow_variable, node, workflow, answer) + + +def get_default_model_params_setting(model_id): + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form( + model.model_name).get_default_form_data() + return model_params_setting + + +class BaseQuestionNode(IQuestionNode): + def execute(self, model_id, system, prompt, dialogue_number, history_chat_record, stream, chat_id, chat_record_id, + model_params_setting=None, + **kwargs) -> NodeResult: + if model_params_setting is None: + model_params_setting = get_default_model_params_setting(model_id) + chat_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id'), + **model_params_setting) + history_message = self.get_history_message(history_chat_record, dialogue_number) + self.context['history_message'] = history_message + question = self.generate_prompt_question(prompt) + self.context['question'] = question.content + message_list = self.generate_message_list(system, prompt, history_message) + self.context['message_list'] = message_list + if stream: + r = chat_model.stream(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context_stream) + else: + r = chat_model.invoke(message_list) + return NodeResult({'result': r, 'chat_model': chat_model, 'message_list': message_list, + 'history_message': history_message, 'question': question.content}, {}, + _write_context=write_context) + + @staticmethod + def get_history_message(history_chat_record, dialogue_number): + start_index = len(history_chat_record) - dialogue_number + history_message = reduce(lambda x, y: [*x, *y], [ + [history_chat_record[index].get_human_message(), history_chat_record[index].get_ai_message()] + for index in + range(start_index if start_index > 0 else 0, len(history_chat_record))], []) + return history_message + + def generate_prompt_question(self, prompt): + return HumanMessage(self.workflow_manage.generate_prompt(prompt)) + + def generate_message_list(self, system: str, prompt: str, history_message): + if system is None or len(system) == 0: + return [SystemMessage(self.workflow_manage.generate_prompt(system)), *history_message, + HumanMessage(self.workflow_manage.generate_prompt(prompt))] + else: + return [*history_message, HumanMessage(self.workflow_manage.generate_prompt(prompt))] + + @staticmethod + def reset_message_list(message_list: List[BaseMessage], answer_text): + result = [{'role': 'user' if isinstance(message, HumanMessage) else 'ai', 'content': message.content} for + message + in + message_list] + result.append({'role': 'ai', 'content': answer_text}) + return result + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'run_time': self.context.get('run_time'), + 'system': self.node_params.get('system'), + 'history_message': [{'content': message.content, 'role': message.type} for message in + (self.context.get('history_message') if self.context.get( + 'history_message') is not None else [])], + 'question': self.context.get('question'), + 'answer': self.context.get('answer'), + 'type': self.node.type, + 'message_tokens': self.context.get('message_tokens'), + 'answer_tokens': self.context.get('answer_tokens'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/__init__.py new file mode 100644 index 0000000..881d0f8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/9/4 11:37 + @desc: +""" +from .impl import * diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/i_reranker_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/i_reranker_node.py new file mode 100644 index 0000000..fec3ec0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/i_reranker_node.py @@ -0,0 +1,59 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_reranker_node.py + @date:2024/9/4 10:40 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.field_message import ErrMessage + + +class RerankerSettingSerializer(serializers.Serializer): + # 需要查询的条数 + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer("引用分段数")) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, + error_messages=ErrMessage.float("引用分段数")) + max_paragraph_char_number = serializers.IntegerField(required=True, + error_messages=ErrMessage.float("最大引用分段字数")) + + +class RerankerStepNodeSerializer(serializers.Serializer): + reranker_setting = RerankerSettingSerializer(required=True) + + question_reference_address = serializers.ListField(required=True) + reranker_model_id = serializers.UUIDField(required=True) + reranker_reference_list = serializers.ListField(required=True, child=serializers.ListField(required=True)) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class IRerankerNode(INode): + type = 'reranker-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return RerankerStepNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) + reranker_list = [self.workflow_manage.get_reference_field( + reference[0], + reference[1:]) for reference in + self.node_params_serializer.data.get('reranker_reference_list')] + return self.execute(**self.node_params_serializer.data, question=str(question), + + reranker_list=reranker_list) + + def execute(self, question, reranker_setting, reranker_list, reranker_model_id, + **kwargs) -> NodeResult: + pass diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/impl/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/impl/__init__.py new file mode 100644 index 0000000..ef5ca80 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/9/4 11:39 + @desc: +""" +from .base_reranker_node import * diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py new file mode 100644 index 0000000..d1eef33 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/reranker_node/impl/base_reranker_node.py @@ -0,0 +1,77 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_reranker_node.py + @date:2024/9/4 11:41 + @desc: +""" +from typing import List + +from langchain_core.documents import Document + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.reranker_node.i_reranker_node import IRerankerNode +from setting.models_provider.tools import get_model_instance_by_model_user_id + + +def merge_reranker_list(reranker_list, result=None): + if result is None: + result = [] + for document in reranker_list: + if isinstance(document, list): + merge_reranker_list(document, result) + elif isinstance(document, dict): + content = document.get('title', '') + document.get('content', '') + result.append(str(document) if len(content) == 0 else content) + else: + result.append(str(document)) + return result + + +def filter_result(document_list: List[Document], max_paragraph_char_number, top_n, similarity): + use_len = 0 + result = [] + for index in range(len(document_list)): + document = document_list[index] + if use_len >= max_paragraph_char_number or index >= top_n or document.metadata.get( + 'relevance_score') < similarity: + break + content = document.page_content[0:max_paragraph_char_number - use_len] + use_len = use_len + len(content) + result.append({'page_content': content, 'metadata': document.metadata}) + return result + + +class BaseRerankerNode(IRerankerNode): + def execute(self, question, reranker_setting, reranker_list, reranker_model_id, + **kwargs) -> NodeResult: + documents = merge_reranker_list(reranker_list) + top_n = reranker_setting.get('top_n', 3) + self.context['document_list'] = documents + self.context['question'] = question + reranker_model = get_model_instance_by_model_user_id(reranker_model_id, + self.flow_params_serializer.data.get('user_id'), + top_n=top_n) + result = reranker_model.compress_documents( + [Document(page_content=document) for document in documents if document is not None and len(document) > 0], + question) + similarity = reranker_setting.get('similarity', 0.6) + max_paragraph_char_number = reranker_setting.get('max_paragraph_char_number', 5000) + r = filter_result(result, max_paragraph_char_number, top_n, similarity) + return NodeResult({'result_list': r, 'result': ''.join([item.get('page_content') for item in r])}, {}) + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + "index": index, + 'document_list': self.context.get('document_list'), + "question": self.context.get('question'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'reranker_setting': self.node_params_serializer.data.get('reranker_setting'), + 'result_list': self.context.get('result_list'), + 'result': self.context.get('result'), + 'status': self.status, + 'err_message': self.err_message + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/__init__.py new file mode 100644 index 0000000..98a1afc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py new file mode 100644 index 0000000..0de4e65 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/i_search_dataset_node.py @@ -0,0 +1,78 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_search_dataset_node.py + @date:2024/6/3 17:52 + @desc: +""" +import re +from typing import Type + +from django.core import validators +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult +from common.util.common import flat_map +from common.util.field_message import ErrMessage + + +class DatasetSettingSerializer(serializers.Serializer): + # 需要查询的条数 + top_n = serializers.IntegerField(required=True, + error_messages=ErrMessage.integer("引用分段数")) + # 相似度 0-1之间 + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, + error_messages=ErrMessage.float("引用分段数")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) + max_paragraph_char_number = serializers.IntegerField(required=True, + error_messages=ErrMessage.float("最大引用分段字数")) + + +class SearchDatasetStepNodeSerializer(serializers.Serializer): + # 需要查询的数据集id列表 + dataset_id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("数据集id列表")) + dataset_setting = DatasetSettingSerializer(required=True) + + question_reference_address = serializers.ListField(required=True, ) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +def get_paragraph_list(chat_record, node_id): + return flat_map([chat_record.details[key].get('paragraph_list', []) for key in chat_record.details if + (chat_record.details[ + key].get('type', '') == 'search-dataset-node') and chat_record.details[key].get( + 'paragraph_list', []) is not None and key == node_id]) + + +class ISearchDatasetStepNode(INode): + type = 'search-dataset-node' + + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + return SearchDatasetStepNodeSerializer + + def _run(self): + question = self.workflow_manage.get_reference_field( + self.node_params_serializer.data.get('question_reference_address')[0], + self.node_params_serializer.data.get('question_reference_address')[1:]) + exclude_paragraph_id_list = [] + if self.flow_params_serializer.data.get('re_chat', False): + history_chat_record = self.flow_params_serializer.data.get('history_chat_record', []) + paragraph_id_list = [p.get('id') for p in flat_map( + [get_paragraph_list(chat_record, self.node.id) for chat_record in history_chat_record if + chat_record.problem_text == question])] + exclude_paragraph_id_list = list(set(paragraph_id_list)) + + return self.execute(**self.node_params_serializer.data, question=str(question), + exclude_paragraph_id_list=exclude_paragraph_id_list) + + def execute(self, dataset_id_list, dataset_setting, question, + exclude_paragraph_id_list=None, + **kwargs) -> NodeResult: + pass diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/impl/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/impl/__init__.py new file mode 100644 index 0000000..a9cff0d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:35 + @desc: +""" +from .base_search_dataset_node import BaseSearchDatasetNode diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py new file mode 100644 index 0000000..693495a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/search_dataset_node/impl/base_search_dataset_node.py @@ -0,0 +1,129 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_search_dataset_node.py + @date:2024/6/4 11:56 + @desc: +""" +import os +from typing import List, Dict + +from django.db.models import QuerySet + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.search_dataset_node.i_search_dataset_node import ISearchDatasetStepNode +from common.config.embedding_config import VectorStore +from common.db.search import native_search +from common.util.file_util import get_file_content +from dataset.models import Document, Paragraph, DataSet +from embedding.models import SearchMode +from setting.models_provider.tools import get_model_instance_by_model_user_id +from smartdoc.conf import PROJECT_DIR + + +def get_embedding_id(dataset_id_list): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("关联知识库的向量模型不一致,无法召回分段。") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return dataset_list[0].embedding_mode_id + + +def get_none_result(question): + return NodeResult( + {'paragraph_list': [], 'is_hit_handling_method': [], 'question': question, 'data': '', + 'directly_return': ''}, {}) + + +def reset_title(title): + if title is None or len(title.strip()) == 0: + return "" + else: + return f"#### {title}\n" + + +class BaseSearchDatasetNode(ISearchDatasetStepNode): + def execute(self, dataset_id_list, dataset_setting, question, + exclude_paragraph_id_list=None, + **kwargs) -> NodeResult: + self.context['question'] = question + if len(dataset_id_list) == 0: + return get_none_result(question) + model_id = get_embedding_id(dataset_id_list) + embedding_model = get_model_instance_by_model_user_id(model_id, self.flow_params_serializer.data.get('user_id')) + embedding_value = embedding_model.embed_query(question) + vector = VectorStore.get_embedding_vector() + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)] + embedding_list = vector.query(question, embedding_value, dataset_id_list, exclude_document_id_list, + exclude_paragraph_id_list, True, dataset_setting.get('top_n'), + dataset_setting.get('similarity'), SearchMode(dataset_setting.get('search_mode'))) + if embedding_list is None: + return get_none_result(question) + paragraph_list = self.list_paragraph(embedding_list, vector) + result = [self.reset_paragraph(paragraph, embedding_list) for paragraph in paragraph_list] + result = sorted(result, key=lambda p: p.get('similarity'), reverse=True) + return NodeResult({'paragraph_list': result, + 'is_hit_handling_method_list': [row for row in result if row.get('is_hit_handling_method')], + 'data': '\n'.join( + [f"{reset_title(paragraph.get('title', ''))}{paragraph.get('content')}" for paragraph in + paragraph_list])[0:dataset_setting.get('max_paragraph_char_number', 5000)], + 'directly_return': '\n'.join( + [paragraph.get('content') for paragraph in + result if + paragraph.get('is_hit_handling_method')]), + 'question': question}, + + {}) + + @staticmethod + def reset_paragraph(paragraph: Dict, embedding_list: List): + filter_embedding_list = [embedding for embedding in embedding_list if + str(embedding.get('paragraph_id')) == str(paragraph.get('id'))] + if filter_embedding_list is not None and len(filter_embedding_list) > 0: + find_embedding = filter_embedding_list[-1] + return { + **paragraph, + 'similarity': find_embedding.get('similarity'), + 'is_hit_handling_method': find_embedding.get('similarity') > paragraph.get( + 'directly_return_similarity') and paragraph.get('hit_handling_method') == 'directly_return', + 'update_time': paragraph.get('update_time').strftime("%Y-%m-%d %H:%M:%S"), + 'create_time': paragraph.get('create_time').strftime("%Y-%m-%d %H:%M:%S"), + 'id': str(paragraph.get('id')), + 'dataset_id': str(paragraph.get('dataset_id')), + 'document_id': str(paragraph.get('document_id')) + } + + @staticmethod + def list_paragraph(embedding_list: List, vector): + paragraph_id_list = [row.get('paragraph_id') for row in embedding_list] + if paragraph_id_list is None or len(paragraph_id_list) == 0: + return [] + paragraph_list = native_search(QuerySet(Paragraph).filter(id__in=paragraph_id_list), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'list_dataset_paragraph_by_paragraph_id.sql')), + with_table_name=True) + # 如果向量库中存在脏数据 直接删除 + if len(paragraph_list) != len(paragraph_id_list): + exist_paragraph_list = [row.get('id') for row in paragraph_list] + for paragraph_id in paragraph_id_list: + if not exist_paragraph_list.__contains__(paragraph_id): + vector.delete_by_paragraph_id(paragraph_id) + return paragraph_list + + def get_details(self, index: int, **kwargs): + return { + 'name': self.node.properties.get('stepName'), + 'question': self.context.get('question'), + "index": index, + 'run_time': self.context.get('run_time'), + 'paragraph_list': self.context.get('paragraph_list'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/__init__.py new file mode 100644 index 0000000..98a1afc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:30 + @desc: +""" +from .impl import * diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/i_start_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/i_start_node.py new file mode 100644 index 0000000..bb23ad3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/i_start_node.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: i_start_node.py + @date:2024/6/3 16:54 + @desc: +""" +from typing import Type + +from rest_framework import serializers + +from application.flow.i_step_node import INode, NodeResult + + +class IStarNode(INode): + type = 'start-node' + + def _run(self): + return self.execute(**self.flow_params_serializer.data) + + def execute(self, question, **kwargs) -> NodeResult: + pass diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/impl/__init__.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/impl/__init__.py new file mode 100644 index 0000000..b68a92d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/impl/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/6/11 15:36 + @desc: +""" +from .base_start_node import BaseStartStepNode diff --git a/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/impl/base_start_node.py b/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/impl/base_start_node.py new file mode 100644 index 0000000..39fbfe7 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/step_node/start_node/impl/base_start_node.py @@ -0,0 +1,65 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_start_node.py + @date:2024/6/3 17:17 + @desc: +""" +import time +from datetime import datetime +from typing import List, Type + +from rest_framework import serializers + +from application.flow.i_step_node import NodeResult +from application.flow.step_node.start_node.i_start_node import IStarNode + + +def get_default_global_variable(input_field_list: List): + return {item.get('variable'): item.get('default_value') for item in input_field_list if + item.get('default_value', None) is not None} + + +def get_global_variable(node): + history_chat_record = node.flow_params_serializer.data.get('history_chat_record', []) + history_context = [{'question': chat_record.problem_text, 'answer': chat_record.answer_text} for chat_record in + history_chat_record] + chat_id = node.flow_params_serializer.data.get('chat_id') + return {'time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'), 'start_time': time.time(), + 'history_context': history_context, 'chat_id': str(chat_id), **node.workflow_manage.form_data} + + +class BaseStartStepNode(IStarNode): + def get_node_params_serializer_class(self) -> Type[serializers.Serializer]: + pass + + def execute(self, question, **kwargs) -> NodeResult: + base_node = self.workflow_manage.get_base_node() + default_global_variable = get_default_global_variable(base_node.properties.get('input_field_list', [])) + workflow_variable = {**default_global_variable, **get_global_variable(self)} + """ + 开始节点 初始化全局变量 + """ + return NodeResult({'question': question}, + workflow_variable) + + def get_details(self, index: int, **kwargs): + global_fields = [] + for field in self.node.properties.get('config')['globalFields']: + key = field['value'] + global_fields.append({ + 'label': field['label'], + 'key': key, + 'value': self.workflow_manage.context[key] if key in self.workflow_manage.context else '' + }) + return { + 'name': self.node.properties.get('stepName'), + "index": index, + "question": self.context.get('question'), + 'run_time': self.context.get('run_time'), + 'type': self.node.type, + 'status': self.status, + 'err_message': self.err_message, + 'global_fields': global_fields + } diff --git a/src/MaxKB-1.7.2/apps/application/flow/tools.py b/src/MaxKB-1.7.2/apps/application/flow/tools.py new file mode 100644 index 0000000..b2bf6b1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/tools.py @@ -0,0 +1,105 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: utils.py + @date:2024/6/6 15:15 + @desc: +""" +import json +from typing import Iterator + +from django.http import StreamingHttpResponse +from langchain_core.messages import BaseMessageChunk, BaseMessage + +from application.flow.i_step_node import WorkFlowPostHandler +from common.response import result + + +def event_content(chat_id, chat_record_id, response, workflow, + write_context, + post_handler: WorkFlowPostHandler): + """ + 用于处理流式输出 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + """ + answer = '' + try: + for chunk in response: + answer += chunk.content + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': chunk.content, 'is_end': False}, ensure_ascii=False) + "\n\n" + write_context(answer, 200) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': '', 'is_end': True}, ensure_ascii=False) + "\n\n" + except Exception as e: + answer = str(e) + write_context(answer, 500) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + yield 'data: ' + json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}, ensure_ascii=False) + "\n\n" + + +def to_stream_response(chat_id, chat_record_id, response: Iterator[BaseMessageChunk], workflow, write_context, + post_handler): + """ + 将结果转换为服务流输出 + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + @return: 响应 + """ + r = StreamingHttpResponse( + streaming_content=event_content(chat_id, chat_record_id, response, workflow, write_context, post_handler), + content_type='text/event-stream;charset=utf-8', + charset='utf-8') + + r['Cache-Control'] = 'no-cache' + return r + + +def to_response(chat_id, chat_record_id, response: BaseMessage, workflow, write_context, + post_handler: WorkFlowPostHandler): + """ + 将结果转换为服务输出 + + @param chat_id: 会话id + @param chat_record_id: 对话记录id + @param response: 响应数据 + @param workflow: 工作流管理器 + @param write_context 写入节点上下文 + @param post_handler: 后置处理器 + @return: 响应 + """ + answer = response.content + write_context(answer) + post_handler.handler(chat_id, chat_record_id, answer, workflow) + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}) + + +def to_response_simple(chat_id, chat_record_id, response: BaseMessage, workflow, + post_handler: WorkFlowPostHandler): + answer = response.content + post_handler.handler(chat_id, chat_record_id, answer, workflow) + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': answer, 'is_end': True}) + + +def to_stream_response_simple(stream_event): + r = StreamingHttpResponse( + streaming_content=stream_event, + content_type='text/event-stream;charset=utf-8', + charset='utf-8') + + r['Cache-Control'] = 'no-cache' + return r diff --git a/src/MaxKB-1.7.2/apps/application/flow/workflow_manage.py b/src/MaxKB-1.7.2/apps/application/flow/workflow_manage.py new file mode 100644 index 0000000..c75efbc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/flow/workflow_manage.py @@ -0,0 +1,572 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: workflow_manage.py + @date:2024/1/9 17:40 + @desc: +""" +import json +import threading +import traceback +import uuid +from concurrent.futures import ThreadPoolExecutor +from functools import reduce +from typing import List, Dict + +from django.db.models import QuerySet +from langchain_core.prompts import PromptTemplate +from rest_framework import status +from rest_framework.exceptions import ErrorDetail, ValidationError + +from application.flow import tools +from application.flow.i_step_node import INode, WorkFlowPostHandler, NodeResult +from application.flow.step_node import get_node +from common.exception.app_exception import AppApiException +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.system_to_response import SystemToResponse +from function_lib.models.function import FunctionLib +from setting.models import Model +from setting.models_provider import get_model_credential + +executor = ThreadPoolExecutor(max_workers=50) + + +class Edge: + def __init__(self, _id: str, _type: str, sourceNodeId: str, targetNodeId: str, **keywords): + self.id = _id + self.type = _type + self.sourceNodeId = sourceNodeId + self.targetNodeId = targetNodeId + for keyword in keywords: + self.__setattr__(keyword, keywords.get(keyword)) + + +class Node: + def __init__(self, _id: str, _type: str, x: int, y: int, properties: dict, **kwargs): + self.id = _id + self.type = _type + self.x = x + self.y = y + self.properties = properties + for keyword in kwargs: + self.__setattr__(keyword, kwargs.get(keyword)) + + +end_nodes = ['ai-chat-node', 'reply-node', 'function-node', 'function-lib-node'] + + +class Flow: + def __init__(self, nodes: List[Node], edges: List[Edge]): + self.nodes = nodes + self.edges = edges + + @staticmethod + def new_instance(flow_obj: Dict): + nodes = flow_obj.get('nodes') + edges = flow_obj.get('edges') + nodes = [Node(node.get('id'), node.get('type'), **node) + for node in nodes] + edges = [Edge(edge.get('id'), edge.get('type'), **edge) for edge in edges] + return Flow(nodes, edges) + + def get_start_node(self): + start_node_list = [node for node in self.nodes if node.id == 'start-node'] + return start_node_list[0] + + def get_search_node(self): + return [node for node in self.nodes if node.type == 'search-dataset-node'] + + def is_valid(self): + """ + 校验工作流数据 + """ + self.is_valid_model_params() + self.is_valid_start_node() + self.is_valid_base_node() + self.is_valid_work_flow() + + @staticmethod + def is_valid_node_params(node: Node): + get_node(node.type)(node, None, None) + + def is_valid_node(self, node: Node): + self.is_valid_node_params(node) + if node.type == 'condition-node': + branch_list = node.properties.get('node_data').get('branch') + for branch in branch_list: + source_anchor_id = f"{node.id}_{branch.get('id')}_right" + edge_list = [edge for edge in self.edges if edge.sourceAnchorId == source_anchor_id] + if len(edge_list) == 0: + raise AppApiException(500, + f'{node.properties.get("stepName")} 节点的{branch.get("type")}分支需要连接') + + else: + edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] + if len(edge_list) == 0 and not end_nodes.__contains__(node.type): + raise AppApiException(500, f'{node.properties.get("stepName")} 节点不能当做结束节点') + + def get_next_nodes(self, node: Node): + edge_list = [edge for edge in self.edges if edge.sourceNodeId == node.id] + node_list = reduce(lambda x, y: [*x, *y], + [[node for node in self.nodes if node.id == edge.targetNodeId] for edge in edge_list], + []) + if len(node_list) == 0 and not end_nodes.__contains__(node.type): + raise AppApiException(500, + f'不存在的下一个节点') + return node_list + + def is_valid_work_flow(self, up_node=None): + if up_node is None: + up_node = self.get_start_node() + self.is_valid_node(up_node) + next_nodes = self.get_next_nodes(up_node) + for next_node in next_nodes: + self.is_valid_work_flow(next_node) + + def is_valid_start_node(self): + start_node_list = [node for node in self.nodes if node.id == 'start-node'] + if len(start_node_list) == 0: + raise AppApiException(500, '开始节点必填') + if len(start_node_list) > 1: + raise AppApiException(500, '开始节点只能有一个') + + def is_valid_model_params(self): + node_list = [node for node in self.nodes if (node.type == 'ai-chat-node' or node.type == 'question-node')] + for node in node_list: + model = QuerySet(Model).filter(id=node.properties.get('node_data', {}).get('model_id')).first() + if model is None: + raise ValidationError(ErrorDetail(f'节点{node.properties.get("stepName")} 模型不存在')) + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = node.properties.get('node_data', {}).get('model_params_setting') + model_params_setting_form = credential.get_model_params_setting_form( + model.model_name) + if model_params_setting is None: + model_params_setting = model_params_setting_form.get_default_form_data() + node.properties.get('node_data', {})['model_params_setting'] = model_params_setting + if node.properties.get('status', 200) != 200: + raise ValidationError(ErrorDetail(f'节点{node.properties.get("stepName")} 不可用')) + node_list = [node for node in self.nodes if (node.type == 'function-lib-node')] + for node in node_list: + function_lib_id = node.properties.get('node_data', {}).get('function_lib_id') + if function_lib_id is None: + raise ValidationError(ErrorDetail(f'节点{node.properties.get("stepName")} 函数库id不能为空')) + f_lib = QuerySet(FunctionLib).filter(id=function_lib_id).first() + if f_lib is None: + raise ValidationError(ErrorDetail(f'节点{node.properties.get("stepName")} 函数库不可用')) + + def is_valid_base_node(self): + base_node_list = [node for node in self.nodes if node.id == 'base-node'] + if len(base_node_list) == 0: + raise AppApiException(500, '基本信息节点必填') + if len(base_node_list) > 1: + raise AppApiException(500, '基本信息节点只能有一个') + + +class NodeResultFuture: + def __init__(self, r, e, status=200): + self.r = r + self.e = e + self.status = status + + def result(self): + if self.status == 200: + return self.r + else: + raise self.e + + +def await_result(result, timeout=1): + try: + result.result(timeout) + return False + except Exception as e: + return True + + +class NodeChunkManage: + + def __init__(self, work_flow): + self.node_chunk_list = [] + self.current_node_chunk = None + self.work_flow = work_flow + + def add_node_chunk(self, node_chunk): + self.node_chunk_list.append(node_chunk) + + def contains(self, node_chunk): + return self.node_chunk_list.__contains__(node_chunk) + + def pop(self): + if self.current_node_chunk is None: + try: + current_node_chunk = self.node_chunk_list.pop(0) + self.current_node_chunk = current_node_chunk + except IndexError as e: + pass + if self.current_node_chunk is not None: + try: + chunk = self.current_node_chunk.chunk_list.pop(0) + return chunk + except IndexError as e: + if self.current_node_chunk.is_end(): + self.current_node_chunk = None + if len(self.work_flow.answer) > 0: + chunk = self.work_flow.base_to_response.to_stream_chunk_response( + self.work_flow.params['chat_id'], + self.work_flow.params['chat_record_id'], + '\n\n', False, 0, 0) + self.work_flow.answer += '\n\n' + return chunk + return self.pop() + return None + + +class NodeChunk: + def __init__(self): + self.status = 0 + self.chunk_list = [] + + def add_chunk(self, chunk): + self.chunk_list.append(chunk) + + def end(self): + self.status = 200 + + def is_end(self): + return self.status == 200 + + +class WorkflowManage: + def __init__(self, flow: Flow, params, work_flow_post_handler: WorkFlowPostHandler, + base_to_response: BaseToResponse = SystemToResponse(), form_data=None): + if form_data is None: + form_data = {} + self.form_data = form_data + self.params = params + self.flow = flow + self.lock = threading.Lock() + self.context = {} + self.node_context = [] + self.node_chunk_manage = NodeChunkManage(self) + self.work_flow_post_handler = work_flow_post_handler + self.current_node = None + self.current_result = None + self.answer = "" + self.status = 0 + self.base_to_response = base_to_response + + def run(self): + if self.params.get('stream'): + return self.run_stream() + return self.run_block() + + def run_block(self): + """ + 非流式响应 + @return: 结果 + """ + result = self.run_chain_async(None) + result.result() + details = self.get_runtime_details() + message_tokens = sum([row.get('message_tokens') for row in details.values() if + 'message_tokens' in row and row.get('message_tokens') is not None]) + answer_tokens = sum([row.get('answer_tokens') for row in details.values() if + 'answer_tokens' in row and row.get('answer_tokens') is not None]) + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + self.answer, + self) + return self.base_to_response.to_block_response(self.params['chat_id'], + self.params['chat_record_id'], self.answer, True + , message_tokens, answer_tokens, + _status=status.HTTP_200_OK if self.status == 200 else status.HTTP_500_INTERNAL_SERVER_ERROR) + + def run_stream(self): + """ + 流式响应 + @return: + """ + result = self.run_chain_async(None) + return tools.to_stream_response_simple(self.await_result(result)) + + def await_result(self, result): + try: + while await_result(result): + while True: + chunk = self.node_chunk_manage.pop() + if chunk is not None: + yield chunk + else: + break + while True: + chunk = self.node_chunk_manage.pop() + if chunk is None: + break + yield chunk + finally: + self.work_flow_post_handler.handler(self.params['chat_id'], self.params['chat_record_id'], + self.answer, + self) + yield self.get_chunk_content('', True) + + def run_chain_async(self, current_node): + future = executor.submit(self.run_chain, current_node) + return future + + def run_chain(self, current_node): + if current_node is None: + start_node = self.get_start_node() + current_node = get_node(start_node.type)(start_node, self.params, self) + node_result_future = self.run_node_future(current_node) + try: + is_stream = self.params.get('stream', True) + # 处理节点响应 + result = self.hand_event_node_result(current_node, + node_result_future) if is_stream else self.hand_node_result( + current_node, node_result_future) + with self.lock: + if current_node.status == 500: + return + node_list = self.get_next_node_list(current_node, result) + # 获取到可执行的子节点 + result_list = [] + for node in node_list: + result = self.run_chain_async(node) + result_list.append(result) + [r.result() for r in result_list] + if self.status == 0: + self.status = 200 + except Exception as e: + traceback.print_exc() + + def hand_node_result(self, current_node, node_result_future): + try: + current_result = node_result_future.result() + result = current_result.write_context(current_node, self) + if result is not None: + # 阻塞获取结果 + list(result) + # 添加节点 + self.node_context.append(current_node) + return current_result + except Exception as e: + # 添加节点 + self.node_context.append(current_node) + traceback.print_exc() + self.status = 500 + current_node.get_write_error_context(e) + self.answer += str(e) + + def hand_event_node_result(self, current_node, node_result_future): + node_chunk = NodeChunk() + try: + current_result = node_result_future.result() + result = current_result.write_context(current_node, self) + if result is not None: + if self.is_result(current_node, current_result): + self.node_chunk_manage.add_node_chunk(node_chunk) + for r in result: + chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + r, False, 0, 0) + node_chunk.add_chunk(chunk) + node_chunk.end() + else: + list(result) + # 添加节点 + self.node_context.append(current_node) + return current_result + except Exception as e: + # 添加节点 + self.node_context.append(current_node) + traceback.print_exc() + self.answer += str(e) + chunk = self.base_to_response.to_stream_chunk_response(self.params['chat_id'], + self.params['chat_record_id'], + str(e), False, 0, 0) + if not self.node_chunk_manage.contains(node_chunk): + self.node_chunk_manage.add_node_chunk(node_chunk) + node_chunk.add_chunk(chunk) + node_chunk.end() + current_node.get_write_error_context(e) + self.status = 500 + + def run_node_async(self, node): + future = executor.submit(self.run_node, node) + return future + + def run_node_future(self, node): + try: + node.valid_args(node.node_params, node.workflow_params) + result = self.run_node(node) + return NodeResultFuture(result, None, 200) + except Exception as e: + return NodeResultFuture(None, e, 500) + + def run_node(self, node): + result = node.run() + return result + + def is_result(self, current_node, current_node_result): + return current_node.node_params.get('is_result', not self._has_next_node( + current_node, current_node_result)) if current_node.node_params is not None else False + + def get_chunk_content(self, chunk, is_end=False): + return 'data: ' + json.dumps( + {'chat_id': self.params['chat_id'], 'id': self.params['chat_record_id'], 'operate': True, + 'content': chunk, 'is_end': is_end}, ensure_ascii=False) + "\n\n" + + def _has_next_node(self, current_node, node_result: NodeResult | None): + """ + 是否有下一个可运行的节点 + """ + if node_result is not None and node_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == current_node.id and + f"{edge.sourceNodeId}_{node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + return True + else: + for edge in self.flow.edges: + if edge.sourceNodeId == current_node.id: + return True + + def has_next_node(self, node_result: NodeResult | None): + """ + 是否有下一个可运行的节点 + """ + return self._has_next_node(self.get_start_node() if self.current_node is None else self.current_node, + node_result) + + def get_runtime_details(self): + details_result = {} + for index in range(len(self.node_context)): + node = self.node_context[index] + details = node.get_details(index) + details_result[str(uuid.uuid1())] = details + return details_result + + def get_next_node(self): + """ + 获取下一个可运行的所有节点 + """ + if self.current_node is None: + node = self.get_start_node() + node_instance = get_node(node.type)(node, self.params, self) + return node_instance + if self.current_result is not None and self.current_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == self.current_node.id and + f"{edge.sourceNodeId}_{self.current_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + return self.get_node_cls_by_id(edge.targetNodeId) + else: + for edge in self.flow.edges: + if edge.sourceNodeId == self.current_node.id: + return self.get_node_cls_by_id(edge.targetNodeId) + + return None + + def dependent_node_been_executed(self, node_id): + """ + 判断依赖节点是否都已执行 + @param node_id: 需要判断的节点id + @return: + """ + up_node_id_list = [edge.sourceNodeId for edge in self.flow.edges if edge.targetNodeId == node_id] + return all([any([node.id == up_node_id for node in self.node_context]) for up_node_id in up_node_id_list]) + + def get_next_node_list(self, current_node, current_node_result): + """ + 获取下一个可执行节点列表 + @param current_node: 当前可执行节点 + @param current_node_result: 当前可执行节点结果 + @return: 可执行节点列表 + """ + node_list = [] + if current_node_result is not None and current_node_result.is_assertion_result(): + for edge in self.flow.edges: + if (edge.sourceNodeId == current_node.id and + f"{edge.sourceNodeId}_{current_node_result.node_variable.get('branch_id')}_right" == edge.sourceAnchorId): + if self.dependent_node_been_executed(edge.targetNodeId): + node_list.append(self.get_node_cls_by_id(edge.targetNodeId)) + else: + for edge in self.flow.edges: + if edge.sourceNodeId == current_node.id and self.dependent_node_been_executed(edge.targetNodeId): + node_list.append(self.get_node_cls_by_id(edge.targetNodeId)) + return node_list + + def get_reference_field(self, node_id: str, fields: List[str]): + """ + @param node_id: 节点id + @param fields: 字段 + @return: + """ + if node_id == 'global': + return INode.get_field(self.context, fields) + else: + return self.get_node_by_id(node_id).get_reference_field(fields) + + def generate_prompt(self, prompt: str): + """ + 格式化生成提示词 + @param prompt: 提示词信息 + @return: 格式化后的提示词 + """ + context = { + 'global': self.context, + } + + for node in self.node_context: + properties = node.node.properties + node_config = properties.get('config') + if node_config is not None: + fields = node_config.get('fields') + if fields is not None: + for field in fields: + globeLabel = f"{properties.get('stepName')}.{field.get('value')}" + globeValue = f"context['{node.id}'].{field.get('value')}" + prompt = prompt.replace(globeLabel, globeValue) + global_fields = node_config.get('globalFields') + if global_fields is not None: + for field in global_fields: + globeLabel = f"全局变量.{field.get('value')}" + globeValue = f"context['global'].{field.get('value')}" + prompt = prompt.replace(globeLabel, globeValue) + context[node.id] = node.context + prompt_template = PromptTemplate.from_template(prompt, template_format='jinja2') + + value = prompt_template.format(context=context) + return value + + def get_start_node(self): + """ + 获取启动节点 + @return: + """ + start_node_list = [node for node in self.flow.nodes if node.type == 'start-node'] + return start_node_list[0] + + def get_base_node(self): + """ + 获取基础节点 + @return: + """ + base_node_list = [node for node in self.flow.nodes if node.type == 'base-node'] + return base_node_list[0] + + def get_node_cls_by_id(self, node_id): + for node in self.flow.nodes: + if node.id == node_id: + node_instance = get_node(node.type)(node, + self.params, self) + return node_instance + return None + + def get_node_by_id(self, node_id): + for node in self.node_context: + if node.id == node_id: + return node + return None + + def get_node_reference(self, reference_address: Dict): + node = self.get_node_by_id(reference_address.get('node_id')) + return node.context[reference_address.get('node_field')] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0001_initial.py b/src/MaxKB-1.7.2/apps/application/migrations/0001_initial.py new file mode 100644 index 0000000..d7b627a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0001_initial.py @@ -0,0 +1,134 @@ +# Generated by Django 4.1.10 on 2024-03-18 16:02 + +import application.models.application +import django.contrib.postgres.fields +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('dataset', '0001_initial'), + ('setting', '0001_initial'), + ('users', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Application', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('name', models.CharField(max_length=128, verbose_name='应用名称')), + ('desc', models.CharField(default='', max_length=512, verbose_name='引用描述')), + ('prologue', models.CharField(default='', max_length=1024, verbose_name='开场白')), + ('dialogue_number', models.IntegerField(default=0, verbose_name='会话数量')), + ('dataset_setting', models.JSONField(default=application.models.application.get_dataset_setting_dict, verbose_name='数据集参数设置')), + ('model_setting', models.JSONField(default=application.models.application.get_model_setting_dict, verbose_name='模型参数相关设置')), + ('problem_optimization', models.BooleanField(default=False, verbose_name='问题优化')), + ('model', models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, to='setting.model')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user')), + ], + options={ + 'db_table': 'application', + }, + ), + migrations.CreateModel( + name='Chat', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('abstract', models.CharField(max_length=256, verbose_name='摘要')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), + ], + options={ + 'db_table': 'application_chat', + }, + ), + migrations.CreateModel( + name='ApplicationAccessToken', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('application', models.OneToOneField(on_delete=django.db.models.deletion.CASCADE, primary_key=True, serialize=False, to='application.application', verbose_name='应用id')), + ('access_token', models.CharField(max_length=128, unique=True, verbose_name='用户公开访问 认证token')), + ('is_active', models.BooleanField(default=True, verbose_name='是否开启公开访问')), + ('access_num', models.IntegerField(default=100, verbose_name='访问次数')), + ('white_active', models.BooleanField(default=False, verbose_name='是否开启白名单')), + ('white_list', django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128), default=list, size=None, verbose_name='白名单列表')), + ], + options={ + 'db_table': 'application_access_token', + }, + ), + migrations.CreateModel( + name='ChatRecord', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('vote_status', models.CharField(choices=[('-1', '未投票'), ('0', '赞同'), ('1', '反对')], default='-1', max_length=10, verbose_name='投票')), + ('problem_text', models.CharField(max_length=1024, verbose_name='问题')), + ('answer_text', models.CharField(max_length=4096, verbose_name='答案')), + ('message_tokens', models.IntegerField(default=0, verbose_name='请求token数量')), + ('answer_tokens', models.IntegerField(default=0, verbose_name='响应token数量')), + ('const', models.IntegerField(default=0, verbose_name='总费用')), + ('details', models.JSONField(default=dict, verbose_name='对话详情')), + ('improve_paragraph_id_list', django.contrib.postgres.fields.ArrayField(base_field=models.UUIDField(blank=True), default=list, size=None, verbose_name='改进标注列表')), + ('run_time', models.FloatField(default=0, verbose_name='运行时长')), + ('index', models.IntegerField(verbose_name='对话下标')), + ('chat', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.chat')), + ], + options={ + 'db_table': 'application_chat_record', + }, + ), + migrations.CreateModel( + name='ApplicationPublicAccessClient', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(primary_key=True, serialize=False, verbose_name='公共访问链接客户端id')), + ('access_num', models.IntegerField(default=0, verbose_name='访问总次数次数')), + ('intraday_access_num', models.IntegerField(default=0, verbose_name='当日访问次数')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')), + ], + options={ + 'db_table': 'application_public_access_client', + }, + ), + migrations.CreateModel( + name='ApplicationDatasetMapping', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), + ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='dataset.dataset')), + ], + options={ + 'db_table': 'application_dataset_mapping', + }, + ), + migrations.CreateModel( + name='ApplicationApiKey', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('secret_key', models.CharField(max_length=1024, unique=True, verbose_name='秘钥')), + ('is_active', models.BooleanField(default=True, verbose_name='是否开启')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application', verbose_name='应用id')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='users.user', verbose_name='用户id')), + ], + options={ + 'db_table': 'application_api_key', + }, + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0002_chat_client_id.py b/src/MaxKB-1.7.2/apps/application/migrations/0002_chat_client_id.py new file mode 100644 index 0000000..37d900e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0002_chat_client_id.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-03-28 13:59 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='chat', + name='client_id', + field=models.UUIDField(default=None, null=True, verbose_name='客户端id'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0003_application_icon.py b/src/MaxKB-1.7.2/apps/application/migrations/0003_application_icon.py new file mode 100644 index 0000000..6e040be --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0003_application_icon.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-04-23 11:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0002_chat_client_id'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='icon', + field=models.CharField(default='/ui/favicon.ico', max_length=256, verbose_name='应用icon'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0004_applicationaccesstoken_show_source.py b/src/MaxKB-1.7.2/apps/application/migrations/0004_applicationaccesstoken_show_source.py new file mode 100644 index 0000000..851d731 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0004_applicationaccesstoken_show_source.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-04-25 11:28 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0003_application_icon'), + ] + + operations = [ + migrations.AddField( + model_name='applicationaccesstoken', + name='show_source', + field=models.BooleanField(default=False, verbose_name='是否显示知识来源'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py b/src/MaxKB-1.7.2/apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py new file mode 100644 index 0000000..0643a39 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0005_alter_chat_abstract_alter_chatrecord_answer_text.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.13 on 2024-04-29 13:33 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0004_applicationaccesstoken_show_source'), + ] + + operations = [ + migrations.AlterField( + model_name='chat', + name='abstract', + field=models.CharField(max_length=1024, verbose_name='摘要'), + ), + migrations.AlterField( + model_name='chatrecord', + name='answer_text', + field=models.CharField(max_length=40960, verbose_name='答案'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0006_applicationapikey_allow_cross_domain_and_more.py b/src/MaxKB-1.7.2/apps/application/migrations/0006_applicationapikey_allow_cross_domain_and_more.py new file mode 100644 index 0000000..cd24c0e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0006_applicationapikey_allow_cross_domain_and_more.py @@ -0,0 +1,24 @@ +# Generated by Django 4.1.13 on 2024-05-08 13:57 + +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0005_alter_chat_abstract_alter_chatrecord_answer_text'), + ] + + operations = [ + migrations.AddField( + model_name='applicationapikey', + name='allow_cross_domain', + field=models.BooleanField(default=False, verbose_name='是否允许跨域'), + ), + migrations.AddField( + model_name='applicationapikey', + name='cross_domain_list', + field=django.contrib.postgres.fields.ArrayField(base_field=models.CharField(blank=True, max_length=128), default=list, size=None, verbose_name='跨域列表'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0007_alter_application_prologue.py b/src/MaxKB-1.7.2/apps/application/migrations/0007_alter_application_prologue.py new file mode 100644 index 0000000..27b519c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0007_alter_application_prologue.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-05-24 11:00 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0006_applicationapikey_allow_cross_domain_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='application', + name='prologue', + field=models.CharField(default='', max_length=4096, verbose_name='开场白'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0008_chat_is_deleted.py b/src/MaxKB-1.7.2/apps/application/migrations/0008_chat_is_deleted.py new file mode 100644 index 0000000..5291c3f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0008_chat_is_deleted.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-06-13 11:46 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0007_alter_application_prologue'), + ] + + operations = [ + migrations.AddField( + model_name='chat', + name='is_deleted', + field=models.BooleanField(default=False, verbose_name=''), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0009_application_type_application_work_flow_and_more.py b/src/MaxKB-1.7.2/apps/application/migrations/0009_application_type_application_work_flow_and_more.py new file mode 100644 index 0000000..5d0bf0c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0009_application_type_application_work_flow_and_more.py @@ -0,0 +1,38 @@ +# Generated by Django 4.1.13 on 2024-06-25 16:30 + +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0008_chat_is_deleted'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='type', + field=models.CharField(choices=[('SIMPLE', '简易'), ('WORK_FLOW', '工作流')], default='SIMPLE', max_length=256, verbose_name='应用类型'), + ), + migrations.AddField( + model_name='application', + name='work_flow', + field=models.JSONField(default=dict, verbose_name='工作流数据'), + ), + migrations.CreateModel( + name='WorkFlowVersion', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('work_flow', models.JSONField(default=dict, verbose_name='工作流数据')), + ('application', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='application.application')), + ], + options={ + 'db_table': 'application_work_flow_version', + }, + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0010_alter_chatrecord_details.py b/src/MaxKB-1.7.2/apps/application/migrations/0010_alter_chatrecord_details.py new file mode 100644 index 0000000..e462780 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0010_alter_chatrecord_details.py @@ -0,0 +1,19 @@ +# Generated by Django 4.2.13 on 2024-07-15 15:52 + +import application.models.application +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0009_application_type_application_work_flow_and_more'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='details', + field=models.JSONField(default=dict, encoder=application.models.application.DateEncoder, verbose_name='对话详情'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0011_application_model_params_setting.py b/src/MaxKB-1.7.2/apps/application/migrations/0011_application_model_params_setting.py new file mode 100644 index 0000000..440b94d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0011_application_model_params_setting.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-08-23 14:17 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0010_alter_chatrecord_details'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='model_params_setting', + field=models.JSONField(default={}, verbose_name='模型参数相关设置'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0012_application_stt_model_application_stt_model_enable_and_more.py b/src/MaxKB-1.7.2/apps/application/migrations/0012_application_stt_model_application_stt_model_enable_and_more.py new file mode 100644 index 0000000..f50c39d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0012_application_stt_model_application_stt_model_enable_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 4.2.15 on 2024-09-05 14:35 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0006_alter_model_status'), + ('application', '0011_application_model_params_setting'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='stt_model', + field=models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='stt_model_id', to='setting.model'), + ), + migrations.AddField( + model_name='application', + name='stt_model_enable', + field=models.BooleanField(default=False, verbose_name='语音识别模型是否启用'), + ), + migrations.AddField( + model_name='application', + name='tts_model', + field=models.ForeignKey(blank=True, db_constraint=False, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='tts_model_id', to='setting.model'), + ), + migrations.AddField( + model_name='application', + name='tts_model_enable', + field=models.BooleanField(default=False, verbose_name='语音合成模型是否启用'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0013_application_tts_type.py b/src/MaxKB-1.7.2/apps/application/migrations/0013_application_tts_type.py new file mode 100644 index 0000000..c64c8e7 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0013_application_tts_type.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-09-12 11:01 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0012_application_stt_model_application_stt_model_enable_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='tts_type', + field=models.CharField(default='BROWSER', max_length=20, verbose_name='语音播放类型'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0014_application_problem_optimization_prompt.py b/src/MaxKB-1.7.2/apps/application/migrations/0014_application_problem_optimization_prompt.py new file mode 100644 index 0000000..e2efc10 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0014_application_problem_optimization_prompt.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-09-13 18:57 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0013_application_tts_type'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='problem_optimization_prompt', + field=models.CharField(blank=True, default='()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中', max_length=102400, null=True, verbose_name='问题优化提示词'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0015_re_database_index.py b/src/MaxKB-1.7.2/apps/application/migrations/0015_re_database_index.py new file mode 100644 index 0000000..740a2a2 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0015_re_database_index.py @@ -0,0 +1,63 @@ +# Generated by Django 4.2.15 on 2024-09-18 16:14 +import logging + +import psycopg2 +from django.db import migrations +from psycopg2 import extensions + +from smartdoc.const import CONFIG + + +def get_connect(db_name): + conn_params = { + "dbname": db_name, + "user": CONFIG.get('DB_USER'), + "password": CONFIG.get('DB_PASSWORD'), + "host": CONFIG.get('DB_HOST'), + "port": CONFIG.get('DB_PORT') + } + # 建立连接 + connect = psycopg2.connect(**conn_params) + return connect + + +def sql_execute(conn, reindex_sql: str, alter_database_sql: str): + """ + 执行一条sql + @param reindex_sql: + @param conn: + @param alter_database_sql: + """ + conn.set_isolation_level(extensions.ISOLATION_LEVEL_AUTOCOMMIT) + with conn.cursor() as cursor: + cursor.execute(reindex_sql, []) + cursor.execute(alter_database_sql, []) + cursor.close() + + +def re_index(apps, schema_editor): + app_db_name = CONFIG.get('DB_NAME') + try: + re_index_database(app_db_name) + except Exception as e: + logging.error(f'reindex database {app_db_name}发送错误:{str(e)}') + try: + re_index_database('root') + except Exception as e: + logging.error(f'reindex database root 发送错误:{str(e)}') + + +def re_index_database(db_name): + db_conn = get_connect(db_name) + sql_execute(db_conn, f'REINDEX DATABASE "{db_name}";', f'ALTER DATABASE "{db_name}" REFRESH COLLATION VERSION;') + db_conn.close() + + +class Migration(migrations.Migration): + dependencies = [ + ('application', '0014_application_problem_optimization_prompt'), + ] + + operations = [ + migrations.RunPython(re_index, atomic=False) + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0016_alter_chatrecord_problem_text.py b/src/MaxKB-1.7.2/apps/application/migrations/0016_alter_chatrecord_problem_text.py new file mode 100644 index 0000000..edda1e6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0016_alter_chatrecord_problem_text.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-09-26 13:19 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0015_re_database_index'), + ] + + operations = [ + migrations.AlterField( + model_name='chatrecord', + name='problem_text', + field=models.CharField(max_length=10240, verbose_name='问题'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0017_application_tts_model_params_setting.py b/src/MaxKB-1.7.2/apps/application/migrations/0017_application_tts_model_params_setting.py new file mode 100644 index 0000000..4342884 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0017_application_tts_model_params_setting.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-10-16 13:10 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('application', '0016_alter_chatrecord_problem_text'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='tts_model_params_setting', + field=models.JSONField(default={}, verbose_name='模型参数相关设置'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/0018_workflowversion_name.py b/src/MaxKB-1.7.2/apps/application/migrations/0018_workflowversion_name.py new file mode 100644 index 0000000..51d0417 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/migrations/0018_workflowversion_name.py @@ -0,0 +1,38 @@ +# Generated by Django 4.2.15 on 2024-10-16 15:17 + +from django.db import migrations, models + +sql = """ +UPDATE "public".application_work_flow_version +SET "name" = TO_CHAR(create_time, 'YYYY-MM-DD HH24:MI:SS'); +""" + + +class Migration(migrations.Migration): + dependencies = [ + ('application', '0017_application_tts_model_params_setting'), + ] + + operations = [ + migrations.AddField( + model_name='application', + name='clean_time', + field=models.IntegerField(default=180, verbose_name='清理时间'), + ), + migrations.AddField( + model_name='workflowversion', + name='name', + field=models.CharField(default='', max_length=128, verbose_name='版本名称'), + ), + migrations.RunSQL(sql), + migrations.AddField( + model_name='workflowversion', + name='publish_user_id', + field=models.UUIDField(default=None, null=True, verbose_name='发布者id'), + ), + migrations.AddField( + model_name='workflowversion', + name='publish_user_name', + field=models.CharField(default='', max_length=128, verbose_name='发布者名称'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/application/migrations/__init__.py b/src/MaxKB-1.7.2/apps/application/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/application/models/__init__.py b/src/MaxKB-1.7.2/apps/application/models/__init__.py new file mode 100644 index 0000000..0d57976 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/models/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/25 14:25 + @desc: +""" +from .application import * diff --git a/src/MaxKB-1.7.2/apps/application/models/api_key_model.py b/src/MaxKB-1.7.2/apps/application/models/api_key_model.py new file mode 100644 index 0000000..965e1f1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/models/api_key_model.py @@ -0,0 +1,59 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: api_key_model.py + @date:2023/11/14 17:15 + @desc: +""" +import uuid + +from django.contrib.postgres.fields import ArrayField +from django.db import models + +from application.models import Application +from common.mixins.app_model_mixin import AppModelMixin +from users.models import User + + +class ApplicationApiKey(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + secret_key = models.CharField(max_length=1024, verbose_name="秘钥", unique=True) + user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户id") + application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id") + is_active = models.BooleanField(default=True, verbose_name="是否开启") + allow_cross_domain = models.BooleanField(default=False, verbose_name="是否允许跨域") + cross_domain_list = ArrayField(verbose_name="跨域列表", + base_field=models.CharField(max_length=128, blank=True) + , default=list) + + class Meta: + db_table = "application_api_key" + + +class ApplicationAccessToken(AppModelMixin): + """ + 应用认证token + """ + application = models.OneToOneField(Application, primary_key=True, on_delete=models.CASCADE, verbose_name="应用id") + access_token = models.CharField(max_length=128, verbose_name="用户公开访问 认证token", unique=True) + is_active = models.BooleanField(default=True, verbose_name="是否开启公开访问") + access_num = models.IntegerField(default=100, verbose_name="访问次数") + white_active = models.BooleanField(default=False, verbose_name="是否开启白名单") + white_list = ArrayField(verbose_name="白名单列表", + base_field=models.CharField(max_length=128, blank=True) + , default=list) + show_source = models.BooleanField(default=False, verbose_name="是否显示知识来源") + + class Meta: + db_table = "application_access_token" + + +class ApplicationPublicAccessClient(AppModelMixin): + id = models.UUIDField(max_length=128, primary_key=True, verbose_name="公共访问链接客户端id") + application = models.ForeignKey(Application, on_delete=models.CASCADE, verbose_name="应用id") + access_num = models.IntegerField(default=0, verbose_name="访问总次数次数") + intraday_access_num = models.IntegerField(default=0, verbose_name="当日访问次数") + + class Meta: + db_table = "application_public_access_client" diff --git a/src/MaxKB-1.7.2/apps/application/models/application.py b/src/MaxKB-1.7.2/apps/application/models/application.py new file mode 100644 index 0000000..e65278d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/models/application.py @@ -0,0 +1,166 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application.py + @date:2023/9/25 14:24 + @desc: +""" +import datetime +import json +import uuid + +from django.contrib.postgres.fields import ArrayField +from django.db import models +from langchain.schema import HumanMessage, AIMessage + +from common.mixins.app_model_mixin import AppModelMixin +from dataset.models.data_set import DataSet +from setting.models.model_management import Model +from users.models import User + + +class ApplicationTypeChoices(models.TextChoices): + """订单类型""" + SIMPLE = 'SIMPLE', '简易' + WORK_FLOW = 'WORK_FLOW', '工作流' + + +def get_dataset_setting_dict(): + return {'top_n': 3, 'similarity': 0.6, 'max_paragraph_char_number': 5000, 'search_mode': 'embedding', + 'no_references_setting': { + 'status': 'ai_questioning', + 'value': '{question}' + }} + + +def get_model_setting_dict(): + return {'prompt': Application.get_default_model_prompt(), 'no_references_prompt': '{question}'} + + +class Application(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + name = models.CharField(max_length=128, verbose_name="应用名称") + desc = models.CharField(max_length=512, verbose_name="引用描述", default="") + prologue = models.CharField(max_length=4096, verbose_name="开场白", default="") + dialogue_number = models.IntegerField(default=0, verbose_name="会话数量") + user = models.ForeignKey(User, on_delete=models.DO_NOTHING) + model = models.ForeignKey(Model, on_delete=models.SET_NULL, db_constraint=False, blank=True, null=True) + dataset_setting = models.JSONField(verbose_name="数据集参数设置", default=get_dataset_setting_dict) + model_setting = models.JSONField(verbose_name="模型参数相关设置", default=get_model_setting_dict) + model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default={}) + tts_model_params_setting = models.JSONField(verbose_name="模型参数相关设置", default={}) + problem_optimization = models.BooleanField(verbose_name="问题优化", default=False) + icon = models.CharField(max_length=256, verbose_name="应用icon", default="/ui/favicon.ico") + work_flow = models.JSONField(verbose_name="工作流数据", default=dict) + type = models.CharField(verbose_name="应用类型", choices=ApplicationTypeChoices.choices, + default=ApplicationTypeChoices.SIMPLE, max_length=256) + problem_optimization_prompt = models.CharField(verbose_name="问题优化提示词", max_length=102400, blank=True, + null=True, + default="()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中") + tts_model = models.ForeignKey(Model, related_name='tts_model_id', on_delete=models.SET_NULL, db_constraint=False, + blank=True, null=True) + stt_model = models.ForeignKey(Model, related_name='stt_model_id', on_delete=models.SET_NULL, db_constraint=False, + blank=True, null=True) + tts_model_enable = models.BooleanField(verbose_name="语音合成模型是否启用", default=False) + stt_model_enable = models.BooleanField(verbose_name="语音识别模型是否启用", default=False) + tts_type = models.CharField(verbose_name="语音播放类型", max_length=20, default="BROWSER") + clean_time = models.IntegerField(verbose_name="清理时间", default=180) + + @staticmethod + def get_default_model_prompt(): + return ('已知信息:' + '\n{data}' + '\n回答要求:' + '\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。' + '\n- 避免提及你是从中获得的知识。' + '\n- 请保持答案与中描述的一致。' + '\n- 请使用markdown 语法优化答案的格式。' + '\n- 中的图片链接、链接地址和脚本语言请完整返回。' + '\n- 请使用与问题相同的语言来回答。' + '\n问题:' + '\n{question}') + + class Meta: + db_table = "application" + + +class WorkFlowVersion(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + application = models.ForeignKey(Application, on_delete=models.CASCADE) + name = models.CharField(verbose_name="版本名称", max_length=128, default="") + publish_user_id = models.UUIDField(verbose_name="发布者id", max_length=128, default=None, null=True) + publish_user_name = models.CharField(verbose_name="发布者名称", max_length=128, default="") + work_flow = models.JSONField(verbose_name="工作流数据", default=dict) + + class Meta: + db_table = "application_work_flow_version" + + +class ApplicationDatasetMapping(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + application = models.ForeignKey(Application, on_delete=models.CASCADE) + dataset = models.ForeignKey(DataSet, on_delete=models.CASCADE) + + class Meta: + db_table = "application_dataset_mapping" + + +class Chat(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + application = models.ForeignKey(Application, on_delete=models.CASCADE) + abstract = models.CharField(max_length=1024, verbose_name="摘要") + client_id = models.UUIDField(verbose_name="客户端id", default=None, null=True) + is_deleted = models.BooleanField(verbose_name="", default=False) + + class Meta: + db_table = "application_chat" + + +class VoteChoices(models.TextChoices): + """订单类型""" + UN_VOTE = -1, '未投票' + STAR = 0, '赞同' + TRAMPLE = 1, '反对' + + +class DateEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, uuid.UUID): + return str(obj) + if isinstance(obj, datetime.datetime): + return obj.strftime("%Y-%m-%d %H:%M:%S") + else: + return json.JSONEncoder.default(self, obj) + + +class ChatRecord(AppModelMixin): + """ + 对话日志 详情 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + chat = models.ForeignKey(Chat, on_delete=models.CASCADE) + vote_status = models.CharField(verbose_name='投票', max_length=10, choices=VoteChoices.choices, + default=VoteChoices.UN_VOTE) + problem_text = models.CharField(max_length=10240, verbose_name="问题") + answer_text = models.CharField(max_length=40960, verbose_name="答案") + message_tokens = models.IntegerField(verbose_name="请求token数量", default=0) + answer_tokens = models.IntegerField(verbose_name="响应token数量", default=0) + const = models.IntegerField(verbose_name="总费用", default=0) + details = models.JSONField(verbose_name="对话详情", default=dict, encoder=DateEncoder) + improve_paragraph_id_list = ArrayField(verbose_name="改进标注列表", + base_field=models.UUIDField(max_length=128, blank=True) + , default=list) + run_time = models.FloatField(verbose_name="运行时长", default=0) + index = models.IntegerField(verbose_name="对话下标") + + def get_human_message(self): + if 'problem_padding' in self.details: + return HumanMessage(content=self.details.get('problem_padding').get('padding_problem_text')) + return HumanMessage(content=self.problem_text) + + def get_ai_message(self): + return AIMessage(content=self.answer_text) + + class Meta: + db_table = "application_chat_record" diff --git a/src/MaxKB-1.7.2/apps/application/serializers/application_serializers.py b/src/MaxKB-1.7.2/apps/application/serializers/application_serializers.py new file mode 100644 index 0000000..32ea3ec --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/serializers/application_serializers.py @@ -0,0 +1,1054 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application_serializers.py + @date:2023/11/7 10:02 + @desc: +""" +import datetime +import hashlib +import json +import os +import re +import uuid +from functools import reduce +from typing import Dict, List + +from django.conf import settings +from django.contrib.postgres.fields import ArrayField +from django.core import cache, validators +from django.core import signing +from django.db import transaction, models +from django.db.models import QuerySet +from django.http import HttpResponse +from django.template import Template, Context +from rest_framework import serializers + +from application.flow.workflow_manage import Flow +from application.models import Application, ApplicationDatasetMapping, ApplicationTypeChoices, WorkFlowVersion +from application.models.api_key_model import ApplicationAccessToken, ApplicationApiKey +from common.cache_data.application_access_token_cache import get_application_access_token, del_application_access_token +from common.cache_data.application_api_key_cache import del_application_api_key, get_application_api_key +from common.config.embedding_config import VectorStore +from common.constants.authentication_type import AuthenticationType +from common.db.search import get_dynamics_model, native_search, native_page_search +from common.db.sql_execute import select_list +from common.exception.app_exception import AppApiException, NotFound404, AppUnauthorizedFailed +from common.field.common import UploadedImageField +from common.models.db_model_manage import DBModelManage +from common.util.common import valid_license, password_encrypt +from common.util.field_message import ErrMessage +from common.util.file_util import get_file_content +from dataset.models import DataSet, Document, Image +from dataset.serializers.common_serializers import list_paragraph, get_embedding_model_by_dataset_id_list +from embedding.models import SearchMode +from function_lib.serializers.function_lib_serializer import FunctionLibSerializer +from setting.models import AuthOperate +from setting.models.model_management import Model +from setting.models_provider import get_model_credential +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +from setting.models_provider.tools import get_model_instance_by_model_user_id +from setting.serializers.provider_serializers import ModelSerializer +from smartdoc.conf import PROJECT_DIR +from users.models import User + +chat_cache = cache.caches['chat_cache'] + + +class ModelDatasetAssociation(serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("模型id")) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, + error_messages=ErrMessage.uuid( + "知识库id")), + error_messages=ErrMessage.list("知识库列表")) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + model_id = self.data.get('model_id') + user_id = self.data.get('user_id') + if model_id is not None and len(model_id) > 0: + if not QuerySet(Model).filter(id=model_id).exists(): + raise AppApiException(500, f'模型不存在【{model_id}】') + dataset_id_list = list(set(self.data.get('dataset_id_list'))) + exist_dataset_id_list = [str(dataset.id) for dataset in + QuerySet(DataSet).filter(id__in=dataset_id_list, user_id=user_id)] + for dataset_id in dataset_id_list: + if not exist_dataset_id_list.__contains__(dataset_id): + raise AppApiException(500, f'知识库id不存在【{dataset_id}】') + + +class ApplicationSerializerModel(serializers.ModelSerializer): + class Meta: + model = Application + fields = "__all__" + + +class NoReferencesChoices(models.TextChoices): + """订单类型""" + ai_questioning = 'ai_questioning', 'ai回答' + designated_answer = 'designated_answer', '指定回答' + + +class NoReferencesSetting(serializers.Serializer): + status = serializers.ChoiceField(required=True, choices=NoReferencesChoices.choices, + error_messages=ErrMessage.char("无引用状态")) + value = serializers.CharField(required=True, error_messages=ErrMessage.char("提示词")) + + +def valid_model_params_setting(model_id, model_params_setting): + if model_id is None or model_params_setting is None or len(model_params_setting.keys()) == 0: + return + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + credential.get_model_params_setting_form(model.model_name).valid_form(model_params_setting) + + +class DatasetSettingSerializer(serializers.Serializer): + top_n = serializers.FloatField(required=True, max_value=100, min_value=1, + error_messages=ErrMessage.float("引用分段数")) + similarity = serializers.FloatField(required=True, max_value=1, min_value=0, + error_messages=ErrMessage.float("相识度")) + max_paragraph_char_number = serializers.IntegerField(required=True, min_value=500, max_value=100000, + error_messages=ErrMessage.integer("最多引用字符数")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) + + no_references_setting = NoReferencesSetting(required=True, error_messages=ErrMessage.base("未引用分段设置")) + + +class ModelSettingSerializer(serializers.Serializer): + prompt = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, + error_messages=ErrMessage.char("提示词")) + system = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=102400, + error_messages=ErrMessage.char("角色提示词")) + no_references_prompt = serializers.CharField(required=True, max_length=102400, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("无引用分段提示词")) + + +class ApplicationWorkflowSerializer(serializers.Serializer): + name = serializers.CharField(required=True, max_length=64, min_length=1, error_messages=ErrMessage.char("应用名称")) + desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, + max_length=256, min_length=1, + error_messages=ErrMessage.char("应用描述")) + work_flow = serializers.DictField(required=False, error_messages=ErrMessage.dict("工作流对象")) + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + error_messages=ErrMessage.char("开场白")) + + @staticmethod + def to_application_model(user_id: str, application: Dict): + if application.get('work_flow') is not None: + default_workflow = application.get('work_flow') + else: + default_workflow_json = get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'flow', 'default_workflow.json')) + default_workflow = json.loads(default_workflow_json) + for node in default_workflow.get('nodes'): + if node.get('id') == 'base-node': + node.get('properties')['node_data'] = {"desc": application.get('desc'), + "name": application.get('name'), + "prologue": application.get('prologue')} + return Application(id=uuid.uuid1(), + name=application.get('name'), + desc=application.get('desc'), + prologue="", + dialogue_number=0, + user_id=user_id, model_id=None, + dataset_setting={}, + model_setting={}, + problem_optimization=False, + type=ApplicationTypeChoices.WORK_FLOW, + work_flow=default_workflow + ) + + +def get_base_node_work_flow(work_flow): + node_list = work_flow.get('nodes') + base_node_list = [node for node in node_list if node.get('id') == 'base-node'] + if len(base_node_list) > 0: + return base_node_list[-1] + return None + + +class ApplicationSerializer(serializers.Serializer): + name = serializers.CharField(required=True, max_length=64, min_length=1, error_messages=ErrMessage.char("应用名称")) + desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, + max_length=256, min_length=1, + error_messages=ErrMessage.char("应用描述")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("模型")) + dialogue_number = serializers.IntegerField(required=True, + min_value=0, + max_value=1024, + error_messages=ErrMessage.integer("历史聊天记录")) + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + error_messages=ErrMessage.char("开场白")) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), + allow_null=True, error_messages=ErrMessage.list("关联知识库")) + # 数据集相关设置 + dataset_setting = DatasetSettingSerializer(required=True) + # 模型相关设置 + model_setting = ModelSettingSerializer(required=True) + # 问题补全 + problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全")) + problem_optimization_prompt = serializers.CharField(required=False, max_length=102400, + error_messages=ErrMessage.char("问题补全提示词")) + # 应用类型 + type = serializers.CharField(required=True, error_messages=ErrMessage.char("应用类型"), + validators=[ + validators.RegexValidator(regex=re.compile("^SIMPLE|WORK_FLOW$"), + message="应用类型只支持SIMPLE|WORK_FLOW", code=500) + ] + ) + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.dict('模型参数')) + + def is_valid(self, *, user_id=None, raise_exception=False): + super().is_valid(raise_exception=True) + ModelDatasetAssociation(data={'user_id': user_id, 'model_id': self.data.get('model_id'), + 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid() + + class Embed(serializers.Serializer): + host = serializers.CharField(required=True, error_messages=ErrMessage.char("主机")) + protocol = serializers.CharField(required=True, error_messages=ErrMessage.char("协议")) + token = serializers.CharField(required=True, error_messages=ErrMessage.char("token")) + + def get_embed(self, with_valid=True, params=None): + if params is None: + params = {} + if with_valid: + self.is_valid(raise_exception=True) + index_path = os.path.join(PROJECT_DIR, 'apps', "application", 'template', 'embed.js') + file = open(index_path, "r", encoding='utf-8') + content = file.read() + file.close() + application_access_token = QuerySet(ApplicationAccessToken).filter( + access_token=self.data.get('token')).first() + is_draggable = 'false' + show_guide = 'true' + float_icon = f"{self.data.get('protocol')}://{self.data.get('host')}/ui/MaxKB.gif" + xpack_cache = DBModelManage.get_model('xpack_cache') + X_PACK_LICENSE_IS_VALID = False if xpack_cache is None else xpack_cache.get('XPACK_LICENSE_IS_VALID', False) + # 获取接入的query参数 + query = self.get_query_api_input(application_access_token.application, params) + float_location = {"x": {"type": "right", "value": 0}, "y": {"type": "bottom", "value": 30}} + application_setting_model = DBModelManage.get_model('application_setting') + if application_setting_model is not None and X_PACK_LICENSE_IS_VALID: + application_setting = QuerySet(application_setting_model).filter( + application_id=application_access_token.application_id).first() + if application_setting is not None: + is_draggable = 'true' if application_setting.draggable else 'false' + if application_setting.float_icon is not None and len(application_setting.float_icon) > 0: + float_icon = f"{self.data.get('protocol')}://{self.data.get('host')}{application_setting.float_icon}" + show_guide = 'true' if application_setting.show_guide else 'false' + if application_setting.float_location is not None: + float_location = application_setting.float_location + + is_auth = 'true' if application_access_token is not None and application_access_token.is_active else 'false' + t = Template(content) + s = t.render( + Context( + {'is_auth': is_auth, 'protocol': self.data.get('protocol'), 'host': self.data.get('host'), + 'token': self.data.get('token'), + 'white_list_str': ",".join( + application_access_token.white_list if application_access_token.white_list is not None else []), + 'white_active': 'true' if application_access_token.white_active else 'false', + 'is_draggable': is_draggable, + 'float_icon': float_icon, + 'query': query, + 'show_guide': show_guide, + 'x_type': float_location.get('x', {}).get('type', 'right'), + 'x_value': float_location.get('x', {}).get('value', 0), + 'y_type': float_location.get('y', {}).get('type', 'bottom'), + 'y_value': float_location.get('y', {}).get('value', 30)})) + response = HttpResponse(s, status=200, headers={'Content-Type': 'text/javascript'}) + return response + + def get_query_api_input(self, application, params): + query = '' + if application.work_flow is not None: + work_flow = application.work_flow + if work_flow is not None: + for node in work_flow.get('nodes', []): + if node['id'] == 'base-node': + input_field_list = node.get('properties', {}).get('api_input_field_list', + node.get('properties', {}).get( + 'input_field_list', [])) + if input_field_list is not None: + for field in input_field_list: + if field['assignment_method'] == 'api_input' and field['variable'] in params: + query += f"&{field['variable']}={params[field['variable']]}" + + return query + + class AccessTokenSerializer(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.boolean("应用id")) + + class AccessTokenEditSerializer(serializers.Serializer): + access_token_reset = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean("重置Token")) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("是否开启")) + access_num = serializers.IntegerField(required=False, max_value=10000, + min_value=0, + error_messages=ErrMessage.integer("访问次数")) + white_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("是否开启白名单")) + white_list = serializers.ListSerializer(required=False, child=serializers.CharField(required=True, + error_messages=ErrMessage.char( + "白名单")), + error_messages=ErrMessage.list("白名单列表")), + show_source = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean("是否显示知识来源")) + + def edit(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ApplicationSerializer.AccessTokenSerializer.AccessTokenEditSerializer(data=instance).is_valid( + raise_exception=True) + + application_access_token = QuerySet(ApplicationAccessToken).get( + application_id=self.data.get('application_id')) + if 'is_active' in instance: + application_access_token.is_active = instance.get("is_active") + if 'access_token_reset' in instance and instance.get('access_token_reset'): + del_application_access_token(application_access_token.access_token) + application_access_token.access_token = hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24] + if 'access_num' in instance and instance.get('access_num') is not None: + application_access_token.access_num = instance.get("access_num") + if 'white_active' in instance and instance.get('white_active') is not None: + application_access_token.white_active = instance.get("white_active") + if 'white_list' in instance and instance.get('white_list') is not None: + application_access_token.white_list = instance.get('white_list') + if 'show_source' in instance and instance.get('show_source') is not None: + application_access_token.show_source = instance.get('show_source') + application_access_token.save() + application_setting_model = DBModelManage.get_model('application_setting') + xpack_cache = DBModelManage.get_model('xpack_cache') + X_PACK_LICENSE_IS_VALID = False if xpack_cache is None else xpack_cache.get("XPACK_LICENSE_IS_VALID", False) + if application_setting_model is not None and X_PACK_LICENSE_IS_VALID: + application_setting, _ = application_setting_model.objects.get_or_create( + application_id=self.data.get('application_id')) + if application_setting is not None and instance.get('authentication') is not None and instance.get( + 'authentication_value') is not None: + application_setting.authentication = instance.get('authentication') + application_setting.authentication_value = { + "type": "password", + "value": instance.get('authentication_value') + } + application_setting.save() + + get_application_access_token(application_access_token.access_token, False) + return self.one(with_valid=False) + + def one(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application_id = self.data.get("application_id") + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=application_id).first() + if application_access_token is None: + application_access_token = ApplicationAccessToken(application_id=application_id, + access_token=hashlib.md5( + str(uuid.uuid1()).encode()).hexdigest()[ + 8:24], is_active=True) + application_access_token.save() + return {'application_id': application_access_token.application_id, + 'access_token': application_access_token.access_token, + "is_active": application_access_token.is_active, + 'access_num': application_access_token.access_num, + 'white_active': application_access_token.white_active, + 'white_list': application_access_token.white_list, + 'show_source': application_access_token.show_source + } + + class Authentication(serializers.Serializer): + access_token = serializers.CharField(required=True, error_messages=ErrMessage.char("access_token")) + authentication_value = serializers.JSONField(required=False, allow_null=True, + error_messages=ErrMessage.char("认证信息")) + + def auth(self, request, with_valid=True): + token = request.META.get('HTTP_AUTHORIZATION') + token_details = None + try: + # 校验token + if token is not None: + token_details = signing.loads(token) + except Exception as e: + token = None + if with_valid: + self.is_valid(raise_exception=True) + access_token = self.data.get("access_token") + application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + authentication_value = self.data.get('authentication_value', None) + authentication = {} + if application_access_token is not None and application_access_token.is_active: + if token_details is not None and 'client_id' in token_details and token_details.get( + 'client_id') is not None: + client_id = token_details.get('client_id') + authentication = token_details.get('authentication', {}) + else: + client_id = str(uuid.uuid1()) + if authentication_value is not None: + # 认证用户token + self.auth_authentication_value(authentication_value, str(application_access_token.application_id)) + authentication = {'type': authentication_value.get('type'), + 'value': password_encrypt(authentication_value.get('value'))} + token = signing.dumps({'application_id': str(application_access_token.application_id), + 'user_id': str(application_access_token.application.user.id), + 'access_token': application_access_token.access_token, + 'type': AuthenticationType.APPLICATION_ACCESS_TOKEN.value, + 'client_id': client_id, + 'authentication': authentication}) + return token + else: + raise NotFound404(404, "无效的access_token") + + def auth_authentication_value(self, authentication_value, application_id): + application_setting_model = DBModelManage.get_model('application_setting') + xpack_cache = DBModelManage.get_model('xpack_cache') + X_PACK_LICENSE_IS_VALID = False if xpack_cache is None else xpack_cache.get('XPACK_LICENSE_IS_VALID', False) + if application_setting_model is not None and X_PACK_LICENSE_IS_VALID: + application_setting = QuerySet(application_setting_model).filter(application_id=application_id).first() + if application_setting.authentication and authentication_value is not None: + if authentication_value.get('type') == 'password': + if not self.auth_password(authentication_value, application_setting.authentication_value): + raise AppApiException(1005, "密码错误") + return True + + @staticmethod + def auth_password(source_authentication_value, authentication_value): + return source_authentication_value.get('value') == authentication_value.get('value') + + class Edit(serializers.Serializer): + name = serializers.CharField(required=False, max_length=64, min_length=1, + error_messages=ErrMessage.char("应用名称")) + desc = serializers.CharField(required=False, max_length=256, min_length=1, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("应用描述")) + model_id = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("模型")) + dialogue_number = serializers.IntegerField(required=False, + min_value=0, + max_value=1024, + error_messages=ErrMessage.integer("历史聊天记录")) + prologue = serializers.CharField(required=False, allow_null=True, allow_blank=True, max_length=4096, + error_messages=ErrMessage.char("开场白")) + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("关联知识库") + ) + # 数据集相关设置 + dataset_setting = DatasetSettingSerializer(required=False, allow_null=True, + error_messages=ErrMessage.json("数据集设置")) + # 模型相关设置 + model_setting = ModelSettingSerializer(required=False, allow_null=True, + error_messages=ErrMessage.json("模型设置")) + # 问题补全 + problem_optimization = serializers.BooleanField(required=False, allow_null=True, + error_messages=ErrMessage.boolean("问题补全")) + icon = serializers.CharField(required=False, allow_null=True, error_messages=ErrMessage.char("icon图标")) + + model_params_setting = serializers.DictField(required=False, error_messages=ErrMessage.dict('模型参数')) + + class Create(serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + @valid_license(model=Application, count=5, + message='社区版最多支持 5 个应用,如需拥有更多应用,请联系我们(https://fit2cloud.com/)。') + @transaction.atomic + def insert(self, application: Dict): + application_type = application.get('type') + if 'WORK_FLOW' == application_type: + return self.insert_workflow(application) + else: + return self.insert_simple(application) + + def insert_workflow(self, application: Dict): + self.is_valid(raise_exception=True) + user_id = self.data.get('user_id') + ApplicationWorkflowSerializer(data=application).is_valid(raise_exception=True) + application_model = ApplicationWorkflowSerializer.to_application_model(user_id, application) + application_model.save() + # 插入认证信息 + ApplicationAccessToken(application_id=application_model.id, + access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save() + return ApplicationSerializerModel(application_model).data + + def insert_simple(self, application: Dict): + self.is_valid(raise_exception=True) + user_id = self.data.get('user_id') + ApplicationSerializer(data=application).is_valid(user_id=user_id, raise_exception=True) + application_model = ApplicationSerializer.Create.to_application_model(user_id, application) + dataset_id_list = application.get('dataset_id_list', []) + application_dataset_mapping_model_list = [ + ApplicationSerializer.Create.to_application_dataset_mapping(application_model.id, dataset_id) for + dataset_id in dataset_id_list] + # 插入应用 + application_model.save() + # 插入认证信息 + ApplicationAccessToken(application_id=application_model.id, + access_token=hashlib.md5(str(uuid.uuid1()).encode()).hexdigest()[8:24]).save() + # 插入关联数据 + QuerySet(ApplicationDatasetMapping).bulk_create(application_dataset_mapping_model_list) + return ApplicationSerializerModel(application_model).data + + @staticmethod + def to_application_model(user_id: str, application: Dict): + return Application(id=uuid.uuid1(), name=application.get('name'), desc=application.get('desc'), + prologue=application.get('prologue'), + dialogue_number=application.get('dialogue_number', 0), + user_id=user_id, model_id=application.get('model_id'), + dataset_setting=application.get('dataset_setting'), + model_setting=application.get('model_setting'), + problem_optimization=application.get('problem_optimization'), + type=ApplicationTypeChoices.SIMPLE, + model_params_setting=application.get('model_params_setting', {}), + problem_optimization_prompt=application.get('problem_optimization_prompt', None), + work_flow={} + ) + + @staticmethod + def to_application_dataset_mapping(application_id: str, dataset_id: str): + return ApplicationDatasetMapping(id=uuid.uuid1(), application_id=application_id, dataset_id=dataset_id) + + class HitTest(serializers.Serializer): + id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("应用id")) + user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.uuid("用户id")) + query_text = serializers.CharField(required=True, error_messages=ErrMessage.char("查询文本")) + top_number = serializers.IntegerField(required=True, max_value=100, min_value=1, + error_messages=ErrMessage.integer("topN")) + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, + error_messages=ErrMessage.float("相关度")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Application).filter(id=self.data.get('id')).exists(): + raise AppApiException(500, '不存在的应用id') + + def hit_test(self): + self.is_valid() + vector = VectorStore.get_embedding_vector() + dataset_id_list = [ad.dataset_id for ad in + QuerySet(ApplicationDatasetMapping).filter( + application_id=self.data.get('id'))] + if len(dataset_id_list) == 0: + return [] + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)] + model = get_embedding_model_by_dataset_id_list(dataset_id_list) + # 向量库检索 + hit_list = vector.hit_test(self.data.get('query_text'), dataset_id_list, exclude_document_id_list, + self.data.get('top_number'), + self.data.get('similarity'), + SearchMode(self.data.get('search_mode')), + model) + hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) + p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) + return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), + 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list] + + class Query(serializers.Serializer): + name = serializers.CharField(required=False, error_messages=ErrMessage.char("应用名称")) + + desc = serializers.CharField(required=False, error_messages=ErrMessage.char("应用描述")) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def get_query_set(self): + user_id = self.data.get("user_id") + query_set_dict = {} + query_set = QuerySet(model=get_dynamics_model( + {'temp_application.name': models.CharField(), 'temp_application.desc': models.CharField(), + 'temp_application.create_time': models.DateTimeField()})) + if "desc" in self.data and self.data.get('desc') is not None: + query_set = query_set.filter(**{'temp_application.desc__icontains': self.data.get("desc")}) + if "name" in self.data and self.data.get('name') is not None: + query_set = query_set.filter(**{'temp_application.name__icontains': self.data.get("name")}) + query_set = query_set.order_by("-temp_application.create_time") + query_set_dict['default_sql'] = query_set + + query_set_dict['application_custom_sql'] = QuerySet(model=get_dynamics_model( + {'application.user_id': models.CharField(), + })).filter( + **{'application.user_id': user_id} + ) + + query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model( + {'user_id': models.CharField(), + 'team_member_permission.auth_target_type': models.CharField(), + 'team_member_permission.operate': ArrayField(verbose_name="权限操作列表", + base_field=models.CharField(max_length=256, + blank=True, + choices=AuthOperate.choices, + default=AuthOperate.USE) + )})).filter( + **{'user_id': user_id, 'team_member_permission.operate__contains': ['USE'], + 'team_member_permission.auth_target_type': 'APPLICATION'}) + + return query_set_dict + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return [ApplicationSerializer.Query.reset_application(a) for a in + native_search(self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application.sql')))] + + @staticmethod + def reset_application(application: Dict): + application['multiple_rounds_dialogue'] = True if application.get('dialogue_number') > 0 else False + + if 'dataset_setting' in application: + application['dataset_setting'] = {'search_mode': 'embedding', 'no_references_setting': { + 'status': 'ai_questioning', + 'value': '{question}'}, **application['dataset_setting']} + return application + + def page(self, current_page: int, page_size: int, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application.sql')), + post_records_handler=ApplicationSerializer.Query.reset_application) + + class ApplicationModel(serializers.ModelSerializer): + class Meta: + model = Application + fields = ['id', 'name', 'desc', 'prologue', 'dialogue_number', 'icon', 'type'] + + class IconOperate(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + image = UploadedImageField(required=True, error_messages=ErrMessage.image("图片")) + + def edit(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application = QuerySet(Application).filter(id=self.data.get('application_id')).first() + if application is None: + raise AppApiException(500, '不存在的应用id') + image_id = uuid.uuid1() + image = Image(id=image_id, image=self.data.get('image').read(), image_name=self.data.get('image').name) + image.save() + application.icon = f'/api/image/{image_id}' + application.save() + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=self.data.get('application_id')).first() + get_application_access_token(application_access_token.access_token, False) + return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data)} + + class Operate(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Application).filter(id=self.data.get('application_id')).exists(): + raise AppApiException(500, '不存在的应用id') + + def list_model(self, model_type=None, with_valid=True): + if with_valid: + self.is_valid() + if model_type is None: + model_type = "LLM" + application = QuerySet(Application).filter(id=self.data.get("application_id")).first() + return ModelSerializer.Query( + data={'user_id': application.user_id, 'model_type': model_type}).list( + with_valid=True) + + def list_function_lib(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application = QuerySet(Application).filter(id=self.data.get("application_id")).first() + return FunctionLibSerializer.Query(data={'user_id': application.user_id, 'is_active': True}).list( + with_valid=True) + + def get_function_lib(self, function_lib_id, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application = QuerySet(Application).filter(id=self.data.get("application_id")).first() + return FunctionLibSerializer.Operate(data={'user_id': application.user_id, 'id': function_lib_id}).one( + with_valid=True) + + def get_model_params_form(self, model_id, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application = QuerySet(Application).filter(id=self.data.get("application_id")).first() + return ModelSerializer.ModelParams( + data={'user_id': application.user_id, 'id': model_id}).get_model_params(with_valid=True) + + def delete(self, with_valid=True): + if with_valid: + self.is_valid() + QuerySet(Application).filter(id=self.data.get('application_id')).delete() + return True + + @transaction.atomic + def publish(self, instance, with_valid=True): + if with_valid: + self.is_valid() + user_id = self.data.get('user_id') + user = QuerySet(User).filter(id=user_id).first() + application = QuerySet(Application).filter(id=self.data.get("application_id")).first() + work_flow = instance.get('work_flow') + if work_flow is None: + raise AppApiException(500, "work_flow是必填字段") + Flow.new_instance(work_flow).is_valid() + base_node = get_base_node_work_flow(work_flow) + if base_node is not None: + node_data = base_node.get('properties').get('node_data') + if node_data is not None: + application.name = node_data.get('name') + application.desc = node_data.get('desc') + application.prologue = node_data.get('prologue') + dataset_list = self.list_dataset(with_valid=False) + application_dataset_id_list = [str(dataset.get('id')) for dataset in dataset_list] + dataset_id_list = self.update_reverse_search_node(work_flow, application_dataset_id_list) + application.work_flow = work_flow + application.save() + # 插入知识库关联关系 + self.save_application_mapping(application_dataset_id_list, dataset_id_list, application.id) + chat_cache.clear_by_application_id(str(application.id)) + work_flow_version = WorkFlowVersion(work_flow=work_flow, application=application, + name=datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'), + publish_user_id=user_id, + publish_user_name=user.username) + chat_cache.clear_by_application_id(str(application.id)) + work_flow_version.save() + return True + + def one(self, with_valid=True): + if with_valid: + self.is_valid() + application_id = self.data.get("application_id") + application = QuerySet(Application).get(id=application_id) + dataset_list = self.list_dataset(with_valid=False) + mapping_dataset_id_list = [adm.dataset_id for adm in + QuerySet(ApplicationDatasetMapping).filter(application_id=application_id)] + dataset_id_list = [d.get('id') for d in + list(filter(lambda row: mapping_dataset_id_list.__contains__(row.get('id')), + dataset_list))] + self.update_search_node(application.work_flow, [str(dataset.get('id')) for dataset in dataset_list]) + return {**ApplicationSerializer.Query.reset_application(ApplicationSerializerModel(application).data), + 'dataset_id_list': dataset_id_list} + + def get_search_node(self, work_flow): + if work_flow is None: + return [] + return [node for node in work_flow.get('nodes', []) if node.get('type', '') == 'search-dataset-node'] + + def update_search_node(self, work_flow, user_dataset_id_list: List): + search_node_list = self.get_search_node(work_flow) + for search_node in search_node_list: + node_data = search_node.get('properties', {}).get('node_data', {}) + dataset_id_list = node_data.get('dataset_id_list', []) + node_data['source_dataset_id_list'] = dataset_id_list + node_data['dataset_id_list'] = [dataset_id for dataset_id in dataset_id_list if + user_dataset_id_list.__contains__(dataset_id)] + + def update_reverse_search_node(self, work_flow, user_dataset_id_list: List): + search_node_list = self.get_search_node(work_flow) + result_dataset_id_list = [] + for search_node in search_node_list: + node_data = search_node.get('properties', {}).get('node_data', {}) + dataset_id_list = node_data.get('dataset_id_list', []) + for dataset_id in dataset_id_list: + if not user_dataset_id_list.__contains__(dataset_id): + raise AppApiException(500, f"未知的知识库id${dataset_id},无法关联") + + source_dataset_id_list = node_data.get('source_dataset_id_list', []) + source_dataset_id_list = [source_dataset_id for source_dataset_id in source_dataset_id_list if + not user_dataset_id_list.__contains__(source_dataset_id)] + source_dataset_id_list = list({*source_dataset_id_list, *dataset_id_list}) + node_data['source_dataset_id_list'] = [] + node_data['dataset_id_list'] = source_dataset_id_list + result_dataset_id_list = [*source_dataset_id_list, *result_dataset_id_list] + return list(set(result_dataset_id_list)) + + def profile(self, with_valid=True): + if with_valid: + self.is_valid() + application_id = self.data.get("application_id") + application = QuerySet(Application).get(id=application_id) + application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application.id).first() + if application_access_token is None: + raise AppUnauthorizedFailed(500, "非法用户") + application_setting_model = DBModelManage.get_model('application_setting') + xpack_cache = DBModelManage.get_model('xpack_cache') + X_PACK_LICENSE_IS_VALID = False if xpack_cache is None else xpack_cache.get('XPACK_LICENSE_IS_VALID', False) + application_setting_dict = {} + if application_setting_model is not None and X_PACK_LICENSE_IS_VALID: + application_setting = QuerySet(application_setting_model).filter( + application_id=application_access_token.application_id).first() + if application_setting is not None: + custom_theme = getattr(application_setting, 'custom_theme', {}) + float_location = getattr(application_setting, 'float_location', {}) + if not custom_theme: + application_setting.custom_theme = { + 'theme_color': '', + 'header_font_color': '' + } + if not float_location: + application_setting.float_location = { + 'x': {'type': '', 'value': ''}, + 'y': {'type': '', 'value': ''} + } + application_setting_dict = {'show_source': application_access_token.show_source, + 'show_history': application_setting.show_history, + 'draggable': application_setting.draggable, + 'show_guide': application_setting.show_guide, + 'avatar': application_setting.avatar, + 'float_icon': application_setting.float_icon, + 'authentication': application_setting.authentication, + 'authentication_type': application_setting.authentication_value.get( + 'type', 'password'), + 'disclaimer': application_setting.disclaimer, + 'disclaimer_value': application_setting.disclaimer_value, + 'custom_theme': application_setting.custom_theme, + 'user_avatar': application_setting.user_avatar, + 'float_location': application_setting.float_location} + return ApplicationSerializer.Query.reset_application( + {**ApplicationSerializer.ApplicationModel(application).data, + 'stt_model_id': application.stt_model_id, + 'tts_model_id': application.tts_model_id, + 'stt_model_enable': application.stt_model_enable, + 'tts_model_enable': application.tts_model_enable, + 'tts_type': application.tts_type, + 'work_flow': application.work_flow, + 'show_source': application_access_token.show_source, + **application_setting_dict}) + + @transaction.atomic + def edit(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid() + ApplicationSerializer.Edit(data=instance).is_valid( + raise_exception=True) + application_id = self.data.get("application_id") + + application = QuerySet(Application).get(id=application_id) + if instance.get('model_id') is None or len(instance.get('model_id')) == 0: + application.model_id = None + else: + model = QuerySet(Model).filter( + id=instance.get('model_id')).first() + if model is None: + raise AppApiException(500, "模型不存在") + if not model.is_permission(application.user_id): + raise AppApiException(500, f"沒有权限使用该模型:{model.name}") + if instance.get('stt_model_id') is None or len(instance.get('stt_model_id')) == 0: + application.stt_model_id = None + else: + model = QuerySet(Model).filter( + id=instance.get('stt_model_id')).first() + if model is None: + raise AppApiException(500, "模型不存在") + if not model.is_permission(application.user_id): + raise AppApiException(500, f"沒有权限使用该模型:{model.name}") + if instance.get('tts_model_id') is None or len(instance.get('tts_model_id')) == 0: + application.tts_model_id = None + else: + model = QuerySet(Model).filter( + id=instance.get('tts_model_id')).first() + if model is None: + raise AppApiException(500, "模型不存在") + if not model.is_permission(application.user_id): + raise AppApiException(500, f"沒有权限使用该模型:{model.name}") + if 'work_flow' in instance: + # 当前用户可修改关联的知识库列表 + application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in + self.list_dataset(with_valid=False)] + self.update_reverse_search_node(instance.get('work_flow'), application_dataset_id_list) + # 找到语音配置相关 + self.get_work_flow_model(instance) + + update_keys = ['name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'prologue', 'status', + 'dataset_setting', 'model_setting', 'problem_optimization', 'dialogue_number', + 'stt_model_id', 'tts_model_id', 'tts_model_enable', 'stt_model_enable', 'tts_type', + 'api_key_is_active', 'icon', 'work_flow', 'model_params_setting', 'tts_model_params_setting', + 'problem_optimization_prompt', 'clean_time'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + application.__setattr__(update_key, instance.get(update_key)) + application.save() + + if 'dataset_id_list' in instance: + dataset_id_list = instance.get('dataset_id_list') + # 当前用户可修改关联的知识库列表 + application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in + self.list_dataset(with_valid=False)] + for dataset_id in dataset_id_list: + if not application_dataset_id_list.__contains__(dataset_id): + raise AppApiException(500, f"未知的知识库id${dataset_id},无法关联") + + self.save_application_mapping(application_dataset_id_list, dataset_id_list, application_id) + if application.type == ApplicationTypeChoices.SIMPLE: + chat_cache.clear_by_application_id(application_id) + application_access_token = QuerySet(ApplicationAccessToken).filter(application_id=application_id).first() + # 更新缓存数据 + get_application_access_token(application_access_token.access_token, False) + return self.one(with_valid=False) + + @staticmethod + def save_application_mapping(application_dataset_id_list, dataset_id_list, application_id): + # 需要排除已删除的数据集 + dataset_id_list = [dataset.id for dataset in QuerySet(DataSet).filter(id__in=dataset_id_list)] + # 删除已经关联的id + QuerySet(ApplicationDatasetMapping).filter(dataset_id__in=application_dataset_id_list, + application_id=application_id).delete() + # 插入 + QuerySet(ApplicationDatasetMapping).bulk_create( + [ApplicationDatasetMapping(application_id=application_id, dataset_id=dataset_id) for dataset_id in + dataset_id_list]) if len(dataset_id_list) > 0 else None + + def list_dataset(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application = QuerySet(Application).get(id=self.data.get("application_id")) + return select_list(get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_dataset.sql')), + [self.data.get('user_id') if self.data.get('user_id') == str(application.user_id) else None, + application.user_id, self.data.get('user_id')]) + + @staticmethod + def get_work_flow_model(instance): + if 'nodes' not in instance.get('work_flow'): + return + nodes = instance.get('work_flow')['nodes'] + for node in nodes: + if node['id'] == 'base-node': + node_data = node['properties']['node_data'] + if 'stt_model_id' in node_data: + instance['stt_model_id'] = node_data['stt_model_id'] + if 'tts_model_id' in node_data: + instance['tts_model_id'] = node_data['tts_model_id'] + if 'stt_model_enable' in node_data: + instance['stt_model_enable'] = node_data['stt_model_enable'] + if 'tts_model_enable' in node_data: + instance['tts_model_enable'] = node_data['tts_model_enable'] + if 'tts_type' in node_data: + instance['tts_type'] = node_data['tts_type'] + if 'tts_model_params_setting' in node_data: + instance['tts_model_params_setting'] = node_data['tts_model_params_setting'] + break + + def speech_to_text(self, file, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + application = QuerySet(Application).filter(id=application_id).first() + if application.stt_model_enable: + model = get_model_instance_by_model_user_id(application.stt_model_id, application.user_id) + text = model.speech_to_text(file) + return text + + def text_to_speech(self, text, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + application = QuerySet(Application).filter(id=application_id).first() + if application.tts_model_enable: + model = get_model_instance_by_model_user_id(application.tts_model_id, application.user_id, + **application.tts_model_params_setting) + + return model.text_to_speech(text) + + def play_demo_text(self, form_data, with_valid=True): + text = '你好,这里是语音播放测试' + if with_valid: + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + application = QuerySet(Application).filter(id=application_id).first() + if 'tts_model_id' in form_data: + tts_model_id = form_data.pop('tts_model_id') + model = get_model_instance_by_model_user_id(tts_model_id, application.user_id, **form_data) + return model.text_to_speech(text) + + class ApplicationKeySerializerModel(serializers.ModelSerializer): + class Meta: + model = ApplicationApiKey + fields = "__all__" + + class ApplicationKeySerializer(serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + application_id = self.data.get("application_id") + application = QuerySet(Application).filter(id=application_id).first() + if application is None: + raise AppApiException(1001, "应用不存在") + + def generate(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application_id = self.data.get("application_id") + application = QuerySet(Application).filter(id=application_id).first() + secret_key = 'application-' + hashlib.md5(str(uuid.uuid1()).encode()).hexdigest() + application_api_key = ApplicationApiKey(id=uuid.uuid1(), secret_key=secret_key, user_id=application.user_id, + application_id=application_id) + application_api_key.save() + return ApplicationSerializer.ApplicationKeySerializerModel(application_api_key).data + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + application_id = self.data.get("application_id") + return [ApplicationSerializer.ApplicationKeySerializerModel(application_api_key).data for + application_api_key in + QuerySet(ApplicationApiKey).filter(application_id=application_id)] + + class Edit(serializers.Serializer): + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("是否可用")) + + allow_cross_domain = serializers.BooleanField(required=False, + error_messages=ErrMessage.boolean("是否允许跨域")) + + cross_domain_list = serializers.ListSerializer(required=False, + child=serializers.CharField(required=True, + error_messages=ErrMessage.char( + "跨域列表")), + error_messages=ErrMessage.char("跨域地址")) + + class Operate(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + + api_key_id = serializers.CharField(required=True, error_messages=ErrMessage.char("ApiKeyid")) + + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + api_key_id = self.data.get("api_key_id") + application_id = self.data.get('application_id') + application_api_key = QuerySet(ApplicationApiKey).filter(id=api_key_id, + application_id=application_id).first() + del_application_api_key(application_api_key.secret_key) + application_api_key.delete() + + def edit(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ApplicationSerializer.ApplicationKeySerializer.Edit(data=instance).is_valid(raise_exception=True) + api_key_id = self.data.get("api_key_id") + application_id = self.data.get('application_id') + application_api_key = QuerySet(ApplicationApiKey).filter(id=api_key_id, + application_id=application_id).first() + if application_api_key is None: + raise AppApiException(500, '不存在') + if 'is_active' in instance and instance.get('is_active') is not None: + application_api_key.is_active = instance.get('is_active') + if 'allow_cross_domain' in instance and instance.get('allow_cross_domain') is not None: + application_api_key.allow_cross_domain = instance.get('allow_cross_domain') + if 'cross_domain_list' in instance and instance.get('cross_domain_list') is not None: + application_api_key.cross_domain_list = instance.get('cross_domain_list') + application_api_key.save() + # 写入缓存 + get_application_api_key(application_api_key.secret_key, False) diff --git a/src/MaxKB-1.7.2/apps/application/serializers/application_statistics_serializers.py b/src/MaxKB-1.7.2/apps/application/serializers/application_statistics_serializers.py new file mode 100644 index 0000000..e958cb3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/serializers/application_statistics_serializers.py @@ -0,0 +1,128 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application_statistics_serializers.py + @date:2024/3/27 10:55 + @desc: +""" +import datetime +import os +from typing import List, Dict + +from django.db import models +from django.db.models.query import QuerySet +from rest_framework import serializers + +from application.models.api_key_model import ApplicationPublicAccessClient +from common.db.search import native_search, get_dynamics_model +from common.util.field_message import ErrMessage +from common.util.file_util import get_file_content +from smartdoc.conf import PROJECT_DIR + + +class ApplicationStatisticsSerializer(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("应用id")) + start_time = serializers.DateField(format='%Y-%m-%d', error_messages=ErrMessage.date("开始时间")) + end_time = serializers.DateField(format='%Y-%m-%d', error_messages=ErrMessage.date("结束时间")) + + def get_end_time(self): + return datetime.datetime.combine( + datetime.datetime.strptime(self.data.get('end_time'), '%Y-%m-%d'), + datetime.datetime.max.time()) + + def get_start_time(self): + return self.data.get('start_time') + + def get_customer_count(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + start_time = self.get_start_time() + end_time = self.get_end_time() + return native_search( + QuerySet(ApplicationPublicAccessClient).filter(application_id=self.data.get('application_id'), + create_time__gte=start_time, + create_time__lte=end_time), + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'customer_count.sql')), + with_search_one=True) + + def get_customer_count_trend(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + start_time = self.get_start_time() + end_time = self.get_end_time() + return native_search( + {'default_sql': QuerySet(ApplicationPublicAccessClient).filter( + application_id=self.data.get('application_id'), + create_time__gte=start_time, + create_time__lte=end_time)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'customer_count_trend.sql'))) + + def get_chat_record_aggregate(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + start_time = self.get_start_time() + end_time = self.get_end_time() + chat_record_aggregate = native_search( + QuerySet(model=get_dynamics_model( + {'application_chat.application_id': models.UUIDField(), + 'application_chat_record.create_time': models.DateTimeField()})).filter( + **{'application_chat.application_id': self.data.get('application_id'), + 'application_chat_record.create_time__gte': start_time, + 'application_chat_record.create_time__lte': end_time} + ), + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'chat_record_count.sql')), + with_search_one=True) + customer = self.get_customer_count(with_valid=False) + return {**chat_record_aggregate, **customer} + + def get_chat_record_aggregate_trend(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + start_time = self.get_start_time() + end_time = self.get_end_time() + chat_record_aggregate_trend = native_search( + {'default_sql': QuerySet(model=get_dynamics_model( + {'application_chat.application_id': models.UUIDField(), + 'application_chat_record.create_time': models.DateTimeField()})).filter( + **{'application_chat.application_id': self.data.get('application_id'), + 'application_chat_record.create_time__gte': start_time, + 'application_chat_record.create_time__lte': end_time} + )}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'chat_record_count_trend.sql'))) + customer_count_trend = self.get_customer_count_trend(with_valid=False) + return self.merge_customer_chat_record(chat_record_aggregate_trend, customer_count_trend) + + def merge_customer_chat_record(self, chat_record_aggregate_trend: List[Dict], customer_count_trend: List[Dict]): + + return [{**self.find(chat_record_aggregate_trend, lambda c: c.get('day').strftime('%Y-%m-%d') == day, + {'star_num': 0, 'trample_num': 0, 'tokens_num': 0, 'chat_record_count': 0, + 'customer_num': 0, + 'day': day}), + **self.find(customer_count_trend, lambda c: c.get('day').strftime('%Y-%m-%d') == day, + {'customer_added_count': 0})} + for + day in + self.get_days_between_dates(self.data.get('start_time'), self.data.get('end_time'))] + + @staticmethod + def find(source_list, condition, default): + value_list = [row for row in source_list if condition(row)] + if len(value_list) > 0: + return value_list[0] + return default + + @staticmethod + def get_days_between_dates(start_date, end_date): + start_date = datetime.datetime.strptime(start_date, '%Y-%m-%d') + end_date = datetime.datetime.strptime(end_date, '%Y-%m-%d') + days = [] + current_date = start_date + while current_date <= end_date: + days.append(current_date.strftime('%Y-%m-%d')) + current_date += datetime.timedelta(days=1) + return days diff --git a/src/MaxKB-1.7.2/apps/application/serializers/application_version_serializers.py b/src/MaxKB-1.7.2/apps/application/serializers/application_version_serializers.py new file mode 100644 index 0000000..1fc701d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/serializers/application_version_serializers.py @@ -0,0 +1,84 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: application_version_serializers.py + @date:2024/10/15 16:42 + @desc: +""" +from typing import Dict + +from django.db.models import QuerySet +from rest_framework import serializers + +from application.models import WorkFlowVersion +from common.db.search import page_search +from common.exception.app_exception import AppApiException +from common.util.field_message import ErrMessage + + +class ApplicationVersionModelSerializer(serializers.ModelSerializer): + class Meta: + model = WorkFlowVersion + fields = ['id', 'name', 'application_id', 'work_flow', 'publish_user_id', 'publish_user_name', 'create_time', + 'update_time'] + + +class ApplicationVersionEditSerializer(serializers.Serializer): + name = serializers.CharField(required=False, max_length=128, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("版本名称")) + + +class ApplicationVersionSerializer(serializers.Serializer): + class Query(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("应用id")) + name = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("摘要")) + + def get_query_set(self): + query_set = QuerySet(WorkFlowVersion).filter(application_id=self.data.get('application_id')) + if 'name' in self.data and self.data.get('name') is not None: + query_set = query_set.filter(name__contains=self.data.get('name')) + return query_set.order_by("-create_time") + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + query_set = self.get_query_set() + return [ApplicationVersionModelSerializer(v).data for v in query_set] + + def page(self, current_page, page_size, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return page_search(current_page, page_size, + self.get_query_set(), + post_records_handler=lambda v: ApplicationVersionModelSerializer(v).data) + + class Operate(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("应用id")) + work_flow_version_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("工作流版本id")) + + def one(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=self.data.get('application_id'), + id=self.data.get('work_flow_version_id')).first() + if work_flow_version is not None: + return ApplicationVersionModelSerializer(work_flow_version).data + else: + raise AppApiException(500, '不存在的工作流版本') + + def edit(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ApplicationVersionEditSerializer(data=instance).is_valid(raise_exception=True) + work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=self.data.get('application_id'), + id=self.data.get('work_flow_version_id')).first() + if work_flow_version is not None: + name = instance.get('name', None) + if name is not None and len(name) > 0: + work_flow_version.name = name + work_flow_version.save() + return ApplicationVersionModelSerializer(work_flow_version).data + else: + raise AppApiException(500, '不存在的工作流版本') diff --git a/src/MaxKB-1.7.2/apps/application/serializers/chat_message_serializers.py b/src/MaxKB-1.7.2/apps/application/serializers/chat_message_serializers.py new file mode 100644 index 0000000..488c244 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/serializers/chat_message_serializers.py @@ -0,0 +1,377 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: chat_message_serializers.py + @date:2023/11/14 13:51 + @desc: +""" +import uuid +from typing import List, Dict +from uuid import UUID + +from django.core.cache import caches +from django.db.models import QuerySet +from rest_framework import serializers + +from application.chat_pipeline.pipeline_manage import PipelineManage +from application.chat_pipeline.step.chat_step.i_chat_step import PostResponseHandler +from application.chat_pipeline.step.chat_step.impl.base_chat_step import BaseChatStep +from application.chat_pipeline.step.generate_human_message_step.impl.base_generate_human_message_step import \ + BaseGenerateHumanMessageStep +from application.chat_pipeline.step.reset_problem_step.impl.base_reset_problem_step import BaseResetProblemStep +from application.chat_pipeline.step.search_dataset_step.impl.base_search_dataset_step import BaseSearchDatasetStep +from application.flow.i_step_node import WorkFlowPostHandler +from application.flow.workflow_manage import WorkflowManage, Flow +from application.models import ChatRecord, Chat, Application, ApplicationDatasetMapping, ApplicationTypeChoices, \ + WorkFlowVersion +from application.models.api_key_model import ApplicationPublicAccessClient, ApplicationAccessToken +from common.constants.authentication_type import AuthenticationType +from common.exception.app_exception import AppChatNumOutOfBoundsFailed, ChatException +from common.handle.base_to_response import BaseToResponse +from common.handle.impl.response.openai_to_response import OpenaiToResponse +from common.handle.impl.response.system_to_response import SystemToResponse +from common.util.field_message import ErrMessage +from common.util.split_model import flat_map +from dataset.models import Paragraph, Document +from setting.models import Model, Status +from setting.models_provider import get_model_credential + +chat_cache = caches['chat_cache'] + + +class ChatInfo: + def __init__(self, + chat_id: str, + dataset_id_list: List[str], + exclude_document_id_list: list[str], + application: Application, + work_flow_version: WorkFlowVersion = None): + """ + :param chat_id: 对话id + :param dataset_id_list: 数据集列表 + :param exclude_document_id_list: 排除的文档 + :param application: 应用信息 + """ + self.chat_id = chat_id + self.application = application + self.dataset_id_list = dataset_id_list + self.exclude_document_id_list = exclude_document_id_list + self.chat_record_list: List[ChatRecord] = [] + self.work_flow_version = work_flow_version + + @staticmethod + def get_no_references_setting(dataset_setting, model_setting): + no_references_setting = dataset_setting.get( + 'no_references_setting', { + 'status': 'ai_questioning', + 'value': '{question}'}) + if no_references_setting.get('status') == 'ai_questioning': + no_references_prompt = model_setting.get('no_references_prompt', '{question}') + no_references_setting['value'] = no_references_prompt if len(no_references_prompt) > 0 else "{question}" + return no_references_setting + + def to_base_pipeline_manage_params(self): + dataset_setting = self.application.dataset_setting + model_setting = self.application.model_setting + model_id = self.application.model.id if self.application.model is not None else None + model_params_setting = None + if model_id is not None: + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting = credential.get_model_params_setting_form(model.model_name).get_default_form_data() + return { + 'dataset_id_list': self.dataset_id_list, + 'exclude_document_id_list': self.exclude_document_id_list, + 'exclude_paragraph_id_list': [], + 'top_n': dataset_setting.get('top_n') if 'top_n' in dataset_setting else 3, + 'similarity': dataset_setting.get('similarity') if 'similarity' in dataset_setting else 0.6, + 'max_paragraph_char_number': dataset_setting.get( + 'max_paragraph_char_number') if 'max_paragraph_char_number' in dataset_setting else 5000, + 'history_chat_record': self.chat_record_list, + 'chat_id': self.chat_id, + 'dialogue_number': self.application.dialogue_number, + 'problem_optimization_prompt': self.application.problem_optimization_prompt if self.application.problem_optimization_prompt is not None and len( + self.application.problem_optimization_prompt) > 0 else '()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中', + 'prompt': model_setting.get( + 'prompt') if 'prompt' in model_setting and len(model_setting.get( + 'prompt')) > 0 else Application.get_default_model_prompt(), + 'system': model_setting.get( + 'system', None), + 'model_id': model_id, + 'problem_optimization': self.application.problem_optimization, + 'stream': True, + 'model_params_setting': model_params_setting if self.application.model_params_setting is None or len( + self.application.model_params_setting.keys()) == 0 else self.application.model_params_setting, + 'search_mode': self.application.dataset_setting.get( + 'search_mode') if 'search_mode' in self.application.dataset_setting else 'embedding', + 'no_references_setting': self.get_no_references_setting(self.application.dataset_setting, model_setting), + 'user_id': self.application.user_id + } + + def to_pipeline_manage_params(self, problem_text: str, post_response_handler: PostResponseHandler, + exclude_paragraph_id_list, client_id: str, client_type, stream=True): + params = self.to_base_pipeline_manage_params() + return {**params, 'problem_text': problem_text, 'post_response_handler': post_response_handler, + 'exclude_paragraph_id_list': exclude_paragraph_id_list, 'stream': stream, 'client_id': client_id, + 'client_type': client_type} + + def append_chat_record(self, chat_record: ChatRecord, client_id=None): + chat_record.problem_text = chat_record.problem_text[0:10240] if chat_record.problem_text is not None else "" + chat_record.answer_text = chat_record.answer_text[0:40960] if chat_record.problem_text is not None else "" + # 存入缓存中 + self.chat_record_list.append(chat_record) + if self.application.id is not None: + # 插入数据库 + if not QuerySet(Chat).filter(id=self.chat_id).exists(): + Chat(id=self.chat_id, application_id=self.application.id, abstract=chat_record.problem_text[0:1024], + client_id=client_id).save() + # 插入会话记录 + chat_record.save() + + +def get_post_handler(chat_info: ChatInfo): + class PostHandler(PostResponseHandler): + + def handler(self, + chat_id: UUID, + chat_record_id, + paragraph_list: List[Paragraph], + problem_text: str, + answer_text, + manage: PipelineManage, + step: BaseChatStep, + padding_problem_text: str = None, + client_id=None, + **kwargs): + chat_record = ChatRecord(id=chat_record_id, + chat_id=chat_id, + problem_text=problem_text, + answer_text=answer_text, + details=manage.get_details(), + message_tokens=manage.context['message_tokens'], + answer_tokens=manage.context['answer_tokens'], + run_time=manage.context['run_time'], + index=len(chat_info.chat_record_list) + 1) + chat_info.append_chat_record(chat_record, client_id) + # 重新设置缓存 + chat_cache.set(chat_id, + chat_info, timeout=60 * 30) + + return PostHandler() + + +class OpenAIMessage(serializers.Serializer): + content = serializers.CharField(required=True, error_messages=ErrMessage.char('内容')) + role = serializers.CharField(required=True, error_messages=ErrMessage.char('角色')) + + +class OpenAIInstanceSerializer(serializers.Serializer): + messages = serializers.ListField(child=OpenAIMessage()) + chat_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char("对话id")) + re_chat = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("重新生成")) + stream = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("流式输出")) + + +class OpenAIChatSerializer(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) + client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) + + @staticmethod + def get_message(instance): + return instance.get('messages')[-1].get('content') + + @staticmethod + def generate_chat(chat_id, application_id, message, client_id): + if chat_id is None: + chat_id = str(uuid.uuid1()) + chat = QuerySet(Chat).filter(id=chat_id).first() + if chat is None: + Chat(id=chat_id, application_id=application_id, abstract=message[0:1024], client_id=client_id).save() + return chat_id + + def chat(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + OpenAIInstanceSerializer(data=instance).is_valid(raise_exception=True) + chat_id = instance.get('chat_id') + message = self.get_message(instance) + re_chat = instance.get('re_chat', False) + stream = instance.get('stream', False) + application_id = self.data.get('application_id') + client_id = self.data.get('client_id') + client_type = self.data.get('client_type') + chat_id = self.generate_chat(chat_id, application_id, message, client_id) + return ChatMessageSerializer( + data={'chat_id': chat_id, 'message': message, + 're_chat': re_chat, + 'stream': stream, + 'application_id': application_id, + 'client_id': client_id, + 'client_type': client_type, 'form_data': instance.get('form_data', {})}).chat( + base_to_response=OpenaiToResponse()) + + +class ChatMessageSerializer(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("对话id")) + message = serializers.CharField(required=True, error_messages=ErrMessage.char("用户问题")) + stream = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否流式回答")) + re_chat = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否重新回答")) + application_id = serializers.UUIDField(required=False, allow_null=True, error_messages=ErrMessage.uuid("应用id")) + client_id = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端id")) + client_type = serializers.CharField(required=True, error_messages=ErrMessage.char("客户端类型")) + form_data = serializers.DictField(required=False, error_messages=ErrMessage.char("全局变量")) + + def is_valid_application_workflow(self, *, raise_exception=False): + self.is_valid_intraday_access_num() + + def is_valid_chat_id(self, chat_info: ChatInfo): + if self.data.get('application_id') is not None and self.data.get('application_id') != str( + chat_info.application.id): + raise ChatException(500, "会话不存在") + + def is_valid_intraday_access_num(self): + if self.data.get('client_type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value: + access_client = QuerySet(ApplicationPublicAccessClient).filter(id=self.data.get('client_id')).first() + if access_client is None: + access_client = ApplicationPublicAccessClient(id=self.data.get('client_id'), + application_id=self.data.get('application_id'), + access_num=0, + intraday_access_num=0) + access_client.save() + + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=self.data.get('application_id')).first() + if application_access_token.access_num <= access_client.intraday_access_num: + raise AppChatNumOutOfBoundsFailed(1002, "访问次数超过今日访问量") + + def is_valid_application_simple(self, *, chat_info: ChatInfo, raise_exception=False): + self.is_valid_intraday_access_num() + model = chat_info.application.model + if model is None: + return chat_info + model = QuerySet(Model).filter(id=model.id).first() + if model is None: + return chat_info + if model.status == Status.ERROR: + raise ChatException(500, "当前模型不可用") + if model.status == Status.DOWNLOAD: + raise ChatException(500, "模型正在下载中,请稍后再发起对话") + return chat_info + + def chat_simple(self, chat_info: ChatInfo, base_to_response): + message = self.data.get('message') + re_chat = self.data.get('re_chat') + stream = self.data.get('stream') + client_id = self.data.get('client_id') + client_type = self.data.get('client_type') + pipeline_manage_builder = PipelineManage.builder() + # 如果开启了问题优化,则添加上问题优化步骤 + if chat_info.application.problem_optimization: + pipeline_manage_builder.append_step(BaseResetProblemStep) + # 构建流水线管理器 + pipeline_message = (pipeline_manage_builder.append_step(BaseSearchDatasetStep) + .append_step(BaseGenerateHumanMessageStep) + .append_step(BaseChatStep) + .add_base_to_response(base_to_response) + .build()) + exclude_paragraph_id_list = [] + # 相同问题是否需要排除已经查询到的段落 + if re_chat: + paragraph_id_list = flat_map( + [[paragraph.get('id') for paragraph in chat_record.details['search_step']['paragraph_list']] for + chat_record in chat_info.chat_record_list if + chat_record.problem_text == message and 'search_step' in chat_record.details and 'paragraph_list' in + chat_record.details['search_step']]) + exclude_paragraph_id_list = list(set(paragraph_id_list)) + # 构建运行参数 + params = chat_info.to_pipeline_manage_params(message, get_post_handler(chat_info), exclude_paragraph_id_list, + client_id, client_type, stream) + # 运行流水线作业 + pipeline_message.run(params) + return pipeline_message.context['chat_result'] + + def chat_work_flow(self, chat_info: ChatInfo, base_to_response): + message = self.data.get('message') + re_chat = self.data.get('re_chat') + stream = self.data.get('stream') + client_id = self.data.get('client_id') + client_type = self.data.get('client_type') + form_data = self.data.get('form_data') + user_id = chat_info.application.user_id + work_flow_manage = WorkflowManage(Flow.new_instance(chat_info.work_flow_version.work_flow), + {'history_chat_record': chat_info.chat_record_list, 'question': message, + 'chat_id': chat_info.chat_id, 'chat_record_id': str(uuid.uuid1()), + 'stream': stream, + 're_chat': re_chat, + 'user_id': user_id}, WorkFlowPostHandler(chat_info, client_id, client_type), + base_to_response, form_data) + r = work_flow_manage.run() + return r + + def chat(self, base_to_response: BaseToResponse = SystemToResponse()): + super().is_valid(raise_exception=True) + chat_info = self.get_chat_info() + self.is_valid_chat_id(chat_info) + if chat_info.application.type == ApplicationTypeChoices.SIMPLE: + self.is_valid_application_simple(raise_exception=True, chat_info=chat_info), + return self.chat_simple(chat_info, base_to_response) + else: + self.is_valid_application_workflow(raise_exception=True) + return self.chat_work_flow(chat_info, base_to_response) + + def get_chat_info(self): + self.is_valid(raise_exception=True) + chat_id = self.data.get('chat_id') + chat_info: ChatInfo = chat_cache.get(chat_id) + if chat_info is None: + chat_info: ChatInfo = self.re_open_chat(chat_id) + chat_cache.set(chat_id, + chat_info, timeout=60 * 30) + return chat_info + + def re_open_chat(self, chat_id: str): + chat = QuerySet(Chat).filter(id=chat_id).first() + if chat is None: + raise ChatException(500, "会话不存在") + application = QuerySet(Application).filter(id=chat.application_id).first() + if application is None: + raise ChatException(500, "应用不存在") + if application.type == ApplicationTypeChoices.SIMPLE: + return self.re_open_chat_simple(chat_id, application) + else: + return self.re_open_chat_work_flow(chat_id, application) + + @staticmethod + def re_open_chat_simple(chat_id, application): + # 数据集id列表 + dataset_id_list = [str(row.dataset_id) for row in + QuerySet(ApplicationDatasetMapping).filter( + application_id=application.id)] + + # 需要排除的文档 + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)] + chat_info = ChatInfo(chat_id, dataset_id_list, exclude_document_id_list, application) + chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5]) + chat_record_list.sort(key=lambda r: r.create_time) + for chat_record in chat_record_list: + chat_info.chat_record_list.append(chat_record) + return chat_info + + @staticmethod + def re_open_chat_work_flow(chat_id, application): + work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application.id).order_by( + '-create_time')[0:1].first() + if work_flow_version is None: + raise ChatException(500, "应用未发布,请发布后再使用") + + chat_info = ChatInfo(chat_id, [], [], application, work_flow_version) + chat_record_list = list(QuerySet(ChatRecord).filter(chat_id=chat_id).order_by('-create_time')[0:5]) + chat_record_list.sort(key=lambda r: r.create_time) + for chat_record in chat_record_list: + chat_info.chat_record_list.append(chat_record) + return chat_info diff --git a/src/MaxKB-1.7.2/apps/application/serializers/chat_serializers.py b/src/MaxKB-1.7.2/apps/application/serializers/chat_serializers.py new file mode 100644 index 0000000..b604fbf --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/serializers/chat_serializers.py @@ -0,0 +1,658 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: chat_serializers.py + @date:2023/11/14 9:59 + @desc: +""" +import datetime +import os +import re +import uuid +from functools import reduce +from io import BytesIO +from typing import Dict + +import openpyxl +from django.core import validators +from django.core.cache import caches +from django.db import transaction, models +from django.db.models import QuerySet, Q +from django.http import StreamingHttpResponse +from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE +from rest_framework import serializers + +from application.flow.workflow_manage import Flow +from application.models import Chat, Application, ApplicationDatasetMapping, VoteChoices, ChatRecord, WorkFlowVersion, \ + ApplicationTypeChoices +from application.models.api_key_model import ApplicationAccessToken +from application.serializers.application_serializers import ModelDatasetAssociation, DatasetSettingSerializer, \ + ModelSettingSerializer +from application.serializers.chat_message_serializers import ChatInfo +from common.constants.permission_constants import RoleConstants +from common.db.search import native_search, native_page_search, page_search, get_dynamics_model +from common.exception.app_exception import AppApiException +from common.util.common import post +from common.util.field_message import ErrMessage +from common.util.file_util import get_file_content +from common.util.lock import try_lock, un_lock +from dataset.models import Document, Problem, Paragraph, ProblemParagraphMapping +from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id +from dataset.serializers.paragraph_serializers import ParagraphSerializers +from embedding.task import embedding_by_paragraph +from setting.models import Model +from setting.models_provider import get_model_credential +from smartdoc.conf import PROJECT_DIR + +chat_cache = caches['chat_cache'] + + +class WorkFlowSerializers(serializers.Serializer): + nodes = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("节点")) + edges = serializers.ListSerializer(child=serializers.DictField(), error_messages=ErrMessage.uuid("连线")) + + +def valid_model_params_setting(model_id, model_params_setting): + if model_id is None: + return + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + model_params_setting_form = credential.get_model_params_setting_form(model.model_name) + if model_params_setting is None or len(model_params_setting.keys()) == 0: + model_params_setting = model_params_setting_form.get_default_form_data() + credential.get_model_params_setting_form(model.model_name).valid_form(model_params_setting) + + +class ChatSerializers(serializers.Serializer): + class Operate(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + + def logic_delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + QuerySet(Chat).filter(id=self.data.get('chat_id'), application_id=self.data.get('application_id')).update( + is_deleted=True) + return True + + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + QuerySet(Chat).filter(id=self.data.get('chat_id'), application_id=self.data.get('application_id')).delete() + return True + + class ClientChatHistory(serializers.Serializer): + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + client_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("客户端id")) + + def page(self, current_page: int, page_size: int, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + queryset = QuerySet(Chat).filter(client_id=self.data.get('client_id'), + application_id=self.data.get('application_id'), + is_deleted=False) + queryset = queryset.order_by('-create_time') + return page_search(current_page, page_size, queryset, lambda row: ChatSerializerModel(row).data) + + class Query(serializers.Serializer): + abstract = serializers.CharField(required=False, error_messages=ErrMessage.char("摘要")) + start_time = serializers.DateField(format='%Y-%m-%d', error_messages=ErrMessage.date("开始时间")) + end_time = serializers.DateField(format='%Y-%m-%d', error_messages=ErrMessage.date("结束时间")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + min_star = serializers.IntegerField(required=False, min_value=0, + error_messages=ErrMessage.integer("最小点赞数")) + min_trample = serializers.IntegerField(required=False, min_value=0, + error_messages=ErrMessage.integer("最小点踩数")) + comparer = serializers.CharField(required=False, error_messages=ErrMessage.char("比较器"), validators=[ + validators.RegexValidator(regex=re.compile("^and|or$"), + message="只支持and|or", code=500) + ]) + + def get_end_time(self): + return datetime.datetime.combine( + datetime.datetime.strptime(self.data.get('end_time'), '%Y-%m-%d'), + datetime.datetime.max.time()) + + def get_start_time(self): + return self.data.get('start_time') + + def get_query_set(self, select_ids=None): + end_time = self.get_end_time() + start_time = self.get_start_time() + query_set = QuerySet(model=get_dynamics_model( + {'application_chat.application_id': models.CharField(), + 'application_chat.abstract': models.CharField(), + "star_num": models.IntegerField(), + 'trample_num': models.IntegerField(), + 'comparer': models.CharField(), + 'application_chat.create_time': models.DateTimeField(), + 'application_chat.id': models.UUIDField(), })) + + base_query_dict = {'application_chat.application_id': self.data.get("application_id"), + 'application_chat.create_time__gte': start_time, + 'application_chat.create_time__lte': end_time, + } + if 'abstract' in self.data and self.data.get('abstract') is not None: + base_query_dict['application_chat.abstract__icontains'] = self.data.get('abstract') + + if select_ids is not None and len(select_ids) > 0: + base_query_dict['application_chat.id__in'] = select_ids + base_condition = Q(**base_query_dict) + min_star_query = None + min_trample_query = None + if 'min_star' in self.data and self.data.get('min_star') is not None: + min_star_query = Q(star_num__gte=self.data.get('min_star')) + if 'min_trample' in self.data and self.data.get('min_trample') is not None: + min_trample_query = Q(trample_num__gte=self.data.get('min_trample')) + if min_star_query is not None and min_trample_query is not None: + if self.data.get( + 'comparer') is not None and self.data.get('comparer') == 'or': + condition = base_condition & (min_star_query | min_trample_query) + else: + condition = base_condition & (min_star_query & min_trample_query) + elif min_star_query is not None: + condition = base_condition & min_star_query + elif min_trample_query is not None: + condition = base_condition & min_trample_query + else: + condition = base_condition + return query_set.filter(condition).order_by("-application_chat.create_time") + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return native_search(self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')), + with_table_name=False) + + @staticmethod + def paragraph_list_to_string(paragraph_list): + return "\n**********\n".join( + [f"{paragraph.get('title')}:\n{paragraph.get('content')}" for paragraph in + paragraph_list] if paragraph_list is not None else '') + + @staticmethod + def to_row(row: Dict): + details = row.get('details') + padding_problem_text = details.get('problem_padding').get( + 'padding_problem_text') if 'problem_padding' in details and 'padding_problem_text' in details.get( + 'problem_padding') else "" + search_dataset_node_list = [(key, node) for key, node in details.items() if + node.get("type") == 'search-dataset-node' or node.get( + "step_type") == 'search_step'] + reference_paragraph_len = '\n'.join([str(len(node.get('paragraph_list', + []))) if key == 'search_step' else node.get( + 'name') + ':' + str( + len(node.get('paragraph_list', [])) if node.get('paragraph_list', []) is not None else '0') for + key, node in search_dataset_node_list]) + reference_paragraph = '\n----------\n'.join( + [ChatSerializers.Query.paragraph_list_to_string(node.get('paragraph_list', + [])) if key == 'search_step' else node.get( + 'name') + ':\n' + ChatSerializers.Query.paragraph_list_to_string(node.get('paragraph_list', + [])) for + key, node in search_dataset_node_list]) + improve_paragraph_list = row.get('improve_paragraph_list') + vote_status_map = {'-1': '未投票', '0': '赞同', '1': '反对'} + return [str(row.get('chat_id')), row.get('abstract'), row.get('problem_text'), padding_problem_text, + row.get('answer_text'), vote_status_map.get(row.get('vote_status')), reference_paragraph_len, + reference_paragraph, + "\n".join([ + f"{improve_paragraph_list[index].get('title')}\n{improve_paragraph_list[index].get('content')}" + for index in range(len(improve_paragraph_list))]), + row.get('message_tokens') + row.get('answer_tokens'), row.get('run_time'), + str(row.get('create_time').strftime('%Y-%m-%d %H:%M:%S') + )] + + def export(self, data, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + + data_list = native_search(self.get_query_set(data.get('select_ids')), + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', + 'export_application_chat.sql')), + with_table_name=False) + + batch_size = 500 + + def stream_response(): + workbook = openpyxl.Workbook() + worksheet = workbook.active + worksheet.title = 'Sheet1' + + headers = ['会话ID', '摘要', '用户问题', '优化后问题', '回答', '用户反馈', '引用分段数', + '分段标题+内容', + '标注', '消耗tokens', '耗时(s)', '提问时间'] + for col_idx, header in enumerate(headers, 1): + cell = worksheet.cell(row=1, column=col_idx) + cell.value = header + + for i in range(0, len(data_list), batch_size): + batch_data = data_list[i:i + batch_size] + + for row_idx, row in enumerate(batch_data, start=i + 2): + for col_idx, value in enumerate(self.to_row(row), 1): + cell = worksheet.cell(row=row_idx, column=col_idx) + if isinstance(value, str): + value = re.sub(ILLEGAL_CHARACTERS_RE, '', value) + cell.value = value + + output = BytesIO() + workbook.save(output) + output.seek(0) + yield output.getvalue() + output.close() + workbook.close() + + response = StreamingHttpResponse(stream_response(), + content_type='application/vnd.open.xmlformats-officedocument.spreadsheetml.sheet') + response['Content-Disposition'] = 'attachment; filename="data.xlsx"' + return response + + def page(self, current_page: int, page_size: int, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "application", 'sql', 'list_application_chat.sql')), + with_table_name=False) + + class OpenChat(serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + user_id = self.data.get('user_id') + application_id = self.data.get('application_id') + if not QuerySet(Application).filter(id=application_id, user_id=user_id).exists(): + raise AppApiException(500, '应用不存在') + + def open(self): + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + application = QuerySet(Application).get(id=application_id) + if application.type == ApplicationTypeChoices.SIMPLE: + return self.open_simple(application) + else: + return self.open_work_flow(application) + + def open_work_flow(self, application): + self.is_valid(raise_exception=True) + application_id = self.data.get('application_id') + chat_id = str(uuid.uuid1()) + work_flow_version = QuerySet(WorkFlowVersion).filter(application_id=application_id).order_by( + '-create_time')[0:1].first() + if work_flow_version is None: + raise AppApiException(500, "应用未发布,请发布后再使用") + chat_cache.set(chat_id, + ChatInfo(chat_id, [], + [], + application, work_flow_version), timeout=60 * 30) + return chat_id + + def open_simple(self, application): + application_id = self.data.get('application_id') + dataset_id_list = [str(row.dataset_id) for row in + QuerySet(ApplicationDatasetMapping).filter( + application_id=application_id)] + chat_id = str(uuid.uuid1()) + chat_cache.set(chat_id, + ChatInfo(chat_id, dataset_id_list, + [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)], + application), timeout=60 * 30) + return chat_id + + class OpenWorkFlowChat(serializers.Serializer): + work_flow = WorkFlowSerializers(error_messages=ErrMessage.uuid("工作流")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def open(self): + self.is_valid(raise_exception=True) + work_flow = self.data.get('work_flow') + Flow.new_instance(work_flow).is_valid() + chat_id = str(uuid.uuid1()) + application = Application(id=None, dialogue_number=3, model=None, + dataset_setting={}, + model_setting={}, + problem_optimization=None, + type=ApplicationTypeChoices.WORK_FLOW, + user_id=self.data.get('user_id') + ) + work_flow_version = WorkFlowVersion(work_flow=work_flow) + chat_cache.set(chat_id, + ChatInfo(chat_id, [], + [], + application, work_flow_version), timeout=60 * 30) + return chat_id + + class OpenTempChat(serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + id = serializers.UUIDField(required=False, allow_null=True, + error_messages=ErrMessage.uuid("应用id")) + model_id = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.uuid("模型id")) + + multiple_rounds_dialogue = serializers.BooleanField(required=True, + error_messages=ErrMessage.boolean("多轮会话")) + + dataset_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.list("关联数据集")) + # 数据集相关设置 + dataset_setting = DatasetSettingSerializer(required=True) + # 模型相关设置 + model_setting = ModelSettingSerializer(required=True) + # 问题补全 + problem_optimization = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("问题补全")) + # 模型相关设置 + model_params_setting = serializers.JSONField(required=False, error_messages=ErrMessage.dict("模型参数相关设置")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + user_id = self.get_user_id() + ModelDatasetAssociation( + data={'user_id': user_id, 'model_id': self.data.get('model_id'), + 'dataset_id_list': self.data.get('dataset_id_list')}).is_valid() + return user_id + + def get_user_id(self): + if 'id' in self.data and self.data.get('id') is not None: + application = QuerySet(Application).filter(id=self.data.get('id')).first() + if application is None: + raise AppApiException(500, "应用不存在") + return application.user_id + return self.data.get('user_id') + + def open(self): + user_id = self.is_valid(raise_exception=True) + chat_id = str(uuid.uuid1()) + model_id = self.data.get('model_id') + dataset_id_list = self.data.get('dataset_id_list') + dialogue_number = 3 if self.data.get('multiple_rounds_dialogue', False) else 0 + application = Application(id=None, dialogue_number=dialogue_number, model_id=model_id, + dataset_setting=self.data.get('dataset_setting'), + model_setting=self.data.get('model_setting'), + problem_optimization=self.data.get('problem_optimization'), + model_params_setting=self.data.get('model_params_setting'), + user_id=user_id) + chat_cache.set(chat_id, + ChatInfo(chat_id, dataset_id_list, + [str(document.id) for document in + QuerySet(Document).filter( + dataset_id__in=dataset_id_list, + is_active=False)], + application), timeout=60 * 30) + return chat_id + + +class ChatRecordSerializerModel(serializers.ModelSerializer): + class Meta: + model = ChatRecord + fields = ['id', 'chat_id', 'vote_status', 'problem_text', 'answer_text', + 'message_tokens', 'answer_tokens', 'const', 'improve_paragraph_id_list', 'run_time', 'index', + 'create_time', 'update_time'] + + +class ChatSerializerModel(serializers.ModelSerializer): + class Meta: + model = Chat + fields = ['id', 'application_id', 'abstract', 'client_id'] + + +class ChatRecordSerializer(serializers.Serializer): + class Operate(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) + application_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("应用id")) + chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) + + def is_valid(self, *, current_role=None, raise_exception=False): + super().is_valid(raise_exception=True) + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=self.data.get('application_id')).first() + if application_access_token is None: + raise AppApiException(500, '不存在的应用认证信息') + if not application_access_token.show_source and current_role == RoleConstants.APPLICATION_ACCESS_TOKEN.value: + raise AppApiException(500, '未开启显示知识来源') + + def get_chat_record(self): + chat_record_id = self.data.get('chat_record_id') + chat_id = self.data.get('chat_id') + chat_info: ChatInfo = chat_cache.get(chat_id) + if chat_info is not None: + chat_record_list = [chat_record for chat_record in chat_info.chat_record_list if + str(chat_record.id) == str(chat_record_id)] + if chat_record_list is not None and len(chat_record_list): + return chat_record_list[-1] + return QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() + + def one(self, current_role: RoleConstants, with_valid=True): + if with_valid: + self.is_valid(current_role=current_role, raise_exception=True) + chat_record = self.get_chat_record() + if chat_record is None: + raise AppApiException(500, "对话不存在") + return ChatRecordSerializer.Query.reset_chat_record(chat_record) + + class Query(serializers.Serializer): + application_id = serializers.UUIDField(required=True) + chat_id = serializers.UUIDField(required=True) + order_asc = serializers.BooleanField(required=False) + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')) + order_by = 'create_time' if self.data.get('order_asc') is None or self.data.get( + 'order_asc') else '-create_time' + return [ChatRecordSerializerModel(chat_record).data for chat_record in + QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by(order_by)] + + @staticmethod + def reset_chat_record(chat_record): + dataset_list = [] + paragraph_list = [] + if 'search_step' in chat_record.details and chat_record.details.get('search_step').get( + 'paragraph_list') is not None: + paragraph_list = chat_record.details.get('search_step').get( + 'paragraph_list') + dataset_list = [{'id': dataset_id, 'name': name} for dataset_id, name in reduce(lambda x, y: {**x, **y}, + [{row.get( + 'dataset_id'): row.get( + "dataset_name")} for + row in + paragraph_list], + {}).items()] + + return { + **ChatRecordSerializerModel(chat_record).data, + 'padding_problem_text': chat_record.details.get('problem_padding').get( + 'padding_problem_text') if 'problem_padding' in chat_record.details else None, + 'dataset_list': dataset_list, + 'paragraph_list': paragraph_list, + 'execution_details': [chat_record.details[key] for key in chat_record.details] + } + + def page(self, current_page: int, page_size: int, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + order_by = '-create_time' if self.data.get('order_asc') is None or self.data.get( + 'order_asc') else 'create_time' + page = page_search(current_page, page_size, + QuerySet(ChatRecord).filter(chat_id=self.data.get('chat_id')).order_by(order_by), + post_records_handler=lambda chat_record: self.reset_chat_record(chat_record)) + return page + + class Vote(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) + + chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) + + vote_status = serializers.ChoiceField(choices=VoteChoices.choices, error_messages=ErrMessage.uuid("投标状态")) + + @transaction.atomic + def vote(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + if not try_lock(self.data.get('chat_record_id')): + raise AppApiException(500, "正在对当前会话纪要进行投票中,请勿重复发送请求") + try: + chat_record_details_model = QuerySet(ChatRecord).get(id=self.data.get('chat_record_id'), + chat_id=self.data.get('chat_id')) + if chat_record_details_model is None: + raise AppApiException(500, "不存在的对话 chat_record_id") + vote_status = self.data.get("vote_status") + if chat_record_details_model.vote_status == VoteChoices.UN_VOTE: + if vote_status == VoteChoices.STAR: + # 点赞 + chat_record_details_model.vote_status = VoteChoices.STAR + + if vote_status == VoteChoices.TRAMPLE: + # 点踩 + chat_record_details_model.vote_status = VoteChoices.TRAMPLE + chat_record_details_model.save() + else: + if vote_status == VoteChoices.UN_VOTE: + # 取消点赞 + chat_record_details_model.vote_status = VoteChoices.UN_VOTE + chat_record_details_model.save() + else: + raise AppApiException(500, "已经投票过,请先取消后再进行投票") + finally: + un_lock(self.data.get('chat_record_id')) + return True + + class ImproveSerializer(serializers.Serializer): + title = serializers.CharField(required=False, max_length=256, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("段落标题")) + content = serializers.CharField(required=True, error_messages=ErrMessage.char("段落内容")) + + problem_text = serializers.CharField(required=False, max_length=256, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("问题")) + + class ParagraphModel(serializers.ModelSerializer): + class Meta: + model = Paragraph + fields = "__all__" + + class ChatRecordImprove(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) + + chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) + + def get(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + chat_record_id = self.data.get('chat_record_id') + chat_id = self.data.get('chat_id') + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() + if chat_record is None: + raise AppApiException(500, '不存在的对话记录') + if chat_record.improve_paragraph_id_list is None or len(chat_record.improve_paragraph_id_list) == 0: + return [] + + paragraph_model_list = QuerySet(Paragraph).filter(id__in=chat_record.improve_paragraph_id_list) + if len(paragraph_model_list) < len(chat_record.improve_paragraph_id_list): + paragraph_model_id_list = [str(p.id) for p in paragraph_model_list] + chat_record.improve_paragraph_id_list = list( + filter(lambda p_id: paragraph_model_id_list.__contains__(p_id), + chat_record.improve_paragraph_id_list)) + chat_record.save() + return [ChatRecordSerializer.ParagraphModel(p).data for p in paragraph_model_list] + + class Improve(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) + + chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) + + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Document).filter(id=self.data.get('document_id'), + dataset_id=self.data.get('dataset_id')).exists(): + raise AppApiException(500, "文档id不正确") + + @staticmethod + def post_embedding_paragraph(chat_record, paragraph_id, dataset_id): + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + # 发送向量化事件 + embedding_by_paragraph(paragraph_id, model_id) + return chat_record + + @post(post_function=post_embedding_paragraph) + @transaction.atomic + def improve(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + ChatRecordSerializer.ImproveSerializer(data=instance).is_valid(raise_exception=True) + chat_record_id = self.data.get('chat_record_id') + chat_id = self.data.get('chat_id') + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() + if chat_record is None: + raise AppApiException(500, '不存在的对话记录') + + document_id = self.data.get("document_id") + dataset_id = self.data.get("dataset_id") + paragraph = Paragraph(id=uuid.uuid1(), + document_id=document_id, + content=instance.get("content"), + dataset_id=dataset_id, + title=instance.get("title") if 'title' in instance else '') + problem_text = instance.get('problem_text') if instance.get( + 'problem_text') is not None else chat_record.problem_text + problem = Problem(id=uuid.uuid1(), content=problem_text, dataset_id=dataset_id) + problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(), dataset_id=dataset_id, + document_id=document_id, + problem_id=problem.id, + paragraph_id=paragraph.id) + # 插入问题 + problem.save() + # 插入段落 + paragraph.save() + # 插入关联问题 + problem_paragraph_mapping.save() + chat_record.improve_paragraph_id_list.append(paragraph.id) + # 添加标注 + chat_record.save() + return ChatRecordSerializerModel(chat_record).data, paragraph.id, dataset_id + + class Operate(serializers.Serializer): + chat_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话id")) + + chat_record_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("对话记录id")) + + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) + + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + + chat_record_id = self.data.get('chat_record_id') + chat_id = self.data.get('chat_id') + dataset_id = self.data.get('dataset_id') + document_id = self.data.get('document_id') + paragraph_id = self.data.get('paragraph_id') + chat_record = QuerySet(ChatRecord).filter(id=chat_record_id, chat_id=chat_id).first() + if chat_record is None: + raise AppApiException(500, '不存在的对话记录') + if not chat_record.improve_paragraph_id_list.__contains__(uuid.UUID(paragraph_id)): + raise AppApiException(500, f'段落id错误,当前对话记录不存在【{paragraph_id}】段落id') + chat_record.improve_paragraph_id_list = [row for row in chat_record.improve_paragraph_id_list if + str(row) != paragraph_id] + chat_record.save() + o = ParagraphSerializers.Operate( + data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id}) + o.is_valid(raise_exception=True) + return o.delete() diff --git a/src/MaxKB-1.7.2/apps/application/sql/chat_record_count.sql b/src/MaxKB-1.7.2/apps/application/sql/chat_record_count.sql new file mode 100644 index 0000000..0cdbfb9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/sql/chat_record_count.sql @@ -0,0 +1,9 @@ +SELECT SUM + ( CASE WHEN application_chat_record.vote_status = '0' THEN 1 ELSE 0 END ) AS "star_num", + SUM ( CASE WHEN application_chat_record.vote_status = '1' THEN 1 ELSE 0 END ) AS "trample_num", + SUM ( application_chat_record.message_tokens + application_chat_record.answer_tokens ) as "tokens_num", + "count"(DISTINCT application_chat.client_id) customer_num, + "count"(application_chat_record."id") as chat_record_count +FROM + application_chat_record application_chat_record + LEFT JOIN application_chat application_chat ON application_chat."id" = application_chat_record.chat_id \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/sql/chat_record_count_trend.sql b/src/MaxKB-1.7.2/apps/application/sql/chat_record_count_trend.sql new file mode 100644 index 0000000..b4499f0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/sql/chat_record_count_trend.sql @@ -0,0 +1,12 @@ +SELECT SUM + ( CASE WHEN application_chat_record.vote_status = '0' THEN 1 ELSE 0 END ) AS "star_num", + SUM ( CASE WHEN application_chat_record.vote_status = '1' THEN 1 ELSE 0 END ) AS "trample_num", + SUM ( application_chat_record.message_tokens + application_chat_record.answer_tokens ) as "tokens_num", + "count"(application_chat_record."id") as chat_record_count, + "count"(DISTINCT application_chat.client_id) customer_num, + application_chat_record.create_time :: DATE as "day" +FROM + application_chat_record application_chat_record + LEFT JOIN application_chat application_chat ON application_chat."id" = application_chat_record.chat_id +${default_sql} +GROUP BY "day" \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/sql/customer_count.sql b/src/MaxKB-1.7.2/apps/application/sql/customer_count.sql new file mode 100644 index 0000000..0e40908 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/sql/customer_count.sql @@ -0,0 +1,5 @@ +SELECT + ( SUM ( CASE WHEN create_time :: DATE = CURRENT_DATE THEN 1 ELSE 0 END ) ) AS "customer_today_added_count", + COUNT ( "application_public_access_client"."id" ) AS "customer_added_count" +FROM + "application_public_access_client" \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/sql/customer_count_trend.sql b/src/MaxKB-1.7.2/apps/application/sql/customer_count_trend.sql new file mode 100644 index 0000000..159cd03 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/sql/customer_count_trend.sql @@ -0,0 +1,7 @@ +SELECT + COUNT ( "application_public_access_client"."id" ) AS "customer_added_count", + create_time :: DATE as "day" +FROM + "application_public_access_client" +${default_sql} +GROUP BY "day" \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/sql/export_application_chat.sql b/src/MaxKB-1.7.2/apps/application/sql/export_application_chat.sql new file mode 100644 index 0000000..dc58084 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/sql/export_application_chat.sql @@ -0,0 +1,37 @@ +SELECT + application_chat."id" as chat_id, + application_chat.abstract as abstract, + application_chat_record_temp.problem_text as problem_text, + application_chat_record_temp.answer_text as answer_text, + application_chat_record_temp.message_tokens as message_tokens, + application_chat_record_temp.answer_tokens as answer_tokens, + application_chat_record_temp.run_time as run_time, + application_chat_record_temp.details::JSON as details, + application_chat_record_temp."index" as "index", + application_chat_record_temp.improve_paragraph_list as improve_paragraph_list, + application_chat_record_temp.vote_status as vote_status, + application_chat_record_temp.create_time as create_time +FROM + application_chat application_chat + LEFT JOIN ( + SELECT COUNT + ( "id" ) AS chat_record_count, + SUM ( CASE WHEN "vote_status" = '0' THEN 1 ELSE 0 END ) AS star_num, + SUM ( CASE WHEN "vote_status" = '1' THEN 1 ELSE 0 END ) AS trample_num, + SUM ( CASE WHEN array_length( application_chat_record.improve_paragraph_id_list, 1 ) IS NULL THEN 0 ELSE array_length( application_chat_record.improve_paragraph_id_list, 1 ) END ) AS mark_sum, + chat_id + FROM + application_chat_record + GROUP BY + application_chat_record.chat_id + ) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id + LEFT JOIN ( + SELECT + *, + CASE + WHEN array_length( application_chat_record.improve_paragraph_id_list, 1 ) IS NULL THEN + '{}' ELSE ( SELECT ARRAY_AGG ( row_to_json ( paragraph ) ) FROM paragraph WHERE "id" = ANY ( application_chat_record.improve_paragraph_id_list ) ) + END as improve_paragraph_list + FROM + application_chat_record application_chat_record + ) application_chat_record_temp ON application_chat_record_temp.chat_id = application_chat."id" \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/sql/list_application.sql b/src/MaxKB-1.7.2/apps/application/sql/list_application.sql new file mode 100644 index 0000000..4a4cde5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/sql/list_application.sql @@ -0,0 +1,8 @@ +SELECT *,to_json(dataset_setting) as dataset_setting,to_json(model_setting) as model_setting,to_json(work_flow) as work_flow FROM ( SELECT * FROM application ${application_custom_sql} UNION + SELECT + * + FROM + application + WHERE + application."id" IN ( SELECT team_member_permission.target FROM team_member team_member LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" ${team_member_permission_custom_sql}) + ) temp_application ${default_sql} \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/sql/list_application_chat.sql b/src/MaxKB-1.7.2/apps/application/sql/list_application_chat.sql new file mode 100644 index 0000000..bf269d0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/sql/list_application_chat.sql @@ -0,0 +1,16 @@ +SELECT + * +FROM + application_chat application_chat + LEFT JOIN ( + SELECT COUNT + ( "id" ) AS chat_record_count, + SUM ( CASE WHEN "vote_status" = '0' THEN 1 ELSE 0 END ) AS star_num, + SUM ( CASE WHEN "vote_status" = '1' THEN 1 ELSE 0 END ) AS trample_num, + SUM ( CASE WHEN array_length( application_chat_record.improve_paragraph_id_list, 1 ) IS NULL THEN 0 ELSE array_length( application_chat_record.improve_paragraph_id_list, 1 ) END ) AS mark_sum, + chat_id + FROM + application_chat_record + GROUP BY + application_chat_record.chat_id + ) chat_record_temp ON application_chat."id" = chat_record_temp.chat_id \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/sql/list_application_dataset.sql b/src/MaxKB-1.7.2/apps/application/sql/list_application_dataset.sql new file mode 100644 index 0000000..691036f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/sql/list_application_dataset.sql @@ -0,0 +1,20 @@ +SELECT + * +FROM + dataset +WHERE + user_id = %s UNION +SELECT + * +FROM + dataset +WHERE + "id" IN ( + SELECT + team_member_permission.target + FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" + WHERE + ( "team_member_permission"."auth_target_type" = 'DATASET' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s ) + ) \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql b/src/MaxKB-1.7.2/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql new file mode 100644 index 0000000..2bacd53 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/sql/list_dataset_paragraph_by_paragraph_id.sql @@ -0,0 +1,10 @@ +SELECT + paragraph.*, + dataset."name" AS "dataset_name", + "document"."name" AS "document_name", + "document"."hit_handling_method" AS "hit_handling_method", + "document"."directly_return_similarity" as "directly_return_similarity" +FROM + paragraph paragraph + LEFT JOIN dataset dataset ON dataset."id" = paragraph.dataset_id + LEFT JOIN "document" "document" ON "document"."id" =paragraph.document_id \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/application/swagger_api/application_api.py b/src/MaxKB-1.7.2/apps/application/swagger_api/application_api.py new file mode 100644 index 0000000..d05fbb0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/swagger_api/application_api.py @@ -0,0 +1,348 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application_key.py + @date:2023/11/7 10:50 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ApplicationApi(ApiMixin): + class EditApplicationIcon(ApiMixin): + @staticmethod + def get_request_params_api(): + return [ + openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_FILE, + required=True, + description='上传文件') + ] + + class Authentication(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['access_token', ], + properties={ + 'access_token': openapi.Schema(type=openapi.TYPE_STRING, title="应用认证token", + description="应用认证token"), + + } + ) + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'model_id', 'dialogue_number', 'user_id', 'status', 'create_time', + 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="", description="主键id"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), + "dialogue_number": openapi.Schema(type=openapi.TYPE_NUMBER, title="多轮对话次数", + description="多轮对话次数"), + 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), + 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="示例列表", description="示例列表"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户", description="所属用户"), + + 'status': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否发布", description='是否发布'), + + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'), + + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间'), + + 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title="关联知识库Id列表", + description="关联知识库Id列表(查询详情的时候返回)") + } + ) + + class Model(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='model_type', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='模型类型'), + ] + + class ApiKey(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id') + + ] + + class Operate(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='api_key_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用api_key id') + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否激活", + description="是否激活"), + 'allow_cross_domain': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否允许跨域", + description="是否允许跨域"), + 'cross_domain_list': openapi.Schema(type=openapi.TYPE_ARRAY, title='跨域列表', + items=openapi.Schema(type=openapi.TYPE_STRING)) + } + ) + + class AccessToken(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id') + + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'access_token_reset': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重置Token", + description="重置Token"), + + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否激活", description="是否激活"), + 'access_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="访问次数", description="访问次数"), + 'white_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启白名单", + description="是否开启白名单"), + 'white_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), title="白名单列表", + description="白名单列表"), + 'show_source': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否显示知识来源", + description="是否显示知识来源"), + } + ) + + class Edit(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), + "dialogue_number": openapi.Schema(type=openapi.TYPE_NUMBER, title="多轮对话次数", + description="多轮对话次数"), + 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), + 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title="关联知识库Id列表", description="关联知识库Id列表"), + 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(), + 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(), + 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", + description="是否开启问题优化", default=True), + 'icon': openapi.Schema(type=openapi.TYPE_STRING, title="icon", + description="icon", default="/ui/favicon.ico"), + 'type': openapi.Schema(type=openapi.TYPE_STRING, title="应用类型", + description="应用类型 简易:SIMPLE|工作流:WORK_FLOW"), + 'work_flow': ApplicationApi.WorkFlow.get_request_body_api(), + 'problem_optimization_prompt': openapi.Schema(type=openapi.TYPE_STRING, title='问题优化提示词', + description="问题优化提示词", + default="()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中"), + 'tts_model_id': openapi.Schema(type=openapi.TYPE_STRING, title="文字转语音模型ID", + description="文字转语音模型ID"), + 'stt_model_id': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字模型id", + description="语音转文字模型id"), + 'stt_model_enable': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字是否开启", + description="语音转文字是否开启"), + 'tts_model_enable': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字是否开启", + description="语音转文字是否开启"), + 'tts_type': openapi.Schema(type=openapi.TYPE_STRING, title="文字转语音类型", + description="文字转语音类型") + + } + ) + + class WorkFlow(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[''], + properties={ + 'nodes': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_OBJECT), + title="节点列表", description="节点列表", + default=[]), + 'edges': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_OBJECT), + title='连线列表', description="连线列表", + default={}), + + } + ) + + class DatasetSetting(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[''], + properties={ + 'top_n': openapi.Schema(type=openapi.TYPE_NUMBER, title="引用分段数", description="引用分段数", + default=5), + 'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title='相似度', description="相似度", + default=0.6), + 'max_paragraph_char_number': openapi.Schema(type=openapi.TYPE_NUMBER, title='最多引用字符数', + description="最多引用字符数", default=3000), + 'search_mode': openapi.Schema(type=openapi.TYPE_STRING, title='检索模式', + description="embedding|keywords|blend", default='embedding'), + 'no_references_setting': openapi.Schema(type=openapi.TYPE_OBJECT, title='检索模式', + required=['status', 'value'], + properties={ + 'status': openapi.Schema(type=openapi.TYPE_STRING, + title="状态", + description="ai作答:ai_questioning,指定回答:designated_answer", + default='ai_questioning'), + 'value': openapi.Schema(type=openapi.TYPE_STRING, + title="值", + description="ai作答:就是题词,指定回答:就是指定回答内容", + default='{question}'), + }), + } + ) + + class ModelSetting(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['prompt'], + properties={ + 'prompt': openapi.Schema(type=openapi.TYPE_STRING, title="提示词", description="提示词", + default=('已知信息:' + '\n{data}' + '\n回答要求:' + '\n- 如果你不知道答案或者没有从获取答案,请回答“没有在知识库中查找到相关信息,建议咨询相关技术支持或参考官方文档进行操作”。' + '\n- 避免提及你是从中获得的知识。' + '\n- 请保持答案与中描述的一致。' + '\n- 请使用markdown 语法优化答案的格式。' + '\n- 中的图片链接、链接地址和脚本语言请完整返回。' + '\n- 请使用与问题相同的语言来回答。' + '\n问题:' + '\n{question}')), + + 'system': openapi.Schema(type=openapi.TYPE_STRING, title="系统提示词(角色)", + description="系统提示词(角色)"), + 'no_references_prompt': openapi.Schema(type=openapi.TYPE_STRING, title="无引用分段提示词", + default="{question}", description="无引用分段提示词") + + } + ) + + class Publish(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'work_flow': ApplicationApi.WorkFlow.get_request_body_api() + } + ) + + class Create(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'desc', 'model_id', 'dialogue_number', 'dataset_setting', 'model_setting', + 'problem_optimization', 'stt_model_enable', 'stt_model_enable', 'tts_type'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), + "dialogue_number": openapi.Schema(type=openapi.TYPE_NUMBER, title="多轮对话次数", + description="多轮对话次数"), + 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), + 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title="关联知识库Id列表", description="关联知识库Id列表"), + 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(), + 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(), + 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", + description="是否开启问题优化", default=True), + 'type': openapi.Schema(type=openapi.TYPE_STRING, title="应用类型", + description="应用类型 简易:SIMPLE|工作流:WORK_FLOW"), + 'problem_optimization_prompt': openapi.Schema(type=openapi.TYPE_STRING, title='问题优化提示词', + description="问题优化提示词", + default="()里面是用户问题,根据上下文回答揣测用户问题({question}) 要求: 输出一个补全问题,并且放在标签中"), + 'tts_model_id': openapi.Schema(type=openapi.TYPE_STRING, title="文字转语音模型ID", + description="文字转语音模型ID"), + 'stt_model_id': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字模型id", + description="语音转文字模型id"), + 'stt_model_enable': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字是否开启", + description="语音转文字是否开启"), + 'tts_model_enable': openapi.Schema(type=openapi.TYPE_STRING, title="语音转文字是否开启", + description="语音转文字是否开启"), + 'tts_type': openapi.Schema(type=openapi.TYPE_STRING, title="文字转语音类型", + description="文字转语音类型") + } + ) + + class Query(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='应用名称'), + openapi.Parameter(name='desc', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='应用描述') + ] + + class Operate(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + + ] diff --git a/src/MaxKB-1.7.2/apps/application/swagger_api/application_statistics_api.py b/src/MaxKB-1.7.2/apps/application/swagger_api/application_statistics_api.py new file mode 100644 index 0000000..87fde10 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/swagger_api/application_statistics_api.py @@ -0,0 +1,86 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application_statistics_api.py + @date:2024/3/27 15:09 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ApplicationStatisticsApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='start_time', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='开始时间'), + openapi.Parameter(name='end_time', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='结束时间'), + ] + + class ChatRecordAggregate(ApiMixin): + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['star_num', 'trample_num', 'tokens_num', 'chat_record_count'], + properties={ + 'star_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="点赞数量", + description="点赞数量"), + + 'trample_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="点踩数量", description="点踩数量"), + 'tokens_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="token使用数量", + description="token使用数量"), + 'chat_record_count': openapi.Schema(type=openapi.TYPE_NUMBER, title="对话次数", + description="对话次数"), + 'customer_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="客户数量", + description="客户数量"), + 'customer_added_count': openapi.Schema(type=openapi.TYPE_NUMBER, title="客户新增数量", + description="客户新增数量"), + 'day': openapi.Schema(type=openapi.TYPE_STRING, + title="日期", + description="日期,只有查询趋势的时候才有该字段"), + } + ) + + class CustomerCountTrend(ApiMixin): + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['added_count'], + properties={ + 'added_count': openapi.Schema(type=openapi.TYPE_NUMBER, title="新增数量", description="新增数量"), + + 'day': openapi.Schema(type=openapi.TYPE_STRING, + title="时间", + description="时间"), + } + ) + + class CustomerCount(ApiMixin): + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['added_count'], + properties={ + 'today_added_count': openapi.Schema(type=openapi.TYPE_NUMBER, title="今日新增数量", + description="今日新增数量"), + 'added_count': openapi.Schema(type=openapi.TYPE_NUMBER, title="新增数量", description="新增数量"), + + } + ) diff --git a/src/MaxKB-1.7.2/apps/application/swagger_api/application_version_api.py b/src/MaxKB-1.7.2/apps/application/swagger_api/application_version_api.py new file mode 100644 index 0000000..5335e1a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/swagger_api/application_version_api.py @@ -0,0 +1,69 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: application_version_api.py + @date:2024/10/15 17:18 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ApplicationVersionApi(ApiMixin): + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'work_flow', 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_NUMBER, title="主键id", + description="主键id"), + 'name': openapi.Schema(type=openapi.TYPE_NUMBER, title="版本名称", + description="版本名称"), + 'work_flow': openapi.Schema(type=openapi.TYPE_STRING, title="工作流数据", description='工作流数据'), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间') + } + ) + + class Query(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='版本名称')] + + class Operate(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='work_flow_version_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用版本id'), ] + + class Edit(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="版本名称", + description="版本名称") + } + ) diff --git a/src/MaxKB-1.7.2/apps/application/swagger_api/chat_api.py b/src/MaxKB-1.7.2/apps/application/swagger_api/chat_api.py new file mode 100644 index 0000000..cc2a500 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/swagger_api/chat_api.py @@ -0,0 +1,358 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: chat_api.py + @date:2023/11/7 17:29 + @desc: +""" +from drf_yasg import openapi + +from application.swagger_api.application_api import ApplicationApi +from common.mixins.api_mixin import ApiMixin + + +class ChatClientHistoryApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id') + ] + + +class OpenAIChatApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + required=['message'], + properties={ + 'messages': openapi.Schema(type=openapi.TYPE_ARRAY, title="问题", description="问题", + items=openapi.Schema(type=openapi.TYPE_OBJECT, + required=['role', 'content'], + properties={ + 'content': openapi.Schema( + type=openapi.TYPE_STRING, + title="问题内容", default=''), + 'role': openapi.Schema( + type=openapi.TYPE_STRING, + title='角色', default="user") + } + )), + 'chat_id': openapi.Schema(type=openapi.TYPE_STRING, title="对话id"), + 're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=False), + 'stream': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="流式输出", default=True) + + }) + + +class ChatApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['message'], + properties={ + 'message': openapi.Schema(type=openapi.TYPE_STRING, title="问题", description="问题"), + 're_chat': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=False), + 'stream': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="重新生成", default=True) + + } + ) + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'application', 'abstract', 'chat_record_count', 'mark_sum', 'star_num', 'trample_num', + 'update_time', 'create_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'application_id': openapi.Schema(type=openapi.TYPE_STRING, title="应用id", + description="应用id", default='应用id'), + 'abstract': openapi.Schema(type=openapi.TYPE_STRING, title="摘要", + description="摘要", default='摘要'), + 'chat_id': openapi.Schema(type=openapi.TYPE_STRING, title="对话id", + description="对话id", default="对话id"), + 'chat_record_count': openapi.Schema(type=openapi.TYPE_STRING, title="对话提问数量", + description="对话提问数量", + default="对话提问数量"), + 'mark_sum': openapi.Schema(type=openapi.TYPE_STRING, title="标记数量", + description="标记数量", default=1), + 'star_num': openapi.Schema(type=openapi.TYPE_STRING, title="点赞数量", + description="点赞数量", default=1), + 'trample_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="点踩数量", + description="点踩数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + class OpenChat(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + + ] + + class OpenWorkFlowTemp(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'work_flow': ApplicationApi.WorkFlow.get_request_body_api() + } + ) + + class OpenTempChat(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['model_id', 'multiple_rounds_dialogue', 'dataset_setting', 'model_setting', + 'problem_optimization'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="应用id", + description="应用id,修改的时候传,创建的时候不传"), + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), + 'dataset_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title="关联知识库Id列表", description="关联知识库Id列表"), + 'multiple_rounds_dialogue': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮会话", + description="是否开启多轮会话"), + 'dataset_setting': ApplicationApi.DatasetSetting.get_request_body_api(), + 'model_setting': ApplicationApi.ModelSetting.get_request_body_api(), + 'problem_optimization': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="问题优化", + description="是否开启问题优化", default=True) + } + ) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='history_day', + in_=openapi.IN_QUERY, + type=openapi.TYPE_NUMBER, + required=True, + description='历史天数'), + openapi.Parameter(name='abstract', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, + description="摘要"), + openapi.Parameter(name='min_star', in_=openapi.IN_QUERY, type=openapi.TYPE_INTEGER, required=False, + description="最小点赞数"), + openapi.Parameter(name='min_trample', in_=openapi.IN_QUERY, type=openapi.TYPE_INTEGER, required=False, + description="最小点踩数"), + openapi.Parameter(name='comparer', in_=openapi.IN_QUERY, type=openapi.TYPE_STRING, required=False, + description="or|and 比较器") + ] + + +class ChatRecordApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='chat_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='对话id'), + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'chat', 'vote_status', 'dataset', 'paragraph', 'source_id', 'source_type', + 'message_tokens', 'answer_tokens', + 'problem_text', 'answer_text', 'improve_paragraph_id_list'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'chat': openapi.Schema(type=openapi.TYPE_STRING, title="会话日志id", + description="会话日志id", default='会话日志id'), + 'vote_status': openapi.Schema(type=openapi.TYPE_STRING, title="投票状态", + description="投票状态", default="投票状态"), + 'dataset': openapi.Schema(type=openapi.TYPE_STRING, title="数据集id", description="数据集id", + default="数据集id"), + 'paragraph': openapi.Schema(type=openapi.TYPE_STRING, title="段落id", + description="段落id", default=1), + 'source_id': openapi.Schema(type=openapi.TYPE_STRING, title="资源id", + description="资源id", default=1), + 'source_type': openapi.Schema(type=openapi.TYPE_STRING, title="资源类型", + description="资源类型", default='xxx'), + 'message_tokens': openapi.Schema(type=openapi.TYPE_INTEGER, title="问题消耗token数量", + description="问题消耗token数量", default=0), + 'answer_tokens': openapi.Schema(type=openapi.TYPE_INTEGER, title="答案消耗token数量", + description="答案消耗token数量", default=0), + 'improve_paragraph_id_list': openapi.Schema(type=openapi.TYPE_STRING, title="改进标注列表", + description="改进标注列表", + default=[]), + 'index': openapi.Schema(type=openapi.TYPE_STRING, title="对应会话 对应下标", + description="对应会话id对应下标", + default="对应会话id对应下标" + ), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + +class ImproveApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='chat_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话id'), + openapi.Parameter(name='chat_record_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话记录id'), + openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['content'], + properties={ + 'title': openapi.Schema(type=openapi.TYPE_STRING, title="段落标题", + description="段落标题"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容", + description="段落内容") + + } + ) + + +class VoteApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='chat_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话id'), + openapi.Parameter(name='chat_record_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话记录id') + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['vote_status'], + properties={ + 'vote_status': openapi.Schema(type=openapi.TYPE_STRING, title="投票状态", + description="-1:取消投票|0:赞同|1:反对"), + + } + ) + + +class ChatRecordImproveApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return [openapi.Parameter(name='application_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='应用id'), + openapi.Parameter(name='chat_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话id'), + openapi.Parameter(name='chat_record_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='会话记录id') + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id', + 'document_id', 'title', + 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容", + description="段落内容", default='段落内容'), + 'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题", + description="标题", default="xxx的描述"), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量", + description="点赞数量", default=1), + 'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量", + description="点踩数", default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id", + description="知识库id", default='xxx'), + 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", + description="文档id", default='xxx'), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", + description="是否可用", default=True), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) diff --git a/src/MaxKB-1.7.2/apps/application/task/__init__.py b/src/MaxKB-1.7.2/apps/application/task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/application/template/embed.js b/src/MaxKB-1.7.2/apps/application/template/embed.js new file mode 100644 index 0000000..e3d9f18 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/template/embed.js @@ -0,0 +1,308 @@ +const guideHtml=` +
+
+
+
+
+ + + +
+ +
🌟 遇见问题,不再有障碍!
+

你好,我是你的智能小助手。
+ 点我,开启高效解答模式,让问题变成过去式。

+
+ +
+ +
+` +const chatButtonHtml= +`
+ +
` + + + +const getChatContainerHtml=(protocol,host,token,query)=>{ + return `
+ +
+ +
+
+ + +
+
+ + +
+` +} +/** + * 初始化引导 + * @param {*} root + */ +const initGuide=(root)=>{ + root.insertAdjacentHTML("beforeend",guideHtml) + const button=root.querySelector(".maxkb-button") + const close_icon=root.querySelector('.maxkb-close') + const close_func=()=>{ + root.removeChild(root.querySelector('.maxkb-tips')) + root.removeChild(root.querySelector('.maxkb-mask')) + localStorage.setItem('maxkbMaskTip',true) + } + button.onclick=close_func + close_icon.onclick=close_func +} +const initChat=(root)=>{ + // 添加对话icon + root.insertAdjacentHTML("beforeend",chatButtonHtml) + // 添加对话框 + root.insertAdjacentHTML('beforeend',getChatContainerHtml('{{protocol}}','{{host}}','{{token}}','{{query}}')) + // 按钮元素 + const chat_button=root.querySelector('.maxkb-chat-button') + const chat_button_img=root.querySelector('.maxkb-chat-button > img') + // 对话框元素 + const chat_container=root.querySelector('#maxkb-chat-container') + + const viewport=root.querySelector('.maxkb-openviewport') + const closeviewport=root.querySelector('.maxkb-closeviewport') + const close_func=()=>{ + chat_container.style['display']=chat_container.style['display']=='block'?'none':'block' + chat_button.style['display']=chat_container.style['display']=='block'?'none':'block' + } + close_icon=chat_container.querySelector('.maxkb-chat-close') + chat_button.onclick = close_func + close_icon.onclick=close_func + const viewport_func=()=>{ + if(chat_container.classList.contains('maxkb-enlarge')){ + chat_container.classList.remove("maxkb-enlarge"); + viewport.classList.remove('maxkb-viewportnone') + closeviewport.classList.add('maxkb-viewportnone') + }else{ + chat_container.classList.add("maxkb-enlarge"); + viewport.classList.add('maxkb-viewportnone') + closeviewport.classList.remove('maxkb-viewportnone') + } + } + const drag=(e)=>{ + if (['touchmove','touchstart'].includes(e.type)) { + chat_button.style.top=(e.touches[0].clientY-25)+'px' + chat_button.style.left=(e.touches[0].clientX-25)+'px' + } else { + chat_button.style.top=(e.y-25)+'px' + chat_button.style.left=(e.x-25)+'px' + } + chat_button.style.width =chat_button_img.naturalWidth+'px' + chat_button.style.height =chat_button_img.naturalHeight+'px' + } + if({{is_draggable}}){ + console.dir(chat_button_img) + chat_button.addEventListener("drag",drag) + chat_button.addEventListener("dragover",(e)=>{ + e.preventDefault() + }) + chat_button.addEventListener("dragend",drag) + chat_button.addEventListener("touchstart",drag) + chat_button.addEventListener("touchmove",drag) + } + viewport.onclick=viewport_func + closeviewport.onclick=viewport_func +} +/** + * 第一次进来的引导提示 + */ +function initMaxkb(){ + const maxkb=document.createElement('div') + const root=document.createElement('div') + root.id="maxkb" + initMaxkbStyle(maxkb) + maxkb.appendChild(root) + document.body.appendChild(maxkb) + const maxkbMaskTip=localStorage.getItem('maxkbMaskTip') + if(maxkbMaskTip==null && {{show_guide}}){ + initGuide(root) + } + initChat(root) +} + + +// 初始化全局样式 +function initMaxkbStyle(root){ + style=document.createElement('style') + style.type='text/css' + style.innerText= ` + /* 放大 */ + #maxkb .maxkb-enlarge { + width: 50%!important; + height: 100%!important; + bottom: 0!important; + right: 0 !important; + } + @media only screen and (max-width: 768px){ + #maxkb .maxkb-enlarge { + width: 100%!important; + height: 100%!important; + right: 0 !important; + bottom: 0!important; + } + } + + /* 引导 */ + + #maxkb .maxkb-mask { + position: fixed; + z-index: 999; + background-color: transparent; + height: 100%; + width: 100%; + top: 0; + left: 0; + } + #maxkb .maxkb-mask .maxkb-content { + width: 64px; + height: 64px; + box-shadow: 1px 1px 1px 2000px rgba(0,0,0,.6); + position: absolute; + {{x_type}}: {{x_value}}px; + {{y_type}}: {{y_value}}px; + z-index: 1000; + } + #maxkb .maxkb-tips { + position: fixed; + {{x_type}}:calc({{x_value}}px + 75px); + {{y_type}}: calc({{y_value}}px + 0px); + padding: 22px 24px 24px; + border-radius: 6px; + color: #ffffff; + font-size: 14px; + background: #3370FF; + z-index: 1000; + } + #maxkb .maxkb-tips .maxkb-arrow { + position: absolute; + background: #3370FF; + width: 10px; + height: 10px; + pointer-events: none; + transform: rotate(45deg); + box-sizing: border-box; + /* left */ + {{x_type}}: -5px; + {{y_type}}: 33px; + border-left-color: transparent; + border-bottom-color: transparent + } + #maxkb .maxkb-tips .maxkb-title { + font-size: 20px; + font-weight: 500; + margin-bottom: 8px; + } + #maxkb .maxkb-tips .maxkb-button { + text-align: right; + margin-top: 24px; + } + #maxkb .maxkb-tips .maxkb-button button { + border-radius: 4px; + background: #FFF; + padding: 3px 12px; + color: #3370FF; + cursor: pointer; + outline: none; + border: none; + } + #maxkb .maxkb-tips .maxkb-button button::after{ + border: none; + } + #maxkb .maxkb-tips .maxkb-close { + position: absolute; + right: 20px; + top: 20px; + cursor: pointer; + + } + #maxkb-chat-container { + width: 450px; + height: 600px; + display:none; + } + @media only screen and (max-width: 768px) { + #maxkb-chat-container { + width: 100%; + height: 70%; + right: 0 !important; + } + } + + #maxkb .maxkb-chat-button{ + position: fixed; + {{x_type}}: {{x_value}}px; + {{y_type}}: {{y_value}}px; + cursor: pointer; + max-height:500px; + max-width:500px; + } + #maxkb #maxkb-chat-container{ + z-index:10000;position: relative; + border-radius: 8px; + border: 1px solid #ffffff; + background: linear-gradient(188deg, rgba(235, 241, 255, 0.20) 39.6%, rgba(231, 249, 255, 0.20) 94.3%), #EFF0F1; + box-shadow: 0px 4px 8px 0px rgba(31, 35, 41, 0.10); + position: fixed;bottom: 16px;right: 16px;overflow: hidden; + } + + #maxkb #maxkb-chat-container .maxkb-operate{ + top: 18px; + right: 15px; + position: absolute; + display: flex; + align-items: center; + } + #maxkb #maxkb-chat-container .maxkb-operate .maxkb-chat-close{ + margin-left:15px; + cursor: pointer; + } + #maxkb #maxkb-chat-container .maxkb-operate .maxkb-openviewport{ + + cursor: pointer; + } + #maxkb #maxkb-chat-container .maxkb-operate .maxkb-closeviewport{ + + cursor: pointer; + } + #maxkb #maxkb-chat-container .maxkb-viewportnone{ + display:none; + } + #maxkb #maxkb-chat-container #maxkb-chat{ + height:100%; + width:100%; + border: none; +} + #maxkb #maxkb-chat-container { + animation: appear .4s ease-in-out; + } + @keyframes appear { + from { + height: 0;; + } + + to { + height: 600px; + } + }` + root.appendChild(style) +} + +function embedChatbot() { + white_list_str='{{white_list_str}}' + white_list=white_list_str.split(',') + + if ({{is_auth}}&&({{white_active}}?white_list.includes(window.location.origin):true)) { + // 初始化maxkb智能小助手 + initMaxkb() + } else console.error('invalid parameter') +} +window.onload = embedChatbot diff --git a/src/MaxKB-1.7.2/apps/application/tests.py b/src/MaxKB-1.7.2/apps/application/tests.py new file mode 100644 index 0000000..7ce503c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/src/MaxKB-1.7.2/apps/application/urls.py b/src/MaxKB-1.7.2/apps/application/urls.py new file mode 100644 index 0000000..b3df23d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/urls.py @@ -0,0 +1,81 @@ +from django.urls import path + +from . import views + +app_name = "application" +urlpatterns = [ + path('application', views.Application.as_view(), name="application"), + path('application/profile', views.Application.Profile.as_view(), name='application/profile'), + path('application/embed', views.Application.Embed.as_view()), + path('application/authentication', views.Application.Authentication.as_view()), + path('application//publish', views.Application.Publish.as_view()), + path('application//edit_icon', views.Application.EditIcon.as_view()), + path('application//statistics/customer_count', + views.ApplicationStatistics.CustomerCount.as_view()), + path('application//statistics/customer_count_trend', + views.ApplicationStatistics.CustomerCountTrend.as_view()), + path('application//statistics/chat_record_aggregate', + views.ApplicationStatistics.ChatRecordAggregate.as_view()), + path('application//statistics/chat_record_aggregate_trend', + views.ApplicationStatistics.ChatRecordAggregateTrend.as_view()), + path('application//model', views.Application.Model.as_view()), + path('application//function_lib', views.Application.FunctionLib.as_view()), + path('application//function_lib/', + views.Application.FunctionLib.Operate.as_view()), + path('application//model_params_form/', + views.Application.ModelParamsForm.as_view()), + path('application//hit_test', views.Application.HitTest.as_view()), + path('application//api_key', views.Application.ApplicationKey.as_view()), + path("application//api_key/", + views.Application.ApplicationKey.Operate.as_view()), + path('application/', views.Application.Operate.as_view(), name='application/operate'), + path('application//list_dataset', views.Application.ListApplicationDataSet.as_view(), + name='application/dataset'), + path('application//access_token', views.Application.AccessToken.as_view(), + name='application/access_token'), + path('application//', views.Application.Page.as_view(), name='application_page'), + path('application//chat/open', views.ChatView.Open.as_view(), name='application/open'), + path("application/chat/open", views.ChatView.OpenTemp.as_view()), + path("application/chat_workflow/open", views.ChatView.OpenWorkFlowTemp.as_view()), + path("application//chat/client//", + views.ChatView.ClientChatHistoryPage.as_view()), + path("application//chat/client/", + views.ChatView.ClientChatHistoryPage.Operate.as_view()), + path('application//chat/export', views.ChatView.Export.as_view(), name='export'), + path('application//chat/completions', views.Openai.as_view(), + name='application/chat_completions'), + path('application//chat', views.ChatView.as_view(), name='chats'), + path('application//chat//', views.ChatView.Page.as_view()), + path('application//chat/', views.ChatView.Operate.as_view()), + path('application//chat//chat_record/', views.ChatView.ChatRecord.as_view()), + path('application//chat//chat_record//', + views.ChatView.ChatRecord.Page.as_view()), + path('application//chat//chat_record/', + views.ChatView.ChatRecord.Operate.as_view()), + path('application//chat//chat_record//vote', + views.ChatView.ChatRecord.Vote.as_view(), + name=''), + path( + 'application//chat//chat_record//dataset//document_id//improve', + views.ChatView.ChatRecord.Improve.as_view(), + name=''), + path('application//chat//chat_record//improve', + views.ChatView.ChatRecord.ChatRecordImprove.as_view()), + path('application/chat_message/', views.ChatView.Message.as_view(), name='application/message'), + path( + 'application//chat//chat_record//dataset//document_id//improve/', + views.ChatView.ChatRecord.Improve.Operate.as_view(), + name=''), + path('application//speech_to_text', views.Application.SpeechToText.as_view(), + name='application/audio'), + path('application//text_to_speech', views.Application.TextToSpeech.as_view(), + name='application/audio'), + path('application//work_flow_version', views.ApplicationVersionView.as_view()), + path('application//work_flow_version//', + views.ApplicationVersionView.Page.as_view()), + path('application//work_flow_version/', + views.ApplicationVersionView.Operate.as_view()), + path('application//play_demo_text', views.Application.PlayDemoText.as_view(), + name='application/audio') + +] diff --git a/src/MaxKB-1.7.2/apps/application/views/__init__.py b/src/MaxKB-1.7.2/apps/application/views/__init__.py new file mode 100644 index 0000000..24569c1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/views/__init__.py @@ -0,0 +1,11 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/9/25 17:12 + @desc: +""" +from .application_views import * +from .chat_views import * +from .application_version_views import * diff --git a/src/MaxKB-1.7.2/apps/application/views/application_version_views.py b/src/MaxKB-1.7.2/apps/application/views/application_version_views.py new file mode 100644 index 0000000..105f280 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/views/application_version_views.py @@ -0,0 +1,89 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: application_version_views.py + @date:2024/10/15 16:49 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.views import APIView + +from application.serializers.application_version_serializers import ApplicationVersionSerializer +from application.swagger_api.application_version_api import ApplicationVersionApi +from common.auth import has_permissions, TokenAuth +from common.constants.permission_constants import PermissionConstants, CompareConstants, ViewPermission, RoleConstants, \ + Permission, Group, Operate +from common.response import result + + +class ApplicationVersionView(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用列表", + operation_id="获取应用列表", + manual_parameters=ApplicationVersionApi.Query.get_request_params_api(), + responses=result.get_api_array_response(ApplicationVersionApi.get_response_body_api()), + tags=['应用/版本']) + @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationVersionSerializer.Query( + data={'name': request.query_params.get('name'), 'user_id': request.user.id, + 'application_id': application_id}).list()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取应用版本列表", + operation_id="分页获取应用版本列表", + manual_parameters=result.get_page_request_params( + ApplicationVersionApi.Query.get_request_params_api()), + responses=result.get_page_api_response(ApplicationVersionApi.get_response_body_api()), + tags=['应用/版本']) + @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + def get(self, request: Request, application_id: str, current_page: int, page_size: int): + return result.success( + ApplicationVersionSerializer.Query( + data={'name': request.query_params.get('name'), 'user_id': request.user, + 'application_id': application_id}).page( + current_page, page_size)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用版本详情", + operation_id="获取应用版本详情", + manual_parameters=ApplicationVersionApi.Operate.get_request_params_api(), + responses=result.get_api_response(ApplicationVersionApi.get_response_body_api()), + tags=['应用/版本']) + @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + def get(self, request: Request, application_id: str, work_flow_version_id: str): + return result.success( + ApplicationVersionSerializer.Operate( + data={'user_id': request.user, + 'application_id': application_id, 'work_flow_version_id': work_flow_version_id}).one()) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改应用版本信息", + operation_id="修改应用版本信息", + manual_parameters=ApplicationVersionApi.Operate.get_request_params_api(), + request_body=ApplicationVersionApi.Edit.get_request_body_api(), + responses=result.get_api_response(ApplicationVersionApi.get_response_body_api()), + tags=['应用/版本']) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def put(self, request: Request, application_id: str, work_flow_version_id: str): + return result.success( + ApplicationVersionSerializer.Operate( + data={'application_id': application_id, 'work_flow_version_id': work_flow_version_id, + 'user_id': request.user.id}).edit( + request.data)) diff --git a/src/MaxKB-1.7.2/apps/application/views/application_views.py b/src/MaxKB-1.7.2/apps/application/views/application_views.py new file mode 100644 index 0000000..64b6c36 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/views/application_views.py @@ -0,0 +1,589 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: application_views.py + @date:2023/10/27 14:56 + @desc: +""" + +from django.core import cache +from django.http import HttpResponse +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.request import Request +from rest_framework.views import APIView + +from application.serializers.application_serializers import ApplicationSerializer +from application.serializers.application_statistics_serializers import ApplicationStatisticsSerializer +from application.swagger_api.application_api import ApplicationApi +from application.swagger_api.application_statistics_api import ApplicationStatisticsApi +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import CompareConstants, PermissionConstants, Permission, Group, Operate, \ + ViewPermission, RoleConstants +from common.exception.app_exception import AppAuthenticationFailed +from common.response import result +from common.swagger_api.common_api import CommonApi +from common.util.common import query_params_to_single_dict +from dataset.serializers.dataset_serializers import DataSetSerializers +from setting.swagger_api.provide_api import ProvideApi + +chat_cache = cache.caches['chat_cache'] + + +class ApplicationStatistics(APIView): + class CustomerCount(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="用户统计", + operation_id="用户统计", + tags=["应用/统计"], + manual_parameters=ApplicationStatisticsApi.get_request_params_api(), + responses=result.get_api_response( + ApplicationStatisticsApi.CustomerCount.get_response_body_api()) + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationStatisticsSerializer(data={'application_id': application_id, + 'start_time': request.query_params.get( + 'start_time'), + 'end_time': request.query_params.get( + 'end_time') + }).get_customer_count()) + + class CustomerCountTrend(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="用户统计趋势", + operation_id="用户统计趋势", + tags=["应用/统计"], + manual_parameters=ApplicationStatisticsApi.get_request_params_api(), + responses=result.get_api_array_response( + ApplicationStatisticsApi.CustomerCountTrend.get_response_body_api())) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationStatisticsSerializer(data={'application_id': application_id, + 'start_time': request.query_params.get( + 'start_time'), + 'end_time': request.query_params.get( + 'end_time') + }).get_customer_count_trend()) + + class ChatRecordAggregate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="对话相关统计", + operation_id="对话相关统计", + tags=["应用/统计"], + manual_parameters=ApplicationStatisticsApi.get_request_params_api(), + responses=result.get_api_response( + ApplicationStatisticsApi.ChatRecordAggregate.get_response_body_api()) + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationStatisticsSerializer(data={'application_id': application_id, + 'start_time': request.query_params.get( + 'start_time'), + 'end_time': request.query_params.get( + 'end_time') + }).get_chat_record_aggregate()) + + class ChatRecordAggregateTrend(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="对话相关统计趋势", + operation_id="对话相关统计趋势", + tags=["应用/统计"], + manual_parameters=ApplicationStatisticsApi.get_request_params_api(), + responses=result.get_api_array_response( + ApplicationStatisticsApi.ChatRecordAggregate.get_response_body_api()) + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationStatisticsSerializer(data={'application_id': application_id, + 'start_time': request.query_params.get( + 'start_time'), + 'end_time': request.query_params.get( + 'end_time') + }).get_chat_record_aggregate_trend()) + + +class Application(APIView): + authentication_classes = [TokenAuth] + + class EditIcon(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改应用icon", + operation_id="修改应用icon", + tags=['应用'], + manual_parameters=ApplicationApi.EditApplicationIcon.get_request_params_api(), + request_body=ApplicationApi.Operate.get_request_body_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), PermissionConstants.APPLICATION_EDIT, + compare=CompareConstants.AND) + def put(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.IconOperate( + data={'application_id': application_id, 'user_id': request.user.id, + 'image': request.FILES.get('file')}).edit(request.data)) + + class Embed(APIView): + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="获取嵌入js", + operation_id="获取嵌入js", + tags=["应用"], + manual_parameters=ApplicationApi.ApiKey.get_request_params_api()) + def get(self, request: Request): + return ApplicationSerializer.Embed( + data={'protocol': request.query_params.get('protocol'), 'token': request.query_params.get('token'), + 'host': request.query_params.get('host'), }).get_embed(params=request.query_params) + + class Model(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="获取模型列表", + operation_id="获取模型列表", + tags=["应用"], + manual_parameters=ApplicationApi.Model.get_request_params_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, + 'user_id': request.user.id}).list_model(request.query_params.get('model_type'))) + + class ModelParamsForm(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型参数表单", + operation_id="获取模型参数表单", + tags=["模型"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str, model_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, + 'user_id': request.user.id}).get_model_params_form(model_id)) + + class FunctionLib(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="获取函数库列表", + operation_id="获取函数库列表", + tags=["应用"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, + 'user_id': request.user.id}).list_function_lib()) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="获取函数库列表", + operation_id="获取函数库列表", + tags=["应用"], + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str, function_lib_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, + 'user_id': request.user.id}).get_function_lib(function_lib_id)) + + class Profile(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用相关信息", + operation_id="获取应用相关信息", + tags=["应用/会话"]) + def get(self, request: Request): + if 'application_id' in request.auth.keywords: + return result.success(ApplicationSerializer.Operate( + data={'application_id': request.auth.keywords.get('application_id'), + 'user_id': request.user.id}).profile()) + raise AppAuthenticationFailed(401, "身份异常") + + class ApplicationKey(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="新增ApiKey", + operation_id="新增ApiKey", + tags=['应用/API_KEY'], + manual_parameters=ApplicationApi.ApiKey.get_request_params_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def post(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.ApplicationKeySerializer( + data={'application_id': application_id, 'user_id': request.user.id}).generate()) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用API_KEY列表", + operation_id="获取应用API_KEY列表", + tags=['应用/API_KEY'], + manual_parameters=ApplicationApi.ApiKey.get_request_params_api() + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.ApplicationKeySerializer( + data={'application_id': application_id, 'user_id': request.user.id}).list()) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改应用API_KEY", + operation_id="修改应用API_KEY", + tags=['应用/API_KEY'], + manual_parameters=ApplicationApi.ApiKey.Operate.get_request_params_api(), + request_body=ApplicationApi.ApiKey.Operate.get_request_body_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), PermissionConstants.APPLICATION_EDIT, + compare=CompareConstants.AND) + def put(self, request: Request, application_id: str, api_key_id: str): + return result.success( + ApplicationSerializer.ApplicationKeySerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id, + 'api_key_id': api_key_id}).edit(request.data)) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除应用API_KEY", + operation_id="删除应用API_KEY", + tags=['应用/API_KEY'], + manual_parameters=ApplicationApi.ApiKey.Operate.get_request_params_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), PermissionConstants.APPLICATION_DELETE, + compare=CompareConstants.AND) + def delete(self, request: Request, application_id: str, api_key_id: str): + return result.success( + ApplicationSerializer.ApplicationKeySerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id, + 'api_key_id': api_key_id}).delete()) + + class AccessToken(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改 应用AccessToken", + operation_id="修改 应用AccessToken", + tags=['应用/公开访问'], + manual_parameters=ApplicationApi.AccessToken.get_request_params_api(), + request_body=ApplicationApi.AccessToken.get_request_body_api()) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def put(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).edit(request.data)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用 AccessToken信息", + operation_id="获取应用 AccessToken信息", + manual_parameters=ApplicationApi.AccessToken.get_request_params_api(), + tags=['应用/公开访问'], + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.AccessTokenSerializer(data={'application_id': application_id}).one()) + + class Authentication(APIView): + @action(methods=['OPTIONS'], detail=False) + def options(self, request, *args, **kwargs): + return HttpResponse(headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", + "Access-Control-Allow-Methods": "POST", + "Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"}, ) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="应用认证", + operation_id="应用认证", + request_body=ApplicationApi.Authentication.get_request_body_api(), + tags=["应用/认证"], + security=[]) + def post(self, request: Request): + return result.success( + ApplicationSerializer.Authentication(data={'access_token': request.data.get("access_token"), + 'authentication_value': request.data.get( + 'authentication_value')}).auth( + request), + headers={"Access-Control-Allow-Origin": "*", "Access-Control-Allow-Credentials": "true", + "Access-Control-Allow-Methods": "POST", + "Access-Control-Allow-Headers": "Origin,Content-Type,Cookie,Accept,Token"} + ) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建应用", + operation_id="创建应用", + request_body=ApplicationApi.Create.get_request_body_api(), + tags=['应用']) + @has_permissions(PermissionConstants.APPLICATION_CREATE, compare=CompareConstants.AND) + def post(self, request: Request): + return result.success(ApplicationSerializer.Create(data={'user_id': request.user.id}).insert(request.data)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用列表", + operation_id="获取应用列表", + manual_parameters=ApplicationApi.Query.get_request_params_api(), + responses=result.get_api_array_response(ApplicationApi.get_response_body_api()), + tags=['应用']) + @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + def get(self, request: Request): + return result.success( + ApplicationSerializer.Query( + data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list()) + + class HitTest(APIView): + authentication_classes = [TokenAuth] + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="命中测试列表", operation_id="命中测试列表", + manual_parameters=CommonApi.HitTestApi.get_request_params_api(), + responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()), + tags=["应用"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN, + RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.HitTest(data={'id': application_id, 'user_id': request.user.id, + "query_text": request.query_params.get("query_text"), + "top_number": request.query_params.get("top_number"), + 'similarity': request.query_params.get('similarity'), + 'search_mode': request.query_params.get('search_mode')}).hit_test( + )) + + class Publish(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="发布应用", + operation_id="发布应用", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + request_body=ApplicationApi.Publish.get_request_body_api(), + responses=result.get_default_response(), + tags=['应用']) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def put(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).publish(request.data)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除应用", + operation_id="删除应用", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=['应用']) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), + lambda r, k: Permission(group=Group.APPLICATION, operate=Operate.DELETE, + dynamic_tag=k.get('application_id')), compare=CompareConstants.AND) + def delete(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).delete( + with_valid=True)) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改应用", + operation_id="修改应用", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + request_body=ApplicationApi.Edit.get_request_body_api(), + responses=result.get_api_array_response(ApplicationApi.get_response_body_api()), + tags=['应用']) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def put(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}).edit( + request.data)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取应用详情", + operation_id="获取应用详情", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + responses=result.get_api_array_response(ApplicationApi.get_response_body_api()), + tags=['应用']) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN, + RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).one()) + + class ListApplicationDataSet(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取当前应用可使用的知识库", + operation_id="获取当前应用可使用的知识库", + manual_parameters=ApplicationApi.Operate.get_request_params_api(), + responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()), + tags=['应用']) + @has_permissions(ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, application_id: str): + return result.success(ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).list_dataset()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取应用列表", + operation_id="分页获取应用列表", + manual_parameters=result.get_page_request_params( + ApplicationApi.Query.get_request_params_api()), + responses=result.get_page_api_response(ApplicationApi.get_response_body_api()), + tags=['应用']) + @has_permissions(PermissionConstants.APPLICATION_READ, compare=CompareConstants.AND) + def get(self, request: Request, current_page: int, page_size: int): + return result.success( + ApplicationSerializer.Query( + data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).page( + current_page, page_size)) + + class SpeechToText(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND)) + def post(self, request: Request, application_id: str): + return result.success( + ApplicationSerializer.Operate(data={'application_id': application_id, 'user_id': request.user.id}) + .speech_to_text(request.FILES.getlist('file')[0])) + + class TextToSpeech(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND)) + def post(self, request: Request, application_id: str): + byte_data = ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).text_to_speech( + request.data.get('text')) + return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3', + 'Content-Disposition': 'attachment; filename="abc.mp3"'}) + + class PlayDemoText(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=keywords.get( + 'application_id'))], + compare=CompareConstants.AND)) + def post(self, request: Request, application_id: str): + byte_data = ApplicationSerializer.Operate( + data={'application_id': application_id, 'user_id': request.user.id}).play_demo_text(request.data) + return HttpResponse(byte_data, status=200, headers={'Content-Type': 'audio/mp3', + 'Content-Disposition': 'attachment; filename="abc.mp3"'}) diff --git a/src/MaxKB-1.7.2/apps/application/views/chat_views.py b/src/MaxKB-1.7.2/apps/application/views/chat_views.py new file mode 100644 index 0000000..922bbfc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/application/views/chat_views.py @@ -0,0 +1,393 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: chat_views.py + @date:2023/11/14 9:53 + @desc: +""" + +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.views import APIView + +from application.serializers.chat_message_serializers import ChatMessageSerializer, OpenAIChatSerializer +from application.serializers.chat_serializers import ChatSerializers, ChatRecordSerializer +from application.swagger_api.chat_api import ChatApi, VoteApi, ChatRecordApi, ImproveApi, ChatRecordImproveApi, \ + ChatClientHistoryApi, OpenAIChatApi +from common.auth import TokenAuth, has_permissions, OpenAIKeyAuth +from common.constants.authentication_type import AuthenticationType +from common.constants.permission_constants import Permission, Group, Operate, \ + RoleConstants, ViewPermission, CompareConstants +from common.response import result +from common.util.common import query_params_to_single_dict + + +class Openai(APIView): + authentication_classes = [OpenAIKeyAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="openai接口对话", + operation_id="openai接口对话", + request_body=OpenAIChatApi.get_request_body_api(), + tags=["openai对话"]) + def post(self, request: Request, application_id: str): + return OpenAIChatSerializer(data={'application_id': application_id, 'client_id': request.auth.client_id, + 'client_type': request.auth.client_type}).chat(request.data) + + +class ChatView(APIView): + authentication_classes = [TokenAuth] + + class Export(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="导出对话", + operation_id="导出对话", + manual_parameters=ChatApi.get_request_params_api(), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def post(self, request: Request, application_id: str): + return ChatSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'application_id': application_id, + 'user_id': request.user.id}).export(request.data) + + class Open(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取会话id,根据应用id", + operation_id="获取会话id,根据应用id", + manual_parameters=ChatApi.OpenChat.get_request_params_api(), + tags=["应用/会话"]) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_ACCESS_TOKEN, + RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND) + ) + def get(self, request: Request, application_id: str): + return result.success(ChatSerializers.OpenChat( + data={'user_id': request.user.id, 'application_id': application_id}).open()) + + class OpenWorkFlowTemp(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="获取工作流临时会话id", + operation_id="获取工作流临时会话id", + request_body=ChatApi.OpenWorkFlowTemp.get_request_body_api(), + tags=["应用/会话"]) + def post(self, request: Request): + return result.success(ChatSerializers.OpenWorkFlowChat( + data={**request.data, 'user_id': request.user.id}).open()) + + class OpenTemp(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="获取会话id(根据模型id,知识库列表,是否多轮会话)", + operation_id="获取会话id", + request_body=ChatApi.OpenTempChat.get_request_body_api(), + tags=["应用/会话"]) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def post(self, request: Request): + return result.success(ChatSerializers.OpenTempChat( + data={**request.data, 'user_id': request.user.id}).open()) + + class Message(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="对话", + operation_id="对话", + request_body=ChatApi.get_request_body_api(), + tags=["应用/会话"]) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, + RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def post(self, request: Request, chat_id: str): + return ChatMessageSerializer(data={'chat_id': chat_id, 'message': request.data.get('message'), + 're_chat': (request.data.get( + 're_chat') if 're_chat' in request.data else False), + 'stream': (request.data.get( + 'stream') if 'stream' in request.data else True), + 'application_id': (request.auth.keywords.get( + 'application_id') if request.auth.client_type == AuthenticationType.APPLICATION_ACCESS_TOKEN.value else None), + 'client_id': request.auth.client_id, + 'form_data': (request.data.get( + 'form_data') if 'form_data' in request.data else {}), + 'client_type': request.auth.client_type}).chat() + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取对话列表", + operation_id="获取对话列表", + manual_parameters=ChatApi.get_request_params_api(), + responses=result.get_api_array_response(ChatApi.get_response_body_api()), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str): + return result.success(ChatSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'application_id': application_id, + 'user_id': request.user.id}).list()) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除对话", + operation_id="删除对话", + tags=["应用/对话日志"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.MANAGE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), + compare=CompareConstants.AND) + def delete(self, request: Request, application_id: str, chat_id: str): + return result.success( + ChatSerializers.Operate( + data={'application_id': application_id, 'user_id': request.user.id, + 'chat_id': chat_id}).delete()) + + class ClientChatHistoryPage(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取客户端对话列表", + operation_id="分页获取客户端对话列表", + manual_parameters=result.get_page_request_params( + ChatClientHistoryApi.get_request_params_api()), + responses=result.get_page_api_response(ChatApi.get_response_body_api()), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str, current_page: int, page_size: int): + return result.success(ChatSerializers.ClientChatHistory( + data={'client_id': request.auth.client_id, 'application_id': application_id}).page( + current_page=current_page, + page_size=page_size)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="客户端删除对话", + operation_id="客户端删除对话", + tags=["应用/对话日志"]) + @has_permissions(ViewPermission( + [RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + compare=CompareConstants.AND), + compare=CompareConstants.AND) + def delete(self, request: Request, application_id: str, chat_id: str): + return result.success( + ChatSerializers.Operate( + data={'application_id': application_id, 'user_id': request.user.id, + 'chat_id': chat_id}).logic_delete()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取对话列表", + operation_id="分页获取对话列表", + manual_parameters=result.get_page_request_params(ChatApi.get_request_params_api()), + responses=result.get_page_api_response(ChatApi.get_response_body_api()), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str, current_page: int, page_size: int): + return result.success(ChatSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'application_id': application_id, + 'user_id': request.user.id}).page(current_page=current_page, + page_size=page_size)) + + class ChatRecord(APIView): + authentication_classes = [TokenAuth] + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取对话记录详情", + operation_id="获取对话记录详情", + manual_parameters=ChatRecordApi.get_request_params_api(), + responses=result.get_api_array_response(ChatRecordApi.get_response_body_api()), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, + RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): + return result.success(ChatRecordSerializer.Operate( + data={'application_id': application_id, + 'chat_id': chat_id, + 'chat_record_id': chat_record_id}).one(request.auth.current_role)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取对话记录列表", + operation_id="获取对话记录列表", + manual_parameters=ChatRecordApi.get_request_params_api(), + responses=result.get_api_array_response(ChatRecordApi.get_response_body_api()), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str, chat_id: str): + return result.success(ChatRecordSerializer.Query( + data={'application_id': application_id, + 'chat_id': chat_id, 'order_asc': request.query_params.get('order_asc')}).list()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取对话记录列表", + operation_id="获取对话记录列表", + manual_parameters=result.get_page_request_params( + ChatRecordApi.get_request_params_api()), + responses=result.get_page_api_response(ChatRecordApi.get_response_body_api()), + tags=["应用/对话日志"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def get(self, request: Request, application_id: str, chat_id: str, current_page: int, page_size: int): + return result.success(ChatRecordSerializer.Query( + data={'application_id': application_id, + 'chat_id': chat_id, 'order_asc': request.query_params.get('order_asc')}).page(current_page, + page_size)) + + class Vote(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="点赞,点踩", + operation_id="点赞,点踩", + manual_parameters=VoteApi.get_request_params_api(), + request_body=VoteApi.get_request_body_api(), + responses=result.get_default_response(), + tags=["应用/会话"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER, RoleConstants.APPLICATION_KEY, + RoleConstants.APPLICATION_ACCESS_TOKEN], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))]) + ) + def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): + return result.success(ChatRecordSerializer.Vote( + data={'vote_status': request.data.get('vote_status'), 'chat_id': chat_id, + 'chat_record_id': chat_record_id}).vote()) + + class ChatRecordImprove(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取标注段落列表信息", + operation_id="获取标注段落列表信息", + manual_parameters=ChatRecordImproveApi.get_request_params_api(), + responses=result.get_api_response(ChatRecordImproveApi.get_response_body_api()), + tags=["应用/对话日志/标注"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))] + )) + def get(self, request: Request, application_id: str, chat_id: str, chat_record_id: str): + return result.success(ChatRecordSerializer.ChatRecordImprove( + data={'chat_id': chat_id, 'chat_record_id': chat_record_id}).get()) + + class Improve(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="标注", + operation_id="标注", + manual_parameters=ImproveApi.get_request_params_api(), + request_body=ImproveApi.get_request_body_api(), + responses=result.get_api_response(ChatRecordApi.get_response_body_api()), + tags=["应用/对话日志/标注"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + + ), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.DATASET, + operate=Operate.MANAGE, + dynamic_tag=keywords.get( + 'dataset_id'))], + compare=CompareConstants.AND + ), compare=CompareConstants.AND) + def put(self, request: Request, application_id: str, chat_id: str, chat_record_id: str, dataset_id: str, + document_id: str): + return result.success(ChatRecordSerializer.Improve( + data={'chat_id': chat_id, 'chat_record_id': chat_record_id, + 'dataset_id': dataset_id, 'document_id': document_id}).improve(request.data)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="标注", + operation_id="标注", + manual_parameters=ImproveApi.get_request_params_api(), + responses=result.get_api_response(ChatRecordApi.get_response_body_api()), + tags=["应用/对话日志/标注"] + ) + @has_permissions( + ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.APPLICATION, operate=Operate.USE, + dynamic_tag=keywords.get('application_id'))], + + ), ViewPermission([RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.DATASET, + operate=Operate.MANAGE, + dynamic_tag=keywords.get( + 'dataset_id'))], + compare=CompareConstants.AND + ), compare=CompareConstants.AND) + def delete(self, request: Request, application_id: str, chat_id: str, chat_record_id: str, + dataset_id: str, + document_id: str, paragraph_id: str): + return result.success(ChatRecordSerializer.Improve.Operate( + data={'chat_id': chat_id, 'chat_record_id': chat_record_id, + 'dataset_id': dataset_id, 'document_id': document_id, + 'paragraph_id': paragraph_id}).delete()) diff --git a/src/MaxKB-1.7.2/apps/common/__init__.py b/src/MaxKB-1.7.2/apps/common/__init__.py new file mode 100644 index 0000000..75ce08f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: smart-doc + @Author:虎 + @file: __init__.py + @date:2023/9/14 16:22 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/common/auth/__init__.py b/src/MaxKB-1.7.2/apps/common/auth/__init__.py new file mode 100644 index 0000000..ca866ce --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/auth/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" + @project: smart-doc + @Author:虎 + @file: __init__.py + @date:2023/9/14 19:44 + @desc: +""" +from .authenticate import * +from .authentication import * diff --git a/src/MaxKB-1.7.2/apps/common/auth/authenticate.py b/src/MaxKB-1.7.2/apps/common/auth/authenticate.py new file mode 100644 index 0000000..3d54d47 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/auth/authenticate.py @@ -0,0 +1,95 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authenticate.py + @date:2023/9/4 11:16 + @desc: 认证类 +""" +import traceback +from importlib import import_module + +from django.conf import settings +from django.core import cache +from django.core import signing +from rest_framework.authentication import TokenAuthentication + +from common.exception.app_exception import AppAuthenticationFailed, AppEmbedIdentityFailed, AppChatNumOutOfBoundsFailed, \ + ChatException, AppApiException + +token_cache = cache.caches['token_cache'] + + +class AnonymousAuthentication(TokenAuthentication): + def authenticate(self, request): + return None, None + + +def new_instance_by_class_path(class_path: str): + parts = class_path.rpartition('.') + package_path = parts[0] + class_name = parts[2] + module = import_module(package_path) + HandlerClass = getattr(module, class_name) + return HandlerClass() + + +handles = [new_instance_by_class_path(class_path) for class_path in settings.AUTH_HANDLES] + + +class TokenDetails: + token_details = None + is_load = False + + def __init__(self, token: str): + self.token = token + + def get_token_details(self): + if self.token_details is None and not self.is_load: + try: + self.token_details = signing.loads(self.token) + except Exception as e: + self.is_load = True + return self.token_details + + +class OpenAIKeyAuth(TokenAuthentication): + def authenticate(self, request): + auth = request.META.get('HTTP_AUTHORIZATION') + auth = auth.replace('Bearer ', '') + # 未认证 + if auth is None: + raise AppAuthenticationFailed(1003, '未登录,请先登录') + try: + token_details = TokenDetails(auth) + for handle in handles: + if handle.support(request, auth, token_details.get_token_details): + return handle.handle(request, auth, token_details.get_token_details) + raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") + except Exception as e: + traceback.format_exc() + if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed) or isinstance(e, + AppApiException): + raise e + raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") + + +class TokenAuth(TokenAuthentication): + # 重新 authenticate 方法,自定义认证规则 + def authenticate(self, request): + auth = request.META.get('HTTP_AUTHORIZATION') + # 未认证 + if auth is None: + raise AppAuthenticationFailed(1003, '未登录,请先登录') + try: + token_details = TokenDetails(auth) + for handle in handles: + if handle.support(request, auth, token_details.get_token_details): + return handle.handle(request, auth, token_details.get_token_details) + raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") + except Exception as e: + traceback.format_exc() + if isinstance(e, AppEmbedIdentityFailed) or isinstance(e, AppChatNumOutOfBoundsFailed) or isinstance(e, + AppApiException): + raise e + raise AppAuthenticationFailed(1002, "身份验证信息不正确!非法用户") diff --git a/src/MaxKB-1.7.2/apps/common/auth/authentication.py b/src/MaxKB-1.7.2/apps/common/auth/authentication.py new file mode 100644 index 0000000..d692d61 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/auth/authentication.py @@ -0,0 +1,98 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authentication.py + @date:2023/9/13 15:00 + @desc: 鉴权 +""" +from typing import List + +from common.constants.permission_constants import ViewPermission, CompareConstants, RoleConstants, PermissionConstants, \ + Permission +from common.exception.app_exception import AppUnauthorizedFailed + + +def exist_permissions_by_permission_constants(user_permission: List[PermissionConstants], + permission_list: List[PermissionConstants]): + """ + 用户是否拥有 permission_list的权限 + :param user_permission: 用户权限 + :param permission_list: 需要的权限 + :return: 是否拥有 + """ + return any(list(map(lambda up: permission_list.__contains__(up), user_permission))) + + +def exist_role_by_role_constants(user_role: List[RoleConstants], + role_list: List[RoleConstants]): + """ + 用户是否拥有这个角色 + :param user_role: 用户角色 + :param role_list: 需要拥有的角色 + :return: 是否拥有 + """ + return any(list(map(lambda up: role_list.__contains__(up), user_role))) + + +def exist_permissions_by_view_permission(user_role: List[RoleConstants], + user_permission: List[PermissionConstants | object], + permission: ViewPermission, request, **kwargs): + """ + 用户是否存在这些权限 + :param request: + :param user_role: 用户角色 + :param user_permission: 用户权限 + :param permission: 所属权限 + :return: 是否存在 True False + """ + role_ok = any(list(map(lambda ur: permission.roleList.__contains__(ur), user_role))) + permission_list = [user_p(request, kwargs) if callable(user_p) else user_p for user_p in + permission.permissionList + ] + permission_ok = any(list(map(lambda up: permission_list.__contains__(up), + user_permission))) + return role_ok | permission_ok if permission.compare == CompareConstants.OR else role_ok & permission_ok + + +def exist_permissions(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission, request, + **kwargs): + if isinstance(permission, ViewPermission): + return exist_permissions_by_view_permission(user_role, user_permission, permission, request, **kwargs) + if isinstance(permission, RoleConstants): + return exist_role_by_role_constants(user_role, [permission]) + if isinstance(permission, PermissionConstants): + return exist_permissions_by_permission_constants(user_permission, [permission]) + if isinstance(permission, Permission): + return user_permission.__contains__(permission) + return False + + +def exist(user_role: List[RoleConstants], user_permission: List[PermissionConstants], permission, request, **kwargs): + if callable(permission): + p = permission(request, kwargs) + return exist_permissions(user_role, user_permission, p, request) + return exist_permissions(user_role, user_permission, permission, request, **kwargs) + + +def has_permissions(*permission, compare=CompareConstants.OR): + """ + 权限 role or permission + :param compare: 比较符号 + :param permission: 如果是角色 role:roleId + :return: 权限装饰器函数,用于判断用户是否有权限访问当前接口 + """ + + def inner(func): + def run(view, request, **kwargs): + exit_list = list( + map(lambda p: exist(request.auth.role_list, request.auth.permission_list, p, request, **kwargs), + permission)) + # 判断是否有权限 + if any(exit_list) if compare == CompareConstants.OR else all(exit_list): + return func(view, request, **kwargs) + raise AppUnauthorizedFailed(403, "没有权限访问") + + return run + + return inner diff --git a/src/MaxKB-1.7.2/apps/common/auth/handle/auth_base_handle.py b/src/MaxKB-1.7.2/apps/common/auth/handle/auth_base_handle.py new file mode 100644 index 0000000..991256e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/auth/handle/auth_base_handle.py @@ -0,0 +1,19 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authenticate.py + @date:2024/3/14 03:02 + @desc: 认证处理器 +""" +from abc import ABC, abstractmethod + + +class AuthBaseHandle(ABC): + @abstractmethod + def support(self, request, token: str, get_token_details): + pass + + @abstractmethod + def handle(self, request, token: str, get_token_details): + pass diff --git a/src/MaxKB-1.7.2/apps/common/auth/handle/impl/application_key.py b/src/MaxKB-1.7.2/apps/common/auth/handle/impl/application_key.py new file mode 100644 index 0000000..b35ef80 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/auth/handle/impl/application_key.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authenticate.py + @date:2024/3/14 03:02 + @desc: 应用api key认证 +""" +from django.db.models import QuerySet + +from application.models.api_key_model import ApplicationApiKey +from common.auth.handle.auth_base_handle import AuthBaseHandle +from common.constants.authentication_type import AuthenticationType +from common.constants.permission_constants import Permission, Group, Operate, RoleConstants, Auth +from common.exception.app_exception import AppAuthenticationFailed + + +class ApplicationKey(AuthBaseHandle): + def handle(self, request, token: str, get_token_details): + application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=token).first() + if application_api_key is None: + raise AppAuthenticationFailed(500, "secret_key 无效") + if not application_api_key.is_active: + raise AppAuthenticationFailed(500, "secret_key 无效") + permission_list = [Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=str( + application_api_key.application_id)), + Permission(group=Group.APPLICATION, + operate=Operate.MANAGE, + dynamic_tag=str( + application_api_key.application_id)) + ] + return application_api_key.user, Auth(role_list=[RoleConstants.APPLICATION_KEY], + permission_list=permission_list, + application_id=application_api_key.application_id, + client_id=str(application_api_key.id), + client_type=AuthenticationType.API_KEY.value, + current_role=RoleConstants.APPLICATION_KEY + ) + + def support(self, request, token: str, get_token_details): + return str(token).startswith("application-") diff --git a/src/MaxKB-1.7.2/apps/common/auth/handle/impl/public_access_token.py b/src/MaxKB-1.7.2/apps/common/auth/handle/impl/public_access_token.py new file mode 100644 index 0000000..2b44a9a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/auth/handle/impl/public_access_token.py @@ -0,0 +1,67 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authenticate.py + @date:2024/3/14 03:02 + @desc: 公共访问连接认证 +""" +from django.db.models import QuerySet + +from application.models.api_key_model import ApplicationAccessToken +from common.auth.handle.auth_base_handle import AuthBaseHandle +from common.constants.authentication_type import AuthenticationType +from common.constants.permission_constants import RoleConstants, Permission, Group, Operate, Auth +from common.exception.app_exception import AppAuthenticationFailed, ChatException +from common.models.db_model_manage import DBModelManage +from common.util.common import password_encrypt + + +class PublicAccessToken(AuthBaseHandle): + def support(self, request, token: str, get_token_details): + token_details = get_token_details() + if token_details is None: + return False + return ( + 'application_id' in token_details and + 'access_token' in token_details and + token_details.get('type') == AuthenticationType.APPLICATION_ACCESS_TOKEN.value) + + def handle(self, request, token: str, get_token_details): + auth_details = get_token_details() + application_access_token = QuerySet(ApplicationAccessToken).filter( + application_id=auth_details.get('application_id')).first() + if request.path != '/api/application/profile': + application_setting_model = DBModelManage.get_model('application_setting') + xpack_cache = DBModelManage.get_model('xpack_cache') + X_PACK_LICENSE_IS_VALID = False if xpack_cache is None else xpack_cache.get('XPACK_LICENSE_IS_VALID', False) + if application_setting_model is not None and X_PACK_LICENSE_IS_VALID: + application_setting = QuerySet(application_setting_model).filter(application_id=str( + application_access_token.application_id)).first() + if application_setting.authentication: + authentication = auth_details.get('authentication', {}) + if authentication is None: + authentication = {} + if application_setting.authentication_value.get('type') != authentication.get( + 'type') or password_encrypt( + application_setting.authentication_value.get('value')) != authentication.get('value'): + raise ChatException(1002, "身份验证信息不正确") + if application_access_token is None: + raise AppAuthenticationFailed(1002, "身份验证信息不正确") + if not application_access_token.is_active: + raise AppAuthenticationFailed(1002, "身份验证信息不正确") + if not application_access_token.access_token == auth_details.get('access_token'): + raise AppAuthenticationFailed(1002, "身份验证信息不正确") + + return application_access_token.application.user, Auth( + role_list=[RoleConstants.APPLICATION_ACCESS_TOKEN], + permission_list=[ + Permission(group=Group.APPLICATION, + operate=Operate.USE, + dynamic_tag=str( + application_access_token.application_id))], + application_id=application_access_token.application_id, + client_id=auth_details.get('client_id'), + client_type=AuthenticationType.APPLICATION_ACCESS_TOKEN.value, + current_role=RoleConstants.APPLICATION_ACCESS_TOKEN + ) diff --git a/src/MaxKB-1.7.2/apps/common/auth/handle/impl/user_token.py b/src/MaxKB-1.7.2/apps/common/auth/handle/impl/user_token.py new file mode 100644 index 0000000..6559797 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/auth/handle/impl/user_token.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: authenticate.py + @date:2024/3/14 03:02 + @desc: 用户认证 +""" +from django.db.models import QuerySet + +from common.auth.handle.auth_base_handle import AuthBaseHandle +from common.constants.authentication_type import AuthenticationType +from common.constants.permission_constants import RoleConstants, get_permission_list_by_role, Auth +from common.exception.app_exception import AppAuthenticationFailed +from smartdoc.settings import JWT_AUTH +from users.models import User +from django.core import cache + +from users.models.user import get_user_dynamics_permission + +token_cache = cache.caches['token_cache'] + + +class UserToken(AuthBaseHandle): + def support(self, request, token: str, get_token_details): + auth_details = get_token_details() + if auth_details is None: + return False + return 'id' in auth_details and auth_details.get('type') == AuthenticationType.USER.value + + def handle(self, request, token: str, get_token_details): + cache_token = token_cache.get(token) + if cache_token is None: + raise AppAuthenticationFailed(1002, "登录过期") + auth_details = get_token_details() + user = QuerySet(User).get(id=auth_details['id']) + # 续期 + token_cache.touch(token, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA'].total_seconds()) + rule = RoleConstants[user.role] + permission_list = get_permission_list_by_role(RoleConstants[user.role]) + # 获取用户的应用和知识库的权限 + permission_list += get_user_dynamics_permission(str(user.id)) + return user, Auth(role_list=[rule], + permission_list=permission_list, + client_id=str(user.id), + client_type=AuthenticationType.USER.value, + current_role=rule) diff --git a/src/MaxKB-1.7.2/apps/common/cache/file_cache.py b/src/MaxKB-1.7.2/apps/common/cache/file_cache.py new file mode 100644 index 0000000..45b5a73 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/cache/file_cache.py @@ -0,0 +1,86 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: file_cache.py + @date:2023/9/11 15:58 + @desc: 文件缓存 +""" +import datetime +import math +import os +import time + +from diskcache import Cache +from django.core.cache.backends.base import BaseCache + + +class FileCache(BaseCache): + def __init__(self, dir, params): + super().__init__(params) + self._dir = os.path.abspath(dir) + self._createdir() + self.cache = Cache(self._dir) + + def _createdir(self): + old_umask = os.umask(0o077) + try: + os.makedirs(self._dir, 0o700, exist_ok=True) + finally: + os.umask(old_umask) + + def add(self, key, value, timeout=None, version=None): + expire = timeout if isinstance(timeout, int) or isinstance(timeout, + float) or timeout is None else timeout.total_seconds() + return self.cache.add(self.get_key(key, version), value=value, expire=expire) + + def set(self, key, value, timeout=None, version=None): + expire = timeout if isinstance(timeout, int) or isinstance(timeout, + float) or timeout is None else timeout.total_seconds() + return self.cache.set(self.get_key(key, version), value=value, expire=expire) + + def get(self, key, default=None, version=None): + return self.cache.get(self.get_key(key, version), default=default) + + @staticmethod + def get_key(key, version): + if version is None: + return f"default:{key}" + return f"{version}:{key}" + + def delete(self, key, version=None): + return self.cache.delete(self.get_key(key, version)) + + def touch(self, key, timeout=None, version=None): + expire = timeout if isinstance(timeout, int) or isinstance(timeout, + float) else timeout.total_seconds() + + return self.cache.touch(self.get_key(key, version), expire=expire) + + def ttl(self, key, version=None): + """ + 获取key的剩余时间 + :param key: key + :return: 剩余时间 + @param version: + """ + value, expire_time = self.cache.get(self.get_key(key, version), expire_time=True) + if value is None: + return None + return datetime.timedelta(seconds=math.ceil(expire_time - time.time())) + + def clear_by_application_id(self, application_id): + delete_keys = [] + for key in self.cache.iterkeys(): + value = self.cache.get(key) + if (hasattr(value, + 'application') and value.application is not None and value.application.id is not None and + str( + value.application.id) == application_id): + delete_keys.append(key) + for key in delete_keys: + self.cache.delete(key) + + def clear_timeout_data(self): + for key in self.cache.iterkeys(): + self.get(key) diff --git a/src/MaxKB-1.7.2/apps/common/cache/mem_cache.py b/src/MaxKB-1.7.2/apps/common/cache/mem_cache.py new file mode 100644 index 0000000..5afb1e5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/cache/mem_cache.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: mem_cache.py + @date:2024/3/6 11:20 + @desc: +""" +from django.core.cache.backends.base import DEFAULT_TIMEOUT +from django.core.cache.backends.locmem import LocMemCache + + +class MemCache(LocMemCache): + def __init__(self, name, params): + super().__init__(name, params) + + def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): + key = self.make_and_validate_key(key, version=version) + pickled = value + with self._lock: + self._set(key, pickled, timeout) + + def get(self, key, default=None, version=None): + key = self.make_and_validate_key(key, version=version) + with self._lock: + if self._has_expired(key): + self._delete(key) + return default + pickled = self._cache[key] + self._cache.move_to_end(key, last=False) + return pickled + + def clear_by_application_id(self, application_id): + delete_keys = [] + for key in self._cache.keys(): + value = self._cache.get(key) + if (hasattr(value, + 'application') and value.application is not None and value.application.id is not None and + str( + value.application.id) == application_id): + delete_keys.append(key) + for key in delete_keys: + self._delete(key) + + def clear_timeout_data(self): + for key in self._cache.keys(): + self.get(key) diff --git a/src/MaxKB-1.7.2/apps/common/cache_data/application_access_token_cache.py b/src/MaxKB-1.7.2/apps/common/cache_data/application_access_token_cache.py new file mode 100644 index 0000000..54f2a7e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/cache_data/application_access_token_cache.py @@ -0,0 +1,31 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: application_access_token_cache.py + @date:2024/7/25 11:34 + @desc: +""" +from django.core.cache import cache +from django.db.models import QuerySet + +from application.models.api_key_model import ApplicationAccessToken +from common.constants.cache_code_constants import CacheCodeConstants +from common.util.cache_util import get_cache + + +@get_cache(cache_key=lambda access_token, use_get_data: access_token, + use_get_data=lambda access_token, use_get_data: use_get_data, + version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value) +def get_application_access_token(access_token, use_get_data): + application_access_token = QuerySet(ApplicationAccessToken).filter(access_token=access_token).first() + if application_access_token is None: + return None + return {'white_active': application_access_token.white_active, + 'white_list': application_access_token.white_list, + 'application_icon': application_access_token.application.icon, + 'application_name': application_access_token.application.name} + + +def del_application_access_token(access_token): + cache.delete(access_token, version=CacheCodeConstants.APPLICATION_ACCESS_TOKEN_CACHE.value) diff --git a/src/MaxKB-1.7.2/apps/common/cache_data/application_api_key_cache.py b/src/MaxKB-1.7.2/apps/common/cache_data/application_api_key_cache.py new file mode 100644 index 0000000..a7d810c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/cache_data/application_api_key_cache.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: application_api_key_cache.py + @date:2024/7/25 11:30 + @desc: +""" +from django.core.cache import cache +from django.db.models import QuerySet + +from application.models.api_key_model import ApplicationApiKey +from common.constants.cache_code_constants import CacheCodeConstants +from common.util.cache_util import get_cache + + +@get_cache(cache_key=lambda secret_key, use_get_data: secret_key, + use_get_data=lambda secret_key, use_get_data: use_get_data, + version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value) +def get_application_api_key(secret_key, use_get_data): + application_api_key = QuerySet(ApplicationApiKey).filter(secret_key=secret_key).first() + return {'allow_cross_domain': application_api_key.allow_cross_domain, + 'cross_domain_list': application_api_key.cross_domain_list} + + +def del_application_api_key(secret_key): + cache.delete(secret_key, version=CacheCodeConstants.APPLICATION_API_KEY_CACHE.value) diff --git a/src/MaxKB-1.7.2/apps/common/cache_data/static_resource_cache.py b/src/MaxKB-1.7.2/apps/common/cache_data/static_resource_cache.py new file mode 100644 index 0000000..1bb84e9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/cache_data/static_resource_cache.py @@ -0,0 +1,19 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: static_resource_cache.py + @date:2024/7/25 11:30 + @desc: +""" +from common.constants.cache_code_constants import CacheCodeConstants +from common.util.cache_util import get_cache + + +@get_cache(cache_key=lambda index_path: index_path, + version=CacheCodeConstants.STATIC_RESOURCE_CACHE.value) +def get_index_html(index_path): + file = open(index_path, "r", encoding='utf-8') + content = file.read() + file.close() + return content diff --git a/src/MaxKB-1.7.2/apps/common/chunk/__init__.py b/src/MaxKB-1.7.2/apps/common/chunk/__init__.py new file mode 100644 index 0000000..a4babde --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/chunk/__init__.py @@ -0,0 +1,18 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/7/23 17:03 + @desc: +""" +from common.chunk.impl.mark_chunk_handle import MarkChunkHandle + +handles = [MarkChunkHandle()] + + +def text_to_chunk(text: str): + chunk_list = [text] + for handle in handles: + chunk_list = handle.handle(chunk_list) + return chunk_list diff --git a/src/MaxKB-1.7.2/apps/common/chunk/i_chunk_handle.py b/src/MaxKB-1.7.2/apps/common/chunk/i_chunk_handle.py new file mode 100644 index 0000000..d53575d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/chunk/i_chunk_handle.py @@ -0,0 +1,16 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: i_chunk_handle.py + @date:2024/7/23 16:51 + @desc: +""" +from abc import ABC, abstractmethod +from typing import List + + +class IChunkHandle(ABC): + @abstractmethod + def handle(self, chunk_list: List[str]): + pass diff --git a/src/MaxKB-1.7.2/apps/common/chunk/impl/mark_chunk_handle.py b/src/MaxKB-1.7.2/apps/common/chunk/impl/mark_chunk_handle.py new file mode 100644 index 0000000..5bca2f4 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/chunk/impl/mark_chunk_handle.py @@ -0,0 +1,40 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: mark_chunk_handle.py + @date:2024/7/23 16:52 + @desc: +""" +import re +from typing import List + +from common.chunk.i_chunk_handle import IChunkHandle + +max_chunk_len = 256 +split_chunk_pattern = r'.{1,%d}[。| |\\.|!|;|;|!|\n]' % max_chunk_len +max_chunk_pattern = r'.{1,%d}' % max_chunk_len + + +class MarkChunkHandle(IChunkHandle): + def handle(self, chunk_list: List[str]): + result = [] + for chunk in chunk_list: + chunk_result = re.findall(split_chunk_pattern, chunk, flags=re.DOTALL) + for c_r in chunk_result: + if len(c_r.strip()) > 0: + result.append(c_r.strip()) + + other_chunk_list = re.split(split_chunk_pattern, chunk, flags=re.DOTALL) + for other_chunk in other_chunk_list: + if len(other_chunk) > 0: + if len(other_chunk) < max_chunk_len: + if len(other_chunk.strip()) > 0: + result.append(other_chunk.strip()) + else: + max_chunk_list = re.findall(max_chunk_pattern, other_chunk, flags=re.DOTALL) + for m_c in max_chunk_list: + if len(m_c.strip()) > 0: + result.append(m_c.strip()) + + return result diff --git a/src/MaxKB-1.7.2/apps/common/config/embedding_config.py b/src/MaxKB-1.7.2/apps/common/config/embedding_config.py new file mode 100644 index 0000000..a6e9ab9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/config/embedding_config.py @@ -0,0 +1,66 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: embedding_config.py + @date:2023/10/23 16:03 + @desc: +""" +import threading +import time + +from common.cache.mem_cache import MemCache + +lock = threading.Lock() + + +class ModelManage: + cache = MemCache('model', {}) + up_clear_time = time.time() + + @staticmethod + def get_model(_id, get_model): + # 获取锁 + lock.acquire() + try: + model_instance = ModelManage.cache.get(_id) + if model_instance is None or not model_instance.is_cache_model(): + model_instance = get_model(_id) + ModelManage.cache.set(_id, model_instance, timeout=60 * 30) + return model_instance + # 续期 + ModelManage.cache.touch(_id, timeout=60 * 30) + ModelManage.clear_timeout_cache() + return model_instance + finally: + # 释放锁 + lock.release() + + @staticmethod + def clear_timeout_cache(): + if time.time() - ModelManage.up_clear_time > 60: + ModelManage.cache.clear_timeout_data() + + @staticmethod + def delete_key(_id): + if ModelManage.cache.has_key(_id): + ModelManage.cache.delete(_id) + + +class VectorStore: + from embedding.vector.pg_vector import PGVector + from embedding.vector.base_vector import BaseVectorStore + instance_map = { + 'pg_vector': PGVector, + } + instance = None + + @staticmethod + def get_embedding_vector() -> BaseVectorStore: + from embedding.vector.pg_vector import PGVector + if VectorStore.instance is None: + from smartdoc.const import CONFIG + vector_store_class = VectorStore.instance_map.get(CONFIG.get("VECTOR_STORE_NAME"), + PGVector) + VectorStore.instance = vector_store_class() + return VectorStore.instance diff --git a/src/MaxKB-1.7.2/apps/common/config/swagger_conf.py b/src/MaxKB-1.7.2/apps/common/config/swagger_conf.py new file mode 100644 index 0000000..15a7422 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/config/swagger_conf.py @@ -0,0 +1,29 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: swagger_conf.py + @date:2023/9/5 14:01 + @desc: 用于swagger 分组 +""" + +from drf_yasg.inspectors import SwaggerAutoSchema + +tags_dict = { + 'user': '用户' +} + + +class CustomSwaggerAutoSchema(SwaggerAutoSchema): + def get_tags(self, operation_keys=None): + tags = super().get_tags(operation_keys) + if "api" in tags and operation_keys: + return [tags_dict.get(operation_keys[1]) if operation_keys[1] in tags_dict else operation_keys[1]] + return tags + def get_schema(self, request=None, public=False): + schema = super().get_schema(request, public) + if request.is_secure(): + schema.schemes = ['https'] + else: + schema.schemes = ['http'] + return schema \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/common/config/tokenizer_manage_config.py b/src/MaxKB-1.7.2/apps/common/config/tokenizer_manage_config.py new file mode 100644 index 0000000..1d3fa8d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/config/tokenizer_manage_config.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: tokenizer_manage_config.py + @date:2024/4/28 10:17 + @desc: +""" + + +class TokenizerManage: + tokenizer = None + + @staticmethod + def get_tokenizer(): + from transformers import GPT2TokenizerFast + if TokenizerManage.tokenizer is None: + TokenizerManage.tokenizer = GPT2TokenizerFast.from_pretrained( + 'gpt2', + cache_dir="/opt/maxkb/model/tokenizer", + local_files_only=True, + resume_download=False, + force_download=False) + return TokenizerManage.tokenizer diff --git a/src/MaxKB-1.7.2/apps/common/constants/authentication_type.py b/src/MaxKB-1.7.2/apps/common/constants/authentication_type.py new file mode 100644 index 0000000..83586ee --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/constants/authentication_type.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: authentication_type.py + @date:2023/11/14 20:03 + @desc: +""" +from enum import Enum + + +class AuthenticationType(Enum): + # 普通用户 + USER = "USER" + # 公共访问链接 + APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN" + # key API + API_KEY = "API_KEY" + # 第三方对接 + PLATFORM = 'PLATFORM' diff --git a/src/MaxKB-1.7.2/apps/common/constants/cache_code_constants.py b/src/MaxKB-1.7.2/apps/common/constants/cache_code_constants.py new file mode 100644 index 0000000..dd64805 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/constants/cache_code_constants.py @@ -0,0 +1,18 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: cache_code_constants.py + @date:2024/7/24 18:20 + @desc: +""" +from enum import Enum + + +class CacheCodeConstants(Enum): + # 应用ACCESS_TOKEN缓存 + APPLICATION_ACCESS_TOKEN_CACHE = 'APPLICATION_ACCESS_TOKEN_CACHE' + # 静态资源缓存 + STATIC_RESOURCE_CACHE = 'STATIC_RESOURCE_CACHE' + # 应用API_KEY缓存 + APPLICATION_API_KEY_CACHE = 'APPLICATION_API_KEY_CACHE' diff --git a/src/MaxKB-1.7.2/apps/common/constants/exception_code_constants.py b/src/MaxKB-1.7.2/apps/common/constants/exception_code_constants.py new file mode 100644 index 0000000..ba7a810 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/constants/exception_code_constants.py @@ -0,0 +1,39 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: exception_code_constants.py + @date:2023/9/4 14:09 + @desc: 异常常量类 +""" +from enum import Enum + +from common.exception.app_exception import AppApiException + + +class ExceptionCodeConstantsValue: + def __init__(self, code, message): + self.code = code + self.message = message + + def get_message(self): + return self.message + + def get_code(self): + return self.code + + def to_app_api_exception(self): + return AppApiException(code=self.code, message=self.message) + + +class ExceptionCodeConstants(Enum): + INCORRECT_USERNAME_AND_PASSWORD = ExceptionCodeConstantsValue(1000, "用户名或者密码不正确") + NOT_AUTHENTICATION = ExceptionCodeConstantsValue(1001, "请先登录,并携带用户Token") + EMAIL_SEND_ERROR = ExceptionCodeConstantsValue(1002, "邮件发送失败") + EMAIL_FORMAT_ERROR = ExceptionCodeConstantsValue(1003, "邮箱格式错误") + EMAIL_IS_EXIST = ExceptionCodeConstantsValue(1004, "邮箱已经被注册,请勿重复注册") + EMAIL_IS_NOT_EXIST = ExceptionCodeConstantsValue(1005, "邮箱尚未注册,请先注册") + CODE_ERROR = ExceptionCodeConstantsValue(1005, "验证码不正确,或者验证码过期") + USERNAME_IS_EXIST = ExceptionCodeConstantsValue(1006, "用户名已被使用,请使用其他用户名") + USERNAME_ERROR = ExceptionCodeConstantsValue(1006, "用户名不能为空,并且长度在6-20") + PASSWORD_NOT_EQ_RE_PASSWORD = ExceptionCodeConstantsValue(1007, "密码与确认密码不一致") diff --git a/src/MaxKB-1.7.2/apps/common/constants/permission_constants.py b/src/MaxKB-1.7.2/apps/common/constants/permission_constants.py new file mode 100644 index 0000000..04f86bb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/constants/permission_constants.py @@ -0,0 +1,176 @@ +""" + @project: qabot + @Author:虎 + @file: permission_constants.py + @date:2023/9/13 18:23 + @desc: 权限,角色 常量 +""" +from enum import Enum +from typing import List + + +class Group(Enum): + """ + 权限组 一个组一般对应前端一个菜单 + """ + USER = "USER" + + DATASET = "DATASET" + + APPLICATION = "APPLICATION" + + SETTING = "SETTING" + + MODEL = "MODEL" + + TEAM = "TEAM" + + +class Operate(Enum): + """ + 一个权限组的操作权限 + """ + READ = 'READ' + EDIT = "EDIT" + CREATE = "CREATE" + DELETE = "DELETE" + """ + 管理权限 + """ + MANAGE = "MANAGE" + """ + 使用权限 + """ + USE = "USE" + + +class RoleGroup(Enum): + USER = 'USER' + APPLICATION_KEY = "APPLICATION_KEY" + APPLICATION_ACCESS_TOKEN = "APPLICATION_ACCESS_TOKEN" + + +class Role: + def __init__(self, name: str, decs: str, group: RoleGroup): + self.name = name + self.decs = decs + self.group = group + + +class RoleConstants(Enum): + ADMIN = Role("管理员", "管理员,预制目前不会使用", RoleGroup.USER) + USER = Role("用户", "用户所有权限", RoleGroup.USER) + APPLICATION_ACCESS_TOKEN = Role("会话", "只拥有应用会话框接口权限", RoleGroup.APPLICATION_ACCESS_TOKEN), + APPLICATION_KEY = Role("应用私钥", "应用私钥", RoleGroup.APPLICATION_KEY) + + +class Permission: + """ + 权限信息 + """ + + def __init__(self, group: Group, operate: Operate, roles=None, dynamic_tag=None): + if roles is None: + roles = [] + self.group = group + self.operate = operate + self.roleList = roles + self.dynamic_tag = dynamic_tag + + def __str__(self): + return self.group.value + ":" + self.operate.value + ( + (":" + self.dynamic_tag) if self.dynamic_tag is not None else '') + + def __eq__(self, other): + return str(self) == str(other) + + +class PermissionConstants(Enum): + """ + 权限枚举 + """ + USER_READ = Permission(group=Group.USER, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + USER_EDIT = Permission(group=Group.USER, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + USER_DELETE = Permission(group=Group.USER, operate=Operate.DELETE, roles=[RoleConstants.USER]) + + DATASET_CREATE = Permission(group=Group.DATASET, operate=Operate.CREATE, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + DATASET_READ = Permission(group=Group.DATASET, operate=Operate.READ, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + DATASET_EDIT = Permission(group=Group.DATASET, operate=Operate.EDIT, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + APPLICATION_READ = Permission(group=Group.APPLICATION, operate=Operate.READ, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + APPLICATION_CREATE = Permission(group=Group.APPLICATION, operate=Operate.CREATE, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + APPLICATION_DELETE = Permission(group=Group.APPLICATION, operate=Operate.DELETE, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + APPLICATION_EDIT = Permission(group=Group.APPLICATION, operate=Operate.EDIT, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + SETTING_READ = Permission(group=Group.SETTING, operate=Operate.READ, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + MODEL_READ = Permission(group=Group.MODEL, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + MODEL_EDIT = Permission(group=Group.MODEL, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + MODEL_DELETE = Permission(group=Group.MODEL, operate=Operate.DELETE, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + MODEL_CREATE = Permission(group=Group.MODEL, operate=Operate.CREATE, + roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + TEAM_READ = Permission(group=Group.TEAM, operate=Operate.READ, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + TEAM_CREATE = Permission(group=Group.TEAM, operate=Operate.CREATE, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + TEAM_DELETE = Permission(group=Group.TEAM, operate=Operate.DELETE, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + TEAM_EDIT = Permission(group=Group.TEAM, operate=Operate.EDIT, roles=[RoleConstants.ADMIN, RoleConstants.USER]) + + +def get_permission_list_by_role(role: RoleConstants): + """ + 根据角色 获取角色对应的权限 + :param role: 角色 + :return: 权限 + """ + return list(map(lambda k: PermissionConstants[k], + list(filter(lambda k: PermissionConstants[k].value.roleList.__contains__(role), + PermissionConstants.__members__)))) + + +class Auth: + """ + 用于存储当前用户的角色和权限 + """ + + def __init__(self, role_list: List[RoleConstants], permission_list: List[PermissionConstants | Permission] + , client_id, client_type, current_role: RoleConstants, **keywords): + self.role_list = role_list + self.permission_list = permission_list + self.client_id = client_id + self.client_type = client_type + self.keywords = keywords + self.current_role = current_role + + +class CompareConstants(Enum): + # 或者 + OR = "OR" + # 并且 + AND = "AND" + + +class ViewPermission: + def __init__(self, roleList: List[RoleConstants], permissionList: List[PermissionConstants | object], + compare=CompareConstants.OR): + self.roleList = roleList + self.permissionList = permissionList + self.compare = compare diff --git a/src/MaxKB-1.7.2/apps/common/db/compiler.py b/src/MaxKB-1.7.2/apps/common/db/compiler.py new file mode 100644 index 0000000..69640c8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/db/compiler.py @@ -0,0 +1,217 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: compiler.py + @date:2023/10/7 10:53 + @desc: +""" + +from django.core.exceptions import EmptyResultSet, FullResultSet +from django.db import NotSupportedError +from django.db.models.sql.compiler import SQLCompiler +from django.db.transaction import TransactionManagementError + + +class AppSQLCompiler(SQLCompiler): + def __init__(self, query, connection, using, elide_empty=True, field_replace_dict=None): + super().__init__(query, connection, using, elide_empty) + if field_replace_dict is None: + field_replace_dict = {} + self.field_replace_dict = field_replace_dict + + def get_query_str(self, with_limits=True, with_table_name=False, with_col_aliases=False): + refcounts_before = self.query.alias_refcount.copy() + try: + combinator = self.query.combinator + extra_select, order_by, group_by = self.pre_sql_setup( + with_col_aliases=with_col_aliases or bool(combinator), + ) + for_update_part = None + # Is a LIMIT/OFFSET clause needed? + with_limit_offset = with_limits and self.query.is_sliced + combinator = self.query.combinator + features = self.connection.features + if combinator: + if not getattr(features, "supports_select_{}".format(combinator)): + raise NotSupportedError( + "{} is not supported on this database backend.".format( + combinator + ) + ) + result, params = self.get_combinator_sql( + combinator, self.query.combinator_all + ) + elif self.qualify: + result, params = self.get_qualify_sql() + order_by = None + else: + distinct_fields, distinct_params = self.get_distinct() + # This must come after 'select', 'ordering', and 'distinct' + # (see docstring of get_from_clause() for details). + from_, f_params = self.get_from_clause() + try: + where, w_params = ( + self.compile(self.where) if self.where is not None else ("", []) + ) + except EmptyResultSet: + if self.elide_empty: + raise + # Use a predicate that's always False. + where, w_params = "0 = 1", [] + except FullResultSet: + where, w_params = "", [] + try: + having, h_params = ( + self.compile(self.having) + if self.having is not None + else ("", []) + ) + except FullResultSet: + having, h_params = "", [] + result = [] + params = [] + + if self.query.distinct: + distinct_result, distinct_params = self.connection.ops.distinct_sql( + distinct_fields, + distinct_params, + ) + result += distinct_result + params += distinct_params + + out_cols = [] + for _, (s_sql, s_params), alias in self.select + extra_select: + if alias: + s_sql = "%s AS %s" % ( + s_sql, + self.connection.ops.quote_name(alias), + ) + params.extend(s_params) + out_cols.append(s_sql) + + params.extend(f_params) + + if self.query.select_for_update and features.has_select_for_update: + if ( + self.connection.get_autocommit() + # Don't raise an exception when database doesn't + # support transactions, as it's a noop. + and features.supports_transactions + ): + raise TransactionManagementError( + "select_for_update cannot be used outside of a transaction." + ) + + if ( + with_limit_offset + and not features.supports_select_for_update_with_limit + ): + raise NotSupportedError( + "LIMIT/OFFSET is not supported with " + "select_for_update on this database backend." + ) + nowait = self.query.select_for_update_nowait + skip_locked = self.query.select_for_update_skip_locked + of = self.query.select_for_update_of + no_key = self.query.select_for_no_key_update + # If it's a NOWAIT/SKIP LOCKED/OF/NO KEY query but the + # backend doesn't support it, raise NotSupportedError to + # prevent a possible deadlock. + if nowait and not features.has_select_for_update_nowait: + raise NotSupportedError( + "NOWAIT is not supported on this database backend." + ) + elif skip_locked and not features.has_select_for_update_skip_locked: + raise NotSupportedError( + "SKIP LOCKED is not supported on this database backend." + ) + elif of and not features.has_select_for_update_of: + raise NotSupportedError( + "FOR UPDATE OF is not supported on this database backend." + ) + elif no_key and not features.has_select_for_no_key_update: + raise NotSupportedError( + "FOR NO KEY UPDATE is not supported on this " + "database backend." + ) + for_update_part = self.connection.ops.for_update_sql( + nowait=nowait, + skip_locked=skip_locked, + of=self.get_select_for_update_of_arguments(), + no_key=no_key, + ) + + if for_update_part and features.for_update_after_from: + result.append(for_update_part) + + if where: + result.append("WHERE %s" % where) + params.extend(w_params) + + grouping = [] + for g_sql, g_params in group_by: + grouping.append(g_sql) + params.extend(g_params) + if grouping: + if distinct_fields: + raise NotImplementedError( + "annotate() + distinct(fields) is not implemented." + ) + order_by = order_by or self.connection.ops.force_no_ordering() + result.append("GROUP BY %s" % ", ".join(grouping)) + if self._meta_ordering: + order_by = None + if having: + result.append("HAVING %s" % having) + params.extend(h_params) + + if self.query.explain_info: + result.insert( + 0, + self.connection.ops.explain_query_prefix( + self.query.explain_info.format, + **self.query.explain_info.options, + ), + ) + + if order_by: + ordering = [] + for _, (o_sql, o_params, _) in order_by: + ordering.append(o_sql) + params.extend(o_params) + order_by_sql = "ORDER BY %s" % ", ".join(ordering) + if combinator and features.requires_compound_order_by_subquery: + result = ["SELECT * FROM (", *result, ")", order_by_sql] + else: + result.append(order_by_sql) + + if with_limit_offset: + result.append( + self.connection.ops.limit_offset_sql( + self.query.low_mark, self.query.high_mark + ) + ) + + if for_update_part and not features.for_update_after_from: + result.append(for_update_part) + + from_, f_params = self.get_from_clause() + sql = " ".join(result) + if not with_table_name: + for table_name in from_: + sql = sql.replace(table_name + ".", "") + for key in self.field_replace_dict.keys(): + value = self.field_replace_dict.get(key) + sql = sql.replace(key, value) + return sql, tuple(params) + finally: + # Finally do cleanup - get rid of the joins we created above. + self.query.reset_refcounts(refcounts_before) + + def as_sql(self, with_limits=True, with_col_aliases=False, select_string=None): + if select_string is None: + return super().as_sql(with_limits, with_col_aliases) + else: + sql, params = self.get_query_str(with_table_name=False) + return (select_string + " " + sql), params diff --git a/src/MaxKB-1.7.2/apps/common/db/search.py b/src/MaxKB-1.7.2/apps/common/db/search.py new file mode 100644 index 0000000..7636671 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/db/search.py @@ -0,0 +1,176 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: search.py + @date:2023/10/7 18:20 + @desc: +""" +from typing import Dict, Any + +from django.db import DEFAULT_DB_ALIAS, models, connections +from django.db.models import QuerySet + +from common.db.compiler import AppSQLCompiler +from common.db.sql_execute import select_one, select_list +from common.response.result import Page + + +def get_dynamics_model(attr: dict, table_name='dynamics'): + """ + 获取一个动态的django模型 + :param attr: 模型字段 + :param table_name: 表名 + :return: django 模型 + """ + attributes = { + "__module__": "dataset.models", + "Meta": type("Meta", (), {'db_table': table_name}), + **attr + } + return type('Dynamics', (models.Model,), attributes) + + +def generate_sql_by_query_dict(queryset_dict: Dict[str, QuerySet], select_string: str, + field_replace_dict: None | Dict[str, Dict[str, str]] = None, with_table_name=False): + """ + 生成 查询sql + :param with_table_name: + :param queryset_dict: 多条件 查询条件 + :param select_string: 查询sql + :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 + :return: sql:需要查询的sql params: sql 参数 + """ + + params_dict: Dict[int, Any] = {} + result_params = [] + for key in queryset_dict.keys(): + value = queryset_dict.get(key) + sql, params = compiler_queryset(value, None if field_replace_dict is None else field_replace_dict.get(key), + with_table_name) + params_dict = {**params_dict, select_string.index("${" + key + "}"): params} + select_string = select_string.replace("${" + key + "}", sql) + + for key in sorted(list(params_dict.keys())): + result_params = [*result_params, *params_dict.get(key)] + return select_string, result_params + + +def generate_sql_by_query(queryset: QuerySet, select_string: str, + field_replace_dict: None | Dict[str, str] = None, with_table_name=False): + """ + 生成 查询sql + :param queryset: 查询条件 + :param select_string: 原始sql + :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 + :return: sql:需要查询的sql params: sql 参数 + """ + sql, params = compiler_queryset(queryset, field_replace_dict, with_table_name) + return select_string + " " + sql, params + + +def compiler_queryset(queryset: QuerySet, field_replace_dict: None | Dict[str, str] = None, with_table_name=False): + """ + 解析 queryset查询对象 + :param with_table_name: + :param queryset: 查询对象 + :param field_replace_dict: 需要替换的查询字段,一般不需要传入如果有特殊的需要传入 + :return: sql:需要查询的sql params: sql 参数 + """ + q = queryset.query + compiler = q.get_compiler(DEFAULT_DB_ALIAS) + if field_replace_dict is None: + field_replace_dict = get_field_replace_dict(queryset) + app_sql_compiler = AppSQLCompiler(q, using=DEFAULT_DB_ALIAS, connection=compiler.connection, + field_replace_dict=field_replace_dict) + sql, params = app_sql_compiler.get_query_str(with_table_name=with_table_name) + return sql, params + + +def native_search(queryset: QuerySet | Dict[str, QuerySet], select_string: str, + field_replace_dict: None | Dict[str, Dict[str, str]] | Dict[str, str] = None, + with_search_one=False, with_table_name=False): + """ + 复杂查询 + :param with_table_name: 生成sql是否包含表名 + :param queryset: 查询条件构造器 + :param select_string: 查询前缀 不包括 where limit 等信息 + :param field_replace_dict: 需要替换的字段 + :param with_search_one: 查询 + :return: 查询结果 + """ + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name) + if with_search_one: + return select_one(exec_sql, exec_params) + else: + return select_list(exec_sql, exec_params) + + +def page_search(current_page: int, page_size: int, queryset: QuerySet, post_records_handler): + """ + 分页查询 + :param current_page: 当前页 + :param page_size: 每页大小 + :param queryset: 查询条件 + :param post_records_handler: 数据处理器 + :return: 分页结果 + """ + total = QuerySet(query=queryset.query.clone(), model=queryset.model).count() + result = queryset.all()[((current_page - 1) * page_size):(current_page * page_size)] + return Page(total, list(map(post_records_handler, result)), current_page, page_size) + + +def native_page_search(current_page: int, page_size: int, queryset: QuerySet | Dict[str, QuerySet], select_string: str, + field_replace_dict=None, + post_records_handler=lambda r: r, + with_table_name=False): + """ + 复杂分页查询 + :param with_table_name: + :param current_page: 当前页 + :param page_size: 每页大小 + :param queryset: 查询条件 + :param select_string: 查询 + :param field_replace_dict: 特殊字段替换 + :param post_records_handler: 数据row处理器 + :return: 分页结果 + """ + if isinstance(queryset, Dict): + exec_sql, exec_params = generate_sql_by_query_dict(queryset, select_string, field_replace_dict, with_table_name) + else: + exec_sql, exec_params = generate_sql_by_query(queryset, select_string, field_replace_dict, with_table_name) + total_sql = "SELECT \"count\"(*) FROM (%s) temp" % exec_sql + total = select_one(total_sql, exec_params) + limit_sql = connections[DEFAULT_DB_ALIAS].ops.limit_offset_sql( + ((current_page - 1) * page_size), (current_page * page_size) + ) + page_sql = exec_sql + " " + limit_sql + result = select_list(page_sql, exec_params) + return Page(total.get("count"), list(map(post_records_handler, result)), current_page, page_size) + + +def get_field_replace_dict(queryset: QuerySet): + """ + 获取需要替换的字段 默认 “xxx.xxx”需要被替换成 “xxx”."xxx" + :param queryset: 查询对象 + :return: 需要替换的字典 + """ + result = {} + for field in queryset.model._meta.local_fields: + if field.attname.__contains__("."): + replace_field = to_replace_field(field.attname) + result.__setitem__('"' + field.attname + '"', replace_field) + return result + + +def to_replace_field(field: str): + """ + 将field 转换为 需要替换的field “xxx.xxx”需要被替换成 “xxx”."xxx" 只替换 field包含.的字段 + :param field: django field字段 + :return: 替换字段 + """ + split_field = field.split(".") + return ".".join(list(map(lambda sf: '"' + sf + '"', split_field))) diff --git a/src/MaxKB-1.7.2/apps/common/db/sql_execute.py b/src/MaxKB-1.7.2/apps/common/db/sql_execute.py new file mode 100644 index 0000000..79e7de4 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/db/sql_execute.py @@ -0,0 +1,66 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: sql_execute.py + @date:2023/9/25 20:05 + @desc: +""" +from typing import List + +from django.db import connection + + +def sql_execute(sql: str, params): + """ + 执行一条sql + :param sql: 需要执行的sql + :param params: sql参数 + :return: 执行结果 + """ + with connection.cursor() as cursor: + cursor.execute(sql, params) + columns = list(map(lambda d: d.name, cursor.description)) + res = cursor.fetchall() + result = list(map(lambda row: dict(list(zip(columns, row))), res)) + cursor.close() + return result + + +def update_execute(sql: str, params): + """ + 执行一条sql + :param sql: 需要执行的sql + :param params: sql参数 + :return: 执行结果 + """ + with connection.cursor() as cursor: + cursor.execute(sql, params) + cursor.close() + return None + + +def select_list(sql: str, params: List): + """ + 执行sql 查询列表数据 + :param sql: 需要执行的sql + :param params: sql的参数 + :return: 查询结果 + """ + result_list = sql_execute(sql, params) + if result_list is None: + return [] + return result_list + + +def select_one(sql: str, params: List): + """ + 执行sql 查询一条数据 + :param sql: 需要执行的sql + :param params: 参数 + :return: 查询结果 + """ + result_list = sql_execute(sql, params) + if result_list is None or len(result_list) == 0: + return None + return result_list[0] diff --git a/src/MaxKB-1.7.2/apps/common/event/__init__.py b/src/MaxKB-1.7.2/apps/common/event/__init__.py new file mode 100644 index 0000000..6b6d054 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/event/__init__.py @@ -0,0 +1,17 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/11/10 10:43 + @desc: +""" +import setting.models +from setting.models import Model +from .listener_manage import * + + +def run(): + # QuerySet(Document).filter(status__in=[Status.embedding, Status.queue_up]).update(**{'status': Status.error}) + QuerySet(Model).filter(status=setting.models.Status.DOWNLOAD).update(status=setting.models.Status.ERROR, + meta={'message': "下载程序被中断,请重试"}) diff --git a/src/MaxKB-1.7.2/apps/common/event/common.py b/src/MaxKB-1.7.2/apps/common/event/common.py new file mode 100644 index 0000000..a54d24d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/event/common.py @@ -0,0 +1,50 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common.py + @date:2023/11/10 10:41 + @desc: +""" +from concurrent.futures import ThreadPoolExecutor + +from django.core.cache.backends.locmem import LocMemCache + +work_thread_pool = ThreadPoolExecutor(5) + +embedding_thread_pool = ThreadPoolExecutor(3) + +memory_cache = LocMemCache('task', {"OPTIONS": {"MAX_ENTRIES": 1000}}) + + +def poxy(poxy_function): + def inner(args, **keywords): + work_thread_pool.submit(poxy_function, args, **keywords) + + return inner + + +def get_cache_key(poxy_function, args): + return poxy_function.__name__ + str(args) + + +def get_cache_poxy_function(poxy_function, cache_key): + def fun(args, **keywords): + try: + poxy_function(args, **keywords) + finally: + memory_cache.delete(cache_key) + + return fun + + +def embedding_poxy(poxy_function): + def inner(*args, **keywords): + key = get_cache_key(poxy_function, args) + if memory_cache.has_key(key): + return + memory_cache.add(key, None) + f = get_cache_poxy_function(poxy_function, key) + embedding_thread_pool.submit(f, args, **keywords) + + return inner diff --git a/src/MaxKB-1.7.2/apps/common/event/listener_manage.py b/src/MaxKB-1.7.2/apps/common/event/listener_manage.py new file mode 100644 index 0000000..40ac488 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/event/listener_manage.py @@ -0,0 +1,274 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: listener_manage.py + @date:2023/10/20 14:01 + @desc: +""" +import datetime +import logging +import os +import traceback +from typing import List + +import django.db.models +from django.db.models import QuerySet +from langchain_core.embeddings import Embeddings + +from common.config.embedding_config import VectorStore +from common.db.search import native_search, get_dynamics_model +from common.event.common import embedding_poxy +from common.util.file_util import get_file_content +from common.util.lock import try_lock, un_lock +from dataset.models import Paragraph, Status, Document, ProblemParagraphMapping +from embedding.models import SourceType, SearchMode +from smartdoc.conf import PROJECT_DIR + +max_kb_error = logging.getLogger(__file__) +max_kb = logging.getLogger(__file__) + + +class SyncWebDatasetArgs: + def __init__(self, lock_key: str, url: str, selector: str, handler): + self.lock_key = lock_key + self.url = url + self.selector = selector + self.handler = handler + + +class SyncWebDocumentArgs: + def __init__(self, source_url_list: List[str], selector: str, handler): + self.source_url_list = source_url_list + self.selector = selector + self.handler = handler + + +class UpdateProblemArgs: + def __init__(self, problem_id: str, problem_content: str, embedding_model: Embeddings): + self.problem_id = problem_id + self.problem_content = problem_content + self.embedding_model = embedding_model + + +class UpdateEmbeddingDatasetIdArgs: + def __init__(self, paragraph_id_list: List[str], target_dataset_id: str): + self.paragraph_id_list = paragraph_id_list + self.target_dataset_id = target_dataset_id + + +class UpdateEmbeddingDocumentIdArgs: + def __init__(self, paragraph_id_list: List[str], target_document_id: str, target_dataset_id: str, + target_embedding_model: Embeddings = None): + self.paragraph_id_list = paragraph_id_list + self.target_document_id = target_document_id + self.target_dataset_id = target_dataset_id + self.target_embedding_model = target_embedding_model + + +class ListenerManagement: + + @staticmethod + def embedding_by_problem(args, embedding_model: Embeddings): + VectorStore.get_embedding_vector().save(**args, embedding=embedding_model) + + @staticmethod + def embedding_by_paragraph_list(paragraph_id_list, embedding_model: Embeddings): + try: + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( + **{'paragraph.id__in': paragraph_id_list}), + 'paragraph': QuerySet(Paragraph).filter(id__in=paragraph_id_list)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list=paragraph_id_list, + embedding_model=embedding_model) + except Exception as e: + max_kb_error.error(f'查询向量数据:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}') + + @staticmethod + def embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model: Embeddings): + max_kb.info(f'开始--->向量化段落:{paragraph_id_list}') + status = Status.success + try: + # 删除段落 + VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_id_list) + + def is_save_function(): + return QuerySet(Paragraph).filter(id__in=paragraph_id_list).exists() + + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) + except Exception as e: + max_kb_error.error(f'向量化段落:{paragraph_id_list}出现错误{str(e)}{traceback.format_exc()}') + status = Status.error + finally: + QuerySet(Paragraph).filter(id__in=paragraph_id_list).update(**{'status': status}) + max_kb.info(f'结束--->向量化段落:{paragraph_id_list}') + + @staticmethod + def embedding_by_paragraph(paragraph_id, embedding_model: Embeddings): + """ + 向量化段落 根据段落id + @param paragraph_id: 段落id + @param embedding_model: 向量模型 + """ + max_kb.info(f"开始--->向量化段落:{paragraph_id}") + status = Status.success + try: + data_list = native_search( + {'problem': QuerySet(get_dynamics_model({'paragraph.id': django.db.models.CharField()})).filter( + **{'paragraph.id': paragraph_id}), + 'paragraph': QuerySet(Paragraph).filter(id=paragraph_id)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + # 删除段落 + VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) + + def is_save_function(): + return QuerySet(Paragraph).filter(id=paragraph_id).exists() + + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) + except Exception as e: + max_kb_error.error(f'向量化段落:{paragraph_id}出现错误{str(e)}{traceback.format_exc()}') + status = Status.error + finally: + QuerySet(Paragraph).filter(id=paragraph_id).update(**{'status': status}) + max_kb.info(f'结束--->向量化段落:{paragraph_id}') + + @staticmethod + def embedding_by_data_list(data_list: List, embedding_model: Embeddings): + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, lambda: True) + + @staticmethod + def embedding_by_document(document_id, embedding_model: Embeddings): + """ + 向量化文档 + @param document_id: 文档id + @param embedding_model 向量模型 + :return: None + """ + if not try_lock('embedding' + str(document_id)): + return + max_kb.info(f"开始--->向量化文档:{document_id}") + QuerySet(Document).filter(id=document_id).update(**{'status': Status.embedding}) + QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.embedding}) + status = Status.success + try: + data_list = native_search( + {'problem': QuerySet( + get_dynamics_model({'paragraph.document_id': django.db.models.CharField()})).filter( + **{'paragraph.document_id': document_id}), + 'paragraph': QuerySet(Paragraph).filter(document_id=document_id)}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "common", 'sql', 'list_embedding_text.sql'))) + # 删除文档向量数据 + VectorStore.get_embedding_vector().delete_by_document_id(document_id) + + def is_save_function(): + return QuerySet(Document).filter(id=document_id).exists() + + # 批量向量化 + VectorStore.get_embedding_vector().batch_save(data_list, embedding_model, is_save_function) + except Exception as e: + max_kb_error.error(f'向量化文档:{document_id}出现错误{str(e)}{traceback.format_exc()}') + status = Status.error + finally: + # 修改状态 + QuerySet(Document).filter(id=document_id).update( + **{'status': status, 'update_time': datetime.datetime.now()}) + QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': status}) + max_kb.info(f"结束--->向量化文档:{document_id}") + un_lock('embedding' + str(document_id)) + + @staticmethod + def embedding_by_dataset(dataset_id, embedding_model: Embeddings): + """ + 向量化知识库 + @param dataset_id: 知识库id + @param embedding_model 向量模型 + :return: None + """ + max_kb.info(f"开始--->向量化数据集:{dataset_id}") + try: + ListenerManagement.delete_embedding_by_dataset(dataset_id) + document_list = QuerySet(Document).filter(dataset_id=dataset_id) + max_kb.info(f"数据集文档:{[d.name for d in document_list]}") + for document in document_list: + ListenerManagement.embedding_by_document(document.id, embedding_model=embedding_model) + except Exception as e: + max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') + finally: + max_kb.info(f"结束--->向量化数据集:{dataset_id}") + + @staticmethod + def delete_embedding_by_document(document_id): + VectorStore.get_embedding_vector().delete_by_document_id(document_id) + + @staticmethod + def delete_embedding_by_document_list(document_id_list: List[str]): + VectorStore.get_embedding_vector().delete_by_document_id_list(document_id_list) + + @staticmethod + def delete_embedding_by_dataset(dataset_id): + VectorStore.get_embedding_vector().delete_by_dataset_id(dataset_id) + + @staticmethod + def delete_embedding_by_paragraph(paragraph_id): + VectorStore.get_embedding_vector().delete_by_paragraph_id(paragraph_id) + + @staticmethod + def delete_embedding_by_source(source_id): + VectorStore.get_embedding_vector().delete_by_source_id(source_id, SourceType.PROBLEM) + + @staticmethod + def disable_embedding_by_paragraph(paragraph_id): + VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': False}) + + @staticmethod + def enable_embedding_by_paragraph(paragraph_id): + VectorStore.get_embedding_vector().update_by_paragraph_id(paragraph_id, {'is_active': True}) + + @staticmethod + def update_problem(args: UpdateProblemArgs): + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(problem_id=args.problem_id) + embed_value = args.embedding_model.embed_query(args.problem_content) + VectorStore.get_embedding_vector().update_by_source_ids([v.id for v in problem_paragraph_mapping_list], + {'embedding': embed_value}) + + @staticmethod + def update_embedding_dataset_id(args: UpdateEmbeddingDatasetIdArgs): + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'dataset_id': args.target_dataset_id}) + + @staticmethod + def update_embedding_document_id(args: UpdateEmbeddingDocumentIdArgs): + if args.target_embedding_model is None: + VectorStore.get_embedding_vector().update_by_paragraph_ids(args.paragraph_id_list, + {'document_id': args.target_document_id, + 'dataset_id': args.target_dataset_id}) + else: + ListenerManagement.embedding_by_paragraph_list(args.paragraph_id_list, + embedding_model=args.target_embedding_model) + + @staticmethod + def delete_embedding_by_source_ids(source_ids: List[str]): + VectorStore.get_embedding_vector().delete_by_source_ids(source_ids, SourceType.PROBLEM) + + @staticmethod + def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]): + VectorStore.get_embedding_vector().delete_by_paragraph_ids(paragraph_ids) + + @staticmethod + def delete_embedding_by_dataset_id_list(source_ids: List[str]): + VectorStore.get_embedding_vector().delete_by_dataset_id_list(source_ids) + + @staticmethod + def hit_test(query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, + search_mode: SearchMode, + embedding: Embeddings): + return VectorStore.get_embedding_vector().hit_test(query_text, dataset_id, exclude_document_id_list, top_number, + similarity, search_mode, embedding) diff --git a/src/MaxKB-1.7.2/apps/common/exception/app_exception.py b/src/MaxKB-1.7.2/apps/common/exception/app_exception.py new file mode 100644 index 0000000..b8f5602 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/exception/app_exception.py @@ -0,0 +1,83 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: app_exception.py + @date:2023/9/4 14:04 + @desc: +""" +from rest_framework import status + + +class AppApiException(Exception): + """ + 项目内异常 + """ + status_code = status.HTTP_200_OK + + def __init__(self, code, message): + self.code = code + self.message = message + + +class NotFound404(AppApiException): + """ + 未认证(未登录)异常 + """ + status_code = status.HTTP_404_NOT_FOUND + + def __init__(self, code, message): + self.code = code + self.message = message + + +class AppAuthenticationFailed(AppApiException): + """ + 未认证(未登录)异常 + """ + status_code = status.HTTP_401_UNAUTHORIZED + + def __init__(self, code, message): + self.code = code + self.message = message + + +class AppUnauthorizedFailed(AppApiException): + """ + 未授权(没有权限)异常 + """ + status_code = status.HTTP_403_FORBIDDEN + + def __init__(self, code, message): + self.code = code + self.message = message + + +class AppEmbedIdentityFailed(AppApiException): + """ + 嵌入cookie异常 + """ + status_code = 460 + + def __init__(self, code, message): + self.code = code + self.message = message + + +class AppChatNumOutOfBoundsFailed(AppApiException): + """ + 访问次数超过今日访问量 + """ + status_code = 461 + + def __init__(self, code, message): + self.code = code + self.message = message + + +class ChatException(AppApiException): + status_code = 500 + + def __init__(self, code, message): + self.code = code + self.message = message diff --git a/src/MaxKB-1.7.2/apps/common/field/__init__.py b/src/MaxKB-1.7.2/apps/common/field/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/common/field/common.py b/src/MaxKB-1.7.2/apps/common/field/common.py new file mode 100644 index 0000000..3025ec5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/field/common.py @@ -0,0 +1,65 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common.py + @date:2024/1/11 18:44 + @desc: +""" +from rest_framework import serializers + + +class ObjectField(serializers.Field): + def __init__(self, model_type_list, **kwargs): + self.model_type_list = model_type_list + super().__init__(**kwargs) + + def to_internal_value(self, data): + for model_type in self.model_type_list: + if isinstance(data, model_type): + return data + self.fail('message类型错误', value=data) + + def to_representation(self, value): + return value + + +class InstanceField(serializers.Field): + def __init__(self, model_type, **kwargs): + self.model_type = model_type + super().__init__(**kwargs) + + def to_internal_value(self, data): + if not isinstance(data, self.model_type): + self.fail('message类型错误', value=data) + return data + + def to_representation(self, value): + return value + + +class FunctionField(serializers.Field): + + def to_internal_value(self, data): + if not callable(data): + self.fail('不是一个函數', value=data) + return data + + def to_representation(self, value): + return value + + +class UploadedImageField(serializers.ImageField): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def to_representation(self, value): + return value + + +class UploadedFileField(serializers.FileField): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def to_representation(self, value): + return value diff --git a/src/MaxKB-1.7.2/apps/common/field/vector_field.py b/src/MaxKB-1.7.2/apps/common/field/vector_field.py new file mode 100644 index 0000000..5916198 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/field/vector_field.py @@ -0,0 +1,12 @@ +from django.db import models + + +class VectorField(models.Field): + + def db_type(self, connection): + return 'vector' + + +class TsVectorField(models.Field): + def db_type(self, connection): + return 'tsvector' diff --git a/src/MaxKB-1.7.2/apps/common/forms/__init__.py b/src/MaxKB-1.7.2/apps/common/forms/__init__.py new file mode 100644 index 0000000..6095421 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/__init__.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/10/31 17:56 + @desc: +""" +from .array_object_card import * +from .base_field import * +from .base_form import * +from .multi_select import * +from .object_card import * +from .password_input import * +from .radio_field import * +from .single_select_field import * +from .tab_card import * +from .table_radio import * +from .text_input_field import * +from .radio_button_field import * +from .table_checkbox import * +from .radio_card_field import * +from .label import * +from .slider_field import * diff --git a/src/MaxKB-1.7.2/apps/common/forms/array_object_card.py b/src/MaxKB-1.7.2/apps/common/forms/array_object_card.py new file mode 100644 index 0000000..2dc71aa --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/array_object_card.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: array_object_card.py + @date:2023/10/31 18:03 + @desc: +""" +from typing import Dict + +from common.forms.base_field import BaseExecField, TriggerType + + +class ArrayCard(BaseExecField): + """ + 收集List[Object] + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("ArrayObjectCard", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) diff --git a/src/MaxKB-1.7.2/apps/common/forms/base_field.py b/src/MaxKB-1.7.2/apps/common/forms/base_field.py new file mode 100644 index 0000000..dedd78d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/base_field.py @@ -0,0 +1,156 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_field.py + @date:2023/10/31 18:07 + @desc: +""" +from enum import Enum +from typing import List, Dict + +from common.exception.app_exception import AppApiException +from common.forms.label.base_label import BaseLabel + + +class TriggerType(Enum): + # 执行函数获取 OptionList数据 + OPTION_LIST = 'OPTION_LIST' + # 执行函数获取子表单 + CHILD_FORMS = 'CHILD_FORMS' + + +class BaseField: + def __init__(self, + input_type: str, + label: str or BaseLabel, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + """ + + :param input_type: 字段 + :param label: 提示 + :param default_value: 默认值 + :param relation_show_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示 + :param relation_trigger_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才 执行函数获取 数据 + :param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单 + :param attrs: 前端attr数据 + :param props_info: 其他额外信息 + """ + if props_info is None: + props_info = {} + if attrs is None: + attrs = {} + self.label = label + self.attrs = attrs + self.props_info = props_info + self.default_value = default_value + self.input_type = input_type + self.relation_show_field_dict = {} if relation_show_field_dict is None else relation_show_field_dict + self.relation_trigger_field_dict = [] if relation_trigger_field_dict is None else relation_trigger_field_dict + self.required = required + self.trigger_type = trigger_type + + def is_valid(self, value): + field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label + if self.required and value is None: + raise AppApiException(500, + f"{field_label} 为必填参数") + + def to_dict(self, **kwargs): + return { + 'input_type': self.input_type, + 'label': self.label.to_dict(**kwargs) if hasattr(self.label, 'to_dict') else self.label, + 'required': self.required, + 'default_value': self.default_value, + 'relation_show_field_dict': self.relation_show_field_dict, + 'relation_trigger_field_dict': self.relation_trigger_field_dict, + 'trigger_type': self.trigger_type.value, + 'attrs': self.attrs, + 'props_info': self.props_info, + **kwargs + } + + +class BaseDefaultOptionField(BaseField): + def __init__(self, input_type: str, + label: str, + text_field: str, + value_field: str, + option_list: List[dict], + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict[str, object] = None, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + """ + + :param input_type: 字段 + :param label: label + :param text_field: 文本字段 + :param value_field: 值字段 + :param option_list: 可选列表 + :param required: 是否必填 + :param default_value: 默认值 + :param relation_show_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示 + :param attrs: 前端attr数据 + :param props_info: 其他额外信息 + """ + super().__init__(input_type, label, required, default_value, relation_show_field_dict, + {}, TriggerType.OPTION_LIST, attrs, props_info) + self.text_field = text_field + self.value_field = value_field + self.option_list = option_list + + def to_dict(self, **kwargs): + return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field, + 'option_list': self.option_list} + + +class BaseExecField(BaseField): + def __init__(self, + input_type: str, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + """ + + :param input_type: 字段 + :param label: 提示 + :param text_field: 文本字段 + :param value_field: 值字段 + :param provider: 指定供应商 + :param method: 执行供应商函数 method + :param required: 是否必填 + :param default_value: 默认值 + :param relation_show_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示 + :param relation_trigger_field_dict: {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才 执行函数获取 数据 + :param trigger_type: 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单 + :param attrs: 前端attr数据 + :param props_info: 其他额外信息 + """ + super().__init__(input_type, label, required, default_value, relation_show_field_dict, + relation_trigger_field_dict, + trigger_type, attrs, props_info) + self.text_field = text_field + self.value_field = value_field + self.provider = provider + self.method = method + + def to_dict(self, **kwargs): + return {**super().to_dict(**kwargs), 'text_field': self.text_field, 'value_field': self.value_field, + 'provider': self.provider, 'method': self.method} diff --git a/src/MaxKB-1.7.2/apps/common/forms/base_form.py b/src/MaxKB-1.7.2/apps/common/forms/base_form.py new file mode 100644 index 0000000..5ef92c5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/base_form.py @@ -0,0 +1,30 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_form.py + @date:2023/11/1 16:04 + @desc: +""" +from typing import Dict + +from common.forms import BaseField + + +class BaseForm: + def to_form_list(self, **kwargs): + return [{**self.__getattribute__(key).to_dict(**kwargs), 'field': key} for key in + list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField), + [attr for attr in vars(self.__class__) if not attr.startswith("__")]))] + + def valid_form(self, form_data): + field_keys = list(filter(lambda key: isinstance(self.__getattribute__(key), BaseField), + [attr for attr in vars(self.__class__) if not attr.startswith("__")])) + for field_key in field_keys: + self.__getattribute__(field_key).is_valid(form_data.get(field_key)) + + def get_default_form_data(self): + return {key: self.__getattribute__(key).default_value for key in + [attr for attr in vars(self.__class__) if not attr.startswith("__")] if + isinstance(self.__getattribute__(key), BaseField) and self.__getattribute__( + key).default_value is not None} diff --git a/src/MaxKB-1.7.2/apps/common/forms/label/__init__.py b/src/MaxKB-1.7.2/apps/common/forms/label/__init__.py new file mode 100644 index 0000000..81c1b32 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/label/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/22 17:19 + @desc: +""" +from .base_label import * +from .tooltip_label import * diff --git a/src/MaxKB-1.7.2/apps/common/forms/label/base_label.py b/src/MaxKB-1.7.2/apps/common/forms/label/base_label.py new file mode 100644 index 0000000..59e4d37 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/label/base_label.py @@ -0,0 +1,28 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_label.py + @date:2024/8/22 17:11 + @desc: +""" + + +class BaseLabel: + def __init__(self, + input_type: str, + label: str, + attrs=None, + props_info=None): + self.input_type = input_type + self.label = label + self.attrs = attrs + self.props_info = props_info + + def to_dict(self, **kwargs): + return { + 'input_type': self.input_type, + 'label': self.label, + 'attrs': {} if self.attrs is None else self.attrs, + 'props_info': {} if self.props_info is None else self.props_info, + } diff --git a/src/MaxKB-1.7.2/apps/common/forms/label/tooltip_label.py b/src/MaxKB-1.7.2/apps/common/forms/label/tooltip_label.py new file mode 100644 index 0000000..885345d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/label/tooltip_label.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: tooltip_label.py + @date:2024/8/22 17:19 + @desc: +""" +from common.forms.label.base_label import BaseLabel + + +class TooltipLabel(BaseLabel): + def __init__(self, label, tooltip): + super().__init__('TooltipLabel', label, attrs={'tooltip': tooltip}, props_info={}) diff --git a/src/MaxKB-1.7.2/apps/common/forms/multi_select.py b/src/MaxKB-1.7.2/apps/common/forms/multi_select.py new file mode 100644 index 0000000..791c8e9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/multi_select.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: multi_select.py + @date:2023/10/31 18:00 + @desc: +""" +from typing import List, Dict + +from common.forms.base_field import BaseExecField, TriggerType + + +class MultiSelect(BaseExecField): + """ + 下拉单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + option_list: List[str:object], + provider: str = None, + method: str = None, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("MultiSelect", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'option_list': self.option_list} diff --git a/src/MaxKB-1.7.2/apps/common/forms/object_card.py b/src/MaxKB-1.7.2/apps/common/forms/object_card.py new file mode 100644 index 0000000..ddb192e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/object_card.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: object_card.py + @date:2023/10/31 18:02 + @desc: +""" +from typing import Dict + +from common.forms.base_field import BaseExecField, TriggerType + + +class ObjectCard(BaseExecField): + """ + 收集对象子表卡片 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("ObjectCard", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) diff --git a/src/MaxKB-1.7.2/apps/common/forms/password_input.py b/src/MaxKB-1.7.2/apps/common/forms/password_input.py new file mode 100644 index 0000000..e7c7923 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/password_input.py @@ -0,0 +1,26 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: password_input.py + @date:2023/11/1 14:48 + @desc: +""" +from typing import Dict + +from common.forms import BaseField, TriggerType + + +class PasswordInputField(BaseField): + """ + 文本输入框 + """ + + def __init__(self, label: str, + required: bool = False, + default_value=None, + relation_show_field_dict: Dict = None, + attrs=None, props_info=None): + super().__init__('PasswordInput', label, required, default_value, relation_show_field_dict, + {}, + TriggerType.OPTION_LIST, attrs, props_info) diff --git a/src/MaxKB-1.7.2/apps/common/forms/radio_button_field.py b/src/MaxKB-1.7.2/apps/common/forms/radio_button_field.py new file mode 100644 index 0000000..aa69523 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/radio_button_field.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: radio_field.py + @date:2023/10/31 17:59 + @desc: +""" +from typing import List, Dict + +from common.forms.base_field import BaseExecField, TriggerType + + +class Radio(BaseExecField): + """ + 下拉单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + option_list: List[str:object], + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("RadioButton", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'option_list': self.option_list} diff --git a/src/MaxKB-1.7.2/apps/common/forms/radio_card_field.py b/src/MaxKB-1.7.2/apps/common/forms/radio_card_field.py new file mode 100644 index 0000000..b3579b8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/radio_card_field.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: radio_field.py + @date:2023/10/31 17:59 + @desc: +""" +from typing import List, Dict + +from common.forms.base_field import BaseExecField, TriggerType + + +class Radio(BaseExecField): + """ + 下拉单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + option_list: List[str:object], + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("RadioCard", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'option_list': self.option_list} diff --git a/src/MaxKB-1.7.2/apps/common/forms/radio_field.py b/src/MaxKB-1.7.2/apps/common/forms/radio_field.py new file mode 100644 index 0000000..94a016d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/radio_field.py @@ -0,0 +1,38 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: radio_field.py + @date:2023/10/31 17:59 + @desc: +""" +from typing import List, Dict + +from common.forms.base_field import BaseExecField, TriggerType + + +class Radio(BaseExecField): + """ + 下拉单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + option_list: List[str:object], + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("Radio", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'option_list': self.option_list} diff --git a/src/MaxKB-1.7.2/apps/common/forms/single_select_field.py b/src/MaxKB-1.7.2/apps/common/forms/single_select_field.py new file mode 100644 index 0000000..21bd5de --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/single_select_field.py @@ -0,0 +1,39 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: single_select_field.py + @date:2023/10/31 18:00 + @desc: +""" +from typing import List, Dict + +from common.forms import BaseLabel +from common.forms.base_field import TriggerType, BaseExecField + + +class SingleSelect(BaseExecField): + """ + 下拉单选 + """ + + def __init__(self, + label: str or BaseLabel, + text_field: str, + value_field: str, + option_list: List[str:object], + provider: str = None, + method: str = None, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("SingleSelect", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) + self.option_list = option_list + + def to_dict(self): + return {**super().to_dict(), 'option_list': self.option_list} diff --git a/src/MaxKB-1.7.2/apps/common/forms/slider_field.py b/src/MaxKB-1.7.2/apps/common/forms/slider_field.py new file mode 100644 index 0000000..6bf3625 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/slider_field.py @@ -0,0 +1,58 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: slider_field.py + @date:2024/8/22 17:06 + @desc: +""" +from typing import Dict + +from common.exception.app_exception import AppApiException +from common.forms import BaseField, TriggerType, BaseLabel + + +class SliderField(BaseField): + """ + 滑块输入框 + """ + + def __init__(self, label: str or BaseLabel, + _min, + _max, + _step, + precision, + required: bool = False, + default_value=None, + relation_show_field_dict: Dict = None, + attrs=None, props_info=None): + """ + @param label: 提示 + @param _min: 最小值 + @param _max: 最大值 + @param _step: 步长 + @param precision: 保留多少小数 + @param required: 是否必填 + @param default_value: 默认值 + @param relation_show_field_dict: + @param attrs: + @param props_info: + """ + _attrs = {'min': _min, 'max': _max, 'step': _step, + 'precision': precision, 'show-input-controls': False, 'show-input': True} + if attrs is not None: + _attrs.update(attrs) + super().__init__('Slider', label, required, default_value, relation_show_field_dict, + {}, + TriggerType.OPTION_LIST, _attrs, props_info) + + def is_valid(self, value): + super().is_valid(value) + field_label = self.label.label if hasattr(self.label, 'to_dict') else self.label + if value is not None: + if value < self.attrs.get('min'): + raise AppApiException(500, + f"{field_label} 不能小于{self.attrs.get('min')}") + if value > self.attrs.get('max'): + raise AppApiException(500, + f"{field_label} 不能大于{self.attrs.get('max')}") diff --git a/src/MaxKB-1.7.2/apps/common/forms/switch_field.py b/src/MaxKB-1.7.2/apps/common/forms/switch_field.py new file mode 100644 index 0000000..9fa176b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/switch_field.py @@ -0,0 +1,33 @@ +""" + @project: MaxKB + @Author:虎 + @file: switch_field.py + @date:2024/10/13 19:43 + @desc: +""" +from typing import Dict +from common.forms import BaseField, TriggerType, BaseLabel + + +class SwitchField(BaseField): + """ + 滑块输入框 + """ + + def __init__(self, label: str or BaseLabel, + required: bool = False, + default_value=None, + relation_show_field_dict: Dict = None, + + attrs=None, props_info=None): + """ + @param required: 是否必填 + @param default_value: 默认值 + @param relation_show_field_dict: + @param attrs: + @param props_info: + """ + + super().__init__('Switch', label, required, default_value, relation_show_field_dict, + {}, + TriggerType.OPTION_LIST, attrs, props_info) diff --git a/src/MaxKB-1.7.2/apps/common/forms/tab_card.py b/src/MaxKB-1.7.2/apps/common/forms/tab_card.py new file mode 100644 index 0000000..7907714 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/tab_card.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: tab_card.py + @date:2023/10/31 18:03 + @desc: +""" +from typing import Dict + +from common.forms.base_field import BaseExecField, TriggerType + + +class TabCard(BaseExecField): + """ + 收集 Tab类型数据 tab1:{},tab2:{} + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("TabCard", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) diff --git a/src/MaxKB-1.7.2/apps/common/forms/table_checkbox.py b/src/MaxKB-1.7.2/apps/common/forms/table_checkbox.py new file mode 100644 index 0000000..e01f14d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/table_checkbox.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: table_radio.py + @date:2023/10/31 18:01 + @desc: +""" +from typing import Dict + +from common.forms.base_field import TriggerType, BaseExecField + + +class TableRadio(BaseExecField): + """ + table 单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("TableCheckbox", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) diff --git a/src/MaxKB-1.7.2/apps/common/forms/table_radio.py b/src/MaxKB-1.7.2/apps/common/forms/table_radio.py new file mode 100644 index 0000000..3b4c2bf --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/table_radio.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: table_radio.py + @date:2023/10/31 18:01 + @desc: +""" +from typing import Dict + +from common.forms.base_field import TriggerType, BaseExecField + + +class TableRadio(BaseExecField): + """ + table 单选 + """ + + def __init__(self, + label: str, + text_field: str, + value_field: str, + provider: str, + method: str, + required: bool = False, + default_value: object = None, + relation_show_field_dict: Dict = None, + relation_trigger_field_dict: Dict = None, + trigger_type: TriggerType = TriggerType.OPTION_LIST, + attrs: Dict[str, object] = None, + props_info: Dict[str, object] = None): + super().__init__("TableRadio", label, text_field, value_field, provider, method, required, default_value, + relation_show_field_dict, relation_trigger_field_dict, trigger_type, attrs, props_info) diff --git a/src/MaxKB-1.7.2/apps/common/forms/text_input_field.py b/src/MaxKB-1.7.2/apps/common/forms/text_input_field.py new file mode 100644 index 0000000..28a821e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/forms/text_input_field.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_input_field.py + @date:2023/10/31 17:58 + @desc: +""" +from typing import Dict + +from common.forms.base_field import BaseField, TriggerType + + +class TextInputField(BaseField): + """ + 文本输入框 + """ + + def __init__(self, label: str, + required: bool = False, + default_value=None, + relation_show_field_dict: Dict = None, + + attrs=None, props_info=None): + super().__init__('TextInput', label, required, default_value, relation_show_field_dict, + {}, + TriggerType.OPTION_LIST, attrs, props_info) diff --git a/src/MaxKB-1.7.2/apps/common/handle/__init__.py b/src/MaxKB-1.7.2/apps/common/handle/__init__.py new file mode 100644 index 0000000..ad09602 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: __init__.py.py + @date:2023/9/6 10:09 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/common/handle/base_parse_qa_handle.py b/src/MaxKB-1.7.2/apps/common/handle/base_parse_qa_handle.py new file mode 100644 index 0000000..8cd1cd1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/base_parse_qa_handle.py @@ -0,0 +1,52 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_parse_qa_handle.py + @date:2024/5/21 14:56 + @desc: +""" +from abc import ABC, abstractmethod + + +def get_row_value(row, title_row_index_dict, field): + index = title_row_index_dict.get(field) + if index is None: + return None + if (len(row) - 1) >= index: + return row[index] + return None + + +def get_title_row_index_dict(title_row_list): + title_row_index_dict = {} + if len(title_row_list) == 1: + title_row_index_dict['content'] = 0 + elif len(title_row_list) == 1: + title_row_index_dict['title'] = 0 + title_row_index_dict['content'] = 1 + else: + title_row_index_dict['title'] = 0 + title_row_index_dict['content'] = 1 + title_row_index_dict['problem_list'] = 2 + for index in range(len(title_row_list)): + title_row = title_row_list[index] + if title_row is None: + title_row = '' + if title_row.startswith('分段标题'): + title_row_index_dict['title'] = index + if title_row.startswith('分段内容'): + title_row_index_dict['content'] = index + if title_row.startswith('问题'): + title_row_index_dict['problem_list'] = index + return title_row_index_dict + + +class BaseParseQAHandle(ABC): + @abstractmethod + def support(self, file, get_buffer): + pass + + @abstractmethod + def handle(self, file, get_buffer, save_image): + pass diff --git a/src/MaxKB-1.7.2/apps/common/handle/base_parse_table_handle.py b/src/MaxKB-1.7.2/apps/common/handle/base_parse_table_handle.py new file mode 100644 index 0000000..4872903 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/base_parse_table_handle.py @@ -0,0 +1,19 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_parse_qa_handle.py + @date:2024/5/21 14:56 + @desc: +""" +from abc import ABC, abstractmethod + + +class BaseParseTableHandle(ABC): + @abstractmethod + def support(self, file, get_buffer): + pass + + @abstractmethod + def handle(self, file, get_buffer,save_image): + pass diff --git a/src/MaxKB-1.7.2/apps/common/handle/base_split_handle.py b/src/MaxKB-1.7.2/apps/common/handle/base_split_handle.py new file mode 100644 index 0000000..f9b573f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/base_split_handle.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_split_handle.py + @date:2024/3/27 18:13 + @desc: +""" +from abc import ABC, abstractmethod +from typing import List + + +class BaseSplitHandle(ABC): + @abstractmethod + def support(self, file, get_buffer): + pass + + @abstractmethod + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + pass diff --git a/src/MaxKB-1.7.2/apps/common/handle/base_to_response.py b/src/MaxKB-1.7.2/apps/common/handle/base_to_response.py new file mode 100644 index 0000000..05af57c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/base_to_response.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_to_response.py + @date:2024/9/6 16:04 + @desc: +""" +from abc import ABC, abstractmethod + +from rest_framework import status + + +class BaseToResponse(ABC): + + @abstractmethod + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + _status=status.HTTP_200_OK): + pass + + @abstractmethod + def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): + pass + + @staticmethod + def format_stream_chunk(response_str): + return 'data: ' + response_str + '\n\n' diff --git a/src/MaxKB-1.7.2/apps/common/handle/handle_exception.py b/src/MaxKB-1.7.2/apps/common/handle/handle_exception.py new file mode 100644 index 0000000..bff0c4c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/handle_exception.py @@ -0,0 +1,91 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: handle_exception.py + @date:2023/9/5 19:29 + @desc: +""" +import logging +import traceback + +from rest_framework.exceptions import ValidationError, ErrorDetail, APIException +from rest_framework.views import exception_handler + +from common.exception.app_exception import AppApiException +from common.response import result + + +def to_result(key, args, parent_key=None): + """ + 将校验异常 args转换为统一数据 + :param key: 校验key + :param args: 校验异常参数 + :param parent_key 父key + :return: 接口响应对象 + """ + error_detail = list(filter( + lambda d: True if isinstance(d, ErrorDetail) else True if isinstance(d, dict) and len( + d.keys()) > 0 else False, + (args[0] if len(args) > 0 else {key: [ErrorDetail('未知异常', code='unknown')]}).get(key)))[0] + + if isinstance(error_detail, dict): + return list(map(lambda k: to_result(k, args=[error_detail], + parent_key=key if parent_key is None else parent_key + '.' + key), + error_detail.keys() if len(error_detail) > 0 else []))[0] + + return result.Result(500 if isinstance(error_detail.code, str) else error_detail.code, + message=f"【{key if parent_key is None else parent_key + '.' + key}】为必填参数" if str( + error_detail) == "This field is required." else error_detail) + + +def validation_error_to_result(exc: ValidationError): + """ + 校验异常转响应对象 + :param exc: 校验异常 + :return: 接口响应对象 + """ + try: + v = find_err_detail(exc.detail) + if v is None: + return result.error(str(exc.detail)) + return result.error(str(v)) + except Exception as e: + return result.error(str(exc.detail)) + + +def find_err_detail(exc_detail): + if isinstance(exc_detail, ErrorDetail): + return exc_detail + if isinstance(exc_detail, dict): + keys = exc_detail.keys() + for key in keys: + _value = exc_detail[key] + if isinstance(_value, list): + return find_err_detail(_value) + if isinstance(_value, ErrorDetail): + return _value + if isinstance(_value, dict) and len(_value.keys()) > 0: + return find_err_detail(_value) + if isinstance(exc_detail, list): + for v in exc_detail: + r = find_err_detail(v) + if r is not None: + return r + + +def handle_exception(exc, context): + exception_class = exc.__class__ + # 先调用REST framework默认的异常处理方法获得标准错误响应对象 + response = exception_handler(exc, context) + # 在此处补充自定义的异常处理 + if issubclass(exception_class, ValidationError): + return validation_error_to_result(exc) + if issubclass(exception_class, AppApiException): + return result.Result(exc.code, exc.message, response_status=exc.status_code) + if issubclass(exception_class, APIException): + return result.error(exc.detail) + if response is None: + logging.getLogger("max_kb_error").error(f'{str(exc)}:{traceback.format_exc()}') + return result.error(str(exc)) + return response diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/doc_split_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/doc_split_handle.py new file mode 100644 index 0000000..c31c53e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/doc_split_handle.py @@ -0,0 +1,191 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import io +import re +import traceback +import uuid +from functools import reduce +from typing import List + +from docx import Document, ImagePart +from docx.oxml import ns +from docx.table import Table +from docx.text.paragraph import Paragraph + +from common.handle.base_split_handle import BaseSplitHandle +from common.util.split_model import SplitModel +from dataset.models import Image + +default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(? 0: + for image in _images: + images.append({'image': image, 'get_image_id_handle': get_image_id_handle}) + except Exception as e: + pass + return images + + +def images_to_string(images, doc: Document, images_list, get_image_id): + return "".join( + [item for item in [image_to_mode(image, doc, images_list, get_image_id) for image in images] if + item is not None]) + + +def get_paragraph_element_txt(paragraph_element, doc: Document, images_list, get_image_id): + try: + images = get_paragraph_element_images(paragraph_element, doc, images_list, get_image_id) + if len(images) > 0: + return images_to_string(images, doc, images_list, get_image_id) + elif paragraph_element.text is not None: + return paragraph_element.text + return "" + except Exception as e: + print(e) + return "" + + +def get_paragraph_txt(paragraph: Paragraph, doc: Document, images_list, get_image_id): + try: + return "".join([get_paragraph_element_txt(e, doc, images_list, get_image_id) for e in paragraph._element]) + except Exception as e: + return "" + + +def get_cell_text(cell, doc: Document, images_list, get_image_id): + try: + return "".join( + [get_paragraph_txt(paragraph, doc, images_list, get_image_id) for paragraph in cell.paragraphs]).replace( + "\n", '
') + except Exception as e: + return "" + + +def get_image_id_func(): + image_map = {} + + def get_image_id(image_id): + _v = image_map.get(image_id) + if _v is None: + image_map[image_id] = uuid.uuid1() + return image_map.get(image_id) + return _v + + return get_image_id + + +class DocSplitHandle(BaseSplitHandle): + @staticmethod + def paragraph_to_md(paragraph: Paragraph, doc: Document, images_list, get_image_id): + try: + psn = paragraph.style.name + if psn.startswith('Heading'): + title = "".join(["#" for i in range(int(psn.replace("Heading ", '')))]) + " " + paragraph.text + images = reduce(lambda x, y: [*x, *y], + [get_paragraph_element_images(e, doc, images_list, get_image_id) for e in + paragraph._element], + []) + + if len(images) > 0: + return title + '\n' + images_to_string(images, doc, images_list, get_image_id) if len( + paragraph.text) > 0 else images_to_string(images, doc, images_list, get_image_id) + return title + except Exception as e: + traceback.print_exc() + return paragraph.text + return get_paragraph_txt(paragraph, doc, images_list, get_image_id) + + @staticmethod + def table_to_md(table, doc: Document, images_list, get_image_id): + rows = table.rows + + # 创建 Markdown 格式的表格 + md_table = '| ' + ' | '.join( + [get_cell_text(cell, doc, images_list, get_image_id) for cell in rows[0].cells]) + ' |\n' + md_table += '| ' + ' | '.join(['---' for i in range(len(rows[0].cells))]) + ' |\n' + for row in rows[1:]: + md_table += '| ' + ' | '.join( + [get_cell_text(cell, doc, images_list, get_image_id) for cell in row.cells]) + ' |\n' + return md_table + + def to_md(self, doc, images_list, get_image_id): + elements = [] + for element in doc.element.body: + tag = str(element.tag) + if tag.endswith('tbl'): + # 处理表格 + table = Table(element, doc) + elements.append(table) + elif tag.endswith('p'): + # 处理段落 + paragraph = Paragraph(element, doc) + elements.append(paragraph) + return "\n".join( + [self.paragraph_to_md(element, doc, images_list, get_image_id) if isinstance(element, + Paragraph) else self.table_to_md( + element, + doc, + images_list, get_image_id) + for element + in elements]) + + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + try: + image_list = [] + buffer = get_buffer(file) + doc = Document(io.BytesIO(buffer)) + content = self.to_md(doc, image_list, get_image_id_func()) + if len(image_list) > 0: + save_image(image_list) + if pattern_list is not None and len(pattern_list) > 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + except BaseException as e: + traceback.print_exception(e) + return {'name': file.name, + 'content': []} + return {'name': file.name, + 'content': split_model.parse(content) + } + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".docx") or file_name.endswith(".doc") or file_name.endswith( + ".DOC") or file_name.endswith(".DOCX"): + return True + return False diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/html_split_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/html_split_handle.py new file mode 100644 index 0000000..878d9ed --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/html_split_handle.py @@ -0,0 +1,61 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: html_split_handle.py + @date:2024/5/23 10:58 + @desc: +""" +import re +from typing import List + +from bs4 import BeautifulSoup +from charset_normalizer import detect +from html2text import html2text + +from common.handle.base_split_handle import BaseSplitHandle +from common.util.split_model import SplitModel + +default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(? 0: + charset = charset_list[0] + return charset + return detect(buffer)['encoding'] + + +class HTMLSplitHandle(BaseSplitHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".html") or file_name.endswith(".HTML"): + return True + return False + + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + + if pattern_list is not None and len(pattern_list) > 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + try: + encoding = get_encoding(buffer) + content = buffer.decode(encoding) + content = html2text(content) + except BaseException as e: + return {'name': file.name, + 'content': []} + return {'name': file.name, + 'content': split_model.parse(content) + } diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/pdf_split_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/pdf_split_handle.py new file mode 100644 index 0000000..52a33b0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/pdf_split_handle.py @@ -0,0 +1,299 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import logging +import os +import re +import tempfile +import time +from typing import List + +import fitz +from langchain_community.document_loaders import PyPDFLoader + +from common.handle.base_split_handle import BaseSplitHandle +from common.util.split_model import SplitModel + +default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(? 0: + return {'name': file.name, 'content': result} + + # 没有目录的pdf + content = self.handle_pdf_content(file, pdf_document) + + if pattern_list is not None and len(pattern_list) > 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + except BaseException as e: + max_kb.error(f"File: {file.name}, error: {e}") + return {'name': file.name, + 'content': []} + finally: + pdf_document.close() + # 处理完后可以删除临时文件 + os.remove(temp_file_path) + + return {'name': file.name, + 'content': split_model.parse(content) + } + + @staticmethod + def handle_pdf_content(file, pdf_document): + content = "" + for page_num in range(len(pdf_document)): + start_time = time.time() + page = pdf_document.load_page(page_num) + text = page.get_text() + + if text and text.strip(): # 如果页面中有文本内容 + page_content = text + else: + try: + new_doc = fitz.open() + new_doc.insert_pdf(pdf_document, from_page=page_num, to_page=page_num) + page_num_pdf = tempfile.gettempdir() + f"/{file.name}_{page_num}.pdf" + new_doc.save(page_num_pdf) + new_doc.close() + + loader = PyPDFLoader(page_num_pdf, extract_images=True) + page_content = "\n" + loader.load()[0].page_content + except NotImplementedError as e: + # 文件格式不支持,直接退出 + raise e + except BaseException as e: + # 当页出错继续进行下一页,防止一个页面出错导致整个文件解析失败 + max_kb.error(f"File: {file.name}, Page: {page_num + 1}, error: {e}") + continue + finally: + os.remove(page_num_pdf) + + content += page_content + + elapsed_time = time.time() - start_time + max_kb.debug( + f"File: {file.name}, Page: {page_num + 1}, Time : {elapsed_time: .3f}s, content-length: {len(page_content)}") + + return content + + @staticmethod + def handle_toc(doc, limit): + # 找到目录 + toc = doc.get_toc() + if toc is None or len(toc) == 0: + return None + + # 创建存储章节内容的数组 + chapters = [] + + # 遍历目录并按章节提取文本 + for i, entry in enumerate(toc): + level, title, start_page = entry + start_page -= 1 # PyMuPDF 页码从 0 开始,书签页码从 1 开始 + chapter_title = title + # 确定结束页码,如果是最后一个章节则到文档末尾 + if i + 1 < len(toc): + end_page = toc[i + 1][2] - 1 + else: + end_page = doc.page_count - 1 + + # 去掉标题中的符号 + title = PdfSplitHandle.handle_chapter_title(title) + + # 提取该章节的文本内容 + chapter_text = "" + for page_num in range(start_page, end_page + 1): + page = doc.load_page(page_num) # 加载页面 + text = page.get_text("text") + text = re.sub(r'(? -1: + text = text[idx + len(title):] + + if i + 1 < len(toc): + l, next_title, next_start_page = toc[i + 1] + next_title = PdfSplitHandle.handle_chapter_title(next_title) + # print(f'next_title: {next_title}') + idx = text.find(next_title) + if idx > -1: + text = text[:idx] + + chapter_text += text # 提取文本 + # 限制章节内容长度 + if 0 < limit < len(chapter_text): + split_text = PdfSplitHandle.split_text(chapter_text, limit) + for text in split_text: + chapters.append({"title": chapter_title, "content": text}) + else: + chapters.append({"title": chapter_title, "content": chapter_text if chapter_text else chapter_title}) + # 保存章节内容和章节标题 + return chapters + + @staticmethod + def handle_links(doc, pattern_list, with_filter, limit): + # 创建存储章节内容的数组 + chapters = [] + toc_start_page = -1 + page_content = "" + handle_pre_toc = True + # 遍历 PDF 的每一页,查找带有目录链接的页 + for page_num in range(doc.page_count): + page = doc.load_page(page_num) + links = page.get_links() + # 如果目录开始页码未设置,则设置为当前页码 + if len(links) > 0: + toc_start_page = page_num + if toc_start_page < 0: + page_content += page.get_text('text') + # 检查该页是否包含内部链接(即指向文档内部的页面) + for num in range(len(links)): + link = links[num] + if link['kind'] == 1: # 'kind' 为 1 表示内部链接 + # 获取链接目标的页面 + dest_page = link['page'] + rect = link['from'] # 获取链接的矩形区域 + # 如果目录开始页码包括前言部分,则不处理前言部分 + if dest_page < toc_start_page: + handle_pre_toc = False + + # 提取链接区域的文本作为标题 + link_title = page.get_text("text", clip=rect).strip().split("\n")[0].replace('.', '').strip() + # print(f'link_title: {link_title}') + # 提取目标页面内容作为章节开始 + start_page = dest_page + end_page = dest_page + # 下一个link + next_link = links[num + 1] if num + 1 < len(links) else None + next_link_title = None + if next_link is not None and next_link['kind'] == 1: + rect = next_link['from'] + next_link_title = page.get_text("text", clip=rect).strip() \ + .split("\n")[0].replace('.', '').strip() + # print(f'next_link_title: {next_link_title}') + end_page = next_link['page'] + + # 提取章节内容 + chapter_text = "" + for p_num in range(start_page, end_page + 1): + p = doc.load_page(p_num) + text = p.get_text("text") + text = re.sub(r'(? -1: + text = text[idx + len(link_title):] + + if next_link_title is not None: + idx = text.find(next_link_title) + if idx > -1: + text = text[:idx] + chapter_text += text + + # 限制章节内容长度 + if 0 < limit < len(chapter_text): + split_text = PdfSplitHandle.split_text(chapter_text, limit) + for text in split_text: + chapters.append({"title": link_title, "content": text}) + else: + # 保存章节信息 + chapters.append({"title": link_title, "content": chapter_text}) + + # 目录中没有前言部分,手动处理 + if handle_pre_toc: + pre_toc = [] + lines = page_content.strip().split('\n') + try: + for line in lines: + if re.match(r'^前\s*言', line): + pre_toc.append({'title': line, 'content': ''}) + else: + pre_toc[-1]['content'] += line + for i in range(len(pre_toc)): + pre_toc[i]['content'] = re.sub(r'(? 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + # 插入目录前的部分 + page_content = re.sub(r'(?= length: + # 查找最近的句号 + last_period_index = current_segment.rfind('.') + if last_period_index != -1: + segments.append(current_segment[:last_period_index + 1]) + current_segment = current_segment[last_period_index + 1:] # 更新当前段落 + else: + segments.append(current_segment) + current_segment = "" + + # 处理剩余的部分 + if current_segment: + segments.append(current_segment) + + return segments + + @staticmethod + def handle_chapter_title(title): + title = re.sub(r'[一二三四五六七八九十\s*]、\s*', '', title) + title = re.sub(r'第[一二三四五六七八九十]章\s*', '', title) + return title + + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".pdf") or file_name.endswith(".PDF"): + return True + return False diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/qa/csv_parse_qa_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/qa/csv_parse_qa_handle.py new file mode 100644 index 0000000..75c22cb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/qa/csv_parse_qa_handle.py @@ -0,0 +1,59 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: csv_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import csv +import io + +from charset_normalizer import detect + +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value + + +def read_csv_standard(file_path): + data = [] + with open(file_path, 'r') as file: + reader = csv.reader(file) + for row in reader: + data.append(row) + return data + + +class CsvParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".csv"): + return True + return False + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + reader = csv.reader(io.TextIOWrapper(io.BytesIO(buffer), encoding=detect(buffer)['encoding'])) + try: + title_row_list = reader.__next__() + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] + if len(title_row_list) == 0: + return [{'name': file.name, 'paragraphs': []}] + title_row_index_dict = get_title_row_index_dict(title_row_list) + paragraph_list = [] + for row in reader: + content = get_row_value(row, title_row_index_dict, 'content') + if content is None: + continue + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem) if problem is not None else '' + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title) if title is not None else '' + paragraph_list.append({'title': title[0:255], + 'content': content[0:102400], + 'problem_list': problem_list}) + return [{'name': file.name, 'paragraphs': paragraph_list}] + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/qa/xls_parse_qa_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/qa/xls_parse_qa_handle.py new file mode 100644 index 0000000..06edb1f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/qa/xls_parse_qa_handle.py @@ -0,0 +1,61 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xls_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" + +import xlrd + +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value + + +def handle_sheet(file_name, sheet): + rows = iter([sheet.row_values(i) for i in range(sheet.nrows)]) + try: + title_row_list = next(rows) + except Exception as e: + return {'name': file_name, 'paragraphs': []} + if len(title_row_list) == 0: + return {'name': file_name, 'paragraphs': []} + title_row_index_dict = get_title_row_index_dict(title_row_list) + paragraph_list = [] + for row in rows: + content = get_row_value(row, title_row_index_dict, 'content') + if content is None: + continue + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem) if problem is not None else '' + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title) if title is not None else '' + content = str(content) + paragraph_list.append({'title': title[0:255], + 'content': content[0:102400], + 'problem_list': problem_list}) + return {'name': file_name, 'paragraphs': paragraph_list} + + +class XlsParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + buffer = get_buffer(file) + if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer): + return True + return False + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + workbook = xlrd.open_workbook(file_contents=buffer) + worksheets = workbook.sheets() + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet) if worksheets_size == 1 and sheet.name == 'Sheet1' else handle_sheet( + sheet.name, sheet) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py new file mode 100644 index 0000000..c3ee40d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/qa/xlsx_parse_qa_handle.py @@ -0,0 +1,72 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xlsx_parse_qa_handle.py + @date:2024/5/21 14:59 + @desc: +""" +import io + +import openpyxl + +from common.handle.base_parse_qa_handle import BaseParseQAHandle, get_title_row_index_dict, get_row_value +from common.handle.impl.tools import xlsx_embed_cells_images + + +def handle_sheet(file_name, sheet, image_dict): + rows = sheet.rows + try: + title_row_list = next(rows) + title_row_list = [row.value for row in title_row_list] + except Exception as e: + return {'name': file_name, 'paragraphs': []} + if len(title_row_list) == 0: + return {'name': file_name, 'paragraphs': []} + title_row_index_dict = get_title_row_index_dict(title_row_list) + paragraph_list = [] + for row in rows: + content = get_row_value(row, title_row_index_dict, 'content') + if content is None or content.value is None: + continue + problem = get_row_value(row, title_row_index_dict, 'problem_list') + problem = str(problem.value) if problem is not None and problem.value is not None else '' + problem_list = [{'content': p[0:255]} for p in problem.split('\n') if len(p.strip()) > 0] + title = get_row_value(row, title_row_index_dict, 'title') + title = str(title.value) if title is not None and title.value is not None else '' + content = str(content.value) + image = image_dict.get(content, None) + if image is not None: + content = f'![](/api/image/{image.id})' + paragraph_list.append({'title': title[0:255], + 'content': content[0:102400], + 'problem_list': problem_list}) + return {'name': file_name, 'paragraphs': paragraph_list} + + +class XlsxParseQAHandle(BaseParseQAHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".xlsx"): + return True + return False + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + workbook = openpyxl.load_workbook(io.BytesIO(buffer)) + try: + image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer)) + save_image([item for item in image_dict.values()]) + except Exception as e: + image_dict = {} + worksheets = workbook.worksheets + worksheets_size = len(worksheets) + return [row for row in + [handle_sheet(file.name, + sheet, + image_dict) if worksheets_size == 1 and sheet.title == 'Sheet1' else handle_sheet( + sheet.title, sheet, image_dict) for sheet + in worksheets] if row is not None] + except Exception as e: + return [{'name': file.name, 'paragraphs': []}] diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/response/openai_to_response.py b/src/MaxKB-1.7.2/apps/common/handle/impl/response/openai_to_response.py new file mode 100644 index 0000000..b250807 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/response/openai_to_response.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: openai_to_response.py + @date:2024/9/6 16:08 + @desc: +""" +import datetime + +from django.http import JsonResponse +from openai.types import CompletionUsage +from openai.types.chat import ChatCompletionChunk, ChatCompletionMessage, ChatCompletion +from openai.types.chat.chat_completion import Choice as BlockChoice +from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta +from rest_framework import status + +from common.handle.base_to_response import BaseToResponse + + +class OpenaiToResponse(BaseToResponse): + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + _status=status.HTTP_200_OK): + data = ChatCompletion(id=chat_record_id, choices=[ + BlockChoice(finish_reason='stop', index=0, chat_id=chat_id, + message=ChatCompletionMessage(role='assistant', content=content))], + created=datetime.datetime.now().second, model='', object='chat.completion', + usage=CompletionUsage(completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=completion_tokens + prompt_tokens) + ).dict() + return JsonResponse(data=data, status=_status) + + def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): + chunk = ChatCompletionChunk(id=chat_record_id, model='', object='chat.completion.chunk', + created=datetime.datetime.now().second, choices=[ + Choice(delta=ChoiceDelta(content=content, chat_id=chat_id), finish_reason='stop' if is_end else None, + index=0)], + usage=CompletionUsage(completion_tokens=completion_tokens, + prompt_tokens=prompt_tokens, + total_tokens=completion_tokens + prompt_tokens)).json() + return super().format_stream_chunk(chunk) diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/response/system_to_response.py b/src/MaxKB-1.7.2/apps/common/handle/impl/response/system_to_response.py new file mode 100644 index 0000000..1ec9806 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/response/system_to_response.py @@ -0,0 +1,26 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: system_to_response.py + @date:2024/9/6 18:03 + @desc: +""" +import json + +from rest_framework import status + +from common.handle.base_to_response import BaseToResponse +from common.response import result + + +class SystemToResponse(BaseToResponse): + def to_block_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens, + _status=status.HTTP_200_OK): + return result.success({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': content, 'is_end': is_end}, response_status=_status, code=_status) + + def to_stream_chunk_response(self, chat_id, chat_record_id, content, is_end, completion_tokens, prompt_tokens): + chunk = json.dumps({'chat_id': str(chat_id), 'id': str(chat_record_id), 'operate': True, + 'content': content, 'is_end': is_end}) + return super().format_stream_chunk(chunk) diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/table/csv_parse_table_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/table/csv_parse_table_handle.py new file mode 100644 index 0000000..71152f3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/table/csv_parse_table_handle.py @@ -0,0 +1,36 @@ +# coding=utf-8 +import logging + +from charset_normalizer import detect + +from common.handle.base_parse_table_handle import BaseParseTableHandle + +max_kb = logging.getLogger("max_kb") + + +class CsvSplitHandle(BaseParseTableHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith(".csv"): + return True + return False + + def handle(self, file, get_buffer,save_image): + buffer = get_buffer(file) + try: + content = buffer.decode(detect(buffer)['encoding']) + except BaseException as e: + max_kb.error(f'csv split handle error: {e}') + return [{'name': file.name, 'paragraphs': []}] + + csv_model = content.split('\n') + paragraphs = [] + # 第一行为标题 + title = csv_model[0].split(',') + for row in csv_model[1:]: + if not row: + continue + line = '; '.join([f'{key}:{value}' for key, value in zip(title, row.split(','))]) + paragraphs.append({'title': '', 'content': line}) + + return [{'name': file.name, 'paragraphs': paragraphs}] diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/table/xls_parse_table_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/table/xls_parse_table_handle.py new file mode 100644 index 0000000..6c30d49 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/table/xls_parse_table_handle.py @@ -0,0 +1,62 @@ +# coding=utf-8 +import logging + +import xlrd + +from common.handle.base_parse_table_handle import BaseParseTableHandle + +max_kb = logging.getLogger("max_kb") + + +class XlsSplitHandle(BaseParseTableHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + buffer = get_buffer(file) + if file_name.endswith(".xls") and xlrd.inspect_format(content=buffer): + return True + return False + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + wb = xlrd.open_workbook(file_contents=buffer, formatting_info=True) + result = [] + sheets = wb.sheets() + for sheet in sheets: + # 获取合并单元格的范围信息 + merged_cells = sheet.merged_cells + print(merged_cells) + data = [] + paragraphs = [] + # 获取第一行作为标题行 + headers = [sheet.cell_value(0, col_idx) for col_idx in range(sheet.ncols)] + # 从第二行开始遍历每一行(跳过标题行) + for row_idx in range(1, sheet.nrows): + row_data = {} + for col_idx in range(sheet.ncols): + cell_value = sheet.cell_value(row_idx, col_idx) + + # 检查是否为空单元格,如果为空检查是否在合并区域中 + if cell_value == "": + # 检查当前单元格是否在合并区域 + for (rlo, rhi, clo, chi) in merged_cells: + if rlo <= row_idx < rhi and clo <= col_idx < chi: + # 使用合并区域的左上角单元格的值 + cell_value = sheet.cell_value(rlo, clo) + break + + # 将标题作为键,单元格的值作为值存入字典 + row_data[headers[col_idx]] = cell_value + data.append(row_data) + + for row in data: + row_output = "; ".join([f"{key}: {value}" for key, value in row.items()]) + # print(row_output) + paragraphs.append({'title': '', 'content': row_output}) + + result.append({'name': sheet.name, 'paragraphs': paragraphs}) + + except BaseException as e: + max_kb.error(f'excel split handle error: {e}') + return [{'name': file.name, 'paragraphs': []}] + return result diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/table/xlsx_parse_table_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/table/xlsx_parse_table_handle.py new file mode 100644 index 0000000..35ef2f1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/table/xlsx_parse_table_handle.py @@ -0,0 +1,74 @@ +# coding=utf-8 +import io +import logging + +from openpyxl import load_workbook + +from common.handle.base_parse_table_handle import BaseParseTableHandle +from common.handle.impl.tools import xlsx_embed_cells_images + +max_kb = logging.getLogger("max_kb") + + +class XlsxSplitHandle(BaseParseTableHandle): + def support(self, file, get_buffer): + file_name: str = file.name.lower() + if file_name.endswith('.xlsx'): + return True + return False + + def fill_merged_cells(self, sheet, image_dict): + data = [] + + # 获取第一行作为标题行 + headers = [cell.value for cell in sheet[1]] + + # 从第二行开始遍历每一行 + for row in sheet.iter_rows(min_row=2, values_only=False): + row_data = {} + for col_idx, cell in enumerate(row): + cell_value = cell.value + + # 如果单元格为空,并且该单元格在合并单元格内,获取合并单元格的值 + if cell_value is None: + for merged_range in sheet.merged_cells.ranges: + if cell.coordinate in merged_range: + cell_value = sheet[merged_range.min_row][merged_range.min_col - 1].value + break + + image = image_dict.get(cell_value, None) + if image is not None: + cell_value = f'![](/api/image/{image.id})' + + # 使用标题作为键,单元格的值作为值存入字典 + row_data[headers[col_idx]] = cell_value + data.append(row_data) + + return data + + def handle(self, file, get_buffer, save_image): + buffer = get_buffer(file) + try: + wb = load_workbook(io.BytesIO(buffer)) + try: + image_dict: dict = xlsx_embed_cells_images(io.BytesIO(buffer)) + save_image([item for item in image_dict.values()]) + except Exception as e: + image_dict = {} + result = [] + for sheetname in wb.sheetnames: + paragraphs = [] + ws = wb[sheetname] + data = self.fill_merged_cells(ws, image_dict) + + for row in data: + row_output = "; ".join([f"{key}: {value}" for key, value in row.items()]) + # print(row_output) + paragraphs.append({'title': '', 'content': row_output}) + + result.append({'name': sheetname, 'paragraphs': paragraphs}) + + except BaseException as e: + max_kb.error(f'excel split handle error: {e}') + return [{'name': file.name, 'paragraphs': []}] + return result diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/text_split_handle.py b/src/MaxKB-1.7.2/apps/common/handle/impl/text_split_handle.py new file mode 100644 index 0000000..467607f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/text_split_handle.py @@ -0,0 +1,51 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: text_split_handle.py + @date:2024/3/27 18:19 + @desc: +""" +import re +from typing import List + +from charset_normalizer import detect + +from common.handle.base_split_handle import BaseSplitHandle +from common.util.split_model import SplitModel + +default_pattern_list = [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(? 0.5: + return True + return False + + def handle(self, file, pattern_list: List, with_filter: bool, limit: int, get_buffer, save_image): + buffer = get_buffer(file) + if pattern_list is not None and len(pattern_list) > 0: + split_model = SplitModel(pattern_list, with_filter, limit) + else: + split_model = SplitModel(default_pattern_list, with_filter=with_filter, limit=limit) + try: + content = buffer.decode(detect(buffer)['encoding']) + except BaseException as e: + return {'name': file.name, + 'content': []} + return {'name': file.name, + 'content': split_model.parse(content) + } diff --git a/src/MaxKB-1.7.2/apps/common/handle/impl/tools.py b/src/MaxKB-1.7.2/apps/common/handle/impl/tools.py new file mode 100644 index 0000000..d041397 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/handle/impl/tools.py @@ -0,0 +1,118 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: tools.py + @date:2024/9/11 16:41 + @desc: +""" +import io +import uuid +from functools import reduce +from io import BytesIO +from xml.etree.ElementTree import fromstring +from zipfile import ZipFile + +from PIL import Image as PILImage +from openpyxl.drawing.image import Image as openpyxl_Image +from openpyxl.packaging.relationship import get_rels_path, get_dependents +from openpyxl.xml.constants import SHEET_DRAWING_NS, REL_NS, SHEET_MAIN_NS + +from common.handle.base_parse_qa_handle import get_title_row_index_dict, get_row_value +from dataset.models import Image + + +def parse_element(element) -> {}: + data = {} + xdr_namespace = "{%s}" % SHEET_DRAWING_NS + targets = level_order_traversal(element, xdr_namespace + "nvPicPr") + for target in targets: + cNvPr = embed = "" + for child in target: + if child.tag == xdr_namespace + "nvPicPr": + cNvPr = child[0].attrib["name"] + elif child.tag == xdr_namespace + "blipFill": + _rel_embed = "{%s}embed" % REL_NS + embed = child[0].attrib[_rel_embed] + if cNvPr: + data[cNvPr] = embed + return data + + +def parse_element_sheet_xml(element) -> []: + data = [] + xdr_namespace = "{%s}" % SHEET_MAIN_NS + targets = level_order_traversal(element, xdr_namespace + "f") + for target in targets: + for child in target: + if child.tag == xdr_namespace + "f": + data.append(child.text) + return data + + +def level_order_traversal(root, flag: str) -> []: + queue = [root] + targets = [] + while queue: + node = queue.pop(0) + children = [child.tag for child in node] + if flag in children: + targets.append(node) + continue + for child in node: + queue.append(child) + return targets + + +def handle_images(deps, archive: ZipFile) -> []: + images = [] + if not PILImage: # Pillow not installed, drop images + return images + for dep in deps: + try: + image_io = archive.read(dep.target) + image = openpyxl_Image(BytesIO(image_io)) + except Exception as e: + print(e) + continue + image.embed = dep.id # 文件rId + image.target = dep.target # 文件地址 + images.append(image) + return images + + +def xlsx_embed_cells_images(buffer) -> {}: + archive = ZipFile(buffer) + # 解析cellImage.xml文件 + deps = get_dependents(archive, get_rels_path("xl/cellimages.xml")) + image_rel = handle_images(deps=deps, archive=archive) + # 工作表及其中图片ID + sheet_list = {} + for item in archive.namelist(): + if not item.startswith('xl/worksheets/sheet'): + continue + key = item.split('/')[-1].split('.')[0].split('sheet')[-1] + sheet_list[key] = parse_element_sheet_xml(fromstring(archive.read(item))) + cell_images_xml = parse_element(fromstring(archive.read("xl/cellimages.xml"))) + cell_images_rel = {} + for image in image_rel: + cell_images_rel[image.embed] = image + for cnv, embed in cell_images_xml.items(): + cell_images_xml[cnv] = cell_images_rel.get(embed) + result = {} + for key, img in cell_images_xml.items(): + image_excel_id_list = [_xl for _xl in + reduce(lambda x, y: [*x, *y], [sheet for sheet_id, sheet in sheet_list.items()], []) if + key in _xl] + if len(image_excel_id_list) > 0: + image_excel_id = image_excel_id_list[-1] + f = archive.open(img.target) + img_byte = io.BytesIO() + im = PILImage.open(f).convert('RGB') + im.save(img_byte, format='JPEG') + image = Image(id=uuid.uuid1(), image=img_byte.getvalue(), image_name=img.path) + result['=' + image_excel_id] = image + archive.close() + return result + + diff --git a/src/MaxKB-1.7.2/apps/common/init/init_doc.py b/src/MaxKB-1.7.2/apps/common/init/init_doc.py new file mode 100644 index 0000000..5a60e55 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/init/init_doc.py @@ -0,0 +1,92 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: init_doc.py + @date:2024/5/24 14:11 + @desc: +""" +import hashlib + +from django.urls import re_path, path, URLPattern +from drf_yasg import openapi +from drf_yasg.views import get_schema_view +from rest_framework import permissions + +from common.auth import AnonymousAuthentication +from smartdoc.const import CONFIG + + +def init_app_doc(application_urlpatterns): + schema_view = get_schema_view( + openapi.Info( + title="Python API", + default_version='v1', + description="智能客服平台", + ), + public=True, + permission_classes=[permissions.AllowAny], + authentication_classes=[AnonymousAuthentication] + ) + application_urlpatterns += [ + re_path(r'^doc(?P\.json|\.yaml)$', schema_view.without_ui(cache_timeout=0), + name='schema-json'), # 导出 + path('doc/', schema_view.with_ui('swagger', cache_timeout=0), name='schema-swagger-ui'), + path('redoc/', schema_view.with_ui('redoc', cache_timeout=0), name='schema-redoc'), + ] + + +def init_chat_doc(application_urlpatterns, patterns): + chat_schema_view = get_schema_view( + openapi.Info( + title="Python API", + default_version='/chat', + description="智能客服平台", + ), + public=True, + permission_classes=[permissions.AllowAny], + authentication_classes=[AnonymousAuthentication], + patterns=[ + URLPattern(pattern='api/' + str(url.pattern), callback=url.callback, default_args=url.default_args, + name=url.name) + for url in patterns if + url.name is not None and ['application/message', 'application/open', + 'application/profile'].__contains__( + url.name)] + ) + + application_urlpatterns += [ + path('doc/chat/', chat_schema_view.with_ui('swagger', cache_timeout=0), name='schema-swagger-ui'), + path('redoc/chat/', chat_schema_view.with_ui('redoc', cache_timeout=0), name='schema-redoc'), + ] + + +def encrypt(text): + md5 = hashlib.md5() + md5.update(text.encode()) + result = md5.hexdigest() + return result + + +def get_call(application_urlpatterns, patterns, params, func): + def run(): + if params['valid'](): + func(*params['get_params'](application_urlpatterns, patterns)) + + return run + + +init_list = [(init_app_doc, {'valid': lambda: CONFIG.get('DOC_PASSWORD') is not None and encrypt( + CONFIG.get('DOC_PASSWORD')) == 'd4fc097197b4b90a122b92cbd5bbe867', + 'get_call': get_call, + 'get_params': lambda application_urlpatterns, patterns: (application_urlpatterns,)}), + (init_chat_doc, {'valid': lambda: CONFIG.get('DOC_PASSWORD') is not None and encrypt( + CONFIG.get('DOC_PASSWORD')) == 'd4fc097197b4b90a122b92cbd5bbe867' or True, 'get_call': get_call, + 'get_params': lambda application_urlpatterns, patterns: ( + application_urlpatterns, patterns)})] + + +def init_doc(application_urlpatterns, patterns): + for init, params in init_list: + if params['valid'](): + get_call(application_urlpatterns, patterns, params, init)() diff --git a/src/MaxKB-1.7.2/apps/common/job/__init__.py b/src/MaxKB-1.7.2/apps/common/job/__init__.py new file mode 100644 index 0000000..2f4ef26 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/job/__init__.py @@ -0,0 +1,15 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2024/3/14 11:54 + @desc: +""" +from .client_access_num_job import * +from .clean_chat_job import * + + +def run(): + client_access_num_job.run() + clean_chat_job.run() diff --git a/src/MaxKB-1.7.2/apps/common/job/clean_chat_job.py b/src/MaxKB-1.7.2/apps/common/job/clean_chat_job.py new file mode 100644 index 0000000..23ff2c8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/job/clean_chat_job.py @@ -0,0 +1,56 @@ +# coding=utf-8 + +import logging +import datetime + +from django.db import transaction +from django.utils import timezone +from apscheduler.schedulers.background import BackgroundScheduler +from django_apscheduler.jobstores import DjangoJobStore +from application.models import Application, Chat +from django.db.models import Q +from common.lock.impl.file_lock import FileLock + +scheduler = BackgroundScheduler() +scheduler.add_jobstore(DjangoJobStore(), "default") +lock = FileLock() + + +def clean_chat_log_job(): + logging.getLogger("max_kb").info('开始清理对话记录') + now = timezone.now() + + applications = Application.objects.all().values('id', 'clean_time') + cutoff_dates = { + app['id']: now - datetime.timedelta(days=app['clean_time'] or 180) + for app in applications + } + + query_conditions = Q() + for app_id, cutoff_date in cutoff_dates.items(): + query_conditions |= Q(application_id=app_id, create_time__lt=cutoff_date) + + batch_size = 500 + while True: + with transaction.atomic(): + logs_to_delete = Chat.objects.filter(query_conditions).values_list('id', flat=True)[:batch_size] + count = logs_to_delete.count() + if count == 0: + break + deleted_count, _ = Chat.objects.filter(id__in=logs_to_delete).delete() + if deleted_count < batch_size: + break + + logging.getLogger("max_kb").info(f'结束清理对话记录') + + +def run(): + if lock.try_lock('clean_chat_log_job', 30 * 30): + try: + scheduler.start() + existing_job = scheduler.get_job(job_id='clean_chat_log') + if existing_job is not None: + existing_job.remove() + scheduler.add_job(clean_chat_log_job, 'cron', hour='0', minute='5', id='clean_chat_log') + finally: + lock.un_lock('clean_chat_log_job') diff --git a/src/MaxKB-1.7.2/apps/common/job/client_access_num_job.py b/src/MaxKB-1.7.2/apps/common/job/client_access_num_job.py new file mode 100644 index 0000000..9d91054 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/job/client_access_num_job.py @@ -0,0 +1,39 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: client_access_num_job.py + @date:2024/3/14 11:56 + @desc: +""" +import logging + +from apscheduler.schedulers.background import BackgroundScheduler +from django.db.models import QuerySet +from django_apscheduler.jobstores import DjangoJobStore + +from application.models.api_key_model import ApplicationPublicAccessClient +from common.lock.impl.file_lock import FileLock + +scheduler = BackgroundScheduler() +scheduler.add_jobstore(DjangoJobStore(), "default") +lock = FileLock() + + +def client_access_num_reset_job(): + logging.getLogger("max_kb").info('开始重置access_num') + QuerySet(ApplicationPublicAccessClient).update(intraday_access_num=0) + logging.getLogger("max_kb").info('结束重置access_num') + + +def run(): + if lock.try_lock('client_access_num_reset_job', 30 * 30): + try: + scheduler.start() + access_num_reset = scheduler.get_job(job_id='access_num_reset') + if access_num_reset is not None: + access_num_reset.remove() + scheduler.add_job(client_access_num_reset_job, 'cron', hour='0', minute='0', second='0', + id='access_num_reset') + finally: + lock.un_lock('client_access_num_reset_job') diff --git a/src/MaxKB-1.7.2/apps/common/lock/base_lock.py b/src/MaxKB-1.7.2/apps/common/lock/base_lock.py new file mode 100644 index 0000000..2ca5b21 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/lock/base_lock.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_lock.py + @date:2024/8/20 10:33 + @desc: +""" + +from abc import ABC, abstractmethod + + +class BaseLock(ABC): + @abstractmethod + def try_lock(self, key, timeout): + pass + + @abstractmethod + def un_lock(self, key): + pass diff --git a/src/MaxKB-1.7.2/apps/common/lock/impl/file_lock.py b/src/MaxKB-1.7.2/apps/common/lock/impl/file_lock.py new file mode 100644 index 0000000..f8ea639 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/lock/impl/file_lock.py @@ -0,0 +1,77 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: file_lock.py + @date:2024/8/20 10:48 + @desc: +""" +import errno +import hashlib +import os +import time + +import six + +from common.lock.base_lock import BaseLock +from smartdoc.const import PROJECT_DIR + + +def key_to_lock_name(key): + """ + Combine part of a key with its hash to prevent very long filenames + """ + MAX_LENGTH = 50 + key_hash = hashlib.md5(six.b(key)).hexdigest() + lock_name = key[:MAX_LENGTH - len(key_hash) - 1] + '_' + key_hash + return lock_name + + +class FileLock(BaseLock): + """ + File locking backend. + """ + + def __init__(self, settings=None): + if settings is None: + settings = {} + self.location = settings.get('location') + if self.location is None: + self.location = os.path.join(PROJECT_DIR, 'data', 'lock') + try: + os.makedirs(self.location) + except OSError as error: + # Directory exists? + if error.errno != errno.EEXIST: + # Re-raise unexpected OSError + raise + + def _get_lock_path(self, key): + lock_name = key_to_lock_name(key) + return os.path.join(self.location, lock_name) + + def try_lock(self, key, timeout): + lock_path = self._get_lock_path(key) + try: + # 创建锁文件,如果没创建成功则拿不到 + fd = os.open(lock_path, os.O_CREAT | os.O_EXCL) + except OSError as error: + if error.errno == errno.EEXIST: + # File already exists, check its modification time + mtime = os.path.getmtime(lock_path) + ttl = mtime + timeout - time.time() + if ttl > 0: + return False + else: + # 如果超时时间已到,直接上锁成功继续执行 + os.utime(lock_path, None) + return True + else: + return False + else: + os.close(fd) + return True + + def un_lock(self, key): + lock_path = self._get_lock_path(key) + os.remove(lock_path) diff --git a/src/MaxKB-1.7.2/apps/common/management/__init__.py b/src/MaxKB-1.7.2/apps/common/management/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/__init__.py b/src/MaxKB-1.7.2/apps/common/management/commands/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/celery.py b/src/MaxKB-1.7.2/apps/common/management/commands/celery.py new file mode 100644 index 0000000..a26b435 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/celery.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: celery.py + @date:2024/8/19 11:57 + @desc: +""" +import os +import subprocess + +from django.core.management.base import BaseCommand + +from smartdoc.const import BASE_DIR + + +class Command(BaseCommand): + help = 'celery' + + def add_arguments(self, parser): + parser.add_argument( + 'service', nargs='+', type=str, choices=("celery", "model"), help='Service', + ) + + def handle(self, *args, **options): + service = options.get('service') + os.environ.setdefault('CELERY_NAME', ','.join(service)) + server_hostname = os.environ.get("SERVER_HOSTNAME") + if hasattr(os, 'getuid') and os.getuid() == 0: + os.environ.setdefault('C_FORCE_ROOT', '1') + if not server_hostname: + server_hostname = '%h' + cmd = [ + 'celery', + '-A', 'ops', + 'worker', + '-P', 'threads', + '-l', 'info', + '-c', '10', + '-Q', ','.join(service), + '--heartbeat-interval', '10', + '-n', f'{",".join(service)}@{server_hostname}', + '--without-mingle', + ] + kwargs = {'cwd': BASE_DIR} + subprocess.run(cmd, **kwargs) diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/restart.py b/src/MaxKB-1.7.2/apps/common/management/commands/restart.py new file mode 100644 index 0000000..57285f9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/restart.py @@ -0,0 +1,6 @@ +from .services.command import BaseActionCommand, Action + + +class Command(BaseActionCommand): + help = 'Restart services' + action = Action.restart.value diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/__init__.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/command.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/command.py new file mode 100644 index 0000000..c5b7192 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/services/command.py @@ -0,0 +1,131 @@ +from django.core.management.base import BaseCommand +from django.db.models import TextChoices + +from .hands import * +from .utils import ServicesUtil +import os + + +class Services(TextChoices): + gunicorn = 'gunicorn', 'gunicorn' + celery_default = 'celery_default', 'celery_default' + local_model = 'local_model', 'local_model' + web = 'web', 'web' + celery = 'celery', 'celery' + celery_model = 'celery_model', 'celery_model' + task = 'task', 'task' + all = 'all', 'all' + + @classmethod + def get_service_object_class(cls, name): + from . import services + services_map = { + cls.gunicorn.value: services.GunicornService, + cls.celery_default: services.CeleryDefaultService, + cls.local_model: services.GunicornLocalModelService + } + return services_map.get(name) + + @classmethod + def web_services(cls): + return [cls.gunicorn, cls.local_model] + + @classmethod + def celery_services(cls): + return [cls.celery_default, cls.celery_model] + + @classmethod + def task_services(cls): + return cls.celery_services() + + @classmethod + def all_services(cls): + return cls.web_services() + cls.task_services() + + @classmethod + def export_services_values(cls): + return [cls.all.value, cls.web.value, cls.task.value] + [s.value for s in cls.all_services()] + + @classmethod + def get_service_objects(cls, service_names, **kwargs): + services = set() + for name in service_names: + method_name = f'{name}_services' + if hasattr(cls, method_name): + _services = getattr(cls, method_name)() + elif hasattr(cls, name): + _services = [getattr(cls, name)] + else: + continue + services.update(set(_services)) + + service_objects = [] + for s in services: + service_class = cls.get_service_object_class(s.value) + if not service_class: + continue + kwargs.update({ + 'name': s.value + }) + service_object = service_class(**kwargs) + service_objects.append(service_object) + return service_objects + + +class Action(TextChoices): + start = 'start', 'start' + status = 'status', 'status' + stop = 'stop', 'stop' + restart = 'restart', 'restart' + + +class BaseActionCommand(BaseCommand): + help = 'Service Base Command' + + action = None + util = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def add_arguments(self, parser): + parser.add_argument( + 'services', nargs='+', choices=Services.export_services_values(), help='Service', + ) + parser.add_argument('-d', '--daemon', nargs="?", const=True) + parser.add_argument('-w', '--worker', type=int, nargs="?", default=3 if os.cpu_count() > 3 else os.cpu_count()) + parser.add_argument('-f', '--force', nargs="?", const=True) + + def initial_util(self, *args, **options): + service_names = options.get('services') + service_kwargs = { + 'worker_gunicorn': options.get('worker') + } + services = Services.get_service_objects(service_names=service_names, **service_kwargs) + + kwargs = { + 'services': services, + 'run_daemon': options.get('daemon', False), + 'stop_daemon': self.action == Action.stop.value and Services.all.value in service_names, + 'force_stop': options.get('force') or False, + } + self.util = ServicesUtil(**kwargs) + + def handle(self, *args, **options): + self.initial_util(*args, **options) + assert self.action in Action.values, f'The action {self.action} is not in the optional list' + _handle = getattr(self, f'_handle_{self.action}', lambda: None) + _handle() + + def _handle_start(self): + self.util.start_and_watch() + os._exit(0) + + def _handle_stop(self): + self.util.stop() + + def _handle_restart(self): + self.util.restart() + + def _handle_status(self): + self.util.show_status() diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/hands.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/hands.py new file mode 100644 index 0000000..8244702 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/services/hands.py @@ -0,0 +1,26 @@ +import logging +import os +import sys + +from smartdoc.const import CONFIG, PROJECT_DIR + +try: + from apps.smartdoc import const + + __version__ = const.VERSION +except ImportError as e: + print("Not found __version__: {}".format(e)) + print("Python is: ") + logging.info(sys.executable) + __version__ = 'Unknown' + sys.exit(1) + +HTTP_HOST = '0.0.0.0' +HTTP_PORT = CONFIG.HTTP_LISTEN_PORT or 8080 +DEBUG = CONFIG.DEBUG or False + +LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'logs') +APPS_DIR = os.path.join(PROJECT_DIR, 'apps') +TMP_DIR = os.path.join(PROJECT_DIR, 'tmp') +if not os.path.exists(TMP_DIR): + os.makedirs(TMP_DIR) diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/services/__init__.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/__init__.py new file mode 100644 index 0000000..1027392 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/__init__.py @@ -0,0 +1,3 @@ +from .celery_default import * +from .gunicorn import * +from .local_model import * \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/services/base.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/base.py new file mode 100644 index 0000000..ddcb4fe --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/base.py @@ -0,0 +1,207 @@ +import abc +import time +import shutil +import psutil +import datetime +import threading +import subprocess +from ..hands import * + + +class BaseService(object): + + def __init__(self, **kwargs): + self.name = kwargs['name'] + self._process = None + self.STOP_TIMEOUT = 10 + self.max_retry = 0 + self.retry = 3 + self.LOG_KEEP_DAYS = 7 + self.EXIT_EVENT = threading.Event() + + @property + @abc.abstractmethod + def cmd(self): + return [] + + @property + @abc.abstractmethod + def cwd(self): + return '' + + @property + def is_running(self): + if self.pid == 0: + return False + try: + os.kill(self.pid, 0) + except (OSError, ProcessLookupError): + return False + else: + return True + + def show_status(self): + if self.is_running: + msg = f'{self.name} is running: {self.pid}.' + else: + msg = f'{self.name} is stopped.' + if DEBUG: + msg = '\033[31m{} is stopped.\033[0m\nYou can manual start it to find the error: \n' \ + ' $ cd {}\n' \ + ' $ {}'.format(self.name, self.cwd, ' '.join(self.cmd)) + + print(msg) + + # -- log -- + @property + def log_filename(self): + return f'{self.name}.log' + + @property + def log_filepath(self): + return os.path.join(LOG_DIR, self.log_filename) + + @property + def log_file(self): + return open(self.log_filepath, 'a') + + @property + def log_dir(self): + return os.path.dirname(self.log_filepath) + # -- end log -- + + # -- pid -- + @property + def pid_filepath(self): + return os.path.join(TMP_DIR, f'{self.name}.pid') + + @property + def pid(self): + if not os.path.isfile(self.pid_filepath): + return 0 + with open(self.pid_filepath) as f: + try: + pid = int(f.read().strip()) + except ValueError: + pid = 0 + return pid + + def write_pid(self): + with open(self.pid_filepath, 'w') as f: + f.write(str(self.process.pid)) + + def remove_pid(self): + if os.path.isfile(self.pid_filepath): + os.unlink(self.pid_filepath) + # -- end pid -- + + # -- process -- + @property + def process(self): + if not self._process: + try: + self._process = psutil.Process(self.pid) + except: + pass + return self._process + + # -- end process -- + + # -- action -- + def open_subprocess(self): + kwargs = {'cwd': self.cwd, 'stderr': self.log_file, 'stdout': self.log_file} + self._process = subprocess.Popen(self.cmd, **kwargs) + + def start(self): + if self.is_running: + self.show_status() + return + self.remove_pid() + self.open_subprocess() + self.write_pid() + self.start_other() + + def start_other(self): + pass + + def stop(self, force=False): + if not self.is_running: + self.show_status() + # self.remove_pid() + return + + print(f'Stop service: {self.name}', end='') + sig = 9 if force else 15 + os.kill(self.pid, sig) + + if self.process is None: + print("\033[31m No process found\033[0m") + return + try: + self.process.wait(1) + except: + pass + + for i in range(self.STOP_TIMEOUT): + if i == self.STOP_TIMEOUT - 1: + print("\033[31m Error\033[0m") + if not self.is_running: + print("\033[32m Ok\033[0m") + self.remove_pid() + break + else: + continue + + def watch(self): + self._check() + if not self.is_running: + self._restart() + self._rotate_log() + + def _check(self): + now = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + print(f"{now} Check service status: {self.name} -> ", end='') + if self.process: + try: + self.process.wait(1) # 不wait,子进程可能无法回收 + except: + pass + + if self.is_running: + print(f'running at {self.pid}') + else: + print(f'stopped at {self.pid}') + + def _restart(self): + if self.retry > self.max_retry: + logging.info("Service start failed, exit: {}".format(self.name)) + self.EXIT_EVENT.set() + return + self.retry += 1 + logging.info(f'> Find {self.name} stopped, retry {self.retry}, {self.pid}') + self.start() + + def _rotate_log(self): + now = datetime.datetime.now() + _time = now.strftime('%H:%M') + if _time != '23:59': + return + + backup_date = now.strftime('%Y-%m-%d') + backup_log_dir = os.path.join(self.log_dir, backup_date) + if not os.path.exists(backup_log_dir): + os.mkdir(backup_log_dir) + + backup_log_path = os.path.join(backup_log_dir, self.log_filename) + if os.path.isfile(self.log_filepath) and not os.path.isfile(backup_log_path): + logging.info(f'Rotate log file: {self.log_filepath} => {backup_log_path}') + shutil.copy(self.log_filepath, backup_log_path) + with open(self.log_filepath, 'w') as f: + pass + + to_delete_date = now - datetime.timedelta(days=self.LOG_KEEP_DAYS) + to_delete_dir = os.path.join(LOG_DIR, to_delete_date.strftime('%Y-%m-%d')) + if os.path.exists(to_delete_dir): + logging.info(f'Remove old log: {to_delete_dir}') + shutil.rmtree(to_delete_dir, ignore_errors=True) + # -- end action -- diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/services/celery_base.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/celery_base.py new file mode 100644 index 0000000..0ae219b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/celery_base.py @@ -0,0 +1,45 @@ +from django.conf import settings + +from .base import BaseService +from ..hands import * + + +class CeleryBaseService(BaseService): + + def __init__(self, queue, num=10, **kwargs): + super().__init__(**kwargs) + self.queue = queue + self.num = num + + @property + def cmd(self): + print('\n- Start Celery as Distributed Task Queue: {}'.format(self.queue.capitalize())) + + os.environ.setdefault('LC_ALL', 'C.UTF-8') + os.environ.setdefault('PYTHONOPTIMIZE', '1') + os.environ.setdefault('ANSIBLE_FORCE_COLOR', 'True') + os.environ.setdefault('PYTHONPATH', settings.APPS_DIR) + + if os.getuid() == 0: + os.environ.setdefault('C_FORCE_ROOT', '1') + server_hostname = os.environ.get("SERVER_HOSTNAME") + if not server_hostname: + server_hostname = '%h' + + cmd = [ + 'celery', + '-A', 'ops', + 'worker', + '-P', 'threads', + '-l', 'error', + '-c', str(self.num), + '-Q', self.queue, + '--heartbeat-interval', '10', + '-n', f'{self.queue}@{server_hostname}', + '--without-mingle', + ] + return cmd + + @property + def cwd(self): + return APPS_DIR diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/services/celery_default.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/celery_default.py new file mode 100644 index 0000000..5d3e6d7 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/celery_default.py @@ -0,0 +1,10 @@ +from .celery_base import CeleryBaseService + +__all__ = ['CeleryDefaultService'] + + +class CeleryDefaultService(CeleryBaseService): + + def __init__(self, **kwargs): + kwargs['queue'] = 'celery' + super().__init__(**kwargs) diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/services/gunicorn.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/gunicorn.py new file mode 100644 index 0000000..cc42c4f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/gunicorn.py @@ -0,0 +1,36 @@ +from .base import BaseService +from ..hands import * + +__all__ = ['GunicornService'] + + +class GunicornService(BaseService): + + def __init__(self, **kwargs): + self.worker = kwargs['worker_gunicorn'] + super().__init__(**kwargs) + + @property + def cmd(self): + print("\n- Start Gunicorn WSGI HTTP Server") + + log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' + bind = f'{HTTP_HOST}:{HTTP_PORT}' + cmd = [ + 'gunicorn', 'smartdoc.wsgi:application', + '-b', bind, + '-k', 'gthread', + '--threads', '200', + '-w', str(self.worker), + '--max-requests', '10240', + '--max-requests-jitter', '2048', + '--access-logformat', log_format, + '--access-logfile', '-' + ] + if DEBUG: + cmd.append('--reload') + return cmd + + @property + def cwd(self): + return APPS_DIR diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/services/local_model.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/local_model.py new file mode 100644 index 0000000..4511f8f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/services/services/local_model.py @@ -0,0 +1,44 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: local_model.py + @date:2024/8/21 13:28 + @desc: +""" +from .base import BaseService +from ..hands import * + +__all__ = ['GunicornLocalModelService'] + + +class GunicornLocalModelService(BaseService): + + def __init__(self, **kwargs): + self.worker = kwargs['worker_gunicorn'] + super().__init__(**kwargs) + + @property + def cmd(self): + print("\n- Start Gunicorn Local Model WSGI HTTP Server") + os.environ.setdefault('SERVER_NAME', 'local_model') + log_format = '%(h)s %(t)s %(L)ss "%(r)s" %(s)s %(b)s ' + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + cmd = [ + 'gunicorn', 'smartdoc.wsgi:application', + '-b', bind, + '-k', 'gthread', + '--threads', '200', + '-w', "1", + '--max-requests', '10240', + '--max-requests-jitter', '2048', + '--access-logformat', log_format, + '--access-logfile', '-' + ] + if DEBUG: + cmd.append('--reload') + return cmd + + @property + def cwd(self): + return APPS_DIR diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/services/utils.py b/src/MaxKB-1.7.2/apps/common/management/commands/services/utils.py new file mode 100644 index 0000000..2426758 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/services/utils.py @@ -0,0 +1,140 @@ +import threading +import signal +import time +import daemon +from daemon import pidfile +from .hands import * +from .hands import __version__ +from .services.base import BaseService + + +class ServicesUtil(object): + + def __init__(self, services, run_daemon=False, force_stop=False, stop_daemon=False): + self._services = services + self.run_daemon = run_daemon + self.force_stop = force_stop + self.stop_daemon = stop_daemon + self.EXIT_EVENT = threading.Event() + self.check_interval = 30 + self.files_preserve_map = {} + + def restart(self): + self.stop() + time.sleep(5) + self.start_and_watch() + + def start_and_watch(self): + logging.info(time.ctime()) + logging.info(f'MaxKB version {__version__}, more see https://www.jumpserver.org') + self.start() + if self.run_daemon: + self.show_status() + with self.daemon_context: + self.watch() + else: + self.watch() + + def start(self): + for service in self._services: + service: BaseService + service.start() + self.files_preserve_map[service.name] = service.log_file + + time.sleep(1) + + def stop(self): + for service in self._services: + service: BaseService + service.stop(force=self.force_stop) + + if self.stop_daemon: + self._stop_daemon() + + # -- watch -- + def watch(self): + while not self.EXIT_EVENT.is_set(): + try: + _exit = self._watch() + if _exit: + break + time.sleep(self.check_interval) + except KeyboardInterrupt: + print('Start stop services') + break + self.clean_up() + + def _watch(self): + for service in self._services: + service: BaseService + service.watch() + if service.EXIT_EVENT.is_set(): + self.EXIT_EVENT.set() + return True + return False + # -- end watch -- + + def clean_up(self): + if not self.EXIT_EVENT.is_set(): + self.EXIT_EVENT.set() + self.stop() + + def show_status(self): + for service in self._services: + service: BaseService + service.show_status() + + # -- daemon -- + def _stop_daemon(self): + if self.daemon_pid and self.daemon_is_running: + os.kill(self.daemon_pid, 15) + self.remove_daemon_pid() + + def remove_daemon_pid(self): + if os.path.isfile(self.daemon_pid_filepath): + os.unlink(self.daemon_pid_filepath) + + @property + def daemon_pid(self): + if not os.path.isfile(self.daemon_pid_filepath): + return 0 + with open(self.daemon_pid_filepath) as f: + try: + pid = int(f.read().strip()) + except ValueError: + pid = 0 + return pid + + @property + def daemon_is_running(self): + try: + os.kill(self.daemon_pid, 0) + except (OSError, ProcessLookupError): + return False + else: + return True + + @property + def daemon_pid_filepath(self): + return os.path.join(TMP_DIR, 'mk.pid') + + @property + def daemon_log_filepath(self): + return os.path.join(LOG_DIR, 'mk.log') + + @property + def daemon_context(self): + daemon_log_file = open(self.daemon_log_filepath, 'a') + context = daemon.DaemonContext( + pidfile=pidfile.TimeoutPIDLockFile(self.daemon_pid_filepath), + signal_map={ + signal.SIGTERM: lambda x, y: self.clean_up(), + signal.SIGHUP: 'terminate', + }, + stdout=daemon_log_file, + stderr=daemon_log_file, + files_preserve=list(self.files_preserve_map.values()), + detach_process=True, + ) + return context + # -- end daemon -- diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/start.py b/src/MaxKB-1.7.2/apps/common/management/commands/start.py new file mode 100644 index 0000000..4c078a8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/start.py @@ -0,0 +1,6 @@ +from .services.command import BaseActionCommand, Action + + +class Command(BaseActionCommand): + help = 'Start services' + action = Action.start.value diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/status.py b/src/MaxKB-1.7.2/apps/common/management/commands/status.py new file mode 100644 index 0000000..36f0d36 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/status.py @@ -0,0 +1,6 @@ +from .services.command import BaseActionCommand, Action + + +class Command(BaseActionCommand): + help = 'Show services status' + action = Action.status.value diff --git a/src/MaxKB-1.7.2/apps/common/management/commands/stop.py b/src/MaxKB-1.7.2/apps/common/management/commands/stop.py new file mode 100644 index 0000000..a79a533 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/management/commands/stop.py @@ -0,0 +1,6 @@ +from .services.command import BaseActionCommand, Action + + +class Command(BaseActionCommand): + help = 'Stop services' + action = Action.stop.value diff --git a/src/MaxKB-1.7.2/apps/common/middleware/cross_domain_middleware.py b/src/MaxKB-1.7.2/apps/common/middleware/cross_domain_middleware.py new file mode 100644 index 0000000..06c0a6a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/middleware/cross_domain_middleware.py @@ -0,0 +1,40 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: cross_domain_middleware.py + @date:2024/5/8 13:36 + @desc: +""" +from django.http import HttpResponse +from django.utils.deprecation import MiddlewareMixin + +from common.cache_data.application_api_key_cache import get_application_api_key + + +class CrossDomainMiddleware(MiddlewareMixin): + + def process_request(self, request): + if request.method == 'OPTIONS': + return HttpResponse(status=200, + headers={ + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET,POST,DELETE,PUT", + "Access-Control-Allow-Headers": "Origin,X-Requested-With,Content-Type,Accept,Authorization,token"}) + + def process_response(self, request, response): + auth = request.META.get('HTTP_AUTHORIZATION') + origin = request.META.get('HTTP_ORIGIN') + if auth is not None and str(auth).startswith("application-") and origin is not None: + application_api_key = get_application_api_key(str(auth), True) + cross_domain_list = application_api_key.get('cross_domain_list', []) + allow_cross_domain = application_api_key.get('allow_cross_domain', False) + if allow_cross_domain: + response['Access-Control-Allow-Methods'] = 'GET,POST,DELETE,PUT' + response[ + 'Access-Control-Allow-Headers'] = "Origin,X-Requested-With,Content-Type,Accept,Authorization,token" + if cross_domain_list is None or len(cross_domain_list) == 0: + response['Access-Control-Allow-Origin'] = "*" + elif cross_domain_list.__contains__(origin): + response['Access-Control-Allow-Origin'] = origin + return response diff --git a/src/MaxKB-1.7.2/apps/common/middleware/static_headers_middleware.py b/src/MaxKB-1.7.2/apps/common/middleware/static_headers_middleware.py new file mode 100644 index 0000000..f5afcfb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/middleware/static_headers_middleware.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: static_headers_middleware.py + @date:2024/3/13 18:26 + @desc: +""" +from django.utils.deprecation import MiddlewareMixin + +from common.cache_data.application_access_token_cache import get_application_access_token + + +class StaticHeadersMiddleware(MiddlewareMixin): + def process_response(self, request, response): + if request.path.startswith('/ui/chat/'): + access_token = request.path.replace('/ui/chat/', '') + application_access_token = get_application_access_token(access_token, True) + if application_access_token is not None: + white_active = application_access_token.get('white_active', False) + white_list = application_access_token.get('white_list', []) + application_icon = application_access_token.get('application_icon') + application_name = application_access_token.get('application_name') + if white_active: + # 添加自定义的响应头 + response[ + 'Content-Security-Policy'] = f'frame-ancestors {" ".join(white_list)}' + response.content = (response.content.decode('utf-8').replace( + '', + f'') + .replace('MaxKB', f'{application_name}').encode( + "utf-8")) + return response diff --git a/src/MaxKB-1.7.2/apps/common/mixins/api_mixin.py b/src/MaxKB-1.7.2/apps/common/mixins/api_mixin.py new file mode 100644 index 0000000..d2625a0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/mixins/api_mixin.py @@ -0,0 +1,24 @@ +# coding=utf-8 +""" + @project: smart-doc + @Author:虎 + @file: api_mixin.py + @date:2023/9/14 17:50 + @desc: +""" +from rest_framework import serializers + + +class ApiMixin(serializers.Serializer): + + @staticmethod + def get_request_params_api(): + pass + + @staticmethod + def get_request_body_api(): + pass + + @staticmethod + def get_response_body_api(): + pass diff --git a/src/MaxKB-1.7.2/apps/common/mixins/app_model_mixin.py b/src/MaxKB-1.7.2/apps/common/mixins/app_model_mixin.py new file mode 100644 index 0000000..412dbae --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/mixins/app_model_mixin.py @@ -0,0 +1,18 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: app_model_mixin.py + @date:2023/9/21 9:41 + @desc: +""" +from django.db import models + + +class AppModelMixin(models.Model): + create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True) + update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True) + + class Meta: + abstract = True + ordering = ['create_time'] diff --git a/src/MaxKB-1.7.2/apps/common/models/db_model_manage.py b/src/MaxKB-1.7.2/apps/common/models/db_model_manage.py new file mode 100644 index 0000000..80ce0f5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/models/db_model_manage.py @@ -0,0 +1,35 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: db_model_manage.py + @date:2024/7/22 17:00 + @desc: +""" +from importlib import import_module +from django.conf import settings + + +def new_instance_by_class_path(class_path: str): + parts = class_path.rpartition('.') + package_path = parts[0] + class_name = parts[2] + module = import_module(package_path) + HandlerClass = getattr(module, class_name) + return HandlerClass() + + +class DBModelManage: + model_dict = {} + + @staticmethod + def get_model(model_name): + return DBModelManage.model_dict.get(model_name) + + @staticmethod + def init(): + handles = [new_instance_by_class_path(class_path) for class_path in + (settings.MODEL_HANDLES if hasattr(settings, 'MODEL_HANDLES') else [])] + for h in handles: + model_dict = h.get_model_dict() + DBModelManage.model_dict = {**DBModelManage.model_dict, **model_dict} diff --git a/src/MaxKB-1.7.2/apps/common/models/handle/base_handle.py b/src/MaxKB-1.7.2/apps/common/models/handle/base_handle.py new file mode 100644 index 0000000..1738967 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/models/handle/base_handle.py @@ -0,0 +1,15 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: base_handle.py + @date:2024/7/22 17:02 + @desc: +""" +from abc import ABC, abstractmethod + + +class IBaseModelHandle(ABC): + @abstractmethod + def get_model_dict(self): + pass diff --git a/src/MaxKB-1.7.2/apps/common/models/handle/impl/default_base_model_handle.py b/src/MaxKB-1.7.2/apps/common/models/handle/impl/default_base_model_handle.py new file mode 100644 index 0000000..b1ed705 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/models/handle/impl/default_base_model_handle.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: default_base_model_handle.py + @date:2024/7/22 17:06 + @desc: +""" +from common.models.handle.base_handle import IBaseModelHandle + + +class DefaultBaseModelHandle(IBaseModelHandle): + def get_model_dict(self): + return {} diff --git a/src/MaxKB-1.7.2/apps/common/response/result.py b/src/MaxKB-1.7.2/apps/common/response/result.py new file mode 100644 index 0000000..bb2ba0f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/response/result.py @@ -0,0 +1,166 @@ +from typing import List + +from django.http import JsonResponse +from drf_yasg import openapi +from rest_framework import status + + +class Page(dict): + """ + 分页对象 + """ + + def __init__(self, total: int, records: List, current_page: int, page_size: int, **kwargs): + super().__init__(**{'total': total, 'records': records, 'current': current_page, 'size': page_size}) + + +class Result(JsonResponse): + charset = 'utf-8' + """ + 接口统一返回对象 + """ + + def __init__(self, code=200, message="成功", data=None, response_status=status.HTTP_200_OK, **kwargs): + back_info_dict = {"code": code, "message": message, 'data': data} + super().__init__(data=back_info_dict, status=response_status, **kwargs) + + +def get_page_request_params(other_request_params=None): + if other_request_params is None: + other_request_params = [] + current_page = openapi.Parameter(name='current_page', + in_=openapi.IN_PATH, + type=openapi.TYPE_INTEGER, + required=True, + description='当前页') + + page_size = openapi.Parameter(name='page_size', + in_=openapi.IN_PATH, + type=openapi.TYPE_INTEGER, + required=True, + description='每页大小') + result = [current_page, page_size] + for other_request_param in other_request_params: + result.append(other_request_param) + return result + + +def get_page_api_response(response_data_schema: openapi.Schema): + """ + 获取统一返回 响应Api + """ + return openapi.Responses(responses={200: openapi.Response(description="响应参数", + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'code': openapi.Schema( + type=openapi.TYPE_INTEGER, + title="响应码", + default=200, + description="成功:200 失败:其他"), + "message": openapi.Schema( + type=openapi.TYPE_STRING, + title="提示", + default='成功', + description="错误提示"), + "data": openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'total': openapi.Schema( + type=openapi.TYPE_INTEGER, + title="总条数", + default=1, + description="数据总条数"), + "records": openapi.Schema( + type=openapi.TYPE_ARRAY, + items=response_data_schema), + "current": openapi.Schema( + type=openapi.TYPE_INTEGER, + title="当前页", + default=1, + description="当前页"), + "size": openapi.Schema( + type=openapi.TYPE_INTEGER, + title="每页大小", + default=10, + description="每页大小") + + } + ) + + } + ), + )}) + + +def get_api_response(response_data_schema: openapi.Schema): + """ + 获取统一返回 响应Api + """ + return openapi.Responses(responses={200: openapi.Response(description="响应参数", + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'code': openapi.Schema( + type=openapi.TYPE_INTEGER, + title="响应码", + default=200, + description="成功:200 失败:其他"), + "message": openapi.Schema( + type=openapi.TYPE_STRING, + title="提示", + default='成功', + description="错误提示"), + "data": response_data_schema + + } + ), + )}) + + +def get_default_response(): + return get_api_response(openapi.Schema(type=openapi.TYPE_BOOLEAN)) + + +def get_api_array_response(response_data_schema: openapi.Schema): + """ + 获取统一返回 响应Api + """ + return openapi.Responses(responses={200: openapi.Response(description="响应参数", + schema=openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'code': openapi.Schema( + type=openapi.TYPE_INTEGER, + title="响应码", + default=200, + description="成功:200 失败:其他"), + "message": openapi.Schema( + type=openapi.TYPE_STRING, + title="提示", + default='成功', + description="错误提示"), + "data": openapi.Schema(type=openapi.TYPE_ARRAY, + items=response_data_schema) + + } + ), + )}) + + +def success(data, **kwargs): + """ + 获取一个成功的响应对象 + :param data: 接口响应数据 + :return: 请求响应对象 + """ + return Result(data=data, **kwargs) + + +def error(message): + """ + 获取一个失败的响应对象 + :param message: 错误提示 + :return: 接口响应对象 + """ + return Result(code=500, message=message) diff --git a/src/MaxKB-1.7.2/apps/common/sql/list_embedding_text.sql b/src/MaxKB-1.7.2/apps/common/sql/list_embedding_text.sql new file mode 100644 index 0000000..ac0dc7b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/sql/list_embedding_text.sql @@ -0,0 +1,27 @@ +SELECT + problem_paragraph_mapping."id" AS "source_id", + paragraph.document_id AS document_id, + paragraph."id" AS paragraph_id, + problem.dataset_id AS dataset_id, + 0 AS source_type, + problem."content" AS "text", + paragraph.is_active AS is_active +FROM + problem problem + LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id" + LEFT JOIN paragraph paragraph ON paragraph."id" = problem_paragraph_mapping.paragraph_id + ${problem} + +UNION +SELECT + paragraph."id" AS "source_id", + paragraph.document_id AS document_id, + paragraph."id" AS paragraph_id, + paragraph.dataset_id AS dataset_id, + 1 AS source_type, + concat_ws(E'\n',paragraph.title,paragraph."content") AS "text", + paragraph.is_active AS is_active +FROM + paragraph paragraph + + ${paragraph} \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/common/swagger_api/common_api.py b/src/MaxKB-1.7.2/apps/common/swagger_api/common_api.py new file mode 100644 index 0000000..c3d8be6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/swagger_api/common_api.py @@ -0,0 +1,85 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common.py + @date:2023/12/25 16:17 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class CommonApi: + class HitTestApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [ + openapi.Parameter(name='query_text', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='问题文本'), + openapi.Parameter(name='top_number', + in_=openapi.IN_QUERY, + type=openapi.TYPE_NUMBER, + default=10, + required=True, + description='topN'), + openapi.Parameter(name='similarity', + in_=openapi.IN_QUERY, + type=openapi.TYPE_NUMBER, + default=0.6, + required=True, + description='相关性'), + openapi.Parameter(name='search_mode', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + default="embedding", + required=True, + description='检索模式embedding|keywords|blend' + ) + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id', + 'document_id', 'title', + 'similarity', 'comprehensive_score', + 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容", + description="段落内容", default='段落内容'), + 'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题", + description="标题", default="xxx的描述"), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量", + description="点赞数量", default=1), + 'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量", + description="点踩数", default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id", + description="知识库id", default='xxx'), + 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", + description="文档id", default='xxx'), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", + description="是否可用", default=True), + 'similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="相关性得分", + description="相关性得分", default=True), + 'comprehensive_score': openapi.Schema(type=openapi.TYPE_NUMBER, title="综合得分,用于排序", + description="综合得分,用于排序", default=True), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + + } + ) diff --git a/src/MaxKB-1.7.2/apps/common/task/__init__.py b/src/MaxKB-1.7.2/apps/common/task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/common/template/email_template.html b/src/MaxKB-1.7.2/apps/common/template/email_template.html new file mode 100644 index 0000000..dff0ab3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/template/email_template.html @@ -0,0 +1,122 @@ + + + + + + + + + + +
+ + + + + + + + + +
+
+ 智能知识库问答系统 +
+
+
+

+ + + 尊敬的用户: + +

+ +

+ ${code}  为您的动态验证码,请于30分钟内填写,为保障帐户安全,请勿向任何人提供此验证码。 +

+
+ +
+
+

智能知识库项目组

+
+

+ 此为系统邮件,请勿回复
+ Please do not reply to this system email +

+ +
+
+
+
+
+ + diff --git a/src/MaxKB-1.7.2/apps/common/util/cache_util.py b/src/MaxKB-1.7.2/apps/common/util/cache_util.py new file mode 100644 index 0000000..3d97a47 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/util/cache_util.py @@ -0,0 +1,68 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: cache_util.py + @date:2024/7/24 19:23 + @desc: +""" +from django.core.cache import caches + +cache = caches['default_file'] + + +def get_data_by_default_cache(key: str, get_data, cache_instance=cache, version=None, kwargs=None): + """ + 获取数据, 先从缓存中获取,如果获取不到再调用get_data 获取数据 + @param kwargs: get_data所需参数 + @param key: key + @param get_data: 获取数据函数 + @param cache_instance: cache实例 + @param version: 版本用于隔离 + @return: + """ + if kwargs is None: + kwargs = {} + if cache_instance.has_key(key, version=version): + return cache_instance.get(key, version=version) + data = get_data(**kwargs) + cache_instance.add(key, data, version=version) + return data + + +def set_data_by_default_cache(key: str, get_data, cache_instance=cache, version=None): + data = get_data() + cache_instance.set(key, data, version=version) + return data + + +def get_cache(cache_key, use_get_data: any = True, cache_instance=cache, version=None): + def inner(get_data): + def run(*args, **kwargs): + key = cache_key(*args, **kwargs) if callable(cache_key) else cache_key + is_use_get_data = use_get_data(*args, **kwargs) if callable(use_get_data) else use_get_data + if is_use_get_data: + if cache_instance.has_key(key, version=version): + return cache_instance.get(key, version=version) + data = get_data(*args, **kwargs) + cache_instance.add(key, data, timeout=None, version=version) + return data + data = get_data(*args, **kwargs) + cache_instance.set(key, data, timeout=None, version=version) + return data + + return run + + return inner + + +def del_cache(cache_key, cache_instance=cache, version=None): + def inner(func): + def run(*args, **kwargs): + key = cache_key(*args, **kwargs) if callable(cache_key) else cache_key + func(*args, **kwargs) + cache_instance.delete(key, version=version) + + return run + + return inner diff --git a/src/MaxKB-1.7.2/apps/common/util/common.py b/src/MaxKB-1.7.2/apps/common/util/common.py new file mode 100644 index 0000000..cbf6b00 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/util/common.py @@ -0,0 +1,104 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common.py + @date:2023/10/16 16:42 + @desc: +""" +import hashlib +import importlib +from functools import reduce +from typing import Dict, List + +from django.db.models import QuerySet + +from ..exception.app_exception import AppApiException +from ..models.db_model_manage import DBModelManage + + +def sub_array(array: List, item_num=10): + result = [] + temp = [] + for item in array: + temp.append(item) + if len(temp) >= item_num: + result.append(temp) + temp = [] + if len(temp) > 0: + result.append(temp) + return result + + +def query_params_to_single_dict(query_params: Dict): + return reduce(lambda x, y: {**x, **y}, list( + filter(lambda item: item is not None, [({key: value} if value is not None and len(value) > 0 else None) for + key, value in + query_params.items()])), {}) + + +def get_exec_method(clazz_: str, method_: str): + """ + 根据 class 和method函数 获取执行函数 + :param clazz_: class 字符串 + :param method_: 执行函数 + :return: 执行函数 + """ + clazz_split = clazz_.split('.') + clazz_name = clazz_split[-1] + package = ".".join([clazz_split[index] for index in range(len(clazz_split) - 1)]) + package_model = importlib.import_module(package) + return getattr(getattr(package_model, clazz_name), method_) + + +def flat_map(array: List[List]): + """ + 将二位数组转为一维数组 + :param array: 二维数组 + :return: 一维数组 + """ + result = [] + for e in array: + result += e + return result + + +def password_encrypt(raw_password): + """ + 密码 md5加密 + :param raw_password: 密码 + :return: 加密后密码 + """ + md5 = hashlib.md5() # 2,实例化md5() 方法 + md5.update(raw_password.encode()) # 3,对字符串的字节类型加密 + result = md5.hexdigest() # 4,加密 + return result + + +def post(post_function): + def inner(func): + def run(*args, **kwargs): + result = func(*args, **kwargs) + return post_function(*result) + + return run + + return inner + + +def valid_license(model=None, count=None, message=None): + def inner(func): + def run(*args, **kwargs): + xpack_cache = DBModelManage.get_model('xpack_cache') + is_license_valid = xpack_cache.get('XPACK_LICENSE_IS_VALID', False) if xpack_cache is not None else False + record_count = QuerySet(model).count() + + if not is_license_valid and record_count >= count: + error_message = message or f'超出限制{count}, 请联系我们(https://fit2cloud.com/)。' + raise AppApiException(400, error_message) + + return func(*args, **kwargs) + + return run + + return inner diff --git a/src/MaxKB-1.7.2/apps/common/util/field_message.py b/src/MaxKB-1.7.2/apps/common/util/field_message.py new file mode 100644 index 0000000..61eca2a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/util/field_message.py @@ -0,0 +1,117 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: field_message.py + @date:2024/3/1 14:30 + @desc: +""" +from django.utils.translation import gettext_lazy + + +class ErrMessage: + @staticmethod + def char(field: str): + return { + 'invalid': gettext_lazy("【%s】不是有效的字符串。" % field), + 'blank': gettext_lazy("【%s】此字段不能为空字符串。" % field), + 'max_length': gettext_lazy("【%s】请确保此字段的字符数不超过 {max_length} 个。" % field), + 'min_length': gettext_lazy("【%s】请确保此字段至少包含 {min_length} 个字符。" % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field) + } + + @staticmethod + def uuid(field: str): + return {'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + 'invalid': gettext_lazy("【%s】必须是有效的UUID。" % field), + } + + @staticmethod + def integer(field: str): + return {'invalid': gettext_lazy('【%s】必须是有效的integer。' % field), + 'max_value': gettext_lazy('【%s】请确保此值小于或等于 {max_value} 。' % field), + 'min_value': gettext_lazy('【%s】请确保此值大于或等于 {min_value} 。' % field), + 'max_string_length': gettext_lazy('【%s】字符串值太大。') % field, + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def list(field: str): + return {'not_a_list': gettext_lazy('【%s】应为列表,但得到的类型为 "{input_type}".' % field), + 'empty': gettext_lazy('【%s】此列表不能为空。' % field), + 'min_length': gettext_lazy('【%s】请确保此字段至少包含 {min_length} 个元素。' % field), + 'max_length': gettext_lazy('【%s】请确保此字段的元素不超过 {max_length} 个。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def boolean(field: str): + return {'invalid': gettext_lazy('【%s】必须是有效的布尔值。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field)} + + @staticmethod + def dict(field: str): + return {'not_a_dict': gettext_lazy('【%s】应为字典,但得到的类型为 "{input_type}' % field), + 'empty': gettext_lazy('【%s】能是空的。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def float(field: str): + return {'invalid': gettext_lazy('【%s】需要一个有效的数字。' % field), + 'max_value': gettext_lazy('【%s】请确保此值小于或等于 {max_value}。' % field), + 'min_value': gettext_lazy('【%s】请确保此值大于或等于 {min_value}。' % field), + 'max_string_length': gettext_lazy('【%s】字符串值太大。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def json(field: str): + return { + 'invalid': gettext_lazy('【%s】值必须是有效的JSON。' % field), + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def base(field: str): + return { + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + } + + @staticmethod + def date(field: str): + return { + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + 'invalid': gettext_lazy('【%s】日期格式错误,请改用以下格式之一: {format}。'), + 'datetime': gettext_lazy('【%s】应为日期,但得到的是日期时间。') + } + + @staticmethod + def image(field: str): + return { + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'null': gettext_lazy('【%s】此字段不能为null。' % field), + 'invalid_image': gettext_lazy('您上载的【%s】文件不是图像或图像已损坏,请上载有效的图像。' % field), + 'max_length': gettext_lazy('【%s】请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。' % field), + 'invalid': gettext_lazy('【%s】提交的数据不是文件,请检查表单上的编码类型。' % field) + } + + @staticmethod + def file(field: str): + return { + 'required': gettext_lazy('【%s】此字段必填。' % field), + 'empty': gettext_lazy('【%s】提交的文件为空。' % field), + 'invalid': gettext_lazy('【%s】提交的数据不是文件,请检查表单上的编码类型。' % field), + 'no_name': gettext_lazy('【%s】无法确定任何文件名。' % field), + 'max_length': gettext_lazy('【%s】请确保此文件名最多包含 {max_length} 个字符(长度为 {length})。' % field) + } diff --git a/src/MaxKB-1.7.2/apps/common/util/file_util.py b/src/MaxKB-1.7.2/apps/common/util/file_util.py new file mode 100644 index 0000000..447b007 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/util/file_util.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: file_util.py + @date:2023/9/25 21:06 + @desc: +""" + + +def get_file_content(path): + with open(path, "r", encoding='utf-8') as file: + content = file.read() + return content diff --git a/src/MaxKB-1.7.2/apps/common/util/fork.py b/src/MaxKB-1.7.2/apps/common/util/fork.py new file mode 100644 index 0000000..ee30f69 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/util/fork.py @@ -0,0 +1,175 @@ +import copy +import logging +import re +import traceback +from functools import reduce +from typing import List, Set +from urllib.parse import urljoin, urlparse, ParseResult, urlsplit, urlunparse + +import html2text as ht +import requests +from bs4 import BeautifulSoup + +requests.packages.urllib3.disable_warnings() + + +class ChildLink: + def __init__(self, url, tag): + self.url = url + self.tag = copy.deepcopy(tag) + + +class ForkManage: + def __init__(self, base_url: str, selector_list: List[str]): + self.base_url = base_url + self.selector_list = selector_list + + def fork(self, level: int, exclude_link_url: Set[str], fork_handler): + self.fork_child(ChildLink(self.base_url, None), self.selector_list, level, exclude_link_url, fork_handler) + + @staticmethod + def fork_child(child_link: ChildLink, selector_list: List[str], level: int, exclude_link_url: Set[str], + fork_handler): + if level < 0: + return + else: + child_link.url = remove_fragment(child_link.url) + child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url + if not exclude_link_url.__contains__(child_url): + exclude_link_url.add(child_url) + response = Fork(child_link.url, selector_list).fork() + fork_handler(child_link, response) + for child_link in response.child_link_list: + child_url = child_link.url[:-1] if child_link.url.endswith('/') else child_link.url + if not exclude_link_url.__contains__(child_url): + ForkManage.fork_child(child_link, selector_list, level - 1, exclude_link_url, fork_handler) + + +def remove_fragment(url: str) -> str: + parsed_url = urlparse(url) + modified_url = ParseResult(scheme=parsed_url.scheme, netloc=parsed_url.netloc, path=parsed_url.path, + params=parsed_url.params, query=parsed_url.query, fragment=None) + return urlunparse(modified_url) + + +class Fork: + class Response: + def __init__(self, content: str, child_link_list: List[ChildLink], status, message: str): + self.content = content + self.child_link_list = child_link_list + self.status = status + self.message = message + + @staticmethod + def success(html_content: str, child_link_list: List[ChildLink]): + return Fork.Response(html_content, child_link_list, 200, '') + + @staticmethod + def error(message: str): + return Fork.Response('', [], 500, message) + + def __init__(self, base_fork_url: str, selector_list: List[str]): + base_fork_url = remove_fragment(base_fork_url) + self.base_fork_url = urljoin(base_fork_url if base_fork_url.endswith("/") else base_fork_url + '/', '.') + parsed = urlsplit(base_fork_url) + query = parsed.query + self.base_fork_url = self.base_fork_url[:-1] + if query is not None and len(query) > 0: + self.base_fork_url = self.base_fork_url + '?' + query + self.selector_list = [selector for selector in selector_list if selector is not None and len(selector) > 0] + self.urlparse = urlparse(self.base_fork_url) + self.base_url = ParseResult(scheme=self.urlparse.scheme, netloc=self.urlparse.netloc, path='', params='', + query='', + fragment='').geturl() + + def get_child_link_list(self, bf: BeautifulSoup): + pattern = "^((?!(http:|https:|tel:/|#|mailto:|javascript:))|" + self.base_fork_url + "|/).*" + link_list = bf.find_all(name='a', href=re.compile(pattern)) + result = [ChildLink(link.get('href'), link) if link.get('href').startswith(self.base_url) else ChildLink( + self.base_url + link.get('href'), link) for link in link_list] + result = [row for row in result if row.url.startswith(self.base_fork_url)] + return result + + def get_content_html(self, bf: BeautifulSoup): + if self.selector_list is None or len(self.selector_list) == 0: + return str(bf) + params = reduce(lambda x, y: {**x, **y}, + [{'class_': selector.replace('.', '')} if selector.startswith('.') else + {'id': selector.replace("#", "")} if selector.startswith("#") else {'name': selector} for + selector in + self.selector_list], {}) + f = bf.find_all(**params) + return "\n".join([str(row) for row in f]) + + @staticmethod + def reset_url(tag, field, base_fork_url): + field_value: str = tag[field] + if field_value.startswith("/"): + result = urlparse(base_fork_url) + result_url = ParseResult(scheme=result.scheme, netloc=result.netloc, path=field_value, params='', query='', + fragment='').geturl() + else: + result_url = urljoin( + base_fork_url + '/' + (field_value if field_value.endswith('/') else field_value + '/'), + ".") + result_url = result_url[:-1] if result_url.endswith('/') else result_url + tag[field] = result_url + + def reset_beautiful_soup(self, bf: BeautifulSoup): + reset_config_list = [ + { + 'field': 'href', + }, + { + 'field': 'src', + } + ] + for reset_config in reset_config_list: + field = reset_config.get('field') + tag_list = bf.find_all(**{field: re.compile('^(?!(http:|https:|tel:/|#|mailto:|javascript:)).*')}) + for tag in tag_list: + self.reset_url(tag, field, self.base_fork_url) + return bf + + @staticmethod + def get_beautiful_soup(response): + encoding = response.encoding if response.encoding is not None and response.encoding != 'ISO-8859-1' else response.apparent_encoding + html_content = response.content.decode(encoding) + beautiful_soup = BeautifulSoup(html_content, "html.parser") + meta_list = beautiful_soup.find_all('meta') + charset_list = [meta.attrs.get('charset') for meta in meta_list if + meta.attrs is not None and 'charset' in meta.attrs] + if len(charset_list) > 0: + charset = charset_list[0] + if charset != encoding: + html_content = response.content.decode(charset) + return BeautifulSoup(html_content, "html.parser") + return beautiful_soup + + def fork(self): + try: + + headers = { + 'user-agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/99.0.4844.51 Safari/537.36' + } + + logging.getLogger("max_kb").info(f'fork:{self.base_fork_url}') + response = requests.get(self.base_fork_url, verify=False, headers=headers) + if response.status_code != 200: + logging.getLogger("max_kb").error(f"url: {self.base_fork_url} code:{response.status_code}") + return Fork.Response.error(f"url: {self.base_fork_url} code:{response.status_code}") + bf = self.get_beautiful_soup(response) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + return Fork.Response.error(str(e)) + bf = self.reset_beautiful_soup(bf) + link_list = self.get_child_link_list(bf) + content = self.get_content_html(bf) + r = ht.html2text(content) + return Fork.Response.success(r, link_list) + + +def handler(base_url, response: Fork.Response): + print(base_url.url, base_url.tag.text if base_url.tag else None, response.content) + +# ForkManage('https://bbs.fit2cloud.com/c/de/6', ['.md-content']).fork(3, set(), handler) diff --git a/src/MaxKB-1.7.2/apps/common/util/function_code.py b/src/MaxKB-1.7.2/apps/common/util/function_code.py new file mode 100644 index 0000000..fa3dc50 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/util/function_code.py @@ -0,0 +1,99 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: function_code.py + @date:2024/8/7 16:11 + @desc: +""" +import os +import subprocess +import sys +import uuid +from textwrap import dedent + +from diskcache import Cache + +from smartdoc.const import BASE_DIR +from smartdoc.const import PROJECT_DIR + +python_directory = sys.executable + + +class FunctionExecutor: + def __init__(self, sandbox=False): + self.sandbox = sandbox + if sandbox: + self.sandbox_path = '/opt/maxkb/app/sandbox' + self.user = 'sandbox' + else: + self.sandbox_path = os.path.join(PROJECT_DIR, 'data', 'sandbox') + self.user = None + self._createdir() + if self.sandbox: + os.system(f"chown -R {self.user}:{self.user} {self.sandbox_path}") + + def _createdir(self): + old_mask = os.umask(0o077) + try: + os.makedirs(self.sandbox_path, 0o700, exist_ok=True) + finally: + os.umask(old_mask) + + def exec_code(self, code_str, keywords): + _id = str(uuid.uuid1()) + success = '{"code":200,"msg":"成功","data":exec_result}' + err = '{"code":500,"msg":str(e),"data":None}' + path = r'' + self.sandbox_path + '' + _exec_code = f""" +try: + import os + env = dict(os.environ) + for key in list(env.keys()): + if key in os.environ and (key.startswith('MAXKB') or key.startswith('POSTGRES') or key.startswith('PG')): + del os.environ[key] + locals_v={'{}'} + keywords={keywords} + globals_v=globals() + exec({dedent(code_str)!a}, globals_v, locals_v) + f_name, f = locals_v.popitem() + for local in locals_v: + globals_v[local] = locals_v[local] + exec_result=f(**keywords) + from diskcache import Cache + cache = Cache({path!a}) + cache.set({_id!a},{success}) +except Exception as e: + from diskcache import Cache + cache = Cache({path!a}) + cache.set({_id!a},{err}) +""" + if self.sandbox: + subprocess_result = self._exec_sandbox(_exec_code, _id) + else: + subprocess_result = self._exec(_exec_code) + if subprocess_result.returncode == 1: + raise Exception(subprocess_result.stderr) + cache = Cache(self.sandbox_path) + result = cache.get(_id) + cache.delete(_id) + if result.get('code') == 200: + return result.get('data') + raise Exception(result.get('msg')) + + def _exec_sandbox(self, _code, _id): + exec_python_file = f'{self.sandbox_path}/{_id}.py' + with open(exec_python_file, 'w') as file: + file.write(_code) + os.system(f"chown {self.user}:{self.user} {exec_python_file}") + kwargs = {'cwd': BASE_DIR} + subprocess_result = subprocess.run( + ['su', '-c', python_directory + ' ' + exec_python_file, self.user], + text=True, + capture_output=True, **kwargs) + os.remove(exec_python_file) + return subprocess_result + + @staticmethod + def _exec(_code): + return subprocess.run([python_directory, '-c', _code], text=True, capture_output=True) diff --git a/src/MaxKB-1.7.2/apps/common/util/lock.py b/src/MaxKB-1.7.2/apps/common/util/lock.py new file mode 100644 index 0000000..4276f1c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/util/lock.py @@ -0,0 +1,53 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: lock.py + @date:2023/9/11 11:45 + @desc: +""" +from datetime import timedelta + +from django.core.cache import caches + +memory_cache = caches['default'] + + +def try_lock(key: str, timeout=None): + """ + 获取锁 + :param key: 获取锁 key + :param timeout 超时时间 + :return: 是否获取到锁 + """ + return memory_cache.add(key, 'lock', timeout=timedelta(hours=1).total_seconds() if timeout is not None else timeout) + + +def un_lock(key: str): + """ + 解锁 + :param key: 解锁 key + :return: 是否解锁成功 + """ + return memory_cache.delete(key) + + +def lock(lock_key): + """ + 给一个函数上锁 + :param lock_key: 上锁key 字符串|函数 函数返回值为字符串 + :return: 装饰器函数 当前装饰器主要限制一个key只能一个线程去调用 相同key只能阻塞等待上一个任务执行完毕 不同key不需要等待 + """ + + def inner(func): + def run(*args, **kwargs): + key = lock_key(*args, **kwargs) if callable(lock_key) else lock_key + try: + if try_lock(key=key): + return func(*args, **kwargs) + finally: + un_lock(key=key) + + return run + + return inner diff --git a/src/MaxKB-1.7.2/apps/common/util/rsa_util.py b/src/MaxKB-1.7.2/apps/common/util/rsa_util.py new file mode 100644 index 0000000..0030186 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/util/rsa_util.py @@ -0,0 +1,140 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: rsa_util.py + @date:2023/11/3 11:13 + @desc: +""" +import base64 +import threading + +from Crypto.Cipher import PKCS1_v1_5 as PKCS1_cipher +from Crypto.PublicKey import RSA +from django.core import cache +from django.db.models import QuerySet + +from setting.models import SystemSetting, SettingType + +lock = threading.Lock() +rsa_cache = cache.caches['default'] +cache_key = "rsa_key" +# 对密钥加密的密码 +secret_code = "mac_kb_password" + + +def generate(): + """ + 生成 私钥秘钥对 + :return:{key:'公钥',value:'私钥'} + """ + # 生成一个 2048 位的密钥 + key = RSA.generate(2048) + + # 获取私钥 + encrypted_key = key.export_key(passphrase=secret_code, pkcs=8, + protection="scryptAndAES128-CBC") + return {'key': key.publickey().export_key(), 'value': encrypted_key} + + +def get_key_pair(): + rsa_value = rsa_cache.get(cache_key) + if rsa_value is None: + lock.acquire() + rsa_value = rsa_cache.get(cache_key) + if rsa_value is not None: + return rsa_value + try: + rsa_value = get_key_pair_by_sql() + rsa_cache.set(cache_key, rsa_value) + finally: + lock.release() + return rsa_value + + +def get_key_pair_by_sql(): + system_setting = QuerySet(SystemSetting).filter(type=SettingType.RSA.value).first() + if system_setting is None: + kv = generate() + system_setting = SystemSetting(type=SettingType.RSA.value, + meta={'key': kv.get('key').decode(), 'value': kv.get('value').decode()}) + system_setting.save() + return system_setting.meta + + +def encrypt(msg, public_key: str | None = None): + """ + 加密 + :param msg: 加密数据 + :param public_key: 公钥 + :return: 加密后的数据 + """ + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(public_key)) + encrypt_msg = cipher.encrypt(msg.encode("utf-8")) + return base64.b64encode(encrypt_msg).decode() + + +def decrypt(msg, pri_key: str | None = None): + """ + 解密 + :param msg: 需要解密的数据 + :param pri_key: 私钥 + :return: 解密后数据 + """ + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + decrypt_data = cipher.decrypt(base64.b64decode(msg), 0) + return decrypt_data.decode("utf-8") + + +def rsa_long_encrypt(message, public_key: str | None = None, length=200): + """ + 超长文本加密 + + :param message: 需要加密的字符串 + :param public_key 公钥 + :param length: 1024bit的证书用100, 2048bit的证书用 200 + :return: 加密后的数据 + """ + # 读取公钥 + if public_key is None: + public_key = get_key_pair().get('key') + cipher = PKCS1_cipher.new(RSA.importKey(extern_key=public_key, + passphrase=secret_code)) + # 处理:Plaintext is too long. 分段加密 + if len(message) <= length: + # 对编码的数据进行加密,并通过base64进行编码 + result = base64.b64encode(cipher.encrypt(message.encode('utf-8'))) + else: + rsa_text = [] + # 对编码后的数据进行切片,原因:加密长度不能过长 + for i in range(0, len(message), length): + cont = message[i:i + length] + # 对切片后的数据进行加密,并新增到text后面 + rsa_text.append(cipher.encrypt(cont.encode('utf-8'))) + # 加密完进行拼接 + cipher_text = b''.join(rsa_text) + # base64进行编码 + result = base64.b64encode(cipher_text) + return result.decode() + + +def rsa_long_decrypt(message, pri_key: str | None = None, length=256): + """ + 超长文本解密,默认不加密 + :param message: 需要解密的数据 + :param pri_key: 秘钥 + :param length : 1024bit的证书用128,2048bit证书用256位 + :return: 解密后的数据 + """ + if pri_key is None: + pri_key = get_key_pair().get('value') + cipher = PKCS1_cipher.new(RSA.importKey(pri_key, passphrase=secret_code)) + base64_de = base64.b64decode(message) + res = [] + for i in range(0, len(base64_de), length): + res.append(cipher.decrypt(base64_de[i:i + length], 0)) + return b"".join(res).decode() diff --git a/src/MaxKB-1.7.2/apps/common/util/split_model.py b/src/MaxKB-1.7.2/apps/common/util/split_model.py new file mode 100644 index 0000000..0e7bcd5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/common/util/split_model.py @@ -0,0 +1,413 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: split_model.py + @date:2023/9/1 15:12 + @desc: +""" +import re +from functools import reduce +from typing import List, Dict + +import jieba + + +def get_level_block(text, level_content_list, level_content_index, cursor): + """ + 从文本中获取块数据 + :param text: 文本 + :param level_content_list: 拆分的title数组 + :param level_content_index: 指定的下标 + :param cursor: 开始的下标位置 + :return: 拆分后的文本数据 + """ + start_content: str = level_content_list[level_content_index].get('content') + next_content = level_content_list[level_content_index + 1].get("content") if level_content_index + 1 < len( + level_content_list) else None + start_index = text.index(start_content, cursor) + end_index = text.index(next_content, start_index + 1) if next_content is not None else len(text) + return text[start_index + len(start_content):end_index], end_index + + +def to_tree_obj(content, state='title'): + """ + 转换为树形对象 + :param content: 文本数据 + :param state: 状态: title block + :return: 转换后的数据 + """ + return {'content': content, 'state': state} + + +def remove_special_symbol(str_source: str): + """ + 删除特殊字符 + :param str_source: 需要删除的文本数据 + :return: 删除后的数据 + """ + return str_source + + +def filter_special_symbol(content: dict): + """ + 过滤文本中的特殊字符 + :param content: 需要过滤的对象 + :return: 过滤后返回 + """ + content['content'] = remove_special_symbol(content['content']) + return content + + +def flat(tree_data_list: List[dict], parent_chain: List[dict], result: List[dict]): + """ + 扁平化树形结构数据 + :param tree_data_list: 树形接口数据 + :param parent_chain: 父级数据 传[] 用于递归存储数据 + :param result: 响应数据 传[] 用于递归存放数据 + :return: result 扁平化后的数据 + """ + if parent_chain is None: + parent_chain = [] + if result is None: + result = [] + for tree_data in tree_data_list: + p = parent_chain.copy() + p.append(tree_data) + result.append(to_flat_obj(parent_chain, content=tree_data["content"], state=tree_data["state"])) + children = tree_data.get('children') + if children is not None and len(children) > 0: + flat(children, p, result) + return result + + +def to_paragraph(obj: dict): + """ + 转换为段落 + :param obj: 需要转换的对象 + :return: 段落对象 + """ + content = obj['content'] + return {"keywords": get_keyword(content), + 'parent_chain': list(map(lambda p: p['content'], obj['parent_chain'])), + 'content': ",".join(list(map(lambda p: p['content'], obj['parent_chain']))) + content} + + +def get_keyword(content: str): + """ + 获取content中的关键词 + :param content: 文本 + :return: 关键词数组 + """ + stopwords = [':', '“', '!', '”', '\n', '\\s'] + cutworms = jieba.lcut(content) + return list(set(list(filter(lambda k: (k not in stopwords) | len(k) > 1, cutworms)))) + + +def titles_to_paragraph(list_title: List[dict]): + """ + 将同一父级的title转换为块段落 + :param list_title: 同父级title + :return: 块段落 + """ + if len(list_title) > 0: + content = "\n,".join( + list(map(lambda d: d['content'].strip("\r\n").strip("\n").strip("\\s"), list_title))) + + return {'keywords': '', + 'parent_chain': list( + map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), list_title[0]['parent_chain'])), + 'content': ",".join(list( + map(lambda p: p['content'].strip("\r\n").strip("\n").strip("\\s"), + list_title[0]['parent_chain']))) + content} + return None + + +def parse_group_key(level_list: List[dict]): + """ + 将同级别同父级的title生成段落,加上本身的段落数据形成新的数据 + :param level_list: title n 级数据 + :return: 根据title生成的数据 + 段落数据 + """ + result = [] + group_data = group_by(list(filter(lambda f: f['state'] == 'title' and len(f['parent_chain']) > 0, level_list)), + key=lambda d: ",".join(list(map(lambda p: p['content'], d['parent_chain'])))) + result += list(map(lambda group_data_key: titles_to_paragraph(group_data[group_data_key]), group_data)) + result += list(map(to_paragraph, list(filter(lambda f: f['state'] == 'block', level_list)))) + return result + + +def to_block_paragraph(tree_data_list: List[dict]): + """ + 转换为块段落对象 + :param tree_data_list: 树数据 + :return: 块段落 + """ + flat_list = flat(tree_data_list, [], []) + level_group_dict: dict = group_by(flat_list, key=lambda f: f['level']) + return list(map(lambda level: parse_group_key(level_group_dict[level]), level_group_dict)) + + +def parse_title_level(text, content_level_pattern: List, index): + if index >= len(content_level_pattern): + return [] + result = parse_level(text, content_level_pattern[index]) + if len(result) == 0 and len(content_level_pattern) > index: + return parse_title_level(text, content_level_pattern, index + 1) + return result + + +def parse_level(text, pattern: str): + """ + 获取正则匹配到的文本 + :param text: 需要匹配的文本 + :param pattern: 正则 + :return: 符合正则的文本 + """ + level_content_list = list(map(to_tree_obj, [r[0:255] for r in re_findall(pattern, text) if r is not None])) + return list(map(filter_special_symbol, level_content_list)) + + +def re_findall(pattern, text): + result = re.findall(pattern, text, flags=0) + return list(filter(lambda r: r is not None and len(r) > 0, reduce(lambda x, y: [*x, *y], list( + map(lambda row: [*(row if isinstance(row, tuple) else [row])], result)), + []))) + + +def to_flat_obj(parent_chain: List[dict], content: str, state: str): + """ + 将树形属性转换为扁平对象 + :param parent_chain: + :param content: + :param state: + :return: + """ + return {'parent_chain': parent_chain, 'level': len(parent_chain), "content": content, 'state': state} + + +def flat_map(array: List[List]): + """ + 将二位数组转为一维数组 + :param array: 二维数组 + :return: 一维数组 + """ + result = [] + for e in array: + result += e + return result + + +def group_by(list_source: List, key): + """ + 將數組分組 + :param list_source: 需要分組的數組 + :param key: 分組函數 + :return: key->[] + """ + result = {} + for e in list_source: + k = key(e) + array = result.get(k) if k in result else [] + array.append(e) + result[k] = array + return result + + +def result_tree_to_paragraph(result_tree: List[dict], result, parent_chain, with_filter: bool): + """ + 转换为分段对象 + :param result_tree: 解析文本的树 + :param result: 传[] 用于递归 + :param parent_chain: 传[] 用户递归存储数据 + :param with_filter: 是否过滤block + :return: List[{'problem':'xx','content':'xx'}] + """ + for item in result_tree: + if item.get('state') == 'block': + result.append({'title': " ".join(parent_chain), + 'content': filter_special_char(item.get("content")) if with_filter else item.get("content")}) + children = item.get("children") + if children is not None and len(children) > 0: + result_tree_to_paragraph(children, result, + [*parent_chain, remove_special_symbol(item.get('content'))], with_filter) + return result + + +def post_handler_paragraph(content: str, limit: int): + """ + 根据文本的最大字符分段 + :param content: 需要分段的文本字段 + :param limit: 最大分段字符 + :return: 分段后数据 + """ + result = [] + temp_char, start = '', 0 + while (pos := content.find("\n", start)) != -1: + split, start = content[start:pos + 1], pos + 1 + if len(temp_char + split) > limit: + if len(temp_char) > 4096: + pass + result.append(temp_char) + temp_char = '' + temp_char = temp_char + split + temp_char = temp_char + content[start:] + if len(temp_char) > 0: + if len(temp_char) > 4096: + pass + result.append(temp_char) + + pattern = "[\\S\\s]{1," + str(limit) + '}' + # 如果\n 单段超过限制,则继续拆分 + return reduce(lambda x, y: [*x, *y], map(lambda row: re.findall(pattern, row), result), []) + + +replace_map = { + re.compile('\n+'): '\n', + re.compile(' +'): ' ', + re.compile('#+'): "", + re.compile("\t+"): '' +} + + +def filter_special_char(content: str): + """ + 过滤特殊字段 + :param content: 文本 + :return: 过滤后字段 + """ + items = replace_map.items() + for key, value in items: + content = re.sub(key, value, content) + return content + + +class SplitModel: + + def __init__(self, content_level_pattern, with_filter=True, limit=100000): + self.content_level_pattern = content_level_pattern + self.with_filter = with_filter + if limit is None or limit > 100000: + limit = 100000 + if limit < 50: + limit = 50 + self.limit = limit + + def parse_to_tree(self, text: str, index=0): + """ + 解析文本 + :param text: 需要解析的文本 + :param index: 从那个正则开始解析 + :return: 解析后的树形结果数据 + """ + level_content_list = parse_title_level(text, self.content_level_pattern, index) + if len(level_content_list) == 0: + return [to_tree_obj(row, 'block') for row in post_handler_paragraph(text, limit=self.limit)] + if index == 0 and text.lstrip().index(level_content_list[0]["content"].lstrip()) != 0: + level_content_list.insert(0, to_tree_obj("")) + + cursor = 0 + level_title_content_list = [item for item in level_content_list if item.get('state') == 'title'] + for i in range(len(level_title_content_list)): + start_content: str = level_title_content_list[i].get('content') + if cursor < text.index(start_content, cursor): + for row in post_handler_paragraph(text[cursor: text.index(start_content, cursor)], limit=self.limit): + level_content_list.insert(0, to_tree_obj(row, 'block')) + + block, cursor = get_level_block(text, level_title_content_list, i, cursor) + if len(block) == 0: + continue + children = self.parse_to_tree(text=block, index=index + 1) + level_title_content_list[i]['children'] = children + first_child_idx_in_block = block.lstrip().index(children[0]["content"].lstrip()) + if first_child_idx_in_block != 0: + inner_children = self.parse_to_tree(block[:first_child_idx_in_block], index + 1) + level_title_content_list[i]['children'].extend(inner_children) + return level_content_list + + def parse(self, text: str): + """ + 解析文本 + :param text: 文本数据 + :return: 解析后数据 {content:段落数据,keywords:[‘段落关键词’],parent_chain:['段落父级链路']} + """ + text = text.replace('\r\n', '\n') + text = text.replace('\r', '\n') + text = text.replace("\0", '') + result_tree = self.parse_to_tree(text, 0) + result = result_tree_to_paragraph(result_tree, [], [], self.with_filter) + for e in result: + if len(e['content']) > 4096: + pass + return [item for item in [self.post_reset_paragraph(row) for row in result] if + 'content' in item and len(item.get('content').strip()) > 0] + + def post_reset_paragraph(self, paragraph: Dict): + result = self.filter_title_special_characters(paragraph) + result = self.sub_title(result) + result = self.content_is_null(result) + return result + + @staticmethod + def sub_title(paragraph: Dict): + if 'title' in paragraph: + title = paragraph.get('title') + if len(title) > 255: + return {**paragraph, 'title': title[0:255], 'content': title[255:len(title)] + paragraph.get('content')} + return paragraph + + @staticmethod + def content_is_null(paragraph: Dict): + if 'title' in paragraph: + title = paragraph.get('title') + content = paragraph.get('content') + if (content is None or len(content.strip()) == 0) and (title is not None and len(title) > 0): + return {'title': '', 'content': title} + return paragraph + + @staticmethod + def filter_title_special_characters(paragraph: Dict): + title = paragraph.get('title') if 'title' in paragraph else '' + for title_special_characters in title_special_characters_list: + title = title.replace(title_special_characters, '') + return {**paragraph, + 'title': title} + + +title_special_characters_list = ['#', '\n', '\r', '\\s'] + +default_split_pattern = { + 'md': [re.compile('(?<=^)# .*|(?<=\\n)# .*'), + re.compile('(?<=\\n)(?!@#¥%……&*()!@#$%^&*(): ;,/"./' + +jieba_remove_flag_list = ['x', 'w'] + + +def get_word_list(text: str): + result = [] + for pattern in word_pattern_list: + word_list = re.findall(pattern, text) + for child_list in word_list: + for word in child_list if isinstance(child_list, tuple) else [child_list]: + # 不能有: 所以再使用: 进行分割 + if word.__contains__(':'): + item_list = word.split(":") + for w in item_list: + result.append(w) + else: + result.append(word) + return result + + +def replace_word(word_dict, text: str): + for key in word_dict: + pattern = '(?= 0]) + + +def to_query(text: str): + # 获取不分词的数据 + word_list = get_word_list(text) + # 获取关键词关系 + word_dict = to_word_dict(word_list, text) + # 替换字符串 + text = replace_word(word_dict, text) + extract_tags = analyse.extract_tags(text, topK=5, withWeight=True, allowPOS=('ns', 'n', 'vn', 'v', 'eng')) + result = " ".join([get_key_by_word_dict(word, word_dict) for word, score in extract_tags if + not remove_chars.__contains__(word)]) + # 删除词库 + for word in word_list: + jieba.del_word(word) + return result diff --git a/src/MaxKB-1.7.2/apps/dataset/__init__.py b/src/MaxKB-1.7.2/apps/dataset/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/dataset/apps.py b/src/MaxKB-1.7.2/apps/dataset/apps.py new file mode 100644 index 0000000..166bedb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class DatasetConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'dataset' diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/0001_initial.py b/src/MaxKB-1.7.2/apps/dataset/migrations/0001_initial.py new file mode 100644 index 0000000..e19fc6b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/migrations/0001_initial.py @@ -0,0 +1,98 @@ +# Generated by Django 4.1.10 on 2024-03-18 16:02 + +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('users', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='DataSet', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('name', models.CharField(max_length=150, verbose_name='数据集名称')), + ('desc', models.CharField(max_length=256, verbose_name='数据库描述')), + ('type', models.CharField(choices=[('0', '通用类型'), ('1', 'web站点类型')], default='0', max_length=1, verbose_name='类型')), + ('meta', models.JSONField(default=dict, verbose_name='元数据')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', verbose_name='所属用户')), + ], + options={ + 'db_table': 'dataset', + }, + ), + migrations.CreateModel( + name='Document', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('name', models.CharField(max_length=150, verbose_name='文档名称')), + ('char_length', models.IntegerField(verbose_name='文档字符数 冗余字段')), + ('status', models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败')], default='0', max_length=1, verbose_name='状态')), + ('is_active', models.BooleanField(default=True)), + ('type', models.CharField(choices=[('0', '通用类型'), ('1', 'web站点类型')], default='0', max_length=1, verbose_name='类型')), + ('meta', models.JSONField(default=dict, verbose_name='元数据')), + ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')), + ], + options={ + 'db_table': 'document', + }, + ), + migrations.CreateModel( + name='Paragraph', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('content', models.CharField(max_length=4096, verbose_name='段落内容')), + ('title', models.CharField(default='', max_length=256, verbose_name='标题')), + ('status', models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败')], default='0', max_length=1, verbose_name='状态')), + ('hit_num', models.IntegerField(default=0, verbose_name='命中次数')), + ('is_active', models.BooleanField(default=True)), + ('dataset', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')), + ('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')), + ], + options={ + 'db_table': 'paragraph', + }, + ), + migrations.CreateModel( + name='Problem', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('content', models.CharField(max_length=256, verbose_name='问题内容')), + ('hit_num', models.IntegerField(default=0, verbose_name='命中次数')), + ('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')), + ], + options={ + 'db_table': 'problem', + }, + ), + migrations.CreateModel( + name='ProblemParagraphMapping', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset')), + ('document', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document')), + ('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph')), + ('problem', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.problem')), + ], + options={ + 'db_table': 'problem_paragraph_mapping', + }, + ), + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/0002_image.py b/src/MaxKB-1.7.2/apps/dataset/migrations/0002_image.py new file mode 100644 index 0000000..a5fb59e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/migrations/0002_image.py @@ -0,0 +1,27 @@ +# Generated by Django 4.1.13 on 2024-04-22 19:31 + +from django.db import migrations, models +import uuid + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Image', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('image', models.BinaryField(verbose_name='图片数据')), + ('image_name', models.CharField(default='', max_length=256, verbose_name='图片名称')), + ], + options={ + 'db_table': 'image', + }, + ), + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/0003_document_hit_handling_method.py b/src/MaxKB-1.7.2/apps/dataset/migrations/0003_document_hit_handling_method.py new file mode 100644 index 0000000..e1746d6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/migrations/0003_document_hit_handling_method.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-04-24 15:36 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0002_image'), + ] + + operations = [ + migrations.AddField( + model_name='document', + name='hit_handling_method', + field=models.CharField(choices=[('optimization', '模型优化'), ('directly_return', '直接返回')], default='optimization', max_length=20, verbose_name='命中处理方式'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/0004_document_directly_return_similarity.py b/src/MaxKB-1.7.2/apps/dataset/migrations/0004_document_directly_return_similarity.py new file mode 100644 index 0000000..cddf38c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/migrations/0004_document_directly_return_similarity.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-05-08 16:43 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0003_document_hit_handling_method'), + ] + + operations = [ + migrations.AddField( + model_name='document', + name='directly_return_similarity', + field=models.FloatField(default=0.9, verbose_name='直接回答相似度'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/0005_file.py b/src/MaxKB-1.7.2/apps/dataset/migrations/0005_file.py new file mode 100644 index 0000000..3c74fc8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/migrations/0005_file.py @@ -0,0 +1,30 @@ +# Generated by Django 4.2.13 on 2024-07-05 18:59 + +from django.db import migrations, models +import uuid + +from smartdoc.const import CONFIG + + +class Migration(migrations.Migration): + dependencies = [ + ('dataset', '0004_document_directly_return_similarity'), + ] + + operations = [ + migrations.RunSQL(f"grant execute on function lo_from_bytea to {CONFIG.get('DB_USER')}"), + migrations.CreateModel( + name='File', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('file_name', models.CharField(default='', max_length=256, verbose_name='文件名称')), + ('loid', models.IntegerField(verbose_name='loid')), + ], + options={ + 'db_table': 'file', + }, + ), + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/0006_dataset_embedding_mode.py b/src/MaxKB-1.7.2/apps/dataset/migrations/0006_dataset_embedding_mode.py new file mode 100644 index 0000000..2248d8e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/migrations/0006_dataset_embedding_mode.py @@ -0,0 +1,21 @@ +# Generated by Django 4.2.13 on 2024-07-17 13:56 + +import dataset.models.data_set +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0005_model_permission_type'), + ('dataset', '0005_file'), + ] + + operations = [ + migrations.AddField( + model_name='dataset', + name='embedding_mode', + field=models.ForeignKey(default=dataset.models.data_set.default_model, on_delete=django.db.models.deletion.DO_NOTHING, to='setting.model', verbose_name='向量模型'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/0007_alter_paragraph_content.py b/src/MaxKB-1.7.2/apps/dataset/migrations/0007_alter_paragraph_content.py new file mode 100644 index 0000000..ab654b1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/migrations/0007_alter_paragraph_content.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.14 on 2024-07-24 14:35 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0006_dataset_embedding_mode'), + ] + + operations = [ + migrations.AlterField( + model_name='paragraph', + name='content', + field=models.CharField(max_length=102400, verbose_name='段落内容'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/0008_alter_document_status_alter_paragraph_status.py b/src/MaxKB-1.7.2/apps/dataset/migrations/0008_alter_document_status_alter_paragraph_status.py new file mode 100644 index 0000000..3380d7b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/migrations/0008_alter_document_status_alter_paragraph_status.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.14 on 2024-07-29 15:37 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0007_alter_paragraph_content'), + ] + + operations = [ + migrations.AlterField( + model_name='document', + name='status', + field=models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败'), ('3', '排队中')], default='3', max_length=1, verbose_name='状态'), + ), + migrations.AlterField( + model_name='paragraph', + name='status', + field=models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败'), ('3', '排队中')], default='0', max_length=1, verbose_name='状态'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/0009_alter_document_status_alter_paragraph_status.py b/src/MaxKB-1.7.2/apps/dataset/migrations/0009_alter_document_status_alter_paragraph_status.py new file mode 100644 index 0000000..7c138a6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/migrations/0009_alter_document_status_alter_paragraph_status.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.15 on 2024-10-15 14:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('dataset', '0008_alter_document_status_alter_paragraph_status'), + ] + + operations = [ + migrations.AlterField( + model_name='document', + name='status', + field=models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败'), ('3', '排队中'), ('4', '生成问题中')], default='3', max_length=1, verbose_name='状态'), + ), + migrations.AlterField( + model_name='paragraph', + name='status', + field=models.CharField(choices=[('0', '导入中'), ('1', '已完成'), ('2', '导入失败'), ('3', '排队中'), ('4', '生成问题中')], default='0', max_length=1, verbose_name='状态'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/migrations/__init__.py b/src/MaxKB-1.7.2/apps/dataset/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/dataset/models/__init__.py b/src/MaxKB-1.7.2/apps/dataset/models/__init__.py new file mode 100644 index 0000000..fdee77b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/models/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/21 9:32 + @desc: +""" +from .data_set import * diff --git a/src/MaxKB-1.7.2/apps/dataset/models/data_set.py b/src/MaxKB-1.7.2/apps/dataset/models/data_set.py new file mode 100644 index 0000000..9fcb0d6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/models/data_set.py @@ -0,0 +1,157 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: data_set.py + @date:2023/9/21 9:35 + @desc: 数据集 +""" +import uuid + +from django.db import models + +from common.db.sql_execute import select_one +from common.mixins.app_model_mixin import AppModelMixin +from setting.models import Model +from users.models import User + + +class Status(models.TextChoices): + """订单类型""" + embedding = 0, '导入中' + success = 1, '已完成' + error = 2, '导入失败' + queue_up = 3, '排队中' + generating = 4, '生成问题中' + + +class Type(models.TextChoices): + base = 0, '通用类型' + + web = 1, 'web站点类型' + + +class HitHandlingMethod(models.TextChoices): + optimization = 'optimization', '模型优化' + directly_return = 'directly_return', '直接返回' + + +def default_model(): + return uuid.UUID('42f63a3d-427e-11ef-b3ec-a8a1595801ab') + + +class DataSet(AppModelMixin): + """ + 数据集表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + name = models.CharField(max_length=150, verbose_name="数据集名称") + desc = models.CharField(max_length=256, verbose_name="数据库描述") + user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="所属用户") + type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, + default=Type.base) + embedding_mode = models.ForeignKey(Model, on_delete=models.DO_NOTHING, verbose_name="向量模型", + default=default_model) + meta = models.JSONField(verbose_name="元数据", default=dict) + + class Meta: + db_table = "dataset" + + +class Document(AppModelMixin): + """ + 文档表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) + name = models.CharField(max_length=150, verbose_name="文档名称") + char_length = models.IntegerField(verbose_name="文档字符数 冗余字段") + status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, + default=Status.queue_up) + is_active = models.BooleanField(default=True) + + type = models.CharField(verbose_name='类型', max_length=1, choices=Type.choices, + default=Type.base) + hit_handling_method = models.CharField(verbose_name='命中处理方式', max_length=20, + choices=HitHandlingMethod.choices, + default=HitHandlingMethod.optimization) + directly_return_similarity = models.FloatField(verbose_name='直接回答相似度', default=0.9) + + meta = models.JSONField(verbose_name="元数据", default=dict) + + class Meta: + db_table = "document" + + +class Paragraph(AppModelMixin): + """ + 段落表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, db_constraint=False) + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING) + content = models.CharField(max_length=102400, verbose_name="段落内容") + title = models.CharField(max_length=256, verbose_name="标题", default="") + status = models.CharField(verbose_name='状态', max_length=1, choices=Status.choices, + default=Status.embedding) + hit_num = models.IntegerField(verbose_name="命中次数", default=0) + is_active = models.BooleanField(default=True) + + class Meta: + db_table = "paragraph" + + +class Problem(AppModelMixin): + """ + 问题表 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False) + content = models.CharField(max_length=256, verbose_name="问题内容") + hit_num = models.IntegerField(verbose_name="命中次数", default=0) + + class Meta: + db_table = "problem" + + +class ProblemParagraphMapping(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, db_constraint=False) + document = models.ForeignKey(Document, on_delete=models.DO_NOTHING) + problem = models.ForeignKey(Problem, on_delete=models.DO_NOTHING, db_constraint=False) + paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, db_constraint=False) + + class Meta: + db_table = "problem_paragraph_mapping" + + +class Image(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + image = models.BinaryField(verbose_name="图片数据") + image_name = models.CharField(max_length=256, verbose_name="图片名称", default="") + + class Meta: + db_table = "image" + + +class File(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + + file_name = models.CharField(max_length=256, verbose_name="文件名称", default="") + + loid = models.IntegerField(verbose_name="loid") + + class Meta: + db_table = "file" + + def save( + self, bytea=None, force_insert=False, force_update=False, using=None, update_fields=None + ): + result = select_one("SELECT lo_from_bytea(%s, %s::bytea) as loid", [0, bytea]) + self.loid = result['loid'] + self.file_name = 'speech.mp3' + super().save() + + def get_byte(self): + result = select_one(f'SELECT lo_get({self.loid}) as "data"', []) + return result['data'] diff --git a/src/MaxKB-1.7.2/apps/dataset/serializers/common_serializers.py b/src/MaxKB-1.7.2/apps/dataset/serializers/common_serializers.py new file mode 100644 index 0000000..8f08a26 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/serializers/common_serializers.py @@ -0,0 +1,167 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: common_serializers.py + @date:2023/11/17 11:00 + @desc: +""" +import os +import uuid +from typing import List + +from django.db.models import QuerySet +from drf_yasg import openapi +from rest_framework import serializers + +from common.config.embedding_config import ModelManage +from common.db.search import native_search +from common.db.sql_execute import update_execute +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from common.util.field_message import ErrMessage +from common.util.file_util import get_file_content +from common.util.fork import Fork +from dataset.models import Paragraph, Problem, ProblemParagraphMapping, DataSet +from setting.models_provider import get_model +from smartdoc.conf import PROJECT_DIR + + +def update_document_char_length(document_id: str): + update_execute(get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'update_document_char_length.sql')), + (document_id, document_id)) + + +def list_paragraph(paragraph_list: List[str]): + if paragraph_list is None or len(paragraph_list) == 0: + return [] + return native_search(QuerySet(Paragraph).filter(id__in=paragraph_list), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql'))) + + +class MetaSerializer(serializers.Serializer): + class WebMeta(serializers.Serializer): + source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("文档地址")) + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("选择器")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + source_url = self.data.get('source_url') + response = Fork(source_url, []).fork() + if response.status == 500: + raise AppApiException(500, f"url错误,无法解析【{source_url}】") + + class BaseMeta(serializers.Serializer): + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + +class BatchSerializer(ApiMixin, serializers.Serializer): + id_list = serializers.ListField(required=True, child=serializers.UUIDField(required=True), + error_messages=ErrMessage.char("id列表")) + + def is_valid(self, *, model=None, raise_exception=False): + super().is_valid(raise_exception=True) + if model is not None: + id_list = self.data.get('id_list') + model_list = QuerySet(model).filter(id__in=id_list) + if len(model_list) != len(id_list): + model_id_list = [str(m.id) for m in model_list] + error_id_list = list(filter(lambda row_id: not model_id_list.__contains__(row_id), id_list)) + raise AppApiException(500, f"id不正确:{error_id_list}") + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="主键id列表", + description="主键id列表") + } + ) + + +class ProblemParagraphObject: + def __init__(self, dataset_id: str, document_id: str, paragraph_id: str, problem_content: str): + self.dataset_id = dataset_id + self.document_id = document_id + self.paragraph_id = paragraph_id + self.problem_content = problem_content + + +def or_get(exists_problem_list, content, dataset_id, document_id, paragraph_id, problem_content_dict): + if content in problem_content_dict: + return problem_content_dict.get(content)[0], document_id, paragraph_id + exists = [row for row in exists_problem_list if row.content == content] + if len(exists) > 0: + problem_content_dict[content] = exists[0], False + return exists[0], document_id, paragraph_id + else: + problem = Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id) + problem_content_dict[content] = problem, True + return problem, document_id, paragraph_id + + +class ProblemParagraphManage: + def __init__(self, problemParagraphObjectList: [ProblemParagraphObject], dataset_id): + self.dataset_id = dataset_id + self.problemParagraphObjectList = problemParagraphObjectList + + def to_problem_model_list(self): + problem_list = [item.problem_content for item in self.problemParagraphObjectList] + exists_problem_list = [] + if len(self.problemParagraphObjectList) > 0: + # 查询到已存在的问题列表 + exists_problem_list = QuerySet(Problem).filter(dataset_id=self.dataset_id, + content__in=problem_list).all() + problem_content_dict = {} + problem_model_list = [ + or_get(exists_problem_list, problemParagraphObject.problem_content, problemParagraphObject.dataset_id, + problemParagraphObject.document_id, problemParagraphObject.paragraph_id, problem_content_dict) for + problemParagraphObject in self.problemParagraphObjectList] + + problem_paragraph_mapping_list = [ + ProblemParagraphMapping(id=uuid.uuid1(), document_id=document_id, problem_id=problem_model.id, + paragraph_id=paragraph_id, + dataset_id=self.dataset_id) for + problem_model, document_id, paragraph_id in problem_model_list] + + result = [problem_model for problem_model, is_create in problem_content_dict.values() if + is_create], problem_paragraph_mapping_list + return result + + +def get_embedding_model_by_dataset_id_list(dataset_id_list: List): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("知识库未向量模型不一致") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return ModelManage.get_model(str(dataset_list[0].embedding_mode_id), + lambda _id: get_model(dataset_list[0].embedding_mode)) + + +def get_embedding_model_by_dataset_id(dataset_id: str): + dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first() + return ModelManage.get_model(str(dataset.embedding_mode_id), lambda _id: get_model(dataset.embedding_mode)) + + +def get_embedding_model_by_dataset(dataset): + return ModelManage.get_model(str(dataset.embedding_mode_id), lambda _id: get_model(dataset.embedding_mode)) + + +def get_embedding_model_id_by_dataset_id(dataset_id): + dataset = QuerySet(DataSet).select_related('embedding_mode').filter(id=dataset_id).first() + return str(dataset.embedding_mode_id) + + +def get_embedding_model_id_by_dataset_id_list(dataset_id_list: List): + dataset_list = QuerySet(DataSet).filter(id__in=dataset_id_list) + if len(set([dataset.embedding_mode_id for dataset in dataset_list])) > 1: + raise Exception("知识库未向量模型不一致") + if len(dataset_list) == 0: + raise Exception("知识库设置错误,请重新设置知识库") + return str(dataset_list[0].embedding_mode_id) diff --git a/src/MaxKB-1.7.2/apps/dataset/serializers/dataset_serializers.py b/src/MaxKB-1.7.2/apps/dataset/serializers/dataset_serializers.py new file mode 100644 index 0000000..7250bea --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/serializers/dataset_serializers.py @@ -0,0 +1,872 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: dataset_serializers.py + @date:2023/9/21 16:14 + @desc: +""" +import logging +import os.path +import re +import traceback +import uuid +from functools import reduce +from typing import Dict, List +from urllib.parse import urlparse + +from django.contrib.postgres.fields import ArrayField +from django.core import validators +from django.db import transaction, models +from django.db.models import QuerySet +from django.http import HttpResponse +from drf_yasg import openapi +from rest_framework import serializers + +from application.models import ApplicationDatasetMapping +from common.config.embedding_config import VectorStore +from common.db.search import get_dynamics_model, native_page_search, native_search +from common.db.sql_execute import select_list +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from common.util.common import post, flat_map, valid_license +from common.util.field_message import ErrMessage +from common.util.file_util import get_file_content +from common.util.fork import ChildLink, Fork +from common.util.split_model import get_split_model +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, ProblemParagraphMapping, Status +from dataset.serializers.common_serializers import list_paragraph, MetaSerializer, ProblemParagraphManage, \ + get_embedding_model_by_dataset_id, get_embedding_model_id_by_dataset_id +from dataset.serializers.document_serializers import DocumentSerializers, DocumentInstanceSerializer +from dataset.task import sync_web_dataset, sync_replace_web_dataset +from embedding.models import SearchMode +from embedding.task import embedding_by_dataset, delete_embedding_by_dataset +from setting.models import AuthOperate +from smartdoc.conf import PROJECT_DIR + +""" +# __exact 精确等于 like ‘aaa’ +# __iexact 精确等于 忽略大小写 ilike 'aaa' +# __contains 包含like '%aaa%' +# __icontains 包含 忽略大小写 ilike ‘%aaa%’,但是对于sqlite来说,contains的作用效果等同于icontains。 +# __gt 大于 +# __gte 大于等于 +# __lt 小于 +# __lte 小于等于 +# __in 存在于一个list范围内 +# __startswith 以…开头 +# __istartswith 以…开头 忽略大小写 +# __endswith 以…结尾 +# __iendswith 以…结尾,忽略大小写 +# __range 在…范围内 +# __year 日期字段的年份 +# __month 日期字段的月份 +# __day 日期字段的日 +# __isnull=True/False +""" + + +class DataSetSerializers(serializers.ModelSerializer): + class Meta: + model = DataSet + fields = ['id', 'name', 'desc', 'meta', 'create_time', 'update_time'] + + class Application(ApiMixin, serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id")) + + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id")) + + @staticmethod + def get_request_params_api(): + return [ + openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id') + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'model_id', 'multiple_rounds_dialogue', 'user_id', 'status', + 'create_time', + 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="", description="主键id"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="应用名称", description="应用名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="应用描述", description="应用描述"), + 'model_id': openapi.Schema(type=openapi.TYPE_STRING, title="模型id", description="模型id"), + "multiple_rounds_dialogue": openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否开启多轮对话", + description="是否开启多轮对话"), + 'prologue': openapi.Schema(type=openapi.TYPE_STRING, title="开场白", description="开场白"), + 'example': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="示例列表", description="示例列表"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户", description="所属用户"), + + 'status': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否发布", description='是否发布'), + + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description='创建时间'), + + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description='修改时间') + } + ) + + class Query(ApiMixin, serializers.Serializer): + """ + 查询对象 + """ + name = serializers.CharField(required=False, + error_messages=ErrMessage.char("知识库名称"), + max_length=64, + min_length=1) + + desc = serializers.CharField(required=False, + error_messages=ErrMessage.char("知识库描述"), + max_length=256, + min_length=1, + ) + + user_id = serializers.CharField(required=True) + + def get_query_set(self): + user_id = self.data.get("user_id") + query_set_dict = {} + query_set = QuerySet(model=get_dynamics_model( + {'temp.name': models.CharField(), 'temp.desc': models.CharField(), + "document_temp.char_length": models.IntegerField(), 'temp.create_time': models.DateTimeField()})) + if "desc" in self.data and self.data.get('desc') is not None: + query_set = query_set.filter(**{'temp.desc__icontains': self.data.get("desc")}) + if "name" in self.data and self.data.get('name') is not None: + query_set = query_set.filter(**{'temp.name__icontains': self.data.get("name")}) + query_set = query_set.order_by("-temp.create_time") + query_set_dict['default_sql'] = query_set + + query_set_dict['dataset_custom_sql'] = QuerySet(model=get_dynamics_model( + {'dataset.user_id': models.CharField(), + })).filter( + **{'dataset.user_id': user_id} + ) + + query_set_dict['team_member_permission_custom_sql'] = QuerySet(model=get_dynamics_model( + {'user_id': models.CharField(), + 'team_member_permission.auth_target_type': models.CharField(), + 'team_member_permission.operate': ArrayField(verbose_name="权限操作列表", + base_field=models.CharField(max_length=256, + blank=True, + choices=AuthOperate.choices, + default=AuthOperate.USE) + )})).filter( + **{'user_id': user_id, 'team_member_permission.operate__contains': ['USE'], + 'team_member_permission.auth_target_type': 'DATASET'}) + + return query_set_dict + + def page(self, current_page: int, page_size: int): + return native_page_search(current_page, page_size, self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), + post_records_handler=lambda r: r) + + def list(self): + return native_search(self.get_query_set(), select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql'))) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='知识库名称'), + openapi.Parameter(name='desc', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='知识库描述') + ] + + @staticmethod + def get_response_body_api(): + return DataSetSerializers.Operate.get_response_body_api() + + class Create(ApiMixin, serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id"), ) + + class CreateBaseSerializers(ApiMixin, serializers.Serializer): + """ + 创建通用数据集序列化对象 + """ + name = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库名称"), + max_length=64, + min_length=1) + + desc = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库描述"), + max_length=256, + min_length=1) + + embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型")) + + documents = DocumentInstanceSerializer(required=False, many=True) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + return True + + class CreateQASerializers(serializers.Serializer): + """ + 创建web站点序列化对象 + """ + name = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库名称"), + max_length=64, + min_length=1) + + desc = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库描述"), + max_length=256, + min_length=1) + + embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型")) + + file_list = serializers.ListSerializer(required=True, + error_messages=ErrMessage.list("文件列表"), + child=serializers.FileField(required=True, + error_messages=ErrMessage.file("文件"))) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_ARRAY, + items=openapi.Items(type=openapi.TYPE_FILE), + required=True, + description='上传文件'), + openapi.Parameter(name='name', + in_=openapi.IN_FORM, + required=True, + type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), + openapi.Parameter(name='desc', + in_=openapi.IN_FORM, + required=True, + type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count', + 'update_time', 'create_time', 'document_list'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试知识库"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述", + description="描述", default="测试知识库描述"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id", + description="所属用户id", default="user_xxxx"), + 'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数", + description="字符数", default=10), + 'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量", + description="文档数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + 'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表", + description="文档列表", + items=DocumentSerializers.Operate.get_response_body_api()) + } + ) + + class CreateWebSerializers(serializers.Serializer): + """ + 创建web站点序列化对象 + """ + name = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库名称"), + max_length=64, + min_length=1) + + desc = serializers.CharField(required=True, + error_messages=ErrMessage.char("知识库描述"), + max_length=256, + min_length=1) + source_url = serializers.CharField(required=True, error_messages=ErrMessage.char("Web 根地址"), ) + + embedding_mode_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("向量模型")) + + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("选择器")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + source_url = self.data.get('source_url') + response = Fork(source_url, []).fork() + if response.status == 500: + raise AppApiException(500, f"url错误,无法解析【{source_url}】") + return True + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count', + 'update_time', 'create_time', 'document_list'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试知识库"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述", + description="描述", default="测试知识库描述"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id", + description="所属用户id", default="user_xxxx"), + 'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数", + description="字符数", default=10), + 'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量", + description="文档数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + 'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表", + description="文档列表", + items=DocumentSerializers.Operate.get_response_body_api()) + } + ) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'desc', 'url'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + 'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title="向量模型id", + description="向量模型id"), + 'source_url': openapi.Schema(type=openapi.TYPE_STRING, title="web站点url", + description="web站点url"), + 'selector': openapi.Schema(type=openapi.TYPE_STRING, title="选择器", description="选择器") + } + ) + + @staticmethod + def post_embedding_dataset(document_list, dataset_id): + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + # 发送向量化事件 + embedding_by_dataset.delay(dataset_id, model_id) + return document_list + + def save_qa(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + self.CreateQASerializers(data=instance).is_valid() + file_list = instance.get('file_list') + document_list = flat_map([DocumentSerializers.Create.parse_qa_file(file) for file in file_list]) + dataset_instance = {'name': instance.get('name'), 'desc': instance.get('desc'), 'documents': document_list, + 'embedding_mode_id': instance.get('embedding_mode_id')} + return self.save(dataset_instance, with_valid=True) + + @valid_license(model=DataSet, count=50, + message='社区版最多支持 50 个知识库,如需拥有更多知识库,请联系我们(https://fit2cloud.com/)。') + @post(post_function=post_embedding_dataset) + @transaction.atomic + def save(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + self.CreateBaseSerializers(data=instance).is_valid() + dataset_id = uuid.uuid1() + user_id = self.data.get('user_id') + if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists(): + raise AppApiException(500, "知识库名称重复!") + dataset = DataSet( + **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id, + 'embedding_mode_id': instance.get('embedding_mode_id')}) + + document_model_list = [] + paragraph_model_list = [] + problem_paragraph_object_list = [] + # 插入文档 + for document in instance.get('documents') if 'documents' in instance else []: + document_paragraph_dict_model = DocumentSerializers.Create.get_document_paragraph_model(dataset_id, + document) + document_model_list.append(document_paragraph_dict_model.get('document')) + for paragraph in document_paragraph_dict_model.get('paragraph_model_list'): + paragraph_model_list.append(paragraph) + for problem_paragraph_object in document_paragraph_dict_model.get('problem_paragraph_object_list'): + problem_paragraph_object_list.append(problem_paragraph_object) + + problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list, + dataset_id) + .to_problem_model_list()) + # 插入知识库 + dataset.save() + # 插入文档 + QuerySet(Document).bulk_create(document_model_list) if len(document_model_list) > 0 else None + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 批量插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None + + # 响应数据 + return {**DataSetSerializers(dataset).data, + 'document_list': DocumentSerializers.Query(data={'dataset_id': dataset_id}).list( + with_valid=True)}, dataset_id + + @staticmethod + def get_last_url_path(url): + parsed_url = urlparse(url) + if parsed_url.path is None or len(parsed_url.path) == 0: + return url + else: + return parsed_url.path.split("/")[-1] + + def save_web(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + self.CreateWebSerializers(data=instance).is_valid(raise_exception=True) + user_id = self.data.get('user_id') + if QuerySet(DataSet).filter(user_id=user_id, name=instance.get('name')).exists(): + raise AppApiException(500, "知识库名称重复!") + dataset_id = uuid.uuid1() + dataset = DataSet( + **{'id': dataset_id, 'name': instance.get("name"), 'desc': instance.get('desc'), 'user_id': user_id, + 'type': Type.web, + 'embedding_mode_id': instance.get('embedding_mode_id'), + 'meta': {'source_url': instance.get('source_url'), 'selector': instance.get('selector'), + 'embedding_mode_id': instance.get('embedding_mode_id')}}) + dataset.save() + sync_web_dataset.delay(str(dataset_id), instance.get('source_url'), instance.get('selector')) + return {**DataSetSerializers(dataset).data, + 'document_list': []} + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count', + 'update_time', 'create_time', 'document_list'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试知识库"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述", + description="描述", default="测试知识库描述"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id", + description="所属用户id", default="user_xxxx"), + 'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数", + description="字符数", default=10), + 'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量", + description="文档数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + 'document_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档列表", + description="文档列表", + items=DocumentSerializers.Operate.get_response_body_api()) + } + ) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'desc'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + 'embedding_mode_id': openapi.Schema(type=openapi.TYPE_STRING, title='向量模型', + description='向量模型'), + 'documents': openapi.Schema(type=openapi.TYPE_ARRAY, title="文档数据", description="文档数据", + items=DocumentSerializers().Create.get_request_body_api() + ) + } + ) + + class Edit(serializers.Serializer): + name = serializers.CharField(required=False, max_length=64, min_length=1, + error_messages=ErrMessage.char("知识库名称")) + desc = serializers.CharField(required=False, max_length=256, min_length=1, + error_messages=ErrMessage.char("知识库描述")) + meta = serializers.DictField(required=False) + application_id_list = serializers.ListSerializer(required=False, child=serializers.UUIDField(required=True, + error_messages=ErrMessage.char( + "应用id")), + error_messages=ErrMessage.char("应用列表")) + + @staticmethod + def get_dataset_meta_valid_map(): + dataset_meta_valid_map = { + Type.base: MetaSerializer.BaseMeta, + Type.web: MetaSerializer.WebMeta + } + return dataset_meta_valid_map + + def is_valid(self, *, dataset: DataSet = None): + super().is_valid(raise_exception=True) + if 'meta' in self.data and self.data.get('meta') is not None: + dataset_meta_valid_map = self.get_dataset_meta_valid_map() + valid_class = dataset_meta_valid_map.get(dataset.type) + valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) + + class HitTest(ApiMixin, serializers.Serializer): + id = serializers.CharField(required=True, error_messages=ErrMessage.char("id")) + user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char("用户id")) + query_text = serializers.CharField(required=True, error_messages=ErrMessage.char("查询文本")) + top_number = serializers.IntegerField(required=True, max_value=100, min_value=1, + error_messages=ErrMessage.char("响应Top")) + similarity = serializers.FloatField(required=True, max_value=2, min_value=0, + error_messages=ErrMessage.char("相似度")) + search_mode = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^embedding|keywords|blend$"), + message="类型只支持register|reset_password", code=500) + ], error_messages=ErrMessage.char("检索模式")) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if not QuerySet(DataSet).filter(id=self.data.get("id")).exists(): + raise AppApiException(300, "id不存在") + + def hit_test(self): + self.is_valid() + vector = VectorStore.get_embedding_vector() + exclude_document_id_list = [str(document.id) for document in + QuerySet(Document).filter( + dataset_id=self.data.get('id'), + is_active=False)] + model = get_embedding_model_by_dataset_id(self.data.get('id')) + # 向量库检索 + hit_list = vector.hit_test(self.data.get('query_text'), [self.data.get('id')], exclude_document_id_list, + self.data.get('top_number'), + self.data.get('similarity'), + SearchMode(self.data.get('search_mode')), + model) + hit_dict = reduce(lambda x, y: {**x, **y}, [{hit.get('paragraph_id'): hit} for hit in hit_list], {}) + p_list = list_paragraph([h.get('paragraph_id') for h in hit_list]) + return [{**p, 'similarity': hit_dict.get(p.get('id')).get('similarity'), + 'comprehensive_score': hit_dict.get(p.get('id')).get('comprehensive_score')} for p in p_list] + + class SyncWeb(ApiMixin, serializers.Serializer): + id = serializers.CharField(required=True, error_messages=ErrMessage.char( + "知识库id")) + user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char( + "用户id")) + sync_type = serializers.CharField(required=True, error_messages=ErrMessage.char( + "同步类型"), validators=[ + validators.RegexValidator(regex=re.compile("^replace|complete$"), + message="同步类型只支持:replace|complete", code=500) + ]) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + first = QuerySet(DataSet).filter(id=self.data.get("id")).first() + if first is None: + raise AppApiException(300, "id不存在") + if first.type != Type.web: + raise AppApiException(500, "只有web站点类型才支持同步") + + def sync(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + sync_type = self.data.get('sync_type') + dataset_id = self.data.get('id') + dataset = QuerySet(DataSet).get(id=dataset_id) + self.__getattribute__(sync_type + '_sync')(dataset) + return True + + @staticmethod + def get_sync_handler(dataset): + def handler(child_link: ChildLink, response: Fork.Response): + if response.status == 200: + try: + document_name = child_link.tag.text if child_link.tag is not None and len( + child_link.tag.text.strip()) > 0 else child_link.url + paragraphs = get_split_model('web.md').parse(response.content) + print(child_link.url.strip()) + first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(), + dataset=dataset).first() + if first is not None: + # 如果存在,使用文档同步 + DocumentSerializers.Sync(data={'document_id': first.id}).sync() + else: + # 插入 + DocumentSerializers.Create(data={'dataset_id': dataset.id}).save( + {'name': document_name, 'paragraphs': paragraphs, + 'meta': {'source_url': child_link.url.strip(), + 'selector': dataset.meta.get('selector')}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + + return handler + + def replace_sync(self, dataset): + """ + 替换同步 + :return: + """ + url = dataset.meta.get('source_url') + selector = dataset.meta.get('selector') if 'selector' in dataset.meta else None + sync_replace_web_dataset.delay(str(dataset.id), url, selector) + + def complete_sync(self, dataset): + """ + 完整同步 删掉当前数据集下所有的文档,再进行同步 + :return: + """ + # 删除关联问题 + QuerySet(ProblemParagraphMapping).filter(dataset=dataset).delete() + # 删除文档 + QuerySet(Document).filter(dataset=dataset).delete() + # 删除段落 + QuerySet(Paragraph).filter(dataset=dataset).delete() + # 删除向量 + delete_embedding_by_dataset(self.data.get('id')) + # 同步 + self.replace_sync(dataset) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='sync_type', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='同步类型->replace:替换同步,complete:完整同步') + ] + + class Operate(ApiMixin, serializers.Serializer): + id = serializers.CharField(required=True, error_messages=ErrMessage.char( + "知识库id")) + user_id = serializers.UUIDField(required=False, error_messages=ErrMessage.char( + "用户id")) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if not QuerySet(DataSet).filter(id=self.data.get("id")).exists(): + raise AppApiException(300, "id不存在") + + def export_excel(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_list = QuerySet(Document).filter(dataset_id=self.data.get('id')) + paragraph_list = native_search(QuerySet(Paragraph).filter(dataset_id=self.data.get("id")), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph_document_name.sql'))) + problem_mapping_list = native_search( + QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("id")), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')), + with_table_name=True) + data_dict, document_dict = DocumentSerializers.Operate.merge_problem(paragraph_list, problem_mapping_list, + document_list) + workbook = DocumentSerializers.Operate.get_workbook(data_dict, document_dict) + response = HttpResponse(content_type='application/vnd.ms-excel') + response['Content-Disposition'] = 'attachment; filename="dataset.xlsx"' + workbook.save(response) + return response + + @staticmethod + def merge_problem(paragraph_list: List[Dict], problem_mapping_list: List[Dict]): + result = {} + document_dict = {} + + for paragraph in paragraph_list: + problem_list = [problem_mapping.get('content') for problem_mapping in problem_mapping_list if + problem_mapping.get('paragraph_id') == paragraph.get('id')] + document_sheet = result.get(paragraph.get('document_id')) + d = document_dict.get(paragraph.get('document_name')) + if d is None: + document_dict[paragraph.get('document_name')] = {paragraph.get('document_id')} + else: + d.add(paragraph.get('document_id')) + + if document_sheet is None: + result[paragraph.get('document_id')] = [[paragraph.get('title'), paragraph.get('content'), + '\n'.join(problem_list)]] + else: + document_sheet.append([paragraph.get('title'), paragraph.get('content'), '\n'.join(problem_list)]) + result_document_dict = {} + for d_name in document_dict: + for index, d_id in enumerate(document_dict.get(d_name)): + result_document_dict[d_id] = d_name if index == 0 else d_name + str(index) + return result, result_document_dict + + @transaction.atomic + def delete(self): + self.is_valid() + dataset = QuerySet(DataSet).get(id=self.data.get("id")) + QuerySet(Document).filter(dataset=dataset).delete() + QuerySet(ProblemParagraphMapping).filter(dataset=dataset).delete() + QuerySet(Paragraph).filter(dataset=dataset).delete() + QuerySet(Problem).filter(dataset=dataset).delete() + dataset.delete() + delete_embedding_by_dataset(self.data.get('id')) + return True + + def re_embedding(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + + QuerySet(Document).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) + QuerySet(Paragraph).filter(dataset_id=self.data.get('id')).update(**{'status': Status.queue_up}) + embedding_model_id = get_embedding_model_id_by_dataset_id(self.data.get('id')) + embedding_by_dataset.delay(self.data.get('id'), embedding_model_id) + + def list_application(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + dataset = QuerySet(DataSet).get(id=self.data.get("id")) + return select_list(get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset_application.sql')), + [self.data.get('user_id') if self.data.get('user_id') == str(dataset.user_id) else None, + dataset.user_id, self.data.get('user_id')]) + + def one(self, user_id, with_valid=True): + if with_valid: + self.is_valid() + query_set_dict = {'default_sql': QuerySet(model=get_dynamics_model( + {'temp.id': models.UUIDField()})).filter(**{'temp.id': self.data.get("id")}), + 'dataset_custom_sql': QuerySet(model=get_dynamics_model( + {'dataset.user_id': models.CharField()})).filter( + **{'dataset.user_id': user_id} + ), 'team_member_permission_custom_sql': QuerySet( + model=get_dynamics_model({'user_id': models.CharField(), + 'team_member_permission.operate': ArrayField( + verbose_name="权限操作列表", + base_field=models.CharField(max_length=256, + blank=True, + choices=AuthOperate.choices, + default=AuthOperate.USE) + )})).filter( + **{'user_id': user_id, 'team_member_permission.operate__contains': ['USE']})} + all_application_list = [str(adm.get('id')) for adm in self.list_application(with_valid=False)] + return {**native_search(query_set_dict, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_dataset.sql')), with_search_one=True), + 'application_id_list': list( + filter(lambda application_id: all_application_list.__contains__(application_id), + [str(application_dataset_mapping.application_id) for + application_dataset_mapping in + QuerySet(ApplicationDatasetMapping).filter( + dataset_id=self.data.get('id'))]))} + + @transaction.atomic + def edit(self, dataset: Dict, user_id: str): + """ + 修改知识库 + :param user_id: 用户id + :param dataset: Dict name desc + :return: + """ + self.is_valid() + if QuerySet(DataSet).filter(user_id=user_id, name=dataset.get('name')).exclude( + id=self.data.get('id')).exists(): + raise AppApiException(500, "知识库名称重复!") + _dataset = QuerySet(DataSet).get(id=self.data.get("id")) + DataSetSerializers.Edit(data=dataset).is_valid(dataset=_dataset) + if 'embedding_mode_id' in dataset: + _dataset.embedding_mode_id = dataset.get('embedding_mode_id') + if "name" in dataset: + _dataset.name = dataset.get("name") + if 'desc' in dataset: + _dataset.desc = dataset.get("desc") + if 'meta' in dataset: + _dataset.meta = dataset.get('meta') + if 'application_id_list' in dataset and dataset.get('application_id_list') is not None: + application_id_list = dataset.get('application_id_list') + # 当前用户可修改关联的知识库列表 + application_dataset_id_list = [str(dataset_dict.get('id')) for dataset_dict in + self.list_application(with_valid=False)] + for dataset_id in application_id_list: + if not application_dataset_id_list.__contains__(dataset_id): + raise AppApiException(500, f"未知的应用id${dataset_id},无法关联") + + # 删除已经关联的id + QuerySet(ApplicationDatasetMapping).filter(application_id__in=application_dataset_id_list, + dataset_id=self.data.get("id")).delete() + # 插入 + QuerySet(ApplicationDatasetMapping).bulk_create( + [ApplicationDatasetMapping(application_id=application_id, dataset_id=self.data.get('id')) for + application_id in + application_id_list]) if len(application_id_list) > 0 else None + [ApplicationDatasetMapping(application_id=application_id, dataset_id=self.data.get('id')) for + application_id in application_id_list] + + _dataset.save() + return self.one(with_valid=False, user_id=user_id) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'desc'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="知识库名称", description="知识库名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="知识库描述", description="知识库描述"), + 'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="知识库元数据", + description="知识库元数据->web:{source_url:xxx,selector:'xxx'},base:{}"), + 'application_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="应用id列表", + description="应用id列表", + items=openapi.Schema(type=openapi.TYPE_STRING)) + } + ) + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'user_id', 'char_length', 'document_count', + 'update_time', 'create_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试知识库"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="描述", + description="描述", default="测试知识库描述"), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="所属用户id", + description="所属用户id", default="user_xxxx"), + 'char_length': openapi.Schema(type=openapi.TYPE_STRING, title="字符数", + description="字符数", default=10), + 'document_count': openapi.Schema(type=openapi.TYPE_STRING, title="文档数量", + description="文档数量", default=1), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id') + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/serializers/document_serializers.py b/src/MaxKB-1.7.2/apps/dataset/serializers/document_serializers.py new file mode 100644 index 0000000..61a6b02 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/serializers/document_serializers.py @@ -0,0 +1,1025 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: document_serializers.py + @date:2023/9/22 13:43 + @desc: +""" +import logging +import os +import re +import traceback +import uuid +from functools import reduce +from typing import List, Dict + +import openpyxl +from celery_once import AlreadyQueued +from django.core import validators +from django.db import transaction +from django.db.models import QuerySet +from django.http import HttpResponse +from drf_yasg import openapi +from openpyxl.cell.cell import ILLEGAL_CHARACTERS_RE +from rest_framework import serializers +from xlwt import Utils + +from common.db.search import native_search, native_page_search +from common.event.common import work_thread_pool +from common.exception.app_exception import AppApiException +from common.handle.impl.doc_split_handle import DocSplitHandle +from common.handle.impl.html_split_handle import HTMLSplitHandle +from common.handle.impl.pdf_split_handle import PdfSplitHandle +from common.handle.impl.qa.csv_parse_qa_handle import CsvParseQAHandle +from common.handle.impl.qa.xls_parse_qa_handle import XlsParseQAHandle +from common.handle.impl.qa.xlsx_parse_qa_handle import XlsxParseQAHandle +from common.handle.impl.table.csv_parse_table_handle import CsvSplitHandle +from common.handle.impl.table.xls_parse_table_handle import XlsSplitHandle +from common.handle.impl.table.xlsx_parse_table_handle import XlsxSplitHandle +from common.handle.impl.text_split_handle import TextSplitHandle +from common.mixins.api_mixin import ApiMixin +from common.util.common import post, flat_map +from common.util.field_message import ErrMessage +from common.util.file_util import get_file_content +from common.util.fork import Fork +from common.util.split_model import get_split_model +from dataset.models.data_set import DataSet, Document, Paragraph, Problem, Type, Status, ProblemParagraphMapping, Image +from dataset.serializers.common_serializers import BatchSerializer, MetaSerializer, ProblemParagraphManage, \ + get_embedding_model_id_by_dataset_id +from dataset.serializers.paragraph_serializers import ParagraphSerializers, ParagraphInstanceSerializer +from dataset.task import sync_web_document, generate_related_by_document_id +from embedding.task.embedding import embedding_by_document, delete_embedding_by_document_list, \ + delete_embedding_by_document, update_embedding_dataset_id, delete_embedding_by_paragraph_ids, \ + embedding_by_document_list +from smartdoc.conf import PROJECT_DIR + +parse_qa_handle_list = [XlsParseQAHandle(), CsvParseQAHandle(), XlsxParseQAHandle()] +parse_table_handle_list = [CsvSplitHandle(), XlsSplitHandle(), XlsxSplitHandle()] + + +class FileBufferHandle: + buffer = None + + def get_buffer(self, file): + if self.buffer is None: + self.buffer = file.read() + return self.buffer + + +class DocumentEditInstanceSerializer(ApiMixin, serializers.Serializer): + meta = serializers.DictField(required=False) + name = serializers.CharField(required=False, max_length=128, min_length=1, + error_messages=ErrMessage.char( + "文档名称")) + hit_handling_method = serializers.CharField(required=False, validators=[ + validators.RegexValidator(regex=re.compile("^optimization|directly_return$"), + message="类型只支持optimization|directly_return", + code=500) + ], error_messages=ErrMessage.char("命中处理方式")) + + directly_return_similarity = serializers.FloatField(required=False, + max_value=2, + min_value=0, + error_messages=ErrMessage.float( + "直接返回分数")) + + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean( + "文档是否可用")) + + @staticmethod + def get_meta_valid_map(): + dataset_meta_valid_map = { + Type.base: MetaSerializer.BaseMeta, + Type.web: MetaSerializer.WebMeta + } + return dataset_meta_valid_map + + def is_valid(self, *, document: Document = None): + super().is_valid(raise_exception=True) + if 'meta' in self.data and self.data.get('meta') is not None: + dataset_meta_valid_map = self.get_meta_valid_map() + valid_class = dataset_meta_valid_map.get(document.type) + valid_class(data=self.data.get('meta')).is_valid(raise_exception=True) + + +class DocumentWebInstanceSerializer(ApiMixin, serializers.Serializer): + source_url_list = serializers.ListField(required=True, + child=serializers.CharField(required=True, error_messages=ErrMessage.char( + "文档地址")), + error_messages=ErrMessage.char( + "文档地址列表")) + selector = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char( + "选择器")) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_ARRAY, + items=openapi.Items(type=openapi.TYPE_FILE), + required=True, + description='上传文件'), + openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + ] + + +class DocumentInstanceSerializer(ApiMixin, serializers.Serializer): + name = serializers.CharField(required=True, + error_messages=ErrMessage.char("文档名称"), + max_length=128, + min_length=1) + + paragraphs = ParagraphInstanceSerializer(required=False, many=True, allow_null=True) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'paragraphs'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"), + 'paragraphs': openapi.Schema(type=openapi.TYPE_ARRAY, title="段落列表", description="段落列表", + items=ParagraphSerializers.Create.get_request_body_api()) + } + ) + + +class DocumentInstanceQASerializer(ApiMixin, serializers.Serializer): + file_list = serializers.ListSerializer(required=True, + error_messages=ErrMessage.list("文件列表"), + child=serializers.FileField(required=True, + error_messages=ErrMessage.file("文件"))) + + +class DocumentInstanceTableSerializer(ApiMixin, serializers.Serializer): + file_list = serializers.ListSerializer(required=True, + error_messages=ErrMessage.list("文件列表"), + child=serializers.FileField(required=True, + error_messages=ErrMessage.file("文件"))) + + +class DocumentSerializers(ApiMixin, serializers.Serializer): + class Export(ApiMixin, serializers.Serializer): + type = serializers.CharField(required=True, validators=[ + validators.RegexValidator(regex=re.compile("^csv|excel$"), + message="模版类型只支持excel|csv", + code=500) + ], error_messages=ErrMessage.char("模版类型")) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='type', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='导出模板类型csv|excel'), + + ] + + def export(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + + if self.data.get('type') == 'csv': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'csv_template.csv'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv', + 'Content-Disposition': 'attachment; filename="csv_template.csv"'}) + elif self.data.get('type') == 'excel': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'excel_template.xlsx'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel', + 'Content-Disposition': 'attachment; filename="excel_template.xlsx"'}) + + def table_export(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + + if self.data.get('type') == 'csv': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'MaxKB表格模板.csv'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'text/cxv', + 'Content-Disposition': 'attachment; filename="csv_template.csv"'}) + elif self.data.get('type') == 'excel': + file = open(os.path.join(PROJECT_DIR, "apps", "dataset", 'template', 'MaxKB表格模板.xlsx'), "rb") + content = file.read() + file.close() + return HttpResponse(content, status=200, headers={'Content-Type': 'application/vnd.ms-excel', + 'Content-Disposition': 'attachment; filename="excel_template.xlsx"'}) + + class Migrate(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, + error_messages=ErrMessage.char( + "知识库id")) + target_dataset_id = serializers.UUIDField(required=True, + error_messages=ErrMessage.char( + "目标知识库id")) + document_id_list = serializers.ListField(required=True, error_messages=ErrMessage.char("文档列表"), + child=serializers.UUIDField(required=True, + error_messages=ErrMessage.uuid("文档id"))) + + @transaction.atomic + def migrate(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + target_dataset_id = self.data.get('target_dataset_id') + dataset = QuerySet(DataSet).filter(id=dataset_id).first() + target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first() + document_id_list = self.data.get('document_id_list') + document_list = QuerySet(Document).filter(dataset_id=dataset_id, id__in=document_id_list) + paragraph_list = QuerySet(Paragraph).filter(dataset_id=dataset_id, document_id__in=document_id_list) + + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list) + problem_list = QuerySet(Problem).filter( + id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in + problem_paragraph_mapping_list]) + target_problem_list = list( + QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list], + dataset_id=target_dataset_id)) + target_handle_problem_list = [ + self.get_target_dataset_problem(target_dataset_id, problem_paragraph_mapping, + problem_list, target_problem_list) for + problem_paragraph_mapping + in + problem_paragraph_mapping_list] + + create_problem_list = [problem for problem, is_create in target_handle_problem_list if + is_create is not None and is_create] + # 插入问题 + QuerySet(Problem).bulk_create(create_problem_list) + # 修改mapping + QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, ['problem_id', 'dataset_id']) + # 修改文档 + if dataset.type == Type.base.value and target_dataset.type == Type.web.value: + document_list.update(dataset_id=target_dataset_id, type=Type.web, + meta={'source_url': '', 'selector': ''}) + elif target_dataset.type == Type.base.value and dataset.type == Type.web.value: + document_list.update(dataset_id=target_dataset_id, type=Type.base, + meta={}) + else: + document_list.update(dataset_id=target_dataset_id) + model_id = None + if dataset.embedding_mode_id != target_dataset.embedding_mode_id: + model_id = get_embedding_model_id_by_dataset_id(target_dataset_id) + + pid_list = [paragraph.id for paragraph in paragraph_list] + # 修改段落信息 + paragraph_list.update(dataset_id=target_dataset_id) + # 修改向量信息 + if model_id: + delete_embedding_by_paragraph_ids(pid_list) + QuerySet(Document).filter(id__in=document_id_list).update(status=Status.queue_up) + embedding_by_document_list.delay(document_id_list, model_id) + else: + update_embedding_dataset_id(pid_list, target_dataset_id) + + @staticmethod + def get_target_dataset_problem(target_dataset_id: str, + problem_paragraph_mapping, + source_problem_list, + target_problem_list): + source_problem_list = [source_problem for source_problem in source_problem_list if + source_problem.id == problem_paragraph_mapping.problem_id] + problem_paragraph_mapping.dataset_id = target_dataset_id + if len(source_problem_list) > 0: + problem_content = source_problem_list[-1].content + problem_list = [problem for problem in target_problem_list if problem.content == problem_content] + if len(problem_list) > 0: + problem = problem_list[-1] + problem_paragraph_mapping.problem_id = problem.id + return problem, False + else: + problem = Problem(id=uuid.uuid1(), dataset_id=target_dataset_id, content=problem_content) + target_problem_list.append(problem) + problem_paragraph_mapping.problem_id = problem.id + return problem, True + return None + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='target_dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='目标知识库id') + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title='文档id列表', + description="文档id列表" + ) + + class Query(ApiMixin, serializers.Serializer): + # 知识库id + dataset_id = serializers.UUIDField(required=True, + error_messages=ErrMessage.char( + "知识库id")) + + name = serializers.CharField(required=False, max_length=128, + min_length=1, + error_messages=ErrMessage.char( + "文档名称")) + hit_handling_method = serializers.CharField(required=False, error_messages=ErrMessage.char("命中处理方式")) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean("文档是否可用")) + status = serializers.CharField(required=False, error_messages=ErrMessage.char("文档状态")) + + def get_query_set(self): + query_set = QuerySet(model=Document) + query_set = query_set.filter(**{'dataset_id': self.data.get("dataset_id")}) + if 'name' in self.data and self.data.get('name') is not None: + query_set = query_set.filter(**{'name__icontains': self.data.get('name')}) + if 'hit_handling_method' in self.data and self.data.get('hit_handling_method') is not None: + query_set = query_set.filter(**{'hit_handling_method': self.data.get('hit_handling_method')}) + if 'is_active' in self.data and self.data.get('is_active') is not None: + query_set = query_set.filter(**{'is_active': self.data.get('is_active')}) + if 'status' in self.data and self.data.get('status') is not None: + query_set = query_set.filter(**{'status': self.data.get('status')}) + query_set = query_set.order_by('-create_time') + return query_set + + def list(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + query_set = self.get_query_set() + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql'))) + + def page(self, current_page, page_size): + query_set = self.get_query_set() + return native_page_search(current_page, page_size, query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql'))) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='文档名称'), + openapi.Parameter(name='hit_handling_method', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='文档命中处理方式')] + + @staticmethod + def get_response_body_api(): + return openapi.Schema(type=openapi.TYPE_ARRAY, + title="文档列表", description="文档列表", + items=DocumentSerializers.Operate.get_response_body_api()) + + class Sync(ApiMixin, serializers.Serializer): + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "文档id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + document_id = self.data.get('document_id') + first = QuerySet(Document).filter(id=document_id).first() + if first is None: + raise AppApiException(500, "文档id不存在") + if first.type != Type.web: + raise AppApiException(500, "只有web站点类型才支持同步") + + def sync(self, with_valid=True, with_embedding=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id = self.data.get('document_id') + document = QuerySet(Document).filter(id=document_id).first() + if document.type != Type.web: + return True + try: + document.status = Status.queue_up + document.save() + source_url = document.meta.get('source_url') + selector_list = document.meta.get('selector').split( + " ") if 'selector' in document.meta and document.meta.get('selector') is not None else [] + result = Fork(source_url, selector_list).fork() + if result.status == 200: + # 删除段落 + QuerySet(model=Paragraph).filter(document_id=document_id).delete() + # 删除问题 + QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete() + # 删除向量库 + delete_embedding_by_document(document_id) + paragraphs = get_split_model('web.md').parse(result.content) + document.char_length = reduce(lambda x, y: x + y, + [len(p.get('content')) for p in paragraphs], + 0) + document.save() + document_paragraph_model = DocumentSerializers.Create.get_paragraph_model(document, paragraphs) + + paragraph_model_list = document_paragraph_model.get('paragraph_model_list') + problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list') + problem_model_list, problem_paragraph_mapping_list = ProblemParagraphManage( + problem_paragraph_object_list, document.dataset_id).to_problem_model_list() + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None + # 向量化 + if with_embedding: + embedding_model_id = get_embedding_model_id_by_dataset_id(document.dataset_id) + embedding_by_document.delay(document_id, embedding_model_id) + else: + document.status = Status.error + document.save() + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + document.status = Status.error + document.save() + return True + + class Operate(ApiMixin, serializers.Serializer): + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "文档id")) + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("数据集id")) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id') + ] + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + document_id = self.data.get('document_id') + if not QuerySet(Document).filter(id=document_id).exists(): + raise AppApiException(500, "文档id不存在") + + def export(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document = QuerySet(Document).filter(id=self.data.get("document_id")).first() + paragraph_list = native_search(QuerySet(Paragraph).filter(document_id=self.data.get("document_id")), + get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', + 'list_paragraph_document_name.sql'))) + problem_mapping_list = native_search( + QuerySet(ProblemParagraphMapping).filter(document_id=self.data.get("document_id")), get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem_mapping.sql')), + with_table_name=True) + data_dict, document_dict = self.merge_problem(paragraph_list, problem_mapping_list, [document]) + workbook = self.get_workbook(data_dict, document_dict) + response = HttpResponse(content_type='application/vnd.ms-excel') + response['Content-Disposition'] = f'attachment; filename="data.xlsx"' + workbook.save(response) + return response + + @staticmethod + def get_workbook(data_dict, document_dict): + # 创建工作簿对象 + workbook = openpyxl.Workbook() + workbook.remove_sheet(workbook.active) + if len(data_dict.keys()) == 0: + data_dict['sheet'] = [] + for sheet_id in data_dict: + # 添加工作表 + worksheet = workbook.create_sheet(document_dict.get(sheet_id)) + data = [ + ['分段标题(选填)', '分段内容(必填,问题答案,最长不超过4096个字符)', '问题(选填,单元格内一行一个)'], + *data_dict.get(sheet_id, []) + ] + # 写入数据到工作表 + for row_idx, row in enumerate(data): + for col_idx, col in enumerate(row): + cell = worksheet.cell(row=row_idx + 1, column=col_idx + 1) + if isinstance(col, str): + col = re.sub(ILLEGAL_CHARACTERS_RE, '', col) + cell.value = col + # 创建HttpResponse对象返回Excel文件 + return workbook + + @staticmethod + def merge_problem(paragraph_list: List[Dict], problem_mapping_list: List[Dict], document_list): + result = {} + document_dict = {} + + for paragraph in paragraph_list: + problem_list = [problem_mapping.get('content') for problem_mapping in problem_mapping_list if + problem_mapping.get('paragraph_id') == paragraph.get('id')] + document_sheet = result.get(paragraph.get('document_id')) + document_name = DocumentSerializers.Operate.reset_document_name(paragraph.get('document_name')) + d = document_dict.get(document_name) + if d is None: + document_dict[document_name] = {paragraph.get('document_id')} + else: + d.add(paragraph.get('document_id')) + + if document_sheet is None: + result[paragraph.get('document_id')] = [[paragraph.get('title'), paragraph.get('content'), + '\n'.join(problem_list)]] + else: + document_sheet.append([paragraph.get('title'), paragraph.get('content'), '\n'.join(problem_list)]) + for document in document_list: + if document.id not in result: + document_name = DocumentSerializers.Operate.reset_document_name(document.name) + result[document.id] = [[]] + d = document_dict.get(document_name) + if d is None: + document_dict[document_name] = {document.id} + else: + d.add(document.id) + result_document_dict = {} + for d_name in document_dict: + for index, d_id in enumerate(document_dict.get(d_name)): + result_document_dict[d_id] = d_name if index == 0 else d_name + str(index) + return result, result_document_dict + + @staticmethod + def reset_document_name(document_name): + if document_name is None or not Utils.valid_sheet_name(document_name): + return "Sheet" + return document_name.strip() + + def one(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + query_set = QuerySet(model=Document) + query_set = query_set.filter(**{'id': self.data.get("document_id")}) + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), with_search_one=True) + + def edit(self, instance: Dict, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + _document = QuerySet(Document).get(id=self.data.get("document_id")) + if with_valid: + DocumentEditInstanceSerializer(data=instance).is_valid(document=_document) + update_keys = ['name', 'is_active', 'hit_handling_method', 'directly_return_similarity', 'meta'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + _document.__setattr__(update_key, instance.get(update_key)) + _document.save() + return self.one() + + def refresh(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id = self.data.get("document_id") + QuerySet(Document).filter(id=document_id).update(**{'status': Status.queue_up}) + QuerySet(Paragraph).filter(document_id=document_id).update(**{'status': Status.queue_up}) + embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=self.data.get('dataset_id')) + try: + embedding_by_document.delay(document_id, embedding_model_id) + except AlreadyQueued as e: + raise AppApiException(500, "任务正在执行中,请勿重复下发") + + @transaction.atomic + def delete(self): + document_id = self.data.get("document_id") + QuerySet(model=Document).filter(id=document_id).delete() + # 删除段落 + QuerySet(model=Paragraph).filter(document_id=document_id).delete() + # 删除问题 + QuerySet(model=ProblemParagraphMapping).filter(document_id=document_id).delete() + # 删除向量库 + delete_embedding_by_document(document_id) + return True + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'char_length', 'user_id', 'paragraph_count', 'is_active' + 'update_time', 'create_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="名称", + description="名称", default="测试知识库"), + 'char_length': openapi.Schema(type=openapi.TYPE_INTEGER, title="字符数", + description="字符数", default=10), + 'user_id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), + 'paragraph_count': openapi.Schema(type=openapi.TYPE_INTEGER, title="文档数量", + description="文档数量", default=1), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", + description="是否可用", default=True), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="文档名称", description="文档名称"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + 'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式", + description="ai优化:optimization,直接返回:directly_return"), + 'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="直接返回分数", + default=0.9), + 'meta': openapi.Schema(type=openapi.TYPE_OBJECT, title="文档元数据", + description="文档元数据->web:{source_url:xxx,selector:'xxx'},base:{}"), + } + ) + + class Create(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "文档id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(DataSet).filter(id=self.data.get('dataset_id')).exists(): + raise AppApiException(10000, "知识库id不存在") + return True + + @staticmethod + def post_embedding(result, document_id, dataset_id): + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + embedding_by_document.delay(document_id, model_id) + return result + + @staticmethod + def parse_qa_file(file): + get_buffer = FileBufferHandle().get_buffer + for parse_qa_handle in parse_qa_handle_list: + if parse_qa_handle.support(file, get_buffer): + return parse_qa_handle.handle(file, get_buffer, save_image) + raise AppApiException(500, '不支持的文件格式') + + @staticmethod + def parse_table_file(file): + get_buffer = FileBufferHandle().get_buffer + for parse_table_handle in parse_table_handle_list: + if parse_table_handle.support(file, get_buffer): + return parse_table_handle.handle(file, get_buffer, save_image) + raise AppApiException(500, '不支持的文件格式') + + def save_qa(self, instance: Dict, with_valid=True): + if with_valid: + DocumentInstanceQASerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + file_list = instance.get('file_list') + document_list = flat_map([self.parse_qa_file(file) for file in file_list]) + return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + + def save_table(self, instance: Dict, with_valid=True): + if with_valid: + DocumentInstanceTableSerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + file_list = instance.get('file_list') + document_list = flat_map([self.parse_table_file(file) for file in file_list]) + return DocumentSerializers.Batch(data={'dataset_id': self.data.get('dataset_id')}).batch_save(document_list) + + @post(post_function=post_embedding) + @transaction.atomic + def save(self, instance: Dict, with_valid=False, **kwargs): + if with_valid: + DocumentInstanceSerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + document_paragraph_model = self.get_document_paragraph_model(dataset_id, instance) + + document_model = document_paragraph_model.get('document') + paragraph_model_list = document_paragraph_model.get('paragraph_model_list') + problem_paragraph_object_list = document_paragraph_model.get('problem_paragraph_object_list') + problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list, + dataset_id) + .to_problem_model_list()) + # 插入文档 + document_model.save() + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 批量插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None + document_id = str(document_model.id) + return DocumentSerializers.Operate( + data={'dataset_id': dataset_id, 'document_id': document_id}).one( + with_valid=True), document_id, dataset_id + + def save_web(self, instance: Dict, with_valid=True): + if with_valid: + DocumentWebInstanceSerializer(data=instance).is_valid(raise_exception=True) + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + source_url_list = instance.get('source_url_list') + selector = instance.get('selector') + sync_web_document.delay(dataset_id, source_url_list, selector) + + @staticmethod + def get_paragraph_model(document_model, paragraph_list: List): + dataset_id = document_model.dataset_id + paragraph_model_dict_list = [ParagraphSerializers.Create( + data={'dataset_id': dataset_id, 'document_id': str(document_model.id)}).get_paragraph_problem_model( + dataset_id, document_model.id, paragraph) for paragraph in paragraph_list] + + paragraph_model_list = [] + problem_paragraph_object_list = [] + for paragraphs in paragraph_model_dict_list: + paragraph = paragraphs.get('paragraph') + for problem_model in paragraphs.get('problem_paragraph_object_list'): + problem_paragraph_object_list.append(problem_model) + paragraph_model_list.append(paragraph) + + return {'document': document_model, 'paragraph_model_list': paragraph_model_list, + 'problem_paragraph_object_list': problem_paragraph_object_list} + + @staticmethod + def get_document_paragraph_model(dataset_id, instance: Dict): + document_model = Document( + **{'dataset_id': dataset_id, + 'id': uuid.uuid1(), + 'name': instance.get('name'), + 'char_length': reduce(lambda x, y: x + y, + [len(p.get('content')) for p in instance.get('paragraphs', [])], + 0), + 'meta': instance.get('meta') if instance.get('meta') is not None else {}, + 'type': instance.get('type') if instance.get('type') is not None else Type.base}) + + return DocumentSerializers.Create.get_paragraph_model(document_model, + instance.get('paragraphs') if + 'paragraphs' in instance else []) + + @staticmethod + def get_request_body_api(): + return DocumentInstanceSerializer.get_request_body_api() + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id') + ] + + class Split(ApiMixin, serializers.Serializer): + file = serializers.ListField(required=True, error_messages=ErrMessage.list( + "文件列表")) + + limit = serializers.IntegerField(required=False, error_messages=ErrMessage.integer( + "分段长度")) + + patterns = serializers.ListField(required=False, + child=serializers.CharField(required=True, error_messages=ErrMessage.char( + "分段标识")), + error_messages=ErrMessage.uuid( + "分段标识列表")) + + with_filter = serializers.BooleanField(required=False, error_messages=ErrMessage.boolean( + "自动清洗")) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + files = self.data.get('file') + for f in files: + if f.size > 1024 * 1024 * 100: + raise AppApiException(500, "上传文件最大不能超过100MB") + + @staticmethod + def get_request_params_api(): + return [ + openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_ARRAY, + items=openapi.Items(type=openapi.TYPE_FILE), + required=True, + description='上传文件'), + openapi.Parameter(name='limit', + in_=openapi.IN_FORM, + required=False, + type=openapi.TYPE_INTEGER, title="分段长度", description="分段长度"), + openapi.Parameter(name='patterns', + in_=openapi.IN_FORM, + required=False, + type=openapi.TYPE_ARRAY, items=openapi.Items(type=openapi.TYPE_STRING), + title="分段正则列表", description="分段正则列表"), + openapi.Parameter(name='with_filter', + in_=openapi.IN_FORM, + required=False, + type=openapi.TYPE_BOOLEAN, title="是否清除特殊字符", description="是否清除特殊字符"), + ] + + def parse(self): + file_list = self.data.get("file") + return list( + map(lambda f: file_to_paragraph(f, self.data.get("patterns", None), self.data.get("with_filter", None), + self.data.get("limit", 4096)), file_list)) + + class SplitPattern(ApiMixin, serializers.Serializer): + @staticmethod + def list(): + return [{'key': "#", 'value': '(?<=^)# .*|(?<=\\n)# .*'}, + {'key': '##', 'value': '(?<=\\n)(? 0 else None + # 批量插入段落 + QuerySet(Paragraph).bulk_create(paragraph_model_list) if len(paragraph_model_list) > 0 else None + # 批量插入问题 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 批量插入关联问题 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None + # 查询文档 + query_set = QuerySet(model=Document) + if len(document_model_list) == 0: + return [], + query_set = query_set.filter(**{'id__in': [d.id for d in document_model_list]}) + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_document.sql')), + with_search_one=False), dataset_id + + @staticmethod + def _batch_sync(document_id_list: List[str]): + for document_id in document_id_list: + DocumentSerializers.Sync(data={'document_id': document_id}).sync() + + def batch_sync(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + self.is_valid(raise_exception=True) + # 异步同步 + work_thread_pool.submit(self._batch_sync, + instance.get('id_list')) + return True + + @transaction.atomic + def batch_delete(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + self.is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + QuerySet(Document).filter(id__in=document_id_list).delete() + QuerySet(Paragraph).filter(document_id__in=document_id_list).delete() + QuerySet(ProblemParagraphMapping).filter(document_id__in=document_id_list).delete() + # 删除向量库 + delete_embedding_by_document_list(document_id_list) + return True + + def batch_edit_hit_handling(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Document, raise_exception=True) + hit_handling_method = instance.get('hit_handling_method') + if hit_handling_method is None: + raise AppApiException(500, '命中处理方式必填') + if hit_handling_method != 'optimization' and hit_handling_method != 'directly_return': + raise AppApiException(500, '命中处理方式必须为directly_return|optimization') + self.is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + hit_handling_method = instance.get('hit_handling_method') + directly_return_similarity = instance.get('directly_return_similarity') + update_dict = {'hit_handling_method': hit_handling_method} + if directly_return_similarity is not None: + update_dict['directly_return_similarity'] = directly_return_similarity + QuerySet(Document).filter(id__in=document_id_list).update(**update_dict) + + def batch_refresh(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id_list = instance.get("id_list") + with transaction.atomic(): + Document.objects.filter(id__in=document_id_list).update(status=Status.queue_up) + Paragraph.objects.filter(document_id__in=document_id_list).update(status=Status.queue_up) + dataset_id = self.data.get('dataset_id') + embedding_model_id = get_embedding_model_id_by_dataset_id(dataset_id=dataset_id) + for document_id in document_id_list: + try: + embedding_by_document.delay(document_id, embedding_model_id) + except AlreadyQueued as e: + raise AppApiException(500, "任务正在执行中,请勿重复下发") + + class GenerateRelated(ApiMixin, serializers.Serializer): + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + document_id = self.data.get('document_id') + if not QuerySet(Document).filter(id=document_id).exists(): + raise AppApiException(500, "文档id不存在") + + def generate_related(self, model_id, prompt, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id = self.data.get('document_id') + QuerySet(Document).filter(id=document_id).update(status=Status.queue_up) + try: + generate_related_by_document_id.delay(document_id, model_id, prompt) + except AlreadyQueued as e: + raise AppApiException(500, "任务正在执行中,请勿重复下发") + + class BatchGenerateRelated(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + @transaction.atomic + def batch_generate_related(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + document_id_list = instance.get("document_id_list") + model_id = instance.get("model_id") + prompt = instance.get("prompt") + for document_id in document_id_list: + DocumentSerializers.GenerateRelated(data={'document_id': document_id}).generate_related(model_id, + prompt) + + +class FileBufferHandle: + buffer = None + + def get_buffer(self, file): + if self.buffer is None: + self.buffer = file.read() + return self.buffer + + +default_split_handle = TextSplitHandle() +split_handles = [HTMLSplitHandle(), DocSplitHandle(), PdfSplitHandle(), default_split_handle] + + +def save_image(image_list): + if image_list is not None and len(image_list) > 0: + QuerySet(Image).bulk_create(image_list) + + +def file_to_paragraph(file, pattern_list: List, with_filter: bool, limit: int): + get_buffer = FileBufferHandle().get_buffer + for split_handle in split_handles: + if split_handle.support(file, get_buffer): + return split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image) + return default_split_handle.handle(file, pattern_list, with_filter, limit, get_buffer, save_image) diff --git a/src/MaxKB-1.7.2/apps/dataset/serializers/file_serializers.py b/src/MaxKB-1.7.2/apps/dataset/serializers/file_serializers.py new file mode 100644 index 0000000..894f149 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/serializers/file_serializers.py @@ -0,0 +1,79 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: image_serializers.py + @date:2024/4/22 16:36 + @desc: +""" +import uuid + +from django.db.models import QuerySet +from django.http import HttpResponse +from rest_framework import serializers + +from common.exception.app_exception import NotFound404 +from common.field.common import UploadedFileField +from common.util.field_message import ErrMessage +from dataset.models import File + +mime_types = {"html": "text/html", "htm": "text/html", "shtml": "text/html", "css": "text/css", "xml": "text/xml", + "gif": "image/gif", "jpeg": "image/jpeg", "jpg": "image/jpeg", "js": "application/javascript", + "atom": "application/atom+xml", "rss": "application/rss+xml", "mml": "text/mathml", "txt": "text/plain", + "jad": "text/vnd.sun.j2me.app-descriptor", "wml": "text/vnd.wap.wml", "htc": "text/x-component", + "avif": "image/avif", "png": "image/png", "svg": "image/svg+xml", "svgz": "image/svg+xml", + "tif": "image/tiff", "tiff": "image/tiff", "wbmp": "image/vnd.wap.wbmp", "webp": "image/webp", + "ico": "image/x-icon", "jng": "image/x-jng", "bmp": "image/x-ms-bmp", "woff": "font/woff", + "woff2": "font/woff2", "jar": "application/java-archive", "war": "application/java-archive", + "ear": "application/java-archive", "json": "application/json", "hqx": "application/mac-binhex40", + "doc": "application/msword", "pdf": "application/pdf", "ps": "application/postscript", + "eps": "application/postscript", "ai": "application/postscript", "rtf": "application/rtf", + "m3u8": "application/vnd.apple.mpegurl", "kml": "application/vnd.google-earth.kml+xml", + "kmz": "application/vnd.google-earth.kmz", "xls": "application/vnd.ms-excel", + "eot": "application/vnd.ms-fontobject", "ppt": "application/vnd.ms-powerpoint", + "odg": "application/vnd.oasis.opendocument.graphics", + "odp": "application/vnd.oasis.opendocument.presentation", + "ods": "application/vnd.oasis.opendocument.spreadsheet", "odt": "application/vnd.oasis.opendocument.text", + "wmlc": "application/vnd.wap.wmlc", "wasm": "application/wasm", "7z": "application/x-7z-compressed", + "cco": "application/x-cocoa", "jardiff": "application/x-java-archive-diff", + "jnlp": "application/x-java-jnlp-file", "run": "application/x-makeself", "pl": "application/x-perl", + "pm": "application/x-perl", "prc": "application/x-pilot", "pdb": "application/x-pilot", + "rar": "application/x-rar-compressed", "rpm": "application/x-redhat-package-manager", + "sea": "application/x-sea", "swf": "application/x-shockwave-flash", "sit": "application/x-stuffit", + "tcl": "application/x-tcl", "tk": "application/x-tcl", "der": "application/x-x509-ca-cert", + "pem": "application/x-x509-ca-cert", "crt": "application/x-x509-ca-cert", + "xpi": "application/x-xpinstall", "xhtml": "application/xhtml+xml", "xspf": "application/xspf+xml", + "zip": "application/zip", "bin": "application/octet-stream", "exe": "application/octet-stream", + "dll": "application/octet-stream", "deb": "application/octet-stream", "dmg": "application/octet-stream", + "iso": "application/octet-stream", "img": "application/octet-stream", "msi": "application/octet-stream", + "msp": "application/octet-stream", "msm": "application/octet-stream", "mid": "audio/midi", + "midi": "audio/midi", "kar": "audio/midi", "mp3": "audio/mpeg", "ogg": "audio/ogg", "m4a": "audio/x-m4a", + "ra": "audio/x-realaudio", "3gpp": "video/3gpp", "3gp": "video/3gpp", "ts": "video/mp2t", + "mp4": "video/mp4", "mpeg": "video/mpeg", "mpg": "video/mpeg", "mov": "video/quicktime", + "webm": "video/webm", "flv": "video/x-flv", "m4v": "video/x-m4v", "mng": "video/x-mng", + "asx": "video/x-ms-asf", "asf": "video/x-ms-asf", "wmv": "video/x-ms-wmv", "avi": "video/x-msvideo"} + + +class FileSerializer(serializers.Serializer): + file = UploadedFileField(required=True, error_messages=ErrMessage.image("文件")) + + def upload(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + file_id = uuid.uuid1() + file = File(id=file_id, file_name=self.data.get('file').name) + file.save(self.data.get('file').read()) + return f'/api/file/{file_id}' + + class Operate(serializers.Serializer): + id = serializers.UUIDField(required=True) + + def get(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + file_id = self.data.get('id') + file = QuerySet(File).filter(id=file_id).first() + if file is None: + raise NotFound404(404, "不存在的文件") + return HttpResponse(file.get_byte(), status=200, + headers={'Content-Type': mime_types.get(file.file_name.split(".")[-1], 'text/plain')}) diff --git a/src/MaxKB-1.7.2/apps/dataset/serializers/image_serializers.py b/src/MaxKB-1.7.2/apps/dataset/serializers/image_serializers.py new file mode 100644 index 0000000..3ee477f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/serializers/image_serializers.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: image_serializers.py + @date:2024/4/22 16:36 + @desc: +""" +import uuid + +from django.db.models import QuerySet +from django.http import HttpResponse +from rest_framework import serializers + +from common.exception.app_exception import NotFound404 +from common.field.common import UploadedImageField +from common.util.field_message import ErrMessage +from dataset.models import Image + + +class ImageSerializer(serializers.Serializer): + image = UploadedImageField(required=True, error_messages=ErrMessage.image("图片")) + + def upload(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + image_id = uuid.uuid1() + image = Image(id=image_id, image=self.data.get('image').read(), image_name=self.data.get('image').name) + image.save() + return f'/api/image/{image_id}' + + class Operate(serializers.Serializer): + id = serializers.UUIDField(required=True) + + def get(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + image_id = self.data.get('id') + image = QuerySet(Image).filter(id=image_id).first() + if image is None: + raise NotFound404(404, "不存在的图片") + if image.image_name.endswith('.svg'): + return HttpResponse(image.image, status=200, headers={'Content-Type': 'image/svg+xml'}) + # gif + elif image.image_name.endswith('.gif'): + return HttpResponse(image.image, status=200, headers={'Content-Type': 'image/gif'}) + return HttpResponse(image.image, status=200, headers={'Content-Type': 'image/png'}) diff --git a/src/MaxKB-1.7.2/apps/dataset/serializers/paragraph_serializers.py b/src/MaxKB-1.7.2/apps/dataset/serializers/paragraph_serializers.py new file mode 100644 index 0000000..6614d71 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/serializers/paragraph_serializers.py @@ -0,0 +1,743 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: paragraph_serializers.py + @date:2023/10/16 15:51 + @desc: +""" +import uuid +from typing import Dict + +from celery_once import AlreadyQueued +from django.db import transaction +from django.db.models import QuerySet +from drf_yasg import openapi +from rest_framework import serializers + +from common.db.search import page_search +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from common.util.common import post +from common.util.field_message import ErrMessage +from dataset.models import Paragraph, Problem, Document, ProblemParagraphMapping, DataSet +from dataset.serializers.common_serializers import update_document_char_length, BatchSerializer, ProblemParagraphObject, \ + ProblemParagraphManage, get_embedding_model_id_by_dataset_id +from dataset.serializers.problem_serializers import ProblemInstanceSerializer, ProblemSerializer, ProblemSerializers +from embedding.models import SourceType +from embedding.task.embedding import embedding_by_problem as embedding_by_problem_task, embedding_by_problem, \ + delete_embedding_by_source, enable_embedding_by_paragraph, disable_embedding_by_paragraph, embedding_by_paragraph, \ + delete_embedding_by_paragraph, delete_embedding_by_paragraph_ids, update_embedding_document_id +from dataset.task import generate_related_by_paragraph_id_list + + +class ParagraphSerializer(serializers.ModelSerializer): + class Meta: + model = Paragraph + fields = ['id', 'content', 'is_active', 'document_id', 'title', + 'create_time', 'update_time'] + + +class ParagraphInstanceSerializer(ApiMixin, serializers.Serializer): + """ + 段落实例对象 + """ + content = serializers.CharField(required=True, error_messages=ErrMessage.char("段落内容"), + max_length=102400, + min_length=1, + allow_null=True, allow_blank=True) + + title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char("段落标题"), + allow_null=True, allow_blank=True) + + problem_list = ProblemInstanceSerializer(required=False, many=True) + + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char("段落是否可用")) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['content'], + properties={ + 'content': openapi.Schema(type=openapi.TYPE_STRING, max_length=4096, title="分段内容", + description="分段内容"), + + 'title': openapi.Schema(type=openapi.TYPE_STRING, max_length=256, title="分段标题", + description="分段标题"), + + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + + 'problem_list': openapi.Schema(type=openapi.TYPE_ARRAY, title='问题列表', + description="问题列表", + items=ProblemInstanceSerializer.get_request_body_api()) + } + ) + + +class EditParagraphSerializers(serializers.Serializer): + title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char( + "分段标题"), allow_null=True, allow_blank=True) + content = serializers.CharField(required=False, max_length=102400, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char( + "分段内容")) + problem_list = ProblemInstanceSerializer(required=False, many=True) + + +class ParagraphSerializers(ApiMixin, serializers.Serializer): + title = serializers.CharField(required=False, max_length=256, error_messages=ErrMessage.char( + "分段标题"), allow_null=True, allow_blank=True) + content = serializers.CharField(required=True, max_length=102400, error_messages=ErrMessage.char( + "分段内容")) + + class Problem(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists(): + raise AppApiException(500, "段落id不存在") + + def list(self, with_valid=False): + """ + 获取问题列表 + :param with_valid: 是否校验 + :return: 问题列表 + """ + if with_valid: + self.is_valid(raise_exception=True) + problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"), + paragraph_id=self.data.get( + 'paragraph_id')) + return [ProblemSerializer(row).data for row in + QuerySet(Problem).filter(id__in=[row.problem_id for row in problem_paragraph_mapping])] + + @transaction.atomic + def save(self, instance: Dict, with_valid=True, with_embedding=True, embedding_by_problem=None): + if with_valid: + self.is_valid() + ProblemInstanceSerializer(data=instance).is_valid(raise_exception=True) + problem = QuerySet(Problem).filter(dataset_id=self.data.get('dataset_id'), + content=instance.get('content')).first() + if problem is None: + problem = Problem(id=uuid.uuid1(), dataset_id=self.data.get('dataset_id'), + content=instance.get('content')) + problem.save() + if QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get('dataset_id'), problem_id=problem.id, + paragraph_id=self.data.get('paragraph_id')).exists(): + raise AppApiException(500, "已经关联,请勿重复关联") + problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(), + problem_id=problem.id, + document_id=self.data.get('document_id'), + paragraph_id=self.data.get('paragraph_id'), + dataset_id=self.data.get('dataset_id')) + problem_paragraph_mapping.save() + model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id')) + if with_embedding: + embedding_by_problem_task({'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': problem_paragraph_mapping.id, + 'document_id': self.data.get('document_id'), + 'paragraph_id': self.data.get('paragraph_id'), + 'dataset_id': self.data.get('dataset_id'), + }, model_id) + + return ProblemSerializers.Operate( + data={'dataset_id': self.data.get('dataset_id'), + 'problem_id': problem.id}).one(with_valid=True) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + openapi.Parameter(name='paragraph_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='段落id')] + + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + required=["content"], + properties={ + 'content': openapi.Schema( + type=openapi.TYPE_STRING, title="内容") + }) + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容", + description="问题内容", default='问题内容'), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id", + description="知识库id", default='xxx'), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + class Association(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id")) + + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + paragraph_id = self.data.get('paragraph_id') + problem_id = self.data.get("problem_id") + if not QuerySet(Paragraph).filter(dataset_id=dataset_id, id=paragraph_id).exists(): + raise AppApiException(500, "段落不存在") + if not QuerySet(Problem).filter(dataset_id=dataset_id, id=problem_id).exists(): + raise AppApiException(500, "问题不存在") + + def association(self, with_valid=True, with_embedding=True): + if with_valid: + self.is_valid(raise_exception=True) + problem = QuerySet(Problem).filter(id=self.data.get("problem_id")).first() + problem_paragraph_mapping = ProblemParagraphMapping(id=uuid.uuid1(), + document_id=self.data.get('document_id'), + paragraph_id=self.data.get('paragraph_id'), + dataset_id=self.data.get('dataset_id'), + problem_id=problem.id) + problem_paragraph_mapping.save() + if with_embedding: + model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id')) + embedding_by_problem({'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': problem_paragraph_mapping.id, + 'document_id': self.data.get('document_id'), + 'paragraph_id': self.data.get('paragraph_id'), + 'dataset_id': self.data.get('dataset_id'), + }, model_id) + + def un_association(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter( + paragraph_id=self.data.get('paragraph_id'), + dataset_id=self.data.get('dataset_id'), + problem_id=self.data.get( + 'problem_id')).first() + problem_paragraph_mapping_id = problem_paragraph_mapping.id + problem_paragraph_mapping.delete() + delete_embedding_by_source(problem_paragraph_mapping_id) + return True + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id') + , openapi.Parameter(name='paragraph_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='段落id'), + openapi.Parameter(name='problem_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='问题id') + ] + + class Batch(serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + @transaction.atomic + def batch_delete(self, instance: Dict, with_valid=True): + if with_valid: + BatchSerializer(data=instance).is_valid(model=Paragraph, raise_exception=True) + self.is_valid(raise_exception=True) + paragraph_id_list = instance.get("id_list") + QuerySet(Paragraph).filter(id__in=paragraph_id_list).delete() + QuerySet(ProblemParagraphMapping).filter(paragraph_id__in=paragraph_id_list).delete() + update_document_char_length(self.data.get('document_id')) + # 删除向量库 + delete_embedding_by_paragraph_ids(paragraph_id_list) + return True + + class Migrate(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + target_dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标知识库id")) + target_document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("目标文档id")) + paragraph_id_list = serializers.ListField(required=True, error_messages=ErrMessage.char("段落列表"), + child=serializers.UUIDField(required=True, + error_messages=ErrMessage.uuid("段落id"))) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + document_list = QuerySet(Document).filter( + id__in=[self.data.get('document_id'), self.data.get('target_document_id')]) + document_id = self.data.get('document_id') + target_document_id = self.data.get('target_document_id') + if document_id == target_document_id: + raise AppApiException(5000, "需要迁移的文档和目标文档一致") + if len([document for document in document_list if str(document.id) == self.data.get('document_id')]) < 1: + raise AppApiException(5000, f"文档id不存在【{self.data.get('document_id')}】") + if len([document for document in document_list if + str(document.id) == self.data.get('target_document_id')]) < 1: + raise AppApiException(5000, f"目标文档id不存在【{self.data.get('target_document_id')}】") + + @transaction.atomic + def migrate(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + target_dataset_id = self.data.get('target_dataset_id') + document_id = self.data.get('document_id') + target_document_id = self.data.get('target_document_id') + paragraph_id_list = self.data.get('paragraph_id_list') + paragraph_list = QuerySet(Paragraph).filter(dataset_id=dataset_id, document_id=document_id, + id__in=paragraph_id_list) + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter(paragraph__in=paragraph_list) + # 同数据集迁移 + if target_dataset_id == dataset_id: + if len(problem_paragraph_mapping_list): + problem_paragraph_mapping_list = [ + self.update_problem_paragraph_mapping(target_document_id, + problem_paragraph_mapping) for problem_paragraph_mapping + in + problem_paragraph_mapping_list] + # 修改mapping + QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, + ['document_id']) + update_embedding_document_id([paragraph.id for paragraph in paragraph_list], + target_document_id, target_dataset_id, None) + # 修改段落信息 + paragraph_list.update(document_id=target_document_id) + # 不同数据集迁移 + else: + problem_list = QuerySet(Problem).filter( + id__in=[problem_paragraph_mapping.problem_id for problem_paragraph_mapping in + problem_paragraph_mapping_list]) + # 目标数据集问题 + target_problem_list = list( + QuerySet(Problem).filter(content__in=[problem.content for problem in problem_list], + dataset_id=target_dataset_id)) + + target_handle_problem_list = [ + self.get_target_dataset_problem(target_dataset_id, target_document_id, problem_paragraph_mapping, + problem_list, target_problem_list) for + problem_paragraph_mapping + in + problem_paragraph_mapping_list] + + create_problem_list = [problem for problem, is_create in target_handle_problem_list if + is_create is not None and is_create] + # 插入问题 + QuerySet(Problem).bulk_create(create_problem_list) + # 修改mapping + QuerySet(ProblemParagraphMapping).bulk_update(problem_paragraph_mapping_list, + ['problem_id', 'dataset_id', 'document_id']) + target_dataset = QuerySet(DataSet).filter(id=target_dataset_id).first() + dataset = QuerySet(DataSet).filter(id=dataset_id).first() + embedding_model_id = None + if target_dataset.embedding_mode_id != dataset.embedding_mode_id: + embedding_model_id = str(target_dataset.embedding_mode_id) + pid_list = [paragraph.id for paragraph in paragraph_list] + # 修改段落信息 + paragraph_list.update(dataset_id=target_dataset_id, document_id=target_document_id) + # 修改向量段落信息 + update_embedding_document_id(pid_list, target_document_id, target_dataset_id, embedding_model_id) + + update_document_char_length(document_id) + update_document_char_length(target_document_id) + + @staticmethod + def update_problem_paragraph_mapping(target_document_id: str, problem_paragraph_mapping): + problem_paragraph_mapping.document_id = target_document_id + return problem_paragraph_mapping + + @staticmethod + def get_target_dataset_problem(target_dataset_id: str, + target_document_id: str, + problem_paragraph_mapping, + source_problem_list, + target_problem_list): + source_problem_list = [source_problem for source_problem in source_problem_list if + source_problem.id == problem_paragraph_mapping.problem_id] + problem_paragraph_mapping.dataset_id = target_dataset_id + problem_paragraph_mapping.document_id = target_document_id + if len(source_problem_list) > 0: + problem_content = source_problem_list[-1].content + problem_list = [problem for problem in target_problem_list if problem.content == problem_content] + if len(problem_list) > 0: + problem = problem_list[-1] + problem_paragraph_mapping.problem_id = problem.id + return problem, False + else: + problem = Problem(id=uuid.uuid1(), dataset_id=target_dataset_id, content=problem_content) + target_problem_list.append(problem) + problem_paragraph_mapping.problem_id = problem.id + return problem, True + return None + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + openapi.Parameter(name='target_dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='目标知识库id'), + openapi.Parameter(name='target_document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='目标知识库id') + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING), + title='段落id列表', + description="段落id列表" + ) + + class Operate(ApiMixin, serializers.Serializer): + # 段落id + paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "段落id")) + # 知识库id + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "知识库id")) + # 文档id + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "文档id")) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if not QuerySet(Paragraph).filter(id=self.data.get('paragraph_id')).exists(): + raise AppApiException(500, "段落id不存在") + + @staticmethod + def post_embedding(paragraph, instance, dataset_id): + if 'is_active' in instance and instance.get('is_active') is not None: + (enable_embedding_by_paragraph if instance.get( + 'is_active') else disable_embedding_by_paragraph)(paragraph.get('id')) + + else: + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + embedding_by_paragraph(paragraph.get('id'), model_id) + return paragraph + + @post(post_embedding) + @transaction.atomic + def edit(self, instance: Dict): + self.is_valid() + EditParagraphSerializers(data=instance).is_valid(raise_exception=True) + _paragraph = QuerySet(Paragraph).get(id=self.data.get("paragraph_id")) + update_keys = ['title', 'content', 'is_active'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + _paragraph.__setattr__(update_key, instance.get(update_key)) + + if 'problem_list' in instance: + update_problem_list = list( + filter(lambda row: 'id' in row and row.get('id') is not None, instance.get('problem_list'))) + + create_problem_list = list(filter(lambda row: row.get('id') is None, instance.get('problem_list'))) + + # 问题集合 + problem_list = QuerySet(Problem).filter(paragraph_id=self.data.get("paragraph_id")) + + # 校验前端 携带过来的id + for update_problem in update_problem_list: + if not set([str(row.id) for row in problem_list]).__contains__(update_problem.get('id')): + raise AppApiException(500, update_problem.get('id') + '问题id不存在') + # 对比需要删除的问题 + delete_problem_list = list(filter( + lambda row: not [str(update_row.get('id')) for update_row in update_problem_list].__contains__( + str(row.id)), problem_list)) if len(update_problem_list) > 0 else [] + # 删除问题 + QuerySet(Problem).filter(id__in=[row.id for row in delete_problem_list]).delete() if len( + delete_problem_list) > 0 else None + # 插入新的问题 + QuerySet(Problem).bulk_create( + [Problem(id=uuid.uuid1(), content=p.get('content'), paragraph_id=self.data.get('paragraph_id'), + dataset_id=self.data.get('dataset_id'), document_id=self.data.get('document_id')) for + p in create_problem_list]) if len(create_problem_list) else None + + # 修改问题集合 + QuerySet(Problem).bulk_update( + [Problem(id=row.get('id'), content=row.get('content')) for row in update_problem_list], + ['content']) if len( + update_problem_list) > 0 else None + + _paragraph.save() + update_document_char_length(self.data.get('document_id')) + return self.one(), instance, self.data.get('dataset_id') + + def get_problem_list(self): + ProblemParagraphMapping(ProblemParagraphMapping) + problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter( + paragraph_id=self.data.get("paragraph_id")) + if len(problem_paragraph_mapping) > 0: + return [ProblemSerializer(problem).data for problem in + QuerySet(Problem).filter(id__in=[ppm.problem_id for ppm in problem_paragraph_mapping])] + return [] + + def one(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + return {**ParagraphSerializer(QuerySet(model=Paragraph).get(id=self.data.get('paragraph_id'))).data, + 'problem_list': self.get_problem_list()} + + def delete(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + paragraph_id = self.data.get('paragraph_id') + QuerySet(Paragraph).filter(id=paragraph_id).delete() + QuerySet(ProblemParagraphMapping).filter(paragraph_id=paragraph_id).delete() + update_document_char_length(self.data.get('document_id')) + delete_embedding_by_paragraph(paragraph_id) + + @staticmethod + def get_request_body_api(): + return ParagraphInstanceSerializer.get_request_body_api() + + @staticmethod + def get_response_body_api(): + return ParagraphInstanceSerializer.get_request_body_api() + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(type=openapi.TYPE_STRING, in_=openapi.IN_PATH, name='paragraph_id', + description="段落id")] + + class Create(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "知识库id")) + + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "文档id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(Document).filter(id=self.data.get('document_id'), + dataset_id=self.data.get('dataset_id')).exists(): + raise AppApiException(500, "文档id不正确") + + def save(self, instance: Dict, with_valid=True, with_embedding=True): + if with_valid: + ParagraphSerializers(data=instance).is_valid(raise_exception=True) + self.is_valid() + dataset_id = self.data.get("dataset_id") + document_id = self.data.get('document_id') + paragraph_problem_model = self.get_paragraph_problem_model(dataset_id, document_id, instance) + paragraph = paragraph_problem_model.get('paragraph') + problem_paragraph_object_list = paragraph_problem_model.get('problem_paragraph_object_list') + problem_model_list, problem_paragraph_mapping_list = (ProblemParagraphManage(problem_paragraph_object_list, + dataset_id). + to_problem_model_list()) + # 插入段落 + paragraph_problem_model.get('paragraph').save() + # 插入問題 + QuerySet(Problem).bulk_create(problem_model_list) if len(problem_model_list) > 0 else None + # 插入问题关联关系 + QuerySet(ProblemParagraphMapping).bulk_create(problem_paragraph_mapping_list) if len( + problem_paragraph_mapping_list) > 0 else None + # 修改长度 + update_document_char_length(document_id) + if with_embedding: + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + embedding_by_paragraph(str(paragraph.id), model_id) + return ParagraphSerializers.Operate( + data={'paragraph_id': str(paragraph.id), 'dataset_id': dataset_id, 'document_id': document_id}).one( + with_valid=True) + + @staticmethod + def get_paragraph_problem_model(dataset_id: str, document_id: str, instance: Dict): + paragraph = Paragraph(id=uuid.uuid1(), + document_id=document_id, + content=instance.get("content"), + dataset_id=dataset_id, + title=instance.get("title") if 'title' in instance else '') + problem_paragraph_object_list = [ + ProblemParagraphObject(dataset_id, document_id, paragraph.id, problem.get('content')) for problem in + (instance.get('problem_list') if 'problem_list' in instance else [])] + + return {'paragraph': paragraph, + 'problem_paragraph_object_list': problem_paragraph_object_list} + + @staticmethod + def or_get(exists_problem_list, content, dataset_id): + exists = [row for row in exists_problem_list if row.content == content] + if len(exists) > 0: + return exists[0] + else: + return Problem(id=uuid.uuid1(), content=content, dataset_id=dataset_id) + + @staticmethod + def get_request_body_api(): + return ParagraphInstanceSerializer.get_request_body_api() + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='document_id', in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description="文档id") + ] + + class Query(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "知识库id")) + + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.char( + "文档id")) + + title = serializers.CharField(required=False, error_messages=ErrMessage.char( + "段落标题")) + + content = serializers.CharField(required=False) + + def get_query_set(self): + query_set = QuerySet(model=Paragraph) + query_set = query_set.filter( + **{'dataset_id': self.data.get('dataset_id'), 'document_id': self.data.get("document_id")}) + if 'title' in self.data: + query_set = query_set.filter( + **{'title__icontains': self.data.get('title')}) + if 'content' in self.data: + query_set = query_set.filter(**{'content__icontains': self.data.get('content')}) + return query_set + + def list(self): + return list(map(lambda row: ParagraphSerializer(row).data, self.get_query_set())) + + def page(self, current_page, page_size): + query_set = self.get_query_set() + return page_search(current_page, page_size, query_set, lambda row: ParagraphSerializer(row).data) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='document_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='文档id'), + openapi.Parameter(name='title', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='标题'), + openapi.Parameter(name='content', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='内容') + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'star_num', 'trample_num', 'is_active', 'dataset_id', + 'document_id', 'title', + 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="段落内容", + description="段落内容", default='段落内容'), + 'title': openapi.Schema(type=openapi.TYPE_STRING, title="标题", + description="标题", default="xxx的描述"), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'star_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点赞数量", + description="点赞数量", default=1), + 'trample_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="点踩数量", + description="点踩数", default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id", + description="知识库id", default='xxx'), + 'document_id': openapi.Schema(type=openapi.TYPE_STRING, title="文档id", + description="文档id", default='xxx'), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", + description="是否可用", default=True), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + + class BatchGenerateRelated(ApiMixin, serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + @transaction.atomic + def batch_generate_related(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + paragraph_id_list = instance.get("paragraph_id_list") + model_id = instance.get("model_id") + prompt = instance.get("prompt") + try: + generate_related_by_paragraph_id_list.delay(paragraph_id_list, model_id, prompt) + except AlreadyQueued as e: + raise AppApiException(500, "任务正在执行中,请勿重复下发") + + + diff --git a/src/MaxKB-1.7.2/apps/dataset/serializers/problem_serializers.py b/src/MaxKB-1.7.2/apps/dataset/serializers/problem_serializers.py new file mode 100644 index 0000000..b54fe51 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/serializers/problem_serializers.py @@ -0,0 +1,238 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: problem_serializers.py + @date:2023/10/23 13:55 + @desc: +""" +import os +import uuid +from functools import reduce +from typing import Dict, List + +from django.db import transaction +from django.db.models import QuerySet +from drf_yasg import openapi +from rest_framework import serializers + +from common.db.search import native_search, native_page_search +from common.mixins.api_mixin import ApiMixin +from common.util.field_message import ErrMessage +from common.util.file_util import get_file_content +from dataset.models import Problem, Paragraph, ProblemParagraphMapping, DataSet +from dataset.serializers.common_serializers import get_embedding_model_id_by_dataset_id +from embedding.models import SourceType +from embedding.task import delete_embedding_by_source_ids, update_problem_embedding, embedding_by_data_list +from smartdoc.conf import PROJECT_DIR + + +class ProblemSerializer(serializers.ModelSerializer): + class Meta: + model = Problem + fields = ['id', 'content', 'dataset_id', + 'create_time', 'update_time'] + + +class ProblemInstanceSerializer(ApiMixin, serializers.Serializer): + id = serializers.CharField(required=False, error_messages=ErrMessage.char("问题id")) + + content = serializers.CharField(required=True, max_length=256, error_messages=ErrMessage.char("问题内容")) + + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + required=["content"], + properties={ + 'id': openapi.Schema( + type=openapi.TYPE_STRING, + title="问题id,修改的时候传递,创建的时候不传"), + 'content': openapi.Schema( + type=openapi.TYPE_STRING, title="内容") + }) + + +class AssociationParagraph(serializers.Serializer): + paragraph_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("段落id")) + document_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("文档id")) + + +class BatchAssociation(serializers.Serializer): + problem_id_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题id列表"), + child=serializers.UUIDField(required=True, + error_messages=ErrMessage.uuid("问题id"))) + paragraph_list = AssociationParagraph(many=True) + + +def is_exits(exits_problem_paragraph_mapping_list, new_paragraph_mapping): + filter_list = [exits_problem_paragraph_mapping for exits_problem_paragraph_mapping in + exits_problem_paragraph_mapping_list if + str(exits_problem_paragraph_mapping.paragraph_id) == new_paragraph_mapping.paragraph_id + and str(exits_problem_paragraph_mapping.problem_id) == new_paragraph_mapping.problem_id + and str(exits_problem_paragraph_mapping.dataset_id) == new_paragraph_mapping.dataset_id] + return len(filter_list) > 0 + + +def to_problem_paragraph_mapping(problem, document_id: str, paragraph_id: str, dataset_id: str): + return ProblemParagraphMapping(id=uuid.uuid1(), + document_id=document_id, + paragraph_id=paragraph_id, + dataset_id=dataset_id, + problem_id=str(problem.id)), problem + + +class ProblemSerializers(ApiMixin, serializers.Serializer): + class Create(serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + problem_list = serializers.ListField(required=True, error_messages=ErrMessage.list("问题列表"), + child=serializers.CharField(required=True, + max_length=256, + error_messages=ErrMessage.char("问题"))) + + def batch(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + problem_list = self.data.get('problem_list') + problem_list = list(set(problem_list)) + dataset_id = self.data.get('dataset_id') + exists_problem_content_list = [problem.content for problem in + QuerySet(Problem).filter(dataset_id=dataset_id, + content__in=problem_list)] + problem_instance_list = [Problem(id=uuid.uuid1(), dataset_id=dataset_id, content=problem_content) for + problem_content in + problem_list if + (not exists_problem_content_list.__contains__(problem_content) if + len(exists_problem_content_list) > 0 else True)] + + QuerySet(Problem).bulk_create(problem_instance_list) if len(problem_instance_list) > 0 else None + return [ProblemSerializer(problem_instance).data for problem_instance in problem_instance_list] + + class Query(serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + content = serializers.CharField(required=False, error_messages=ErrMessage.char("问题")) + + def get_query_set(self): + query_set = QuerySet(model=Problem) + query_set = query_set.filter( + **{'dataset_id': self.data.get('dataset_id')}) + if 'content' in self.data: + query_set = query_set.filter(**{'content__icontains': self.data.get('content')}) + query_set = query_set.order_by("-create_time") + return query_set + + def list(self): + query_set = self.get_query_set() + return native_search(query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql'))) + + def page(self, current_page, page_size): + query_set = self.get_query_set() + return native_page_search(current_page, page_size, query_set, select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_problem.sql'))) + + class BatchOperate(serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + def delete(self, problem_id_list: List, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter( + dataset_id=dataset_id, + problem_id__in=problem_id_list) + source_ids = [row.id for row in problem_paragraph_mapping_list] + problem_paragraph_mapping_list.delete() + QuerySet(Problem).filter(id__in=problem_id_list).delete() + delete_embedding_by_source_ids(source_ids) + return True + + def association(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + BatchAssociation(data=instance).is_valid(raise_exception=True) + dataset_id = self.data.get('dataset_id') + paragraph_list = instance.get('paragraph_list') + problem_id_list = instance.get('problem_id_list') + problem_list = QuerySet(Problem).filter(id__in=problem_id_list) + exits_problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(problem_id__in=problem_id_list, + paragraph_id__in=[ + p.get('paragraph_id') + for p in + paragraph_list]) + problem_paragraph_mapping_list = [(problem_paragraph_mapping, problem) for + problem_paragraph_mapping, problem in reduce(lambda x, y: [*x, *y], + [[ + to_problem_paragraph_mapping( + problem, + paragraph.get( + 'document_id'), + paragraph.get( + 'paragraph_id'), + dataset_id) for + paragraph in + paragraph_list] + for problem in + problem_list], []) if + not is_exits(exits_problem_paragraph_mapping, problem_paragraph_mapping)] + QuerySet(ProblemParagraphMapping).bulk_create( + [problem_paragraph_mapping for problem_paragraph_mapping, problem in problem_paragraph_mapping_list]) + data_list = [{'text': problem.content, + 'is_active': True, + 'source_type': SourceType.PROBLEM, + 'source_id': str(problem_paragraph_mapping.id), + 'document_id': str(problem_paragraph_mapping.document_id), + 'paragraph_id': str(problem_paragraph_mapping.paragraph_id), + 'dataset_id': dataset_id, + } for problem_paragraph_mapping, problem in problem_paragraph_mapping_list] + model_id = get_embedding_model_id_by_dataset_id(self.data.get('dataset_id')) + embedding_by_data_list(data_list, model_id=model_id) + + class Operate(serializers.Serializer): + dataset_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("知识库id")) + + problem_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("问题id")) + + def list_paragraph(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + problem_paragraph_mapping = QuerySet(ProblemParagraphMapping).filter(dataset_id=self.data.get("dataset_id"), + problem_id=self.data.get("problem_id")) + if problem_paragraph_mapping is None or len(problem_paragraph_mapping) == 0: + return [] + return native_search( + QuerySet(Paragraph).filter(id__in=[row.paragraph_id for row in problem_paragraph_mapping]), + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "dataset", 'sql', 'list_paragraph.sql'))) + + def one(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return ProblemInstanceSerializer(QuerySet(Problem).get(**{'id': self.data.get('problem_id')})).data + + @transaction.atomic + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + problem_paragraph_mapping_list = QuerySet(ProblemParagraphMapping).filter( + dataset_id=self.data.get('dataset_id'), + problem_id=self.data.get('problem_id')) + source_ids = [row.id for row in problem_paragraph_mapping_list] + problem_paragraph_mapping_list.delete() + QuerySet(Problem).filter(id=self.data.get('problem_id')).delete() + delete_embedding_by_source_ids(source_ids) + return True + + @transaction.atomic + def edit(self, instance: Dict, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + problem_id = self.data.get('problem_id') + dataset_id = self.data.get('dataset_id') + content = instance.get('content') + problem = QuerySet(Problem).filter(id=problem_id, + dataset_id=dataset_id).first() + QuerySet(DataSet).filter(id=dataset_id) + problem.content = content + problem.save() + model_id = get_embedding_model_id_by_dataset_id(dataset_id) + update_problem_embedding(problem_id, content, model_id) diff --git a/src/MaxKB-1.7.2/apps/dataset/sql/list_dataset.sql b/src/MaxKB-1.7.2/apps/dataset/sql/list_dataset.sql new file mode 100644 index 0000000..8f62034 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/sql/list_dataset.sql @@ -0,0 +1,35 @@ +SELECT + *, + to_json(meta) as meta +FROM + ( + SELECT + "temp_dataset".*, + "document_temp"."char_length", + CASE + WHEN + "app_dataset_temp"."count" IS NULL THEN 0 ELSE "app_dataset_temp"."count" END AS application_mapping_count, + "document_temp".document_count FROM ( + SELECT dataset.* + FROM + dataset dataset + ${dataset_custom_sql} + UNION + SELECT + * + FROM + dataset + WHERE + dataset."id" IN ( + SELECT + team_member_permission.target + FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" + ${team_member_permission_custom_sql} + ) + ) temp_dataset + LEFT JOIN ( SELECT "count" ( "id" ) AS document_count, "sum" ( "char_length" ) "char_length", dataset_id FROM "document" GROUP BY dataset_id ) "document_temp" ON temp_dataset."id" = "document_temp".dataset_id + LEFT JOIN (SELECT "count"("id"),dataset_id FROM application_dataset_mapping GROUP BY dataset_id) app_dataset_temp ON temp_dataset."id" = "app_dataset_temp".dataset_id + ) temp + ${default_sql} \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/dataset/sql/list_dataset_application.sql b/src/MaxKB-1.7.2/apps/dataset/sql/list_dataset_application.sql new file mode 100644 index 0000000..9da36a3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/sql/list_dataset_application.sql @@ -0,0 +1,20 @@ +SELECT + * +FROM + application +WHERE + user_id = %s UNION +SELECT + * +FROM + application +WHERE + "id" IN ( + SELECT + team_member_permission.target + FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member_permission.member_id = team_member."id" + WHERE + ( "team_member_permission"."auth_target_type" = 'APPLICATION' AND "team_member_permission"."operate"::text[] @> ARRAY['USE'] AND team_member.team_id = %s AND team_member.user_id =%s ) + ) \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/dataset/sql/list_document.sql b/src/MaxKB-1.7.2/apps/dataset/sql/list_document.sql new file mode 100644 index 0000000..818d783 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/sql/list_document.sql @@ -0,0 +1,6 @@ +SELECT + "document".* , + to_json("document"."meta") as meta, + (SELECT "count"("id") FROM "paragraph" WHERE document_id="document"."id") as "paragraph_count" +FROM + "document" "document" diff --git a/src/MaxKB-1.7.2/apps/dataset/sql/list_paragraph.sql b/src/MaxKB-1.7.2/apps/dataset/sql/list_paragraph.sql new file mode 100644 index 0000000..2256f3f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/sql/list_paragraph.sql @@ -0,0 +1,6 @@ +SELECT + (SELECT "name" FROM "document" WHERE "id"=document_id) as document_name, + (SELECT "name" FROM "dataset" WHERE "id"=dataset_id) as dataset_name, + * +FROM + "paragraph" diff --git a/src/MaxKB-1.7.2/apps/dataset/sql/list_paragraph_document_name.sql b/src/MaxKB-1.7.2/apps/dataset/sql/list_paragraph_document_name.sql new file mode 100644 index 0000000..a95209b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/sql/list_paragraph_document_name.sql @@ -0,0 +1,5 @@ +SELECT + (SELECT "name" FROM "document" WHERE "id"=document_id) as document_name, + * +FROM + "paragraph" diff --git a/src/MaxKB-1.7.2/apps/dataset/sql/list_problem.sql b/src/MaxKB-1.7.2/apps/dataset/sql/list_problem.sql new file mode 100644 index 0000000..affb513 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/sql/list_problem.sql @@ -0,0 +1,5 @@ +SELECT + problem.*, + (SELECT "count"("id") FROM "problem_paragraph_mapping" WHERE problem_id="problem"."id") as "paragraph_count" + FROM + problem problem diff --git a/src/MaxKB-1.7.2/apps/dataset/sql/list_problem_mapping.sql b/src/MaxKB-1.7.2/apps/dataset/sql/list_problem_mapping.sql new file mode 100644 index 0000000..8c8ac3c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/sql/list_problem_mapping.sql @@ -0,0 +1,2 @@ +SELECT "problem"."content",problem_paragraph_mapping.paragraph_id FROM problem problem +LEFT JOIN problem_paragraph_mapping problem_paragraph_mapping ON problem_paragraph_mapping.problem_id=problem."id" \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/dataset/sql/update_document_char_length.sql b/src/MaxKB-1.7.2/apps/dataset/sql/update_document_char_length.sql new file mode 100644 index 0000000..4a4060c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/sql/update_document_char_length.sql @@ -0,0 +1,7 @@ +UPDATE "document" +SET "char_length" = ( SELECT CASE WHEN + "sum" ( "char_length" ( "content" ) ) IS NULL THEN + 0 ELSE "sum" ( "char_length" ( "content" ) ) + END FROM paragraph WHERE "document_id" = %s ) +WHERE + "id" = %s \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/dataset/swagger_api/document_api.py b/src/MaxKB-1.7.2/apps/dataset/swagger_api/document_api.py new file mode 100644 index 0000000..637a7e5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/swagger_api/document_api.py @@ -0,0 +1,28 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: document_api.py + @date:2024/4/28 13:56 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class DocumentApi(ApiMixin): + class BatchEditHitHandlingApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'id_list': openapi.Schema(type=openapi.TYPE_ARRAY, items=openapi.Schema(type=openapi.TYPE_STRING), + title="主键id列表", + description="主键id列表"), + 'hit_handling_method': openapi.Schema(type=openapi.TYPE_STRING, title="命中处理方式", + description="directly_return|optimization"), + 'directly_return_similarity': openapi.Schema(type=openapi.TYPE_NUMBER, title="直接返回相似度") + } + ) diff --git a/src/MaxKB-1.7.2/apps/dataset/swagger_api/image_api.py b/src/MaxKB-1.7.2/apps/dataset/swagger_api/image_api.py new file mode 100644 index 0000000..f69b947 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/swagger_api/image_api.py @@ -0,0 +1,22 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: image_api.py + @date:2024/4/23 11:23 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ImageApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_FILE, + required=True, + description='上传图片文件') + ] diff --git a/src/MaxKB-1.7.2/apps/dataset/swagger_api/problem_api.py b/src/MaxKB-1.7.2/apps/dataset/swagger_api/problem_api.py new file mode 100644 index 0000000..7932e0c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/swagger_api/problem_api.py @@ -0,0 +1,176 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: problem_api.py + @date:2024/3/11 10:49 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ProblemApi(ApiMixin): + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'content', 'hit_num', 'dataset_id', 'create_time', 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="id", + description="id", default="xx"), + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容", + description="问题内容", default='问题内容'), + 'hit_num': openapi.Schema(type=openapi.TYPE_INTEGER, title="命中数量", description="命中数量", + default=1), + 'dataset_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库id", + description="知识库id", default='xxx'), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ) + } + ) + + class BatchAssociation(ApiMixin): + @staticmethod + def get_request_params_api(): + return ProblemApi.BatchOperate.get_request_params_api() + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['problem_id_list'], + properties={ + 'problem_id_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="问题id列表", + description="问题id列表", + items=openapi.Schema(type=openapi.TYPE_STRING)), + 'paragraph_list': openapi.Schema(type=openapi.TYPE_ARRAY, title="关联段落信息列表", + description="关联段落信息列表", + items=openapi.Schema(type=openapi.TYPE_OBJECT, + required=['paragraph_id', 'document_id'], + properties={ + 'paragraph_id': openapi.Schema( + type=openapi.TYPE_STRING, + title="段落id"), + 'document_id': openapi.Schema( + type=openapi.TYPE_STRING, + title="文档id") + })) + + } + ) + + class BatchOperate(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + title="问题id列表", + description="问题id列表", + type=openapi.TYPE_ARRAY, + items=openapi.Schema(type=openapi.TYPE_STRING) + ) + + class Operate(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='problem_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='问题id')] + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['content'], + properties={ + 'content': openapi.Schema(type=openapi.TYPE_STRING, title="问题内容", + description="问题内容"), + + } + ) + + class Paragraph(ApiMixin): + @staticmethod + def get_request_params_api(): + return ProblemApi.Operate.get_request_params_api() + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['content'], + properties={ + 'content': openapi.Schema(type=openapi.TYPE_STRING, max_length=4096, title="分段内容", + description="分段内容"), + 'title': openapi.Schema(type=openapi.TYPE_STRING, max_length=256, title="分段标题", + description="分段标题"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + 'hit_num': openapi.Schema(type=openapi.TYPE_NUMBER, title="命中次数", description="命中次数"), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", + description="修改时间", + default="1970-01-01 00:00:00"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", + description="创建时间", + default="1970-01-01 00:00:00" + ), + } + ) + + class Query(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id'), + openapi.Parameter(name='content', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='问题')] + + class BatchCreate(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_ARRAY, + items=ProblemApi.Create.get_request_body_api()) + + @staticmethod + def get_request_params_api(): + return ProblemApi.Create.get_request_params_api() + + class Create(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_STRING, description="问题文本") + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='dataset_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='知识库id')] diff --git a/src/MaxKB-1.7.2/apps/dataset/task/__init__.py b/src/MaxKB-1.7.2/apps/dataset/task/__init__.py new file mode 100644 index 0000000..7bb1839 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/task/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/8/21 9:57 + @desc: +""" +from .sync import * +from .generate import * diff --git a/src/MaxKB-1.7.2/apps/dataset/task/generate.py b/src/MaxKB-1.7.2/apps/dataset/task/generate.py new file mode 100644 index 0000000..8604259 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/task/generate.py @@ -0,0 +1,64 @@ +import logging +from math import ceil + +from celery_once import QueueOnce +from django.db.models import QuerySet +from langchain_core.messages import HumanMessage + +from common.config.embedding_config import ModelManage +from dataset.models import Paragraph, Document, Status +from dataset.task.tools import save_problem +from ops import celery_app +from setting.models import Model +from setting.models_provider import get_model + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +def get_llm_model(model_id): + model = QuerySet(Model).filter(id=model_id).first() + return ModelManage.get_model(model_id, lambda _id: get_model(model)) + + +@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, + name='celery:generate_related_by_document') +def generate_related_by_document_id(document_id, model_id, prompt): + llm_model = get_llm_model(model_id) + offset = 0 + page_size = 10 + QuerySet(Document).filter(id=document_id).update(status=Status.generating) + + count = QuerySet(Paragraph).filter(document_id=document_id).count() + for i in range(0, ceil(count / page_size)): + paragraph_list = QuerySet(Paragraph).filter(document_id=document_id).all()[offset:offset + page_size] + offset += page_size + for paragraph in paragraph_list: + res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) + if (res.content is None) or (len(res.content) == 0): + continue + problems = res.content.split('\n') + for problem in problems: + save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) + + QuerySet(Document).filter(id=document_id).update(status=Status.success) + + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, + name='celery:generate_related_by_paragraph_list') +def generate_related_by_paragraph_id_list(paragraph_id_list, model_id, prompt): + llm_model = get_llm_model(model_id) + offset = 0 + page_size = 10 + count = QuerySet(Paragraph).filter(id__in=paragraph_id_list).count() + for i in range(0, ceil(count / page_size)): + paragraph_list = QuerySet(Paragraph).filter(id__in=paragraph_id_list).all()[offset:offset + page_size] + offset += page_size + for paragraph in paragraph_list: + res = llm_model.invoke([HumanMessage(content=prompt.replace('{data}', paragraph.content))]) + if (res.content is None) or (len(res.content) == 0): + continue + problems = res.content.split('\n') + for problem in problems: + save_problem(paragraph.dataset_id, paragraph.document_id, paragraph.id, problem) diff --git a/src/MaxKB-1.7.2/apps/dataset/task/sync.py b/src/MaxKB-1.7.2/apps/dataset/task/sync.py new file mode 100644 index 0000000..47c72d1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/task/sync.py @@ -0,0 +1,54 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: sync.py + @date:2024/8/20 21:37 + @desc: +""" + +import logging +import traceback +from typing import List + +from celery_once import QueueOnce + +from common.util.fork import ForkManage, Fork +from dataset.task.tools import get_save_handler, get_sync_web_document_handler, get_sync_handler + +from ops import celery_app + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:sync_web_dataset') +def sync_web_dataset(dataset_id: str, url: str, selector: str): + try: + max_kb.info(f"开始--->开始同步web知识库:{dataset_id}") + ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(), + get_save_handler(dataset_id, + selector)) + max_kb.info(f"结束--->结束同步web知识库:{dataset_id}") + except Exception as e: + max_kb_error.error(f'同步web知识库:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') + + +@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:sync_web_dataset') +def sync_replace_web_dataset(dataset_id: str, url: str, selector: str): + try: + max_kb.info(f"开始--->开始同步web知识库:{dataset_id}") + ForkManage(url, selector.split(" ") if selector is not None else []).fork(2, set(), + get_sync_handler(dataset_id + )) + max_kb.info(f"结束--->结束同步web知识库:{dataset_id}") + except Exception as e: + max_kb_error.error(f'同步web知识库:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') + + +@celery_app.task(name='celery:sync_web_document') +def sync_web_document(dataset_id, source_url_list: List[str], selector: str): + handler = get_sync_web_document_handler(dataset_id) + for source_url in source_url_list: + result = Fork(base_fork_url=source_url, selector_list=selector.split(' ')).fork() + handler(source_url, selector, result) diff --git a/src/MaxKB-1.7.2/apps/dataset/task/tools.py b/src/MaxKB-1.7.2/apps/dataset/task/tools.py new file mode 100644 index 0000000..9838a75 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/task/tools.py @@ -0,0 +1,113 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: tools.py + @date:2024/8/20 21:48 + @desc: +""" + +import logging +import re +import traceback + +from django.db.models import QuerySet + +from common.util.fork import ChildLink, Fork +from common.util.split_model import get_split_model +from dataset.models import Type, Document, DataSet, Status + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +def get_save_handler(dataset_id, selector): + from dataset.serializers.document_serializers import DocumentSerializers + + def handler(child_link: ChildLink, response: Fork.Response): + if response.status == 200: + try: + document_name = child_link.tag.text if child_link.tag is not None and len( + child_link.tag.text.strip()) > 0 else child_link.url + paragraphs = get_split_model('web.md').parse(response.content) + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + {'name': document_name, 'paragraphs': paragraphs, + 'meta': {'source_url': child_link.url, 'selector': selector}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + + return handler + + +def get_sync_handler(dataset_id): + from dataset.serializers.document_serializers import DocumentSerializers + dataset = QuerySet(DataSet).filter(id=dataset_id).first() + + def handler(child_link: ChildLink, response: Fork.Response): + if response.status == 200: + try: + + document_name = child_link.tag.text if child_link.tag is not None and len( + child_link.tag.text.strip()) > 0 else child_link.url + paragraphs = get_split_model('web.md').parse(response.content) + first = QuerySet(Document).filter(meta__source_url=child_link.url.strip(), + dataset=dataset).first() + if first is not None: + # 如果存在,使用文档同步 + DocumentSerializers.Sync(data={'document_id': first.id}).sync() + else: + # 插入 + DocumentSerializers.Create(data={'dataset_id': dataset.id}).save( + {'name': document_name, 'paragraphs': paragraphs, + 'meta': {'source_url': child_link.url.strip(), 'selector': dataset.meta.get('selector')}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + + return handler + + +def get_sync_web_document_handler(dataset_id): + from dataset.serializers.document_serializers import DocumentSerializers + + def handler(source_url: str, selector, response: Fork.Response): + if response.status == 200: + try: + paragraphs = get_split_model('web.md').parse(response.content) + # 插入 + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save( + {'name': source_url[0:128], 'paragraphs': paragraphs, + 'meta': {'source_url': source_url, 'selector': selector}, + 'type': Type.web}, with_valid=True) + except Exception as e: + logging.getLogger("max_kb_error").error(f'{str(e)}:{traceback.format_exc()}') + else: + Document(name=source_url[0:128], + dataset_id=dataset_id, + meta={'source_url': source_url, 'selector': selector}, + type=Type.web, + char_length=0, + status=Status.error).save() + + return handler + + +def save_problem(dataset_id, document_id, paragraph_id, problem): + from dataset.serializers.paragraph_serializers import ParagraphSerializers + # print(f"dataset_id: {dataset_id}") + # print(f"document_id: {document_id}") + # print(f"paragraph_id: {paragraph_id}") + # print(f"problem: {problem}") + problem = re.sub(r"^\d+\.\s*", "", problem) + pattern = r"(.*?)" + match = re.search(pattern, problem) + problem = match.group(1) if match else None + if problem is None or len(problem) == 0: + return + try: + ParagraphSerializers.Problem( + data={"dataset_id": dataset_id, 'document_id': document_id, + 'paragraph_id': paragraph_id}).save(instance={"content": problem}, with_valid=True) + except Exception as e: + max_kb_error.error(f'关联问题失败: {e}') diff --git a/src/MaxKB-1.7.2/apps/dataset/template/MaxKB表格模板.csv b/src/MaxKB-1.7.2/apps/dataset/template/MaxKB表格模板.csv new file mode 100644 index 0000000..7cf0f63 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/template/MaxKB表格模板.csv @@ -0,0 +1,13 @@ +职务,报销类型,一线城市报销标准(元),二线城市报销标准(元),三线城市报销标准(元) +普通员工,住宿费,500,400,300 +部门主管,住宿费,600,500,400 +部门总监,住宿费,700,600,500 +区域总经理,住宿费,800,700,600 +普通员工,伙食费,50,40,30 +部门主管,伙食费,50,40,30 +部门总监,伙食费,50,40,30 +区域总经理,伙食费,50,40,30 +普通员工,交通费,50,40,30 +部门主管,交通费,50,40,30 +部门总监,交通费,50,40,30 +区域总经理,交通费,50,40,30 diff --git a/src/MaxKB-1.7.2/apps/dataset/template/MaxKB表格模板.xlsx b/src/MaxKB-1.7.2/apps/dataset/template/MaxKB表格模板.xlsx new file mode 100644 index 0000000..2bc94a5 Binary files /dev/null and b/src/MaxKB-1.7.2/apps/dataset/template/MaxKB表格模板.xlsx differ diff --git a/src/MaxKB-1.7.2/apps/dataset/template/csv_template.csv b/src/MaxKB-1.7.2/apps/dataset/template/csv_template.csv new file mode 100644 index 0000000..b306a9c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/template/csv_template.csv @@ -0,0 +1,8 @@ +分段标题(选填),分段内容(必填,问题答案,最长不超过4096个字符)),问题(选填,单元格内一行一个) +MaxKB产品介绍,"MaxKB 是一款基于 LLM 大语言模型的知识库问答系统。MaxKB = Max Knowledge Base,旨在成为企业的最强大脑。 +开箱即用:支持直接上传文档、自动爬取在线文档,支持文本自动拆分、向量化,智能问答交互体验好; +无缝嵌入:支持零编码快速嵌入到第三方业务系统; +多模型支持:支持对接主流的大模型,包括 Ollama 本地私有大模型(如 Llama 2、Llama 3、qwen)、通义千问、OpenAI、Azure OpenAI、Kimi、智谱 AI、讯飞星火和百度千帆大模型等。","MaxKB是什么? +MaxKB产品介绍 +MaxKB支持的大语言模型 +MaxKB优势" diff --git a/src/MaxKB-1.7.2/apps/dataset/template/excel_template.xlsx b/src/MaxKB-1.7.2/apps/dataset/template/excel_template.xlsx new file mode 100644 index 0000000..6517b15 Binary files /dev/null and b/src/MaxKB-1.7.2/apps/dataset/template/excel_template.xlsx differ diff --git a/src/MaxKB-1.7.2/apps/dataset/tests.py b/src/MaxKB-1.7.2/apps/dataset/tests.py new file mode 100644 index 0000000..7ce503c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/src/MaxKB-1.7.2/apps/dataset/urls.py b/src/MaxKB-1.7.2/apps/dataset/urls.py new file mode 100644 index 0000000..b224635 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/urls.py @@ -0,0 +1,68 @@ +from django.urls import path + +from . import views + +app_name = "dataset" +urlpatterns = [ + path('dataset', views.Dataset.as_view(), name="dataset"), + path('dataset/web', views.Dataset.CreateWebDataset.as_view()), + path('dataset/qa', views.Dataset.CreateQADataset.as_view()), + path('dataset/', views.Dataset.Operate.as_view(), name="dataset_key"), + path('dataset//export', views.Dataset.Export.as_view(), name="export"), + path('dataset//re_embedding', views.Dataset.Embedding.as_view(), name="dataset_key"), + path('dataset//application', views.Dataset.Application.as_view()), + path('dataset//', views.Dataset.Page.as_view(), name="dataset"), + path('dataset//sync_web', views.Dataset.SyncWeb.as_view()), + path('dataset//hit_test', views.Dataset.HitTest.as_view()), + path('dataset//document', views.Document.as_view(), name='document'), + path('dataset//model', views.Dataset.Model.as_view()), + path('dataset/document/template/export', views.Template.as_view()), + path('dataset/document/table_template/export', views.TableTemplate.as_view()), + path('dataset//document/web', views.WebDocument.as_view()), + path('dataset//document/qa', views.QaDocument.as_view()), + path('dataset//document/table', views.TableDocument.as_view()), + path('dataset//document/_bach', views.Document.Batch.as_view()), + path('dataset//document/batch_hit_handling', views.Document.BatchEditHitHandling.as_view()), + path('dataset//document//', views.Document.Page.as_view()), + path('dataset//document/batch_refresh', views.Document.BatchRefresh.as_view()), + path('dataset//document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()), + path('dataset//document/', views.Document.Operate.as_view(), + name="document_operate"), + path('dataset/document/split', views.Document.Split.as_view(), + name="document_operate"), + path('dataset/document/split_pattern', views.Document.SplitPattern.as_view(), + name="document_operate"), + path('dataset//document/migrate/', views.Document.Migrate.as_view()), + path('dataset//document//export', views.Document.Export.as_view(), + name="document_export"), + path('dataset//document//sync', views.Document.SyncWeb.as_view()), + path('dataset//document//refresh', views.Document.Refresh.as_view()), + path('dataset//document//paragraph', views.Paragraph.as_view()), + path('dataset//document/batch_generate_related', views.Document.BatchGenerateRelated.as_view()), + path( + 'dataset//document//paragraph/migrate/dataset//document/', + views.Paragraph.BatchMigrate.as_view()), + path('dataset//document//paragraph/_batch', views.Paragraph.Batch.as_view()), + path('dataset//document//paragraph//', + views.Paragraph.Page.as_view(), name='paragraph_page'), + path('dataset//document//paragraph/batch_generate_related', views.Paragraph.BatchGenerateRelated.as_view()), + path('dataset//document//paragraph/', + views.Paragraph.Operate.as_view()), + path('dataset//document//paragraph//problem', + views.Paragraph.Problem.as_view()), + path( + 'dataset//document//paragraph//problem//un_association', + views.Paragraph.Problem.UnAssociation.as_view()), + path( + 'dataset//document//paragraph//problem//association', + views.Paragraph.Problem.Association.as_view()), + path('dataset//problem', views.Problem.as_view()), + path('dataset//problem/_batch', views.Problem.OperateBatch.as_view()), + path('dataset//problem//', views.Problem.Page.as_view()), + path('dataset//problem/', views.Problem.Operate.as_view()), + path('dataset//problem//paragraph', views.Problem.Paragraph.as_view()), + path('image/', views.Image.Operate.as_view()), + path('image', views.Image.as_view()), + path('file/', views.FileView.Operate.as_view()), + path('file', views.FileView.as_view()) +] diff --git a/src/MaxKB-1.7.2/apps/dataset/views/__init__.py b/src/MaxKB-1.7.2/apps/dataset/views/__init__.py new file mode 100644 index 0000000..e434cec --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/views/__init__.py @@ -0,0 +1,14 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/21 9:32 + @desc: +""" +from .dataset import * +from .document import * +from .paragraph import * +from .problem import * +from .image import * +from .file import * diff --git a/src/MaxKB-1.7.2/apps/dataset/views/dataset.py b/src/MaxKB-1.7.2/apps/dataset/views/dataset.py new file mode 100644 index 0000000..4bd9e1f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/views/dataset.py @@ -0,0 +1,242 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: dataset.py + @date:2023/9/21 15:52 + @desc: +""" + +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import PermissionConstants, CompareConstants, Permission, Group, Operate, \ + ViewPermission, RoleConstants +from common.response import result +from common.response.result import get_page_request_params, get_page_api_response, get_api_response +from common.swagger_api.common_api import CommonApi +from dataset.serializers.dataset_serializers import DataSetSerializers +from setting.serializers.provider_serializers import ModelSerializer + + +class Dataset(APIView): + authentication_classes = [TokenAuth] + + class SyncWeb(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="同步Web站点知识库", + operation_id="同步Web站点知识库", + manual_parameters=DataSetSerializers.SyncWeb.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id'))], + compare=CompareConstants.AND), PermissionConstants.DATASET_EDIT, + compare=CompareConstants.AND) + def put(self, request: Request, dataset_id: str): + return result.success(DataSetSerializers.SyncWeb( + data={'sync_type': request.query_params.get('sync_type'), 'id': dataset_id, + 'user_id': str(request.user.id)}).sync()) + + class CreateQADataset(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建QA知识库", + operation_id="创建QA知识库", + manual_parameters=DataSetSerializers.Create.CreateQASerializers.get_request_params_api(), + responses=get_api_response( + DataSetSerializers.Create.CreateQASerializers.get_response_body_api()), + tags=["知识库"] + ) + @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) + def post(self, request: Request): + return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_qa({ + 'file_list': request.FILES.getlist('file'), + 'name': request.data.get('name'), + 'desc': request.data.get('desc') + })) + + class CreateWebDataset(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建web站点知识库", + operation_id="创建web站点知识库", + request_body=DataSetSerializers.Create.CreateWebSerializers.get_request_body_api(), + responses=get_api_response( + DataSetSerializers.Create.CreateWebSerializers.get_response_body_api()), + tags=["知识库"] + ) + @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) + def post(self, request: Request): + return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save_web(request.data)) + + class Application(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取知识库可用应用列表", + operation_id="获取知识库可用应用列表", + manual_parameters=DataSetSerializers.Application.get_request_params_api(), + responses=result.get_api_array_response( + DataSetSerializers.Application.get_response_body_api()), + tags=["知识库"]) + def get(self, request: Request, dataset_id: str): + return result.success(DataSetSerializers.Operate( + data={'id': dataset_id, 'user_id': str(request.user.id)}).list_application()) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取知识库列表", + operation_id="获取知识库列表", + manual_parameters=DataSetSerializers.Query.get_request_params_api(), + responses=result.get_api_array_response(DataSetSerializers.Query.get_response_body_api()), + tags=["知识库"]) + @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) + def get(self, request: Request): + d = DataSetSerializers.Query(data={**request.query_params, 'user_id': str(request.user.id)}) + d.is_valid() + return result.success(d.list()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建知识库", + operation_id="创建知识库", + request_body=DataSetSerializers.Create.get_request_body_api(), + responses=get_api_response(DataSetSerializers.Create.get_response_body_api()), + tags=["知识库"] + ) + @has_permissions(PermissionConstants.DATASET_CREATE, compare=CompareConstants.AND) + def post(self, request: Request): + return result.success(DataSetSerializers.Create(data={'user_id': request.user.id}).save(request.data)) + + class HitTest(APIView): + authentication_classes = [TokenAuth] + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="命中测试列表", operation_id="命中测试列表", + manual_parameters=CommonApi.HitTestApi.get_request_params_api(), + responses=result.get_api_array_response(CommonApi.HitTestApi.get_response_body_api()), + tags=["知识库"]) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=keywords.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + return result.success( + DataSetSerializers.HitTest(data={'id': dataset_id, 'user_id': request.user.id, + "query_text": request.query_params.get("query_text"), + "top_number": request.query_params.get("top_number"), + 'similarity': request.query_params.get('similarity'), + 'search_mode': request.query_params.get('search_mode')}).hit_test( + )) + + class Embedding(APIView): + authentication_classes = [TokenAuth] + + @action(methods="PUT", detail=False) + @swagger_auto_schema(operation_summary="重新向量化", operation_id="重新向量化", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库"] + ) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success( + DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).re_embedding()) + + class Export(APIView): + authentication_classes = [TokenAuth] + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="导出知识库", operation_id="导出知识库", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), + tags=["知识库"] + ) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + return DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).export_excel() + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods="DELETE", detail=False) + @swagger_auto_schema(operation_summary="删除知识库", operation_id="删除知识库", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库"]) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id')), + lambda r, k: Permission(group=Group.DATASET, operate=Operate.DELETE, + dynamic_tag=k.get('dataset_id')), compare=CompareConstants.AND) + def delete(self, request: Request, dataset_id: str): + operate = DataSetSerializers.Operate(data={'id': dataset_id}) + return result.success(operate.delete()) + + @action(methods="GET", detail=False) + @swagger_auto_schema(operation_summary="查询知识库详情根据知识库id", operation_id="查询知识库详情根据知识库id", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), + responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()), + tags=["知识库"]) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=keywords.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + return result.success(DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).one( + user_id=request.user.id)) + + @action(methods="PUT", detail=False) + @swagger_auto_schema(operation_summary="修改知识库信息", operation_id="修改知识库信息", + manual_parameters=DataSetSerializers.Operate.get_request_params_api(), + request_body=DataSetSerializers.Operate.get_request_body_api(), + responses=get_api_response(DataSetSerializers.Operate.get_response_body_api()), + tags=["知识库"] + ) + @has_permissions(lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success( + DataSetSerializers.Operate(data={'id': dataset_id, 'user_id': request.user.id}).edit(request.data, + user_id=request.user.id)) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取知识库分页列表", + operation_id="获取知识库分页列表", + manual_parameters=get_page_request_params( + DataSetSerializers.Query.get_request_params_api()), + responses=get_page_api_response(DataSetSerializers.Query.get_response_body_api()), + tags=["知识库"] + ) + @has_permissions(PermissionConstants.DATASET_READ, compare=CompareConstants.AND) + def get(self, request: Request, current_page, page_size): + d = DataSetSerializers.Query( + data={'name': request.query_params.get('name', None), 'desc': request.query_params.get("desc", None), + 'user_id': str(request.user.id)}) + d.is_valid() + return result.success(d.page(current_page, page_size)) + + class Model(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN, RoleConstants.USER], + [lambda r, keywords: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=keywords.get('dataset_id'))], + compare=CompareConstants.AND)) + def get(self, request: Request, dataset_id: str): + return result.success( + ModelSerializer.Query( + data={'user_id': request.user.id, 'model_type': 'LLM'}).list( + with_valid=True) + ) diff --git a/src/MaxKB-1.7.2/apps/dataset/views/document.py b/src/MaxKB-1.7.2/apps/dataset/views/document.py new file mode 100644 index 0000000..d911d0d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/views/document.py @@ -0,0 +1,406 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: document.py + @date:2023/9/22 11:32 + @desc: +""" + +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import Permission, Group, Operate, CompareConstants +from common.response import result +from common.util.common import query_params_to_single_dict +from dataset.serializers.common_serializers import BatchSerializer +from dataset.serializers.document_serializers import DocumentSerializers, DocumentWebInstanceSerializer +from dataset.swagger_api.document_api import DocumentApi + + +class Template(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取QA模版", + operation_id="获取QA模版", + manual_parameters=DocumentSerializers.Export.get_request_params_api(), + tags=["知识库/文档"]) + def get(self, request: Request): + return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).export(with_valid=True) + + +class TableTemplate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取表格模版", + operation_id="获取表格模版", + manual_parameters=DocumentSerializers.Export.get_request_params_api(), + tags=["知识库/文档"]) + def get(self, request: Request): + return DocumentSerializers.Export(data={'type': request.query_params.get('type')}).table_export(with_valid=True) + + +class WebDocument(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建Web站点文档", + operation_id="创建Web站点文档", + request_body=DocumentWebInstanceSerializer.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_web(request.data, with_valid=True)) + + +class QaDocument(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="导入QA并创建文档", + operation_id="导入QA并创建文档", + manual_parameters=DocumentWebInstanceSerializer.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Create.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_qa( + {'file_list': request.FILES.getlist('file')}, + with_valid=True)) + + +class TableDocument(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="导入表格并创建文档", + operation_id="导入表格并创建文档", + manual_parameters=DocumentWebInstanceSerializer.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Create.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save_table( + {'file_list': request.FILES.getlist('file')}, + with_valid=True)) + + +class Document(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建文档", + operation_id="创建文档", + request_body=DocumentSerializers.Create.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Create(data={'dataset_id': dataset_id}).save(request.data, with_valid=True)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="文档列表", + operation_id="文档列表", + manual_parameters=DocumentSerializers.Query.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Query.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + d = DocumentSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id}) + d.is_valid(raise_exception=True) + return result.success(d.list()) + + class BatchEditHitHandling(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量修改文档命中处理方式", + operation_id="批量修改文档命中处理方式", + request_body= + DocumentApi.BatchEditHitHandlingApi.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_edit_hit_handling(request.data)) + + class Batch(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量创建文档", + operation_id="批量创建文档", + request_body= + DocumentSerializers.Batch.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_api_array_response( + DocumentSerializers.Operate.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_save(request.data)) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量同步文档", + operation_id="批量同步文档", + request_body= + BatchSerializer.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_sync(request.data)) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="批量删除文档", + operation_id="批量删除文档", + request_body= + BatchSerializer.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str): + return result.success(DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_delete(request.data)) + + class SyncWeb(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="同步web站点类型", + operation_id="同步web站点类型", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str): + return result.success( + DocumentSerializers.Sync(data={'document_id': document_id, 'dataset_id': dataset_id}).sync( + )) + + class Refresh(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="刷新文档向量库", + operation_id="刷新文档向量库", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str): + return result.success( + DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).refresh( + )) + + class BatchRefresh(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="批量刷新文档向量库", + operation_id="批量刷新文档向量库", + request_body= + DocumentApi.BatchEditHitHandlingApi.get_request_body_api(), + manual_parameters=DocumentSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success( + DocumentSerializers.Batch(data={'dataset_id': dataset_id}).batch_refresh(request.data)) + + class Migrate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="批量迁移文档", + operation_id="批量迁移文档", + manual_parameters=DocumentSerializers.Migrate.get_request_params_api(), + request_body=DocumentSerializers.Migrate.get_request_body_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["知识库/文档"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id')), + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('target_dataset_id')), + compare=CompareConstants.AND + ) + def put(self, request: Request, dataset_id: str, target_dataset_id: str): + return result.success( + DocumentSerializers.Migrate( + data={'dataset_id': dataset_id, 'target_dataset_id': target_dataset_id, + 'document_id_list': request.data}).migrate( + + )) + + class Export(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="导出文档", + operation_id="导出文档", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str): + return DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).export() + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取文档详情", + operation_id="获取文档详情", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str): + operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}) + operate.is_valid(raise_exception=True) + return result.success(operate.one()) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改文档", + operation_id="修改文档", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + request_body=DocumentSerializers.Operate.get_request_body_api(), + responses=result.get_api_response(DocumentSerializers.Operate.get_response_body_api()), + tags=["知识库/文档"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str): + return result.success( + DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}).edit( + request.data, + with_valid=True)) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除文档", + operation_id="删除文档", + manual_parameters=DocumentSerializers.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str, document_id: str): + operate = DocumentSerializers.Operate(data={'document_id': document_id, 'dataset_id': dataset_id}) + operate.is_valid(raise_exception=True) + return result.success(operate.delete()) + + class SplitPattern(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取分段标识列表", + operation_id="获取分段标识列表", + tags=["知识库/文档"]) + def get(self, request: Request): + return result.success(DocumentSerializers.SplitPattern.list()) + + class Split(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="分段文档", + operation_id="分段文档", + manual_parameters=DocumentSerializers.Split.get_request_params_api(), + tags=["知识库/文档"]) + def post(self, request: Request): + split_data = {'file': request.FILES.getlist('file')} + request_data = request.data + if 'patterns' in request.data and request.data.get('patterns') is not None and len( + request.data.get('patterns')) > 0: + split_data.__setitem__('patterns', request_data.getlist('patterns')) + if 'limit' in request.data: + split_data.__setitem__('limit', request_data.get('limit')) + if 'with_filter' in request.data: + split_data.__setitem__('with_filter', request_data.get('with_filter')) + ds = DocumentSerializers.Split( + data=split_data) + ds.is_valid(raise_exception=True) + return result.success(ds.parse()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取知识库分页列表", + operation_id="获取知识库分页列表", + manual_parameters=DocumentSerializers.Query.get_request_params_api(), + responses=result.get_page_api_response(DocumentSerializers.Query.get_response_body_api()), + tags=["知识库/文档"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, current_page, page_size): + d = DocumentSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id}) + d.is_valid(raise_exception=True) + return result.success(d.page(current_page, page_size)) + + class BatchGenerateRelated(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str): + return result.success(DocumentSerializers.BatchGenerateRelated(data={'dataset_id': dataset_id}) + .batch_generate_related(request.data)) diff --git a/src/MaxKB-1.7.2/apps/dataset/views/file.py b/src/MaxKB-1.7.2/apps/dataset/views/file.py new file mode 100644 index 0000000..7ec437d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/views/file.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: image.py + @date:2024/4/22 16:23 + @desc: +""" +from drf_yasg import openapi +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth +from common.response import result +from dataset.serializers.file_serializers import FileSerializer + + +class FileView(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="上传文件", + operation_id="上传文件", + manual_parameters=[openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_FILE, + required=True, + description='上传文件')], + tags=["文件"]) + def post(self, request: Request): + return result.success(FileSerializer(data={'file': request.FILES.get('file')}).upload()) + + class Operate(APIView): + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取图片", + operation_id="获取图片", + tags=["文件"]) + def get(self, request: Request, file_id: str): + return FileSerializer.Operate(data={'id': file_id}).get() diff --git a/src/MaxKB-1.7.2/apps/dataset/views/image.py b/src/MaxKB-1.7.2/apps/dataset/views/image.py new file mode 100644 index 0000000..124336f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/views/image.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: image.py + @date:2024/4/22 16:23 + @desc: +""" +from drf_yasg import openapi +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.parsers import MultiPartParser +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth +from common.response import result +from dataset.serializers.image_serializers import ImageSerializer + + +class Image(APIView): + authentication_classes = [TokenAuth] + parser_classes = [MultiPartParser] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="上传图片", + operation_id="上传图片", + manual_parameters=[openapi.Parameter(name='file', + in_=openapi.IN_FORM, + type=openapi.TYPE_FILE, + required=True, + description='上传文件')], + tags=["图片"]) + def post(self, request: Request): + return result.success(ImageSerializer(data={'image': request.FILES.get('file')}).upload()) + + class Operate(APIView): + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取图片", + operation_id="获取图片", + tags=["图片"]) + def get(self, request: Request, image_id: str): + return ImageSerializer.Operate(data={'id': image_id}).get() diff --git a/src/MaxKB-1.7.2/apps/dataset/views/paragraph.py b/src/MaxKB-1.7.2/apps/dataset/views/paragraph.py new file mode 100644 index 0000000..c1286c0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/views/paragraph.py @@ -0,0 +1,246 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: paragraph_serializers.py + @date:2023/10/16 15:51 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import Permission, Group, Operate, CompareConstants +from common.response import result +from common.util.common import query_params_to_single_dict +from dataset.serializers.common_serializers import BatchSerializer +from dataset.serializers.paragraph_serializers import ParagraphSerializers + + +class Paragraph(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="段落列表", + operation_id="段落列表", + manual_parameters=ParagraphSerializers.Query.get_request_params_api(), + responses=result.get_api_array_response(ParagraphSerializers.Query.get_response_body_api()), + tags=["知识库/文档/段落"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str): + q = ParagraphSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'document_id': document_id}) + q.is_valid(raise_exception=True) + return result.success(q.list()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建段落", + operation_id="创建段落", + manual_parameters=ParagraphSerializers.Create.get_request_params_api(), + request_body=ParagraphSerializers.Create.get_request_body_api(), + responses=result.get_api_response(ParagraphSerializers.Query.get_response_body_api()), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str, document_id: str): + return result.success( + ParagraphSerializers.Create(data={'dataset_id': dataset_id, 'document_id': document_id}).save(request.data)) + + class Problem(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="添加关联问题", + operation_id="添加段落关联问题", + manual_parameters=ParagraphSerializers.Problem.get_request_params_api(), + request_body=ParagraphSerializers.Problem.get_request_body_api(), + responses=result.get_api_response(ParagraphSerializers.Problem.get_response_body_api()), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + return result.success(ParagraphSerializers.Problem( + data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).save( + request.data, with_valid=True)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取段落问题列表", + operation_id="获取段落问题列表", + manual_parameters=ParagraphSerializers.Problem.get_request_params_api(), + responses=result.get_api_array_response( + ParagraphSerializers.Problem.get_response_body_api()), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + return result.success(ParagraphSerializers.Problem( + data={"dataset_id": dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id}).list( + with_valid=True)) + + class UnAssociation(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="解除关联问题", + operation_id="解除关联问题", + manual_parameters=ParagraphSerializers.Association.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str): + return result.success(ParagraphSerializers.Association( + data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id, + 'problem_id': problem_id}).un_association()) + + class Association(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="关联问题", + operation_id="关联问题", + manual_parameters=ParagraphSerializers.Association.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str, problem_id: str): + return result.success(ParagraphSerializers.Association( + data={'dataset_id': dataset_id, 'document_id': document_id, 'paragraph_id': paragraph_id, + 'problem_id': problem_id}).association()) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['UPDATE'], detail=False) + @swagger_auto_schema(operation_summary="修改段落数据", + operation_id="修改段落数据", + manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), + request_body=ParagraphSerializers.Operate.get_request_body_api(), + responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()) + , tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + o = ParagraphSerializers.Operate( + data={"paragraph_id": paragraph_id, 'dataset_id': dataset_id, 'document_id': document_id}) + o.is_valid(raise_exception=True) + return result.success(o.edit(request.data)) + + @action(methods=['UPDATE'], detail=False) + @swagger_auto_schema(operation_summary="获取段落详情", + operation_id="获取段落详情", + manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), + responses=result.get_api_response(ParagraphSerializers.Operate.get_response_body_api()), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + o = ParagraphSerializers.Operate( + data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id}) + o.is_valid(raise_exception=True) + return result.success(o.one()) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除段落", + operation_id="删除段落", + manual_parameters=ParagraphSerializers.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str, document_id: str, paragraph_id: str): + o = ParagraphSerializers.Operate( + data={"dataset_id": dataset_id, 'document_id': document_id, "paragraph_id": paragraph_id}) + o.is_valid(raise_exception=True) + return result.success(o.delete()) + + class Batch(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="批量删除段落", + operation_id="批量删除段落", + request_body= + BatchSerializer.get_request_body_api(), + manual_parameters=ParagraphSerializers.Create.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str, document_id: str): + return result.success(ParagraphSerializers.Batch( + data={"dataset_id": dataset_id, 'document_id': document_id}).batch_delete(request.data)) + + class BatchMigrate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="批量迁移段落", + operation_id="批量迁移段落", + manual_parameters=ParagraphSerializers.Migrate.get_request_params_api(), + request_body=ParagraphSerializers.Migrate.get_request_body_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id')), + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('target_dataset_id')), + compare=CompareConstants.AND + ) + def put(self, request: Request, dataset_id: str, target_dataset_id: str, document_id: str, target_document_id): + return result.success( + ParagraphSerializers.Migrate( + data={'dataset_id': dataset_id, 'target_dataset_id': target_dataset_id, + 'document_id': document_id, + 'target_document_id': target_document_id, + 'paragraph_id_list': request.data}).migrate()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取段落列表", + operation_id="分页获取段落列表", + manual_parameters=result.get_page_request_params( + ParagraphSerializers.Query.get_request_params_api()), + responses=result.get_page_api_response(ParagraphSerializers.Query.get_response_body_api()), + tags=["知识库/文档/段落"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, document_id: str, current_page, page_size): + d = ParagraphSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'document_id': document_id}) + d.is_valid(raise_exception=True) + return result.success(d.page(current_page, page_size)) + + class BatchGenerateRelated(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, document_id: str): + return result.success( + ParagraphSerializers.BatchGenerateRelated(data={'dataset_id': dataset_id, 'document_id': document_id}) + .batch_generate_related(request.data)) diff --git a/src/MaxKB-1.7.2/apps/dataset/views/problem.py b/src/MaxKB-1.7.2/apps/dataset/views/problem.py new file mode 100644 index 0000000..1d0ccb5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/dataset/views/problem.py @@ -0,0 +1,154 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: problem.py + @date:2023/10/23 13:54 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import Permission, Group, Operate +from common.response import result +from common.util.common import query_params_to_single_dict +from dataset.serializers.problem_serializers import ProblemSerializers +from dataset.swagger_api.problem_api import ProblemApi + + +class Problem(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="问题列表", + operation_id="问题列表", + manual_parameters=ProblemApi.Query.get_request_params_api(), + responses=result.get_api_array_response(ProblemApi.get_response_body_api()), + tags=["知识库/文档/段落/问题"] + ) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str): + q = ProblemSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id}) + q.is_valid(raise_exception=True) + return result.success(q.list()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建问题", + operation_id="创建问题", + manual_parameters=ProblemApi.BatchCreate.get_request_params_api(), + request_body=ProblemApi.BatchCreate.get_request_body_api(), + responses=result.get_api_response(ProblemApi.Query.get_response_body_api()), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + ProblemSerializers.Create( + data={'dataset_id': dataset_id, 'problem_list': request.data}).batch()) + + class Paragraph(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取关联段落列表", + operation_id="获取关联段落列表", + manual_parameters=ProblemApi.Paragraph.get_request_params_api(), + responses=result.get_api_array_response(ProblemApi.Paragraph.get_response_body_api()), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, problem_id: str): + return result.success(ProblemSerializers.Operate( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'problem_id': problem_id}).list_paragraph()) + + class OperateBatch(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="批量删除问题", + operation_id="批量删除问题", + request_body= + ProblemApi.BatchOperate.get_request_body_api(), + manual_parameters=ProblemApi.BatchOperate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str): + return result.success( + ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).delete(request.data)) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量关联段落", + operation_id="批量关联段落", + request_body=ProblemApi.BatchAssociation.get_request_body_api(), + manual_parameters=ProblemApi.BatchOperate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def post(self, request: Request, dataset_id: str): + return result.success( + ProblemSerializers.BatchOperate(data={'dataset_id': dataset_id}).association(request.data)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除问题", + operation_id="删除问题", + manual_parameters=ProblemApi.Operate.get_request_params_api(), + responses=result.get_default_response(), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def delete(self, request: Request, dataset_id: str, problem_id: str): + return result.success(ProblemSerializers.Operate( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'problem_id': problem_id}).delete()) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改问题", + operation_id="修改问题", + manual_parameters=ProblemApi.Operate.get_request_params_api(), + request_body=ProblemApi.Operate.get_request_body_api(), + responses=result.get_api_response(ProblemApi.get_response_body_api()), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.MANAGE, + dynamic_tag=k.get('dataset_id'))) + def put(self, request: Request, dataset_id: str, problem_id: str): + return result.success(ProblemSerializers.Operate( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id, + 'problem_id': problem_id}).edit(request.data)) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取问题列表", + operation_id="分页获取问题列表", + manual_parameters=result.get_page_request_params( + ProblemApi.Query.get_request_params_api()), + responses=result.get_page_api_response(ProblemApi.get_response_body_api()), + tags=["知识库/文档/段落/问题"]) + @has_permissions( + lambda r, k: Permission(group=Group.DATASET, operate=Operate.USE, + dynamic_tag=k.get('dataset_id'))) + def get(self, request: Request, dataset_id: str, current_page, page_size): + d = ProblemSerializers.Query( + data={**query_params_to_single_dict(request.query_params), 'dataset_id': dataset_id}) + d.is_valid(raise_exception=True) + return result.success(d.page(current_page, page_size)) diff --git a/src/MaxKB-1.7.2/apps/embedding/__init__.py b/src/MaxKB-1.7.2/apps/embedding/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/embedding/admin.py b/src/MaxKB-1.7.2/apps/embedding/admin.py new file mode 100644 index 0000000..8c38f3f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/src/MaxKB-1.7.2/apps/embedding/apps.py b/src/MaxKB-1.7.2/apps/embedding/apps.py new file mode 100644 index 0000000..45a5d88 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class EmbeddingConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'embedding' diff --git a/src/MaxKB-1.7.2/apps/embedding/migrations/0001_initial.py b/src/MaxKB-1.7.2/apps/embedding/migrations/0001_initial.py new file mode 100644 index 0000000..82e850e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/migrations/0001_initial.py @@ -0,0 +1,35 @@ +# Generated by Django 4.1.10 on 2024-03-18 17:48 + +import common.field.vector_field +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('dataset', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Embedding', + fields=[ + ('id', models.CharField(max_length=128, primary_key=True, serialize=False, verbose_name='主键id')), + ('source_id', models.CharField(max_length=128, verbose_name='资源id')), + ('source_type', models.CharField(choices=[('0', '问题'), ('1', '段落'), ('2', '标题')], default='0', max_length=5, verbose_name='资源类型')), + ('is_active', models.BooleanField(default=True, max_length=1, verbose_name='是否可用')), + ('embedding', common.field.vector_field.VectorField(verbose_name='向量')), + ('meta', models.JSONField(default=dict, verbose_name='元数据')), + ('dataset', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.dataset', verbose_name='文档关联')), + ('document', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.document', verbose_name='文档关联')), + ('paragraph', models.ForeignKey(db_constraint=False, on_delete=django.db.models.deletion.DO_NOTHING, to='dataset.paragraph', verbose_name='段落关联')), + ], + options={ + 'db_table': 'embedding', + 'unique_together': {('source_id', 'source_type')}, + }, + ), + ] diff --git a/src/MaxKB-1.7.2/apps/embedding/migrations/0002_embedding_search_vector.py b/src/MaxKB-1.7.2/apps/embedding/migrations/0002_embedding_search_vector.py new file mode 100644 index 0000000..c73a5a0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/migrations/0002_embedding_search_vector.py @@ -0,0 +1,65 @@ +# Generated by Django 4.1.13 on 2024-04-16 11:43 +import threading + +import django.contrib.postgres.search +from django.db import migrations + +from common.util.common import sub_array +from common.util.ts_vecto_util import to_ts_vector +from dataset.models import Status +from embedding.models import Embedding + + +def update_embedding_search_vector(embedding, paragraph_list): + paragraphs = [paragraph for paragraph in paragraph_list if paragraph.id == embedding.get('paragraph')] + if len(paragraphs) > 0: + content = paragraphs[0].title + paragraphs[0].content + return Embedding(id=embedding.get('id'), search_vector=to_ts_vector(content)) + return Embedding(id=embedding.get('id'), search_vector="") + + +def save_keywords(apps, schema_editor): + try: + document = apps.get_model("dataset", "Document") + embedding = apps.get_model("embedding", "Embedding") + paragraph = apps.get_model('dataset', 'Paragraph') + db_alias = schema_editor.connection.alias + document_list = document.objects.using(db_alias).all() + for document in document_list: + document.status = Status.embedding + document.save() + paragraph_list = paragraph.objects.using(db_alias).filter(document=document).all() + embedding_list = embedding.objects.using(db_alias).filter(document=document).values('id', 'search_vector', + 'paragraph') + embedding_update_list = [update_embedding_search_vector(embedding, paragraph_list) for embedding + in embedding_list] + child_array = sub_array(embedding_update_list, 50) + for c in child_array: + try: + embedding.objects.using(db_alias).bulk_update(c, ['search_vector']) + except Exception as e: + print(e) + document.status = Status.success + document.save() + except Exception as e: + print(e) + + +def async_save_keywords(apps, schema_editor): + thread = threading.Thread(target=save_keywords, args=(apps, schema_editor)) + thread.start() + + +class Migration(migrations.Migration): + dependencies = [ + ('embedding', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='embedding', + name='search_vector', + field=django.contrib.postgres.search.SearchVectorField(default='', verbose_name='分词'), + ), + migrations.RunPython(async_save_keywords) + ] diff --git a/src/MaxKB-1.7.2/apps/embedding/migrations/0003_alter_embedding_unique_together.py b/src/MaxKB-1.7.2/apps/embedding/migrations/0003_alter_embedding_unique_together.py new file mode 100644 index 0000000..9cb4506 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/migrations/0003_alter_embedding_unique_together.py @@ -0,0 +1,17 @@ +# Generated by Django 4.2.14 on 2024-07-23 18:14 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('embedding', '0002_embedding_search_vector'), + ] + + operations = [ + migrations.AlterUniqueTogether( + name='embedding', + unique_together=set(), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/embedding/migrations/__init__.py b/src/MaxKB-1.7.2/apps/embedding/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/embedding/models/__init__.py b/src/MaxKB-1.7.2/apps/embedding/models/__init__.py new file mode 100644 index 0000000..b5dcf44 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/models/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/21 14:53 + @desc: +""" +from .embedding import * diff --git a/src/MaxKB-1.7.2/apps/embedding/models/embedding.py b/src/MaxKB-1.7.2/apps/embedding/models/embedding.py new file mode 100644 index 0000000..5f954e3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/models/embedding.py @@ -0,0 +1,52 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: embedding.py + @date:2023/9/21 15:46 + @desc: +""" +from django.db import models + +from common.field.vector_field import VectorField +from dataset.models.data_set import Document, Paragraph, DataSet +from django.contrib.postgres.search import SearchVectorField + + +class SourceType(models.TextChoices): + """订单类型""" + PROBLEM = 0, '问题' + PARAGRAPH = 1, '段落' + TITLE = 2, '标题' + + +class SearchMode(models.TextChoices): + embedding = 'embedding' + keywords = 'keywords' + blend = 'blend' + + +class Embedding(models.Model): + id = models.CharField(max_length=128, primary_key=True, verbose_name="主键id") + + source_id = models.CharField(max_length=128, verbose_name="资源id") + + source_type = models.CharField(verbose_name='资源类型', max_length=5, choices=SourceType.choices, + default=SourceType.PROBLEM) + + is_active = models.BooleanField(verbose_name="是否可用", max_length=1, default=True) + + dataset = models.ForeignKey(DataSet, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False) + + document = models.ForeignKey(Document, on_delete=models.DO_NOTHING, verbose_name="文档关联", db_constraint=False) + + paragraph = models.ForeignKey(Paragraph, on_delete=models.DO_NOTHING, verbose_name="段落关联", db_constraint=False) + + embedding = VectorField(verbose_name="向量") + + search_vector = SearchVectorField(verbose_name="分词", default="") + + meta = models.JSONField(verbose_name="元数据", default=dict) + + class Meta: + db_table = "embedding" diff --git a/src/MaxKB-1.7.2/apps/embedding/sql/blend_search.sql b/src/MaxKB-1.7.2/apps/embedding/sql/blend_search.sql new file mode 100644 index 0000000..afb1f00 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/sql/blend_search.sql @@ -0,0 +1,26 @@ +SELECT + paragraph_id, + comprehensive_score, + comprehensive_score AS similarity +FROM + ( + SELECT DISTINCT ON + ( "paragraph_id" ) ( similarity ),* , + similarity AS comprehensive_score + FROM + ( + SELECT + *, + (( 1 - ( embedding.embedding <=> %s ) )+ts_rank_cd( embedding.search_vector, websearch_to_tsquery('simple', %s ), 32 )) AS similarity + FROM + embedding ${embedding_query} + ) TEMP + ORDER BY + paragraph_id, + similarity DESC + ) DISTINCT_TEMP +WHERE + comprehensive_score >%s +ORDER BY + comprehensive_score DESC + LIMIT %s \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/embedding/sql/embedding_search.sql b/src/MaxKB-1.7.2/apps/embedding/sql/embedding_search.sql new file mode 100644 index 0000000..ce3d4a5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/sql/embedding_search.sql @@ -0,0 +1,17 @@ +SELECT + paragraph_id, + comprehensive_score, + comprehensive_score as similarity +FROM + ( + SELECT DISTINCT ON + ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score + FROM + ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query}) TEMP + ORDER BY + paragraph_id, + similarity DESC + ) DISTINCT_TEMP +WHERE comprehensive_score>%s +ORDER BY comprehensive_score DESC +LIMIT %s \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/embedding/sql/hit_test.sql b/src/MaxKB-1.7.2/apps/embedding/sql/hit_test.sql new file mode 100644 index 0000000..8feffc8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/sql/hit_test.sql @@ -0,0 +1,17 @@ +SELECT + paragraph_id, + comprehensive_score, + comprehensive_score as similarity +FROM + ( + SELECT DISTINCT ON + ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score + FROM + ( SELECT *, ( 1 - ( embedding.embedding <=> %s ) ) AS similarity FROM embedding ${embedding_query} ) TEMP + ORDER BY + paragraph_id, + similarity DESC + ) DISTINCT_TEMP +WHERE comprehensive_score>%s +ORDER BY comprehensive_score DESC +LIMIT %s \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/embedding/sql/keywords_search.sql b/src/MaxKB-1.7.2/apps/embedding/sql/keywords_search.sql new file mode 100644 index 0000000..a27d0a6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/sql/keywords_search.sql @@ -0,0 +1,17 @@ +SELECT + paragraph_id, + comprehensive_score, + comprehensive_score as similarity +FROM + ( + SELECT DISTINCT ON + ("paragraph_id") ( similarity ),* ,similarity AS comprehensive_score + FROM + ( SELECT *,ts_rank_cd(embedding.search_vector,websearch_to_tsquery('simple',%s),32) AS similarity FROM embedding ${keywords_query}) TEMP + ORDER BY + paragraph_id, + similarity DESC + ) DISTINCT_TEMP +WHERE comprehensive_score>%s +ORDER BY comprehensive_score DESC +LIMIT %s \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/embedding/task/__init__.py b/src/MaxKB-1.7.2/apps/embedding/task/__init__.py new file mode 100644 index 0000000..e5e7dd3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/task/__init__.py @@ -0,0 +1 @@ +from .embedding import * diff --git a/src/MaxKB-1.7.2/apps/embedding/task/embedding.py b/src/MaxKB-1.7.2/apps/embedding/task/embedding.py new file mode 100644 index 0000000..b6d5dfb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/task/embedding.py @@ -0,0 +1,245 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/8/19 14:13 + @desc: +""" +import datetime +import logging +import traceback +from typing import List + +from celery_once import QueueOnce +from django.db.models import QuerySet + +from common.config.embedding_config import ModelManage +from common.event import ListenerManagement, UpdateProblemArgs, UpdateEmbeddingDatasetIdArgs, \ + UpdateEmbeddingDocumentIdArgs +from dataset.models import Document, Status +from ops import celery_app +from setting.models import Model +from setting.models_provider import get_model + +max_kb_error = logging.getLogger("max_kb_error") +max_kb = logging.getLogger("max_kb") + + +def get_embedding_model(model_id, exception_handler=lambda e: max_kb_error.error( + f'获取向量模型失败:{str(e)}{traceback.format_exc()}')): + try: + model = QuerySet(Model).filter(id=model_id).first() + embedding_model = ModelManage.get_model(model_id, + lambda _id: get_model(model)) + except Exception as e: + exception_handler(e) + raise e + return embedding_model + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id']}, name='celery:embedding_by_paragraph') +def embedding_by_paragraph(paragraph_id, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_paragraph(paragraph_id, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_data_list') +def embedding_by_paragraph_data_list(data_list, paragraph_id_list, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_paragraph_data_list(data_list, paragraph_id_list, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['paragraph_id_list']}, name='celery:embedding_by_paragraph_list') +def embedding_by_paragraph_list(paragraph_id_list, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_paragraph_list(paragraph_id_list, embedding_model) + + +@celery_app.task(base=QueueOnce, once={'keys': ['document_id']}, name='celery:embedding_by_document') +def embedding_by_document(document_id, model_id): + """ + 向量化文档 + @param document_id: 文档id + @param model_id 向量模型 + :return: None + """ + + def exception_handler(e): + QuerySet(Document).filter(id=document_id).update( + **{'status': Status.error, 'update_time': datetime.datetime.now()}) + max_kb_error.error( + f'获取向量模型失败:{str(e)}{traceback.format_exc()}') + + embedding_model = get_embedding_model(model_id, exception_handler) + ListenerManagement.embedding_by_document(document_id, embedding_model) + + +@celery_app.task(name='celery:embedding_by_document_list') +def embedding_by_document_list(document_id_list, model_id): + """ + 向量化文档 + @param document_id_list: 文档id列表 + @param model_id 向量模型 + :return: None + """ + for document_id in document_id_list: + embedding_by_document.delay(document_id, model_id) + + +@celery_app.task(base=QueueOnce, once={'keys': ['dataset_id']}, name='celery:embedding_by_dataset') +def embedding_by_dataset(dataset_id, model_id): + """ + 向量化知识库 + @param dataset_id: 知识库id + @param model_id 向量模型 + :return: None + """ + max_kb.info(f"开始--->向量化数据集:{dataset_id}") + try: + ListenerManagement.delete_embedding_by_dataset(dataset_id) + document_list = QuerySet(Document).filter(dataset_id=dataset_id) + max_kb.info(f"数据集文档:{[d.name for d in document_list]}") + for document in document_list: + try: + embedding_by_document.delay(document.id, model_id) + except Exception as e: + pass + except Exception as e: + max_kb_error.error(f'向量化数据集:{dataset_id}出现错误{str(e)}{traceback.format_exc()}') + finally: + max_kb.info(f"结束--->向量化数据集:{dataset_id}") + + +def embedding_by_problem(args, model_id): + """ + 向量话问题 + @param args: 问题对象 + @param model_id: 模型id + @return: + """ + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_problem(args, embedding_model) + + +def embedding_by_data_list(args: List, model_id): + embedding_model = get_embedding_model(model_id) + ListenerManagement.embedding_by_data_list(args, embedding_model) + + +def delete_embedding_by_document(document_id): + """ + 删除指定文档id的向量 + @param document_id: 文档id + @return: None + """ + + ListenerManagement.delete_embedding_by_document(document_id) + + +def delete_embedding_by_document_list(document_id_list: List[str]): + """ + 删除指定文档列表的向量数据 + @param document_id_list: 文档id列表 + @return: None + """ + ListenerManagement.delete_embedding_by_document_list(document_id_list) + + +def delete_embedding_by_dataset(dataset_id): + """ + 删除指定数据集向量数据 + @param dataset_id: 数据集id + @return: None + """ + ListenerManagement.delete_embedding_by_dataset(dataset_id) + + +def delete_embedding_by_paragraph(paragraph_id): + """ + 删除指定段落的向量数据 + @param paragraph_id: 段落id + @return: None + """ + ListenerManagement.delete_embedding_by_paragraph(paragraph_id) + + +def delete_embedding_by_source(source_id): + """ + 删除指定资源id的向量数据 + @param source_id: 资源id + @return: None + """ + ListenerManagement.delete_embedding_by_source(source_id) + + +def disable_embedding_by_paragraph(paragraph_id): + """ + 禁用某个段落id的向量 + @param paragraph_id: 段落id + @return: None + """ + ListenerManagement.disable_embedding_by_paragraph(paragraph_id) + + +def enable_embedding_by_paragraph(paragraph_id): + """ + 开启某个段落id的向量数据 + @param paragraph_id: 段落id + @return: None + """ + ListenerManagement.enable_embedding_by_paragraph(paragraph_id) + + +def delete_embedding_by_source_ids(source_ids: List[str]): + """ + 删除向量根据source_id_list + @param source_ids: + @return: + """ + ListenerManagement.delete_embedding_by_source_ids(source_ids) + + +def update_problem_embedding(problem_id: str, problem_content: str, model_id): + """ + 更新问题 + @param problem_id: + @param problem_content: + @param model_id: + @return: + """ + model = get_embedding_model(model_id) + ListenerManagement.update_problem(UpdateProblemArgs(problem_id, problem_content, model)) + + +def update_embedding_dataset_id(paragraph_id_list, target_dataset_id): + """ + 修改向量数据到指定知识库 + @param paragraph_id_list: 指定段落的向量数据 + @param target_dataset_id: 知识库id + @return: + """ + + ListenerManagement.update_embedding_dataset_id( + UpdateEmbeddingDatasetIdArgs(paragraph_id_list, target_dataset_id)) + + +def delete_embedding_by_paragraph_ids(paragraph_ids: List[str]): + """ + 删除指定段落列表的向量数据 + @param paragraph_ids: 段落列表 + @return: None + """ + ListenerManagement.delete_embedding_by_paragraph_ids(paragraph_ids) + + +def update_embedding_document_id(paragraph_id_list, target_document_id, target_dataset_id, + target_embedding_model_id=None): + target_embedding_model = get_embedding_model( + target_embedding_model_id) if target_embedding_model_id is not None else None + ListenerManagement.update_embedding_document_id( + UpdateEmbeddingDocumentIdArgs(paragraph_id_list, target_document_id, target_dataset_id, target_embedding_model)) + + +def delete_embedding_by_dataset_id_list(dataset_id_list): + ListenerManagement.delete_embedding_by_dataset_id_list(dataset_id_list) diff --git a/src/MaxKB-1.7.2/apps/embedding/tests.py b/src/MaxKB-1.7.2/apps/embedding/tests.py new file mode 100644 index 0000000..7ce503c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/src/MaxKB-1.7.2/apps/embedding/vector/base_vector.py b/src/MaxKB-1.7.2/apps/embedding/vector/base_vector.py new file mode 100644 index 0000000..ab5ab41 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/vector/base_vector.py @@ -0,0 +1,187 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_vector.py + @date:2023/10/18 19:16 + @desc: +""" +import threading +from abc import ABC, abstractmethod +from functools import reduce +from typing import List, Dict + +from langchain_core.embeddings import Embeddings + +from common.chunk import text_to_chunk +from common.util.common import sub_array +from embedding.models import SourceType, SearchMode + +lock = threading.Lock() + + +def chunk_data(data: Dict): + if str(data.get('source_type')) == SourceType.PARAGRAPH.value: + text = data.get('text') + chunk_list = text_to_chunk(text) + return [{**data, 'text': chunk} for chunk in chunk_list] + return [data] + + +def chunk_data_list(data_list: List[Dict]): + result = [chunk_data(data) for data in data_list] + return reduce(lambda x, y: [*x, *y], result, []) + + +class BaseVectorStore(ABC): + vector_exists = False + + @abstractmethod + def vector_is_create(self) -> bool: + """ + 判断向量库是否创建 + :return: 是否创建向量库 + """ + pass + + @abstractmethod + def vector_create(self): + """ + 创建 向量库 + :return: + """ + pass + + def save_pre_handler(self): + """ + 插入前置处理器 主要是判断向量库是否创建 + :return: True + """ + if not BaseVectorStore.vector_exists: + if not self.vector_is_create(): + self.vector_create() + BaseVectorStore.vector_exists = True + return True + + def save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: Embeddings): + """ + 插入向量数据 + :param source_id: 资源id + :param dataset_id: 知识库id + :param text: 文本 + :param source_type: 资源类型 + :param document_id: 文档id + :param is_active: 是否禁用 + :param embedding: 向量化处理器 + :param paragraph_id 段落id + :return: bool + """ + self.save_pre_handler() + data = {'document_id': document_id, 'paragraph_id': paragraph_id, 'dataset_id': dataset_id, + 'is_active': is_active, 'source_id': source_id, 'source_type': source_type, 'text': text} + chunk_list = chunk_data(data) + result = sub_array(chunk_list) + for child_array in result: + self._batch_save(child_array, embedding, lambda: True) + + def batch_save(self, data_list: List[Dict], embedding: Embeddings, is_save_function): + """ + 批量插入 + @param data_list: 数据列表 + @param embedding: 向量化处理器 + @param is_save_function: + :return: bool + """ + self.save_pre_handler() + chunk_list = chunk_data_list(data_list) + result = sub_array(chunk_list) + for child_array in result: + if is_save_function(): + self._batch_save(child_array, embedding, is_save_function) + else: + break + + @abstractmethod + def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: Embeddings): + pass + + @abstractmethod + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): + pass + + def search(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], + is_active: bool, + embedding: Embeddings): + if dataset_id_list is None or len(dataset_id_list) == 0: + return [] + embedding_query = embedding.embed_query(query_text) + result = self.query(embedding_query, dataset_id_list, exclude_document_id_list, exclude_paragraph_list, + is_active, 1, 3, 0.65) + return result[0] + + @abstractmethod + def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): + pass + + @abstractmethod + def hit_test(self, query_text, dataset_id: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, + search_mode: SearchMode, + embedding: Embeddings): + pass + + @abstractmethod + def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): + pass + + @abstractmethod + def update_by_paragraph_ids(self, paragraph_ids: str, instance: Dict): + pass + + @abstractmethod + def update_by_source_id(self, source_id: str, instance: Dict): + pass + + @abstractmethod + def update_by_source_ids(self, source_ids: List[str], instance: Dict): + pass + + @abstractmethod + def delete_by_dataset_id(self, dataset_id: str): + pass + + @abstractmethod + def delete_by_document_id(self, document_id: str): + pass + + @abstractmethod + def delete_by_document_id_list(self, document_id_list: List[str]): + pass + + @abstractmethod + def delete_by_dataset_id_list(self, dataset_id_list: List[str]): + pass + + @abstractmethod + def delete_by_source_id(self, source_id: str, source_type: str): + pass + + @abstractmethod + def delete_by_source_ids(self, source_ids: List[str], source_type: str): + pass + + @abstractmethod + def delete_by_paragraph_id(self, paragraph_id: str): + pass + + @abstractmethod + def delete_by_paragraph_ids(self, paragraph_ids: List[str]): + pass diff --git a/src/MaxKB-1.7.2/apps/embedding/vector/pg_vector.py b/src/MaxKB-1.7.2/apps/embedding/vector/pg_vector.py new file mode 100644 index 0000000..8cd2146 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/vector/pg_vector.py @@ -0,0 +1,220 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: pg_vector.py + @date:2023/10/19 15:28 + @desc: +""" +import json +import os +import uuid +from abc import ABC, abstractmethod +from typing import Dict, List + +from django.db.models import QuerySet +from langchain_core.embeddings import Embeddings + +from common.db.search import generate_sql_by_query_dict +from common.db.sql_execute import select_list +from common.util.file_util import get_file_content +from common.util.ts_vecto_util import to_ts_vector, to_query +from embedding.models import Embedding, SourceType, SearchMode +from embedding.vector.base_vector import BaseVectorStore +from smartdoc.conf import PROJECT_DIR + + +class PGVector(BaseVectorStore): + + def delete_by_source_ids(self, source_ids: List[str], source_type: str): + if len(source_ids) == 0: + return + QuerySet(Embedding).filter(source_id__in=source_ids, source_type=source_type).delete() + + def update_by_source_ids(self, source_ids: List[str], instance: Dict): + QuerySet(Embedding).filter(source_id__in=source_ids).update(**instance) + + def vector_is_create(self) -> bool: + # 项目启动默认是创建好的 不需要再创建 + return True + + def vector_create(self): + return True + + def _save(self, text, source_type: SourceType, dataset_id: str, document_id: str, paragraph_id: str, source_id: str, + is_active: bool, + embedding: Embeddings): + text_embedding = embedding.embed_query(text) + embedding = Embedding(id=uuid.uuid1(), + dataset_id=dataset_id, + document_id=document_id, + is_active=is_active, + paragraph_id=paragraph_id, + source_id=source_id, + embedding=text_embedding, + source_type=source_type, + search_vector=to_ts_vector(text)) + embedding.save() + return True + + def _batch_save(self, text_list: List[Dict], embedding: Embeddings, is_save_function): + texts = [row.get('text') for row in text_list] + embeddings = embedding.embed_documents(texts) + embedding_list = [Embedding(id=uuid.uuid1(), + document_id=text_list[index].get('document_id'), + paragraph_id=text_list[index].get('paragraph_id'), + dataset_id=text_list[index].get('dataset_id'), + is_active=text_list[index].get('is_active', True), + source_id=text_list[index].get('source_id'), + source_type=text_list[index].get('source_type'), + embedding=embeddings[index], + search_vector=to_ts_vector(text_list[index]['text'])) for index in + range(0, len(texts))] + if is_save_function(): + QuerySet(Embedding).bulk_create(embedding_list) if len(embedding_list) > 0 else None + return True + + def hit_test(self, query_text, dataset_id_list: list[str], exclude_document_id_list: list[str], top_number: int, + similarity: float, + search_mode: SearchMode, + embedding: Embeddings): + if dataset_id_list is None or len(dataset_id_list) == 0: + return [] + exclude_dict = {} + embedding_query = embedding.embed_query(query_text) + query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=True) + if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: + exclude_dict.__setitem__('document_id__in', exclude_document_id_list) + query_set = query_set.exclude(**exclude_dict) + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(query_set, query_text, embedding_query, top_number, similarity, search_mode) + + def query(self, query_text: str, query_embedding: List[float], dataset_id_list: list[str], + exclude_document_id_list: list[str], + exclude_paragraph_list: list[str], is_active: bool, top_n: int, similarity: float, + search_mode: SearchMode): + exclude_dict = {} + if dataset_id_list is None or len(dataset_id_list) == 0: + return [] + query_set = QuerySet(Embedding).filter(dataset_id__in=dataset_id_list, is_active=is_active) + if exclude_document_id_list is not None and len(exclude_document_id_list) > 0: + query_set = query_set.exclude(document_id__in=exclude_document_id_list) + if exclude_paragraph_list is not None and len(exclude_paragraph_list) > 0: + query_set = query_set.exclude(paragraph_id__in=exclude_paragraph_list) + query_set = query_set.exclude(**exclude_dict) + for search_handle in search_handle_list: + if search_handle.support(search_mode): + return search_handle.handle(query_set, query_text, query_embedding, top_n, similarity, search_mode) + + def update_by_source_id(self, source_id: str, instance: Dict): + QuerySet(Embedding).filter(source_id=source_id).update(**instance) + + def update_by_paragraph_id(self, paragraph_id: str, instance: Dict): + QuerySet(Embedding).filter(paragraph_id=paragraph_id).update(**instance) + + def update_by_paragraph_ids(self, paragraph_id: str, instance: Dict): + QuerySet(Embedding).filter(paragraph_id__in=paragraph_id).update(**instance) + + def delete_by_dataset_id(self, dataset_id: str): + QuerySet(Embedding).filter(dataset_id=dataset_id).delete() + + def delete_by_dataset_id_list(self, dataset_id_list: List[str]): + QuerySet(Embedding).filter(dataset_id__in=dataset_id_list).delete() + + def delete_by_document_id(self, document_id: str): + QuerySet(Embedding).filter(document_id=document_id).delete() + return True + + def delete_by_document_id_list(self, document_id_list: List[str]): + if len(document_id_list) == 0: + return True + return QuerySet(Embedding).filter(document_id__in=document_id_list).delete() + + def delete_by_source_id(self, source_id: str, source_type: str): + QuerySet(Embedding).filter(source_id=source_id, source_type=source_type).delete() + return True + + def delete_by_paragraph_id(self, paragraph_id: str): + QuerySet(Embedding).filter(paragraph_id=paragraph_id).delete() + + def delete_by_paragraph_ids(self, paragraph_ids: List[str]): + QuerySet(Embedding).filter(paragraph_id__in=paragraph_ids).delete() + + +class ISearch(ABC): + @abstractmethod + def support(self, search_mode: SearchMode): + pass + + @abstractmethod + def handle(self, query_set, query_text, query_embedding, top_number: int, + similarity: float, search_mode: SearchMode): + pass + + +class EmbeddingSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'embedding_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [json.dumps(query_embedding), *exec_params, similarity, top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.embedding.value + + +class KeywordsSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'keywords_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'keywords_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [to_query(query_text), *exec_params, similarity, top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.keywords.value + + +class BlendSearch(ISearch): + def handle(self, + query_set, + query_text, + query_embedding, + top_number: int, + similarity: float, + search_mode: SearchMode): + exec_sql, exec_params = generate_sql_by_query_dict({'embedding_query': query_set}, + select_string=get_file_content( + os.path.join(PROJECT_DIR, "apps", "embedding", 'sql', + 'blend_search.sql')), + with_table_name=True) + embedding_model = select_list(exec_sql, + [json.dumps(query_embedding), to_query(query_text), *exec_params, similarity, + top_number]) + return embedding_model + + def support(self, search_mode: SearchMode): + return search_mode.value == SearchMode.blend.value + + +search_handle_list = [EmbeddingSearch(), KeywordsSearch(), BlendSearch()] diff --git a/src/MaxKB-1.7.2/apps/embedding/views.py b/src/MaxKB-1.7.2/apps/embedding/views.py new file mode 100644 index 0000000..91ea44a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/embedding/views.py @@ -0,0 +1,3 @@ +from django.shortcuts import render + +# Create your views here. diff --git a/src/MaxKB-1.7.2/apps/function_lib/__init__.py b/src/MaxKB-1.7.2/apps/function_lib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/function_lib/admin.py b/src/MaxKB-1.7.2/apps/function_lib/admin.py new file mode 100644 index 0000000..8c38f3f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/src/MaxKB-1.7.2/apps/function_lib/apps.py b/src/MaxKB-1.7.2/apps/function_lib/apps.py new file mode 100644 index 0000000..11957d6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class FunctionLibConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'function_lib' diff --git a/src/MaxKB-1.7.2/apps/function_lib/migrations/0001_initial.py b/src/MaxKB-1.7.2/apps/function_lib/migrations/0001_initial.py new file mode 100644 index 0000000..bb2fd60 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/migrations/0001_initial.py @@ -0,0 +1,34 @@ +# Generated by Django 4.2.15 on 2024-08-13 10:04 + +import django.contrib.postgres.fields +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +class Migration(migrations.Migration): + + initial = True + + dependencies = [ + ('users', '0004_alter_user_email'), + ] + + operations = [ + migrations.CreateModel( + name='FunctionLib', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, verbose_name='主键id')), + ('name', models.CharField(max_length=64, verbose_name='函数名称')), + ('desc', models.CharField(max_length=128, verbose_name='描述')), + ('code', models.CharField(max_length=102400, verbose_name='python代码')), + ('input_field_list', django.contrib.postgres.fields.ArrayField(base_field=models.JSONField(default=dict, verbose_name='输入字段'), default=list, size=None, verbose_name='输入字段列表')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='users.user', verbose_name='用户id')), + ], + options={ + 'db_table': 'function_lib', + }, + ), + ] diff --git a/src/MaxKB-1.7.2/apps/function_lib/migrations/0002_functionlib_is_active_functionlib_permission_type.py b/src/MaxKB-1.7.2/apps/function_lib/migrations/0002_functionlib_is_active_functionlib_permission_type.py new file mode 100644 index 0000000..c665ef2 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/migrations/0002_functionlib_is_active_functionlib_permission_type.py @@ -0,0 +1,23 @@ +# Generated by Django 4.2.15 on 2024-09-14 11:23 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('function_lib', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='functionlib', + name='is_active', + field=models.BooleanField(default=True), + ), + migrations.AddField( + model_name='functionlib', + name='permission_type', + field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20, verbose_name='权限类型'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/function_lib/migrations/__init__.py b/src/MaxKB-1.7.2/apps/function_lib/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/function_lib/models/__init__.py b/src/MaxKB-1.7.2/apps/function_lib/models/__init__.py new file mode 100644 index 0000000..a68550e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/models/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/2 14:55 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/function_lib/models/function.py b/src/MaxKB-1.7.2/apps/function_lib/models/function.py new file mode 100644 index 0000000..49a0e98 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/models/function.py @@ -0,0 +1,37 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: function_lib.py + @date:2024/8/2 14:59 + @desc: +""" +import uuid + +from django.contrib.postgres.fields import ArrayField +from django.db import models + +from common.mixins.app_model_mixin import AppModelMixin +from users.models import User + + +class PermissionType(models.TextChoices): + PUBLIC = "PUBLIC", '公开' + PRIVATE = "PRIVATE", "私有" + + +class FunctionLib(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + user = models.ForeignKey(User, on_delete=models.CASCADE, verbose_name="用户id") + name = models.CharField(max_length=64, verbose_name="函数名称") + desc = models.CharField(max_length=128, verbose_name="描述") + code = models.CharField(max_length=102400, verbose_name="python代码") + input_field_list = ArrayField(verbose_name="输入字段列表", + base_field=models.JSONField(verbose_name="输入字段", default=dict) + , default=list) + is_active = models.BooleanField(default=True) + permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices, + default=PermissionType.PRIVATE) + + class Meta: + db_table = "function_lib" diff --git a/src/MaxKB-1.7.2/apps/function_lib/serializers/function_lib_serializer.py b/src/MaxKB-1.7.2/apps/function_lib/serializers/function_lib_serializer.py new file mode 100644 index 0000000..fbf173d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/serializers/function_lib_serializer.py @@ -0,0 +1,223 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: function_lib_serializer.py + @date:2024/8/2 17:35 + @desc: +""" +import json +import re +import uuid + +from django.core import validators +from django.db.models import QuerySet, Q +from rest_framework import serializers + +from common.db.search import page_search +from common.exception.app_exception import AppApiException +from common.util.field_message import ErrMessage +from common.util.function_code import FunctionExecutor +from function_lib.models.function import FunctionLib +from smartdoc.const import CONFIG + +function_executor = FunctionExecutor(CONFIG.get('SANDBOX')) + + +class FunctionLibModelSerializer(serializers.ModelSerializer): + class Meta: + model = FunctionLib + fields = ['id', 'name', 'desc', 'code', 'input_field_list', 'permission_type', 'is_active', 'user_id', + 'create_time', 'update_time'] + + +class FunctionLibInputField(serializers.Serializer): + name = serializers.CharField(required=True, error_messages=ErrMessage.char('变量名')) + is_required = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("是否必填")) + type = serializers.CharField(required=True, error_messages=ErrMessage.char("类型"), validators=[ + validators.RegexValidator(regex=re.compile("^string|int|dict|array|float$"), + message="字段只支持string|int|dict|array|float", code=500) + ]) + source = serializers.CharField(required=True, error_messages=ErrMessage.char("来源"), validators=[ + validators.RegexValidator(regex=re.compile("^custom|reference$"), + message="字段只支持custom|reference", code=500) + ]) + + +class DebugField(serializers.Serializer): + name = serializers.CharField(required=True, error_messages=ErrMessage.char('变量名')) + value = serializers.CharField(required=False, allow_blank=True, allow_null=True, + error_messages=ErrMessage.char("变量值")) + + +class DebugInstance(serializers.Serializer): + debug_field_list = DebugField(required=True, many=True) + input_field_list = FunctionLibInputField(required=True, many=True) + code = serializers.CharField(required=True, error_messages=ErrMessage.char("函数内容")) + + +class EditFunctionLib(serializers.Serializer): + name = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("函数名称")) + + desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("函数描述")) + + code = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("函数内容")) + + input_field_list = FunctionLibInputField(required=False, many=True) + + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char('是否可用')) + + +class CreateFunctionLib(serializers.Serializer): + name = serializers.CharField(required=True, error_messages=ErrMessage.char("函数名称")) + + desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("函数描述")) + + code = serializers.CharField(required=True, error_messages=ErrMessage.char("函数内容")) + + input_field_list = FunctionLibInputField(required=True, many=True) + + permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char('是否可用')) + + +class FunctionLibSerializer(serializers.Serializer): + class Query(serializers.Serializer): + name = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("函数名称")) + + desc = serializers.CharField(required=False, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("函数描述")) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char("是否可用")) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def get_query_set(self): + query_set = QuerySet(FunctionLib).filter( + (Q(user_id=self.data.get('user_id')) | Q(permission_type='PUBLIC'))) + if self.data.get('name') is not None: + query_set = query_set.filter(name__contains=self.data.get('name')) + if self.data.get('desc') is not None: + query_set = query_set.filter(desc__contains=self.data.get('desc')) + if self.data.get('is_active') is not None: + query_set = query_set.filter(is_active=self.data.get('is_active')) + query_set = query_set.order_by("-create_time") + return query_set + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return [FunctionLibModelSerializer(item).data for item in self.get_query_set()] + + def page(self, current_page: int, page_size: int, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return page_search(current_page, page_size, self.get_query_set(), + post_records_handler=lambda row: FunctionLibModelSerializer(row).data) + + class Create(serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def insert(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + CreateFunctionLib(data=instance).is_valid(raise_exception=True) + function_lib = FunctionLib(id=uuid.uuid1(), name=instance.get('name'), desc=instance.get('desc'), + code=instance.get('code'), + user_id=self.data.get('user_id'), + input_field_list=instance.get('input_field_list'), + permission_type=instance.get('permission_type'), + is_active=instance.get('is_active', True)) + function_lib.save() + return FunctionLibModelSerializer(function_lib).data + + class Debug(serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def debug(self, debug_instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + DebugInstance(data=debug_instance).is_valid(raise_exception=True) + input_field_list = debug_instance.get('input_field_list') + code = debug_instance.get('code') + debug_field_list = debug_instance.get('debug_field_list') + params = {field.get('name'): self.convert_value(field.get('name'), field.get('value'), field.get('type'), + field.get('is_required')) + for field in + [{'value': self.get_field_value(debug_field_list, field.get('name'), field.get('is_required')), + **field} for field in + input_field_list]} + return function_executor.exec_code(code, params) + + @staticmethod + def get_field_value(debug_field_list, name, is_required): + result = [field for field in debug_field_list if field.get('name') == name] + if len(result) > 0: + return result[-1].get('value') + if is_required: + raise AppApiException(500, f"{name}字段未设置值") + return None + + @staticmethod + def convert_value(name: str, value: str, _type: str, is_required: bool): + if not is_required and value is None: + return None + try: + if _type == 'int': + return int(value) + if _type == 'float': + return float(value) + if _type == 'dict': + v = json.loads(value) + if isinstance(v, dict): + return v + raise Exception("类型错误") + if _type == 'array': + v = json.loads(value) + if isinstance(v, list): + return v + raise Exception("类型错误") + return value + except Exception as e: + raise AppApiException(500, f'字段:{name}类型:{_type}值:{value}类型转换错误') + + class Operate(serializers.Serializer): + id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("函数id")) + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(FunctionLib).filter(id=self.data.get('id'), user_id=self.data.get('user_id')).exists(): + raise AppApiException(500, '函数不存在') + + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + QuerySet(FunctionLib).filter(id=self.data.get('id')).delete() + return True + + def edit(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + EditFunctionLib(data=instance).is_valid(raise_exception=True) + edit_field_list = ['name', 'desc', 'code', 'input_field_list', 'permission_type', 'is_active'] + edit_dict = {field: instance.get(field) for field in edit_field_list if ( + field in instance and instance.get(field) is not None)} + QuerySet(FunctionLib).filter(id=self.data.get('id')).update(**edit_dict) + return self.one(False) + + def one(self, with_valid=True): + if with_valid: + super().is_valid(raise_exception=True) + if not QuerySet(FunctionLib).filter(id=self.data.get('id')).filter( + Q(user_id=self.data.get('user_id')) | Q(permission_type='PUBLIC')).exists(): + raise AppApiException(500, '函数不存在') + function_lib = QuerySet(FunctionLib).filter(id=self.data.get('id')).first() + return FunctionLibModelSerializer(function_lib).data diff --git a/src/MaxKB-1.7.2/apps/function_lib/serializers/py_lint_serializer.py b/src/MaxKB-1.7.2/apps/function_lib/serializers/py_lint_serializer.py new file mode 100644 index 0000000..f58ce60 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/serializers/py_lint_serializer.py @@ -0,0 +1,58 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: py_lint_serializer.py + @date:2024/9/30 15:38 + @desc: +""" +import os +import uuid + +from pylint.lint import Run +from pylint.reporters import JSON2Reporter +from rest_framework import serializers + +from common.util.field_message import ErrMessage +from smartdoc.const import PROJECT_DIR + + +class PyLintInstance(serializers.Serializer): + code = serializers.CharField(required=True, allow_null=True, allow_blank=True, + error_messages=ErrMessage.char("函数内容")) + + +def to_dict(message, file_name): + return {'line': message.line, + 'column': message.column, + 'endLine': message.end_line, + 'endColumn': message.end_column, + 'message': (message.msg or "").replace(file_name, 'code'), + 'type': message.category} + + +def get_file_name(): + file_name = f"{uuid.uuid1()}" + py_lint_dir = os.path.join(PROJECT_DIR, 'data', 'py_lint') + if not os.path.exists(py_lint_dir): + os.makedirs(py_lint_dir) + return os.path.join(py_lint_dir, file_name) + + +class PyLintSerializer(serializers.Serializer): + + def pylint(self, instance, is_valid=True): + if is_valid: + self.is_valid(raise_exception=True) + PyLintInstance(data=instance).is_valid(raise_exception=True) + code = instance.get('code') + file_name = get_file_name() + with open(file_name, 'w') as file: + file.write(code) + reporter = JSON2Reporter() + Run([file_name, + "--disable=line-too-long", + '--module-rgx=[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}'], + reporter=reporter, exit=False) + os.remove(file_name) + return [to_dict(m, os.path.basename(file_name)) for m in reporter.messages] diff --git a/src/MaxKB-1.7.2/apps/function_lib/swagger_api/function_lib_api.py b/src/MaxKB-1.7.2/apps/function_lib/swagger_api/function_lib_api.py new file mode 100644 index 0000000..9ab7f7c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/swagger_api/function_lib_api.py @@ -0,0 +1,172 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: function_lib_api.py + @date:2024/8/2 17:11 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class FunctionLibApi(ApiMixin): + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'name', 'desc', 'code', 'input_field_list', 'create_time', + 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="", description="主键id"), + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="函数名称", description="函数名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="函数描述", description="函数描述"), + 'code': openapi.Schema(type=openapi.TYPE_STRING, title="函数内容", description="函数内容"), + 'input_field_list': openapi.Schema(type=openapi.TYPE_STRING, title="输入字段", description="输入字段"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description="创建时间"), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description="修改时间"), + } + ) + + class Query(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='函数名称'), + openapi.Parameter(name='desc', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='函数描述') + ] + + class Debug(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'debug_field_list': openapi.Schema(type=openapi.TYPE_ARRAY, + description="输入变量列表", + items=openapi.Schema(type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'name': openapi.Schema( + type=openapi.TYPE_STRING, + title="变量名", + description="变量名"), + 'value': openapi.Schema( + type=openapi.TYPE_STRING, + title="变量值", + description="变量值"), + })), + 'code': openapi.Schema(type=openapi.TYPE_STRING, title="函数内容", description="函数内容"), + 'input_field_list': openapi.Schema(type=openapi.TYPE_ARRAY, + description="输入变量列表", + items=openapi.Schema(type=openapi.TYPE_OBJECT, + required=['name', 'is_required', 'source'], + properties={ + 'name': openapi.Schema( + type=openapi.TYPE_STRING, + title="变量名", + description="变量名"), + 'is_required': openapi.Schema( + type=openapi.TYPE_BOOLEAN, + title="是否必填", + description="是否必填"), + 'type': openapi.Schema( + type=openapi.TYPE_STRING, + title="字段类型", + description="字段类型 string|int|dict|array|float" + ), + 'source': openapi.Schema( + type=openapi.TYPE_STRING, + title="来源", + description="来源只支持custom|reference"), + + })) + } + ) + + class Edit(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="函数名称", description="函数名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="函数描述", description="函数描述"), + 'code': openapi.Schema(type=openapi.TYPE_STRING, title="函数内容", description="函数内容"), + 'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限", description="权限"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + 'input_field_list': openapi.Schema(type=openapi.TYPE_ARRAY, + description="输入变量列表", + items=openapi.Schema(type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'name': openapi.Schema( + type=openapi.TYPE_STRING, + title="变量名", + description="变量名"), + 'is_required': openapi.Schema( + type=openapi.TYPE_BOOLEAN, + title="是否必填", + description="是否必填"), + 'type': openapi.Schema( + type=openapi.TYPE_STRING, + title="字段类型", + description="字段类型 string|int|dict|array|float" + ), + 'source': openapi.Schema( + type=openapi.TYPE_STRING, + title="来源", + description="来源只支持custom|reference"), + + })) + } + ) + + class Create(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'code', 'input_field_list', 'permission_type'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="函数名称", description="函数名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="函数描述", description="函数描述"), + 'code': openapi.Schema(type=openapi.TYPE_STRING, title="函数内容", description="函数内容"), + 'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限", description="权限"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + 'input_field_list': openapi.Schema(type=openapi.TYPE_ARRAY, + description="输入变量列表", + items=openapi.Schema(type=openapi.TYPE_OBJECT, + required=['name', 'is_required', 'source'], + properties={ + 'name': openapi.Schema( + type=openapi.TYPE_STRING, + title="变量名", + description="变量名"), + 'is_required': openapi.Schema( + type=openapi.TYPE_BOOLEAN, + title="是否必填", + description="是否必填"), + 'type': openapi.Schema( + type=openapi.TYPE_STRING, + title="字段类型", + description="字段类型 string|int|dict|array|float" + ), + 'source': openapi.Schema( + type=openapi.TYPE_STRING, + title="来源", + description="来源只支持custom|reference"), + + })) + } + ) diff --git a/src/MaxKB-1.7.2/apps/function_lib/swagger_api/py_lint_api.py b/src/MaxKB-1.7.2/apps/function_lib/swagger_api/py_lint_api.py new file mode 100644 index 0000000..40c44a4 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/swagger_api/py_lint_api.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: py_lint_api.py + @date:2024/9/30 15:48 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class PyLintApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['code'], + properties={ + 'code': openapi.Schema(type=openapi.TYPE_STRING, title="函数内容", description="函数内容") + } + ) diff --git a/src/MaxKB-1.7.2/apps/function_lib/task/__init__.py b/src/MaxKB-1.7.2/apps/function_lib/task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/function_lib/tests.py b/src/MaxKB-1.7.2/apps/function_lib/tests.py new file mode 100644 index 0000000..7ce503c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/src/MaxKB-1.7.2/apps/function_lib/urls.py b/src/MaxKB-1.7.2/apps/function_lib/urls.py new file mode 100644 index 0000000..784b480 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/urls.py @@ -0,0 +1,13 @@ +from django.urls import path + +from . import views + +app_name = "function_lib" +urlpatterns = [ + path('function_lib', views.FunctionLibView.as_view()), + path('function_lib/debug', views.FunctionLibView.Debug.as_view()), + path('function_lib/pylint', views.PyLintView.as_view()), + path('function_lib/', views.FunctionLibView.Operate.as_view()), + path("function_lib//", views.FunctionLibView.Page.as_view(), + name="function_lib_page") +] diff --git a/src/MaxKB-1.7.2/apps/function_lib/views/__init__.py b/src/MaxKB-1.7.2/apps/function_lib/views/__init__.py new file mode 100644 index 0000000..ad3240b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/views/__init__.py @@ -0,0 +1,10 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/8/2 14:53 + @desc: +""" +from .function_lib_views import * +from .py_lint import * diff --git a/src/MaxKB-1.7.2/apps/function_lib/views/function_lib_views.py b/src/MaxKB-1.7.2/apps/function_lib/views/function_lib_views.py new file mode 100644 index 0000000..7589a60 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/views/function_lib_views.py @@ -0,0 +1,109 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: function_lib_views.py + @date:2024/8/2 17:08 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.views import APIView + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import RoleConstants +from common.response import result +from function_lib.serializers.function_lib_serializer import FunctionLibSerializer +from function_lib.swagger_api.function_lib_api import FunctionLibApi + + +class FunctionLibView(APIView): + authentication_classes = [TokenAuth] + + @action(methods=["GET"], detail=False) + @swagger_auto_schema(operation_summary="获取函数列表", + operation_id="获取函数列表", + tags=["函数库"], + manual_parameters=FunctionLibApi.Query.get_request_params_api()) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def get(self, request: Request): + return result.success( + FunctionLibSerializer.Query( + data={'name': request.query_params.get('name'), + 'desc': request.query_params.get('desc'), + 'user_id': request.user.id}).list()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建函数", + operation_id="创建函数", + request_body=FunctionLibApi.Create.get_request_body_api(), + tags=['函数库']) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def post(self, request: Request): + return result.success(FunctionLibSerializer.Create(data={'user_id': request.user.id}).insert(request.data)) + + class Debug(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="调试函数", + operation_id="调试函数", + request_body=FunctionLibApi.Debug.get_request_body_api(), + tags=['函数库']) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def post(self, request: Request): + return result.success( + FunctionLibSerializer.Debug(data={'user_id': request.user.id}).debug( + request.data)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改函数", + operation_id="修改函数", + request_body=FunctionLibApi.Edit.get_request_body_api(), + tags=['函数库']) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def put(self, request: Request, function_lib_id: str): + return result.success( + FunctionLibSerializer.Operate(data={'user_id': request.user.id, 'id': function_lib_id}).edit( + request.data)) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除函数", + operation_id="删除函数", + tags=['函数库']) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def delete(self, request: Request, function_lib_id: str): + return result.success( + FunctionLibSerializer.Operate(data={'user_id': request.user.id, 'id': function_lib_id}).delete()) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取函数详情", + operation_id="获取函数详情", + tags=['函数库']) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def get(self, request: Request, function_lib_id: str): + return result.success( + FunctionLibSerializer.Operate(data={'user_id': request.user.id, 'id': function_lib_id}).one()) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="分页获取函数列表", + operation_id="分页获取函数列表", + manual_parameters=result.get_page_request_params( + FunctionLibApi.Query.get_request_params_api()), + responses=result.get_page_api_response(FunctionLibApi.get_response_body_api()), + tags=['函数库']) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def get(self, request: Request, current_page: int, page_size: int): + return result.success( + FunctionLibSerializer.Query( + data={'name': request.query_params.get('name'), + 'desc': request.query_params.get('desc'), + 'user_id': request.user.id}).page( + current_page, page_size)) diff --git a/src/MaxKB-1.7.2/apps/function_lib/views/py_lint.py b/src/MaxKB-1.7.2/apps/function_lib/views/py_lint.py new file mode 100644 index 0000000..15fc45a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/function_lib/views/py_lint.py @@ -0,0 +1,31 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: py_lint.py + @date:2024/9/30 15:35 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.views import APIView + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import RoleConstants +from common.response import result +from function_lib.serializers.py_lint_serializer import PyLintSerializer +from function_lib.swagger_api.py_lint_api import PyLintApi + + +class PyLintView(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="校验代码", + operation_id="校验代码", + request_body=PyLintApi.get_request_body_api(), + tags=['函数库']) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def post(self, request: Request): + return result.success(PyLintSerializer(data={'user_id': request.user.id}).pylint(request.data)) diff --git a/src/MaxKB-1.7.2/apps/manage.py b/src/MaxKB-1.7.2/apps/manage.py new file mode 100644 index 0000000..dc30985 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/manage.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +"""Django's command-line utility for administrative tasks.""" +import os +import sys + + +def main(): + """Run administrative tasks.""" + os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings') + try: + from django.core.management import execute_from_command_line + except ImportError as exc: + raise ImportError( + "Couldn't import Django. Are you sure it's installed and " + "available on your PYTHONPATH environment variable? Did you " + "forget to activate a virtual environment?" + ) from exc + execute_from_command_line(sys.argv) + + +if __name__ == '__main__': + main() diff --git a/src/MaxKB-1.7.2/apps/ops/__init__.py b/src/MaxKB-1.7.2/apps/ops/__init__.py new file mode 100644 index 0000000..a02f13a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/ops/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py.py + @date:2024/8/16 14:47 + @desc: +""" +from .celery import app as celery_app diff --git a/src/MaxKB-1.7.2/apps/ops/celery/__init__.py b/src/MaxKB-1.7.2/apps/ops/celery/__init__.py new file mode 100644 index 0000000..55e727b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/ops/celery/__init__.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- + +import os + +from celery import Celery +from celery.schedules import crontab +from kombu import Exchange, Queue +from smartdoc import settings +from .heatbeat import * + +# set the default Django settings module for the 'celery' program. +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings') + +app = Celery('MaxKB') + +configs = {k: v for k, v in settings.__dict__.items() if k.startswith('CELERY')} +configs['worker_concurrency'] = 5 +# Using a string here means the worker will not have to +# pickle the object when using Windows. +# app.config_from_object('django.conf:settings', namespace='CELERY') + +configs["task_queues"] = [ + Queue("celery", Exchange("celery"), routing_key="celery"), + Queue("model", Exchange("model"), routing_key="model") +] +app.namespace = 'CELERY' +app.conf.update( + {key.replace('CELERY_', '') if key.replace('CELERY_', '').lower() == key.replace('CELERY_', + '') else key: configs.get( + key) for + key + in configs.keys()}) +app.autodiscover_tasks(lambda: [app_config.split('.')[0] for app_config in settings.INSTALLED_APPS]) diff --git a/src/MaxKB-1.7.2/apps/ops/celery/const.py b/src/MaxKB-1.7.2/apps/ops/celery/const.py new file mode 100644 index 0000000..2f88702 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/ops/celery/const.py @@ -0,0 +1,4 @@ +# -*- coding: utf-8 -*- +# + +CELERY_LOG_MAGIC_MARK = b'\x00\x00\x00\x00\x00' \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/ops/celery/decorator.py b/src/MaxKB-1.7.2/apps/ops/celery/decorator.py new file mode 100644 index 0000000..317a7f7 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/ops/celery/decorator.py @@ -0,0 +1,104 @@ +# -*- coding: utf-8 -*- +# +from functools import wraps + +_need_registered_period_tasks = [] +_after_app_ready_start_tasks = [] +_after_app_shutdown_clean_periodic_tasks = [] + + +def add_register_period_task(task): + _need_registered_period_tasks.append(task) + + +def get_register_period_tasks(): + return _need_registered_period_tasks + + +def add_after_app_shutdown_clean_task(name): + _after_app_shutdown_clean_periodic_tasks.append(name) + + +def get_after_app_shutdown_clean_tasks(): + return _after_app_shutdown_clean_periodic_tasks + + +def add_after_app_ready_task(name): + _after_app_ready_start_tasks.append(name) + + +def get_after_app_ready_tasks(): + return _after_app_ready_start_tasks + + +def register_as_period_task( + crontab=None, interval=None, name=None, + args=(), kwargs=None, + description=''): + """ + Warning: Task must have not any args and kwargs + :param crontab: "* * * * *" + :param interval: 60*60*60 + :param args: () + :param kwargs: {} + :param description: " + :param name: "" + :return: + """ + if crontab is None and interval is None: + raise SyntaxError("Must set crontab or interval one") + + def decorate(func): + if crontab is None and interval is None: + raise SyntaxError("Interval and crontab must set one") + + # Because when this decorator run, the task was not created, + # So we can't use func.name + task = '{func.__module__}.{func.__name__}'.format(func=func) + _name = name if name else task + add_register_period_task({ + _name: { + 'task': task, + 'interval': interval, + 'crontab': crontab, + 'args': args, + 'kwargs': kwargs if kwargs else {}, + 'description': description + } + }) + + @wraps(func) + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + return decorate + + +def after_app_ready_start(func): + # Because when this decorator run, the task was not created, + # So we can't use func.name + name = '{func.__module__}.{func.__name__}'.format(func=func) + if name not in _after_app_ready_start_tasks: + add_after_app_ready_task(name) + + @wraps(func) + def decorate(*args, **kwargs): + return func(*args, **kwargs) + + return decorate + + +def after_app_shutdown_clean_periodic(func): + # Because when this decorator run, the task was not created, + # So we can't use func.name + name = '{func.__module__}.{func.__name__}'.format(func=func) + if name not in _after_app_shutdown_clean_periodic_tasks: + add_after_app_shutdown_clean_task(name) + + @wraps(func) + def decorate(*args, **kwargs): + return func(*args, **kwargs) + + return decorate diff --git a/src/MaxKB-1.7.2/apps/ops/celery/heatbeat.py b/src/MaxKB-1.7.2/apps/ops/celery/heatbeat.py new file mode 100644 index 0000000..339a3c6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/ops/celery/heatbeat.py @@ -0,0 +1,25 @@ +from pathlib import Path + +from celery.signals import heartbeat_sent, worker_ready, worker_shutdown + + +@heartbeat_sent.connect +def heartbeat(sender, **kwargs): + worker_name = sender.eventer.hostname.split('@')[0] + heartbeat_path = Path('/tmp/worker_heartbeat_{}'.format(worker_name)) + heartbeat_path.touch() + + +@worker_ready.connect +def worker_ready(sender, **kwargs): + worker_name = sender.hostname.split('@')[0] + ready_path = Path('/tmp/worker_ready_{}'.format(worker_name)) + ready_path.touch() + + +@worker_shutdown.connect +def worker_shutdown(sender, **kwargs): + worker_name = sender.hostname.split('@')[0] + for signal in ['ready', 'heartbeat']: + path = Path('/tmp/worker_{}_{}'.format(signal, worker_name)) + path.unlink(missing_ok=True) diff --git a/src/MaxKB-1.7.2/apps/ops/celery/logger.py b/src/MaxKB-1.7.2/apps/ops/celery/logger.py new file mode 100644 index 0000000..bdadc56 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/ops/celery/logger.py @@ -0,0 +1,223 @@ +from logging import StreamHandler +from threading import get_ident + +from celery import current_task +from celery.signals import task_prerun, task_postrun +from django.conf import settings +from kombu import Connection, Exchange, Queue, Producer +from kombu.mixins import ConsumerMixin + +from .utils import get_celery_task_log_path +from .const import CELERY_LOG_MAGIC_MARK + +routing_key = 'celery_log' +celery_log_exchange = Exchange('celery_log_exchange', type='direct') +celery_log_queue = [Queue('celery_log', celery_log_exchange, routing_key=routing_key)] + + +class CeleryLoggerConsumer(ConsumerMixin): + def __init__(self): + self.connection = Connection(settings.CELERY_LOG_BROKER_URL) + + def get_consumers(self, Consumer, channel): + return [Consumer(queues=celery_log_queue, + accept=['pickle', 'json'], + callbacks=[self.process_task]) + ] + + def handle_task_start(self, task_id, message): + pass + + def handle_task_end(self, task_id, message): + pass + + def handle_task_log(self, task_id, msg, message): + pass + + def process_task(self, body, message): + action = body.get('action') + task_id = body.get('task_id') + msg = body.get('msg') + if action == CeleryLoggerProducer.ACTION_TASK_LOG: + self.handle_task_log(task_id, msg, message) + elif action == CeleryLoggerProducer.ACTION_TASK_START: + self.handle_task_start(task_id, message) + elif action == CeleryLoggerProducer.ACTION_TASK_END: + self.handle_task_end(task_id, message) + + +class CeleryLoggerProducer: + ACTION_TASK_START, ACTION_TASK_LOG, ACTION_TASK_END = range(3) + + def __init__(self): + self.connection = Connection(settings.CELERY_LOG_BROKER_URL) + + @property + def producer(self): + return Producer(self.connection) + + def publish(self, payload): + self.producer.publish( + payload, serializer='json', exchange=celery_log_exchange, + declare=[celery_log_exchange], routing_key=routing_key + ) + + def log(self, task_id, msg): + payload = {'task_id': task_id, 'msg': msg, 'action': self.ACTION_TASK_LOG} + return self.publish(payload) + + def read(self): + pass + + def flush(self): + pass + + def task_end(self, task_id): + payload = {'task_id': task_id, 'action': self.ACTION_TASK_END} + return self.publish(payload) + + def task_start(self, task_id): + payload = {'task_id': task_id, 'action': self.ACTION_TASK_START} + return self.publish(payload) + + +class CeleryTaskLoggerHandler(StreamHandler): + terminator = '\r\n' + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + task_prerun.connect(self.on_task_start) + task_postrun.connect(self.on_start_end) + + @staticmethod + def get_current_task_id(): + if not current_task: + return + task_id = current_task.request.root_id + return task_id + + def on_task_start(self, sender, task_id, **kwargs): + return self.handle_task_start(task_id) + + def on_start_end(self, sender, task_id, **kwargs): + return self.handle_task_end(task_id) + + def after_task_publish(self, sender, body, **kwargs): + pass + + def emit(self, record): + task_id = self.get_current_task_id() + if not task_id: + return + try: + self.write_task_log(task_id, record) + self.flush() + except Exception: + self.handleError(record) + + def write_task_log(self, task_id, msg): + pass + + def handle_task_start(self, task_id): + pass + + def handle_task_end(self, task_id): + pass + + +class CeleryThreadingLoggerHandler(CeleryTaskLoggerHandler): + @staticmethod + def get_current_thread_id(): + return str(get_ident()) + + def emit(self, record): + thread_id = self.get_current_thread_id() + try: + self.write_thread_task_log(thread_id, record) + self.flush() + except ValueError: + self.handleError(record) + + def write_thread_task_log(self, thread_id, msg): + pass + + def handle_task_start(self, task_id): + pass + + def handle_task_end(self, task_id): + pass + + def handleError(self, record) -> None: + pass + + +class CeleryTaskMQLoggerHandler(CeleryTaskLoggerHandler): + def __init__(self): + self.producer = CeleryLoggerProducer() + super().__init__(stream=None) + + def write_task_log(self, task_id, record): + msg = self.format(record) + self.producer.log(task_id, msg) + + def flush(self): + self.producer.flush() + + +class CeleryTaskFileHandler(CeleryTaskLoggerHandler): + def __init__(self, *args, **kwargs): + self.f = None + super().__init__(*args, **kwargs) + + def emit(self, record): + msg = self.format(record) + if not self.f or self.f.closed: + return + self.f.write(msg) + self.f.write(self.terminator) + self.flush() + + def flush(self): + self.f and self.f.flush() + + def handle_task_start(self, task_id): + log_path = get_celery_task_log_path(task_id) + self.f = open(log_path, 'a') + + def handle_task_end(self, task_id): + self.f and self.f.close() + + +class CeleryThreadTaskFileHandler(CeleryThreadingLoggerHandler): + def __init__(self, *args, **kwargs): + self.thread_id_fd_mapper = {} + self.task_id_thread_id_mapper = {} + super().__init__(*args, **kwargs) + + def write_thread_task_log(self, thread_id, record): + f = self.thread_id_fd_mapper.get(thread_id, None) + if not f: + raise ValueError('Not found thread task file') + msg = self.format(record) + f.write(msg.encode()) + f.write(self.terminator.encode()) + f.flush() + + def flush(self): + for f in self.thread_id_fd_mapper.values(): + f.flush() + + def handle_task_start(self, task_id): + log_path = get_celery_task_log_path(task_id) + thread_id = self.get_current_thread_id() + self.task_id_thread_id_mapper[task_id] = thread_id + f = open(log_path, 'ab') + self.thread_id_fd_mapper[thread_id] = f + + def handle_task_end(self, task_id): + ident_id = self.task_id_thread_id_mapper.get(task_id, '') + f = self.thread_id_fd_mapper.pop(ident_id, None) + if f and not f.closed: + f.write(CELERY_LOG_MAGIC_MARK) + f.close() + self.task_id_thread_id_mapper.pop(task_id, None) diff --git a/src/MaxKB-1.7.2/apps/ops/celery/signal_handler.py b/src/MaxKB-1.7.2/apps/ops/celery/signal_handler.py new file mode 100644 index 0000000..90ed624 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/ops/celery/signal_handler.py @@ -0,0 +1,63 @@ +# -*- coding: utf-8 -*- +# +import logging +import os + +from celery import subtask +from celery.signals import ( + worker_ready, worker_shutdown, after_setup_logger +) +from django.core.cache import cache +from django_celery_beat.models import PeriodicTask + +from .decorator import get_after_app_ready_tasks, get_after_app_shutdown_clean_tasks +from .logger import CeleryThreadTaskFileHandler + +logger = logging.getLogger(__file__) +safe_str = lambda x: x + + +@worker_ready.connect +def on_app_ready(sender=None, headers=None, **kwargs): + if cache.get("CELERY_APP_READY", 0) == 1: + return + cache.set("CELERY_APP_READY", 1, 10) + tasks = get_after_app_ready_tasks() + logger.debug("Work ready signal recv") + logger.debug("Start need start task: [{}]".format(", ".join(tasks))) + for task in tasks: + periodic_task = PeriodicTask.objects.filter(task=task).first() + if periodic_task and not periodic_task.enabled: + logger.debug("Periodic task [{}] is disabled!".format(task)) + continue + subtask(task).delay() + + +def delete_files(directory): + if os.path.isdir(directory): + for filename in os.listdir(directory): + file_path = os.path.join(directory, filename) + if os.path.isfile(file_path): + os.remove(file_path) + + +@worker_shutdown.connect +def after_app_shutdown_periodic_tasks(sender=None, **kwargs): + if cache.get("CELERY_APP_SHUTDOWN", 0) == 1: + return + cache.set("CELERY_APP_SHUTDOWN", 1, 10) + tasks = get_after_app_shutdown_clean_tasks() + logger.debug("Worker shutdown signal recv") + logger.debug("Clean period tasks: [{}]".format(', '.join(tasks))) + PeriodicTask.objects.filter(name__in=tasks).delete() + + +@after_setup_logger.connect +def add_celery_logger_handler(sender=None, logger=None, loglevel=None, format=None, **kwargs): + if not logger: + return + task_handler = CeleryThreadTaskFileHandler() + task_handler.setLevel(loglevel) + formatter = logging.Formatter(format) + task_handler.setFormatter(formatter) + logger.addHandler(task_handler) diff --git a/src/MaxKB-1.7.2/apps/ops/celery/utils.py b/src/MaxKB-1.7.2/apps/ops/celery/utils.py new file mode 100644 index 0000000..288089f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/ops/celery/utils.py @@ -0,0 +1,68 @@ +# -*- coding: utf-8 -*- +# +import logging +import os +import uuid + +from django.conf import settings +from django_celery_beat.models import ( + PeriodicTasks +) + +from smartdoc.const import PROJECT_DIR + +logger = logging.getLogger(__file__) + + +def disable_celery_periodic_task(task_name): + from django_celery_beat.models import PeriodicTask + PeriodicTask.objects.filter(name=task_name).update(enabled=False) + PeriodicTasks.update_changed() + + +def delete_celery_periodic_task(task_name): + from django_celery_beat.models import PeriodicTask + PeriodicTask.objects.filter(name=task_name).delete() + PeriodicTasks.update_changed() + + +def get_celery_periodic_task(task_name): + from django_celery_beat.models import PeriodicTask + task = PeriodicTask.objects.filter(name=task_name).first() + return task + + +def make_dirs(name, mode=0o755, exist_ok=False): + """ 默认权限设置为 0o755 """ + return os.makedirs(name, mode=mode, exist_ok=exist_ok) + + +def get_task_log_path(base_path, task_id, level=2): + task_id = str(task_id) + try: + uuid.UUID(task_id) + except: + return os.path.join(PROJECT_DIR, 'data', 'caution.txt') + + rel_path = os.path.join(*task_id[:level], task_id + '.log') + path = os.path.join(base_path, rel_path) + make_dirs(os.path.dirname(path), exist_ok=True) + return path + + +def get_celery_task_log_path(task_id): + return get_task_log_path(settings.CELERY_LOG_DIR, task_id) + + +def get_celery_status(): + from . import app + i = app.control.inspect() + ping_data = i.ping() or {} + active_nodes = [k for k, v in ping_data.items() if v.get('ok') == 'pong'] + active_queue_worker = set([n.split('@')[0] for n in active_nodes if n]) + # Celery Worker 数量: 2 + if len(active_queue_worker) < 2: + print("Not all celery worker worked") + return False + else: + return True diff --git a/src/MaxKB-1.7.2/apps/setting/__init__.py b/src/MaxKB-1.7.2/apps/setting/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/setting/admin.py b/src/MaxKB-1.7.2/apps/setting/admin.py new file mode 100644 index 0000000..8c38f3f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/admin.py @@ -0,0 +1,3 @@ +from django.contrib import admin + +# Register your models here. diff --git a/src/MaxKB-1.7.2/apps/setting/apps.py b/src/MaxKB-1.7.2/apps/setting/apps.py new file mode 100644 index 0000000..57d346a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/apps.py @@ -0,0 +1,6 @@ +from django.apps import AppConfig + + +class SettingConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'setting' diff --git a/src/MaxKB-1.7.2/apps/setting/migrations/0001_initial.py b/src/MaxKB-1.7.2/apps/setting/migrations/0001_initial.py new file mode 100644 index 0000000..f6900dc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/migrations/0001_initial.py @@ -0,0 +1,95 @@ +# Generated by Django 4.1.10 on 2024-03-18 16:02 + +import django.contrib.postgres.fields +from django.db import migrations, models +import django.db.models.deletion +import uuid + + +def insert_default_data(apps, schema_editor): + TeamModel = apps.get_model('setting', 'Team') + TeamModel.objects.create(user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab', name='admin的团队') + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ('users', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Team', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('user', + models.OneToOneField(on_delete=django.db.models.deletion.DO_NOTHING, primary_key=True, serialize=False, + to='users.user', verbose_name='团队所有者')), + ('name', models.CharField(max_length=128, verbose_name='团队名称')), + ], + options={ + 'db_table': 'team', + }, + ), + migrations.RunPython(insert_default_data), + migrations.CreateModel( + name='TeamMember', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('team', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='setting.team', + verbose_name='团队id')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', + verbose_name='成员用户id')), + ], + options={ + 'db_table': 'team_member', + }, + ), + migrations.CreateModel( + name='TeamMemberPermission', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('auth_target_type', + models.CharField(choices=[('DATASET', '数据集'), ('APPLICATION', '应用')], default='DATASET', + max_length=128, verbose_name='授权目标')), + ('target', models.UUIDField(verbose_name='数据集/应用id')), + ('operate', django.contrib.postgres.fields.ArrayField( + base_field=models.CharField(blank=True, choices=[('MANAGE', '管理'), ('USE', '使用')], + default='USE', max_length=256), size=None, + verbose_name='权限操作列表')), + ('member', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='setting.teammember', + verbose_name='团队成员')), + ], + options={ + 'db_table': 'team_member_permission', + }, + ), + migrations.CreateModel( + name='Model', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('name', models.CharField(max_length=128, verbose_name='名称')), + ('model_type', models.CharField(max_length=128, verbose_name='模型类型')), + ('model_name', models.CharField(max_length=128, verbose_name='模型名称')), + ('provider', models.CharField(max_length=128, verbose_name='供应商')), + ('credential', models.CharField(max_length=5120, verbose_name='模型认证信息')), + ('user', models.ForeignKey(on_delete=django.db.models.deletion.DO_NOTHING, to='users.user', + verbose_name='成员用户id')), + ], + options={ + 'db_table': 'model', + 'unique_together': {('name', 'user_id')}, + }, + ), + ] diff --git a/src/MaxKB-1.7.2/apps/setting/migrations/0002_systemsetting.py b/src/MaxKB-1.7.2/apps/setting/migrations/0002_systemsetting.py new file mode 100644 index 0000000..5c2972f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/migrations/0002_systemsetting.py @@ -0,0 +1,24 @@ +# Generated by Django 4.1.10 on 2024-03-19 16:51 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ('setting', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='SystemSetting', + fields=[ + ('create_time', models.DateTimeField(auto_now_add=True, verbose_name='创建时间')), + ('update_time', models.DateTimeField(auto_now=True, verbose_name='修改时间')), + ('type', models.IntegerField(choices=[(0, '邮箱'), (1, '私钥秘钥')], default=0, primary_key=True, serialize=False, verbose_name='设置类型')), + ('meta', models.JSONField(default=dict, verbose_name='配置数据')), + ], + options={ + 'db_table': 'system_setting', + }, + ), + ] diff --git a/src/MaxKB-1.7.2/apps/setting/migrations/0003_model_meta_model_status.py b/src/MaxKB-1.7.2/apps/setting/migrations/0003_model_meta_model_status.py new file mode 100644 index 0000000..f4956e8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/migrations/0003_model_meta_model_status.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.13 on 2024-03-22 17:51 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0002_systemsetting'), + ] + + operations = [ + migrations.AddField( + model_name='model', + name='meta', + field=models.JSONField(default=dict, verbose_name='模型元数据,用于存储下载,或者错误信息'), + ), + migrations.AddField( + model_name='model', + name='status', + field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中')], default='SUCCESS', max_length=20, verbose_name='设置类型'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/setting/migrations/0004_alter_model_credential.py b/src/MaxKB-1.7.2/apps/setting/migrations/0004_alter_model_credential.py new file mode 100644 index 0000000..4b5e488 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/migrations/0004_alter_model_credential.py @@ -0,0 +1,18 @@ +# Generated by Django 4.1.13 on 2024-04-28 18:06 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0003_model_meta_model_status'), + ] + + operations = [ + migrations.AlterField( + model_name='model', + name='credential', + field=models.CharField(max_length=102400, verbose_name='模型认证信息'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/setting/migrations/0005_model_permission_type.py b/src/MaxKB-1.7.2/apps/setting/migrations/0005_model_permission_type.py new file mode 100644 index 0000000..dba081a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/migrations/0005_model_permission_type.py @@ -0,0 +1,46 @@ +# Generated by Django 4.2.13 on 2024-07-15 15:23 +import json + +from django.db import migrations, models +from django.db.models import QuerySet + +from common.util.rsa_util import rsa_long_encrypt +from setting.models import Status, PermissionType +from smartdoc.const import CONFIG + +default_embedding_model_id = '42f63a3d-427e-11ef-b3ec-a8a1595801ab' + + +def save_default_embedding_model(apps, schema_editor): + ModelModel = apps.get_model('setting', 'Model') + cache_folder = CONFIG.get('EMBEDDING_MODEL_PATH') + model_name = CONFIG.get('EMBEDDING_MODEL_NAME') + credential = {'cache_folder': cache_folder} + model_credential_str = json.dumps(credential) + model = ModelModel(id=default_embedding_model_id, name='maxkb-embedding', status=Status.SUCCESS, + model_type="EMBEDDING", model_name=model_name, user_id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab', + provider='model_local_provider', + credential=rsa_long_encrypt(model_credential_str), meta={}, + permission_type=PermissionType.PUBLIC) + model.save() + + +def reverse_code_embedding_model(apps, schema_editor): + ModelModel = apps.get_model('setting', 'Model') + QuerySet(ModelModel).filter(id=default_embedding_model_id).delete() + + +class Migration(migrations.Migration): + dependencies = [ + ('setting', '0004_alter_model_credential'), + ] + + operations = [ + migrations.AddField( + model_name='model', + name='permission_type', + field=models.CharField(choices=[('PUBLIC', '公开'), ('PRIVATE', '私有')], default='PRIVATE', max_length=20, + verbose_name='权限类型'), + ), + migrations.RunPython(save_default_embedding_model, reverse_code_embedding_model) + ] diff --git a/src/MaxKB-1.7.2/apps/setting/migrations/0006_alter_model_status.py b/src/MaxKB-1.7.2/apps/setting/migrations/0006_alter_model_status.py new file mode 100644 index 0000000..209f57c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/migrations/0006_alter_model_status.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.14 on 2024-07-23 18:14 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0005_model_permission_type'), + ] + + operations = [ + migrations.AlterField( + model_name='model', + name='status', + field=models.CharField(choices=[('SUCCESS', '成功'), ('ERROR', '失败'), ('DOWNLOAD', '下载中'), ('PAUSE_DOWNLOAD', '暂停下载')], default='SUCCESS', max_length=20, verbose_name='设置类型'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/setting/migrations/0007_model_model_params_form.py b/src/MaxKB-1.7.2/apps/setting/migrations/0007_model_model_params_form.py new file mode 100644 index 0000000..fa40b66 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/migrations/0007_model_model_params_form.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.15 on 2024-10-15 14:49 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0006_alter_model_status'), + ] + + operations = [ + migrations.AddField( + model_name='model', + name='model_params_form', + field=models.JSONField(default=list, verbose_name='模型参数配置'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/setting/migrations/0008_modelparam.py b/src/MaxKB-1.7.2/apps/setting/migrations/0008_modelparam.py new file mode 100644 index 0000000..8be3892 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/migrations/0008_modelparam.py @@ -0,0 +1,25 @@ +# Generated by Django 4.2.15 on 2024-10-16 13:10 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('setting', '0007_model_model_params_form'), + ] + + operations = [ + migrations.CreateModel( + name='ModelParam', + fields=[ + ('id', models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('label', models.CharField(max_length=128, verbose_name='参数')), + ('field', models.CharField(max_length=256, verbose_name='显示名称')), + ('default_value', models.CharField(max_length=1000, verbose_name='默认值')), + ('input_type', models.CharField(max_length=32, verbose_name='组件类型')), + ('attrs', models.JSONField(verbose_name='属性')), + ('required', models.BooleanField(verbose_name='必填')), + ], + ), + ] diff --git a/src/MaxKB-1.7.2/apps/setting/migrations/__init__.py b/src/MaxKB-1.7.2/apps/setting/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/setting/models/__init__.py b/src/MaxKB-1.7.2/apps/setting/models/__init__.py new file mode 100644 index 0000000..155129e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models/__init__.py @@ -0,0 +1,11 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py + @date:2023/9/25 15:04 + @desc: +""" +from .team_management import * +from .model_management import * +from .system_management import * diff --git a/src/MaxKB-1.7.2/apps/setting/models/model_management.py b/src/MaxKB-1.7.2/apps/setting/models/model_management.py new file mode 100644 index 0000000..638161e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models/model_management.py @@ -0,0 +1,76 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: model_management.py + @date:2023/10/31 15:11 + @desc: +""" +import uuid + +from django.db import models + +from common.mixins.app_model_mixin import AppModelMixin +from users.models import User + + +class Status(models.TextChoices): + """系统设置类型""" + SUCCESS = "SUCCESS", '成功' + + ERROR = "ERROR", "失败" + + DOWNLOAD = "DOWNLOAD", '下载中' + + PAUSE_DOWNLOAD = "PAUSE_DOWNLOAD", '暂停下载' + + +class PermissionType(models.TextChoices): + PUBLIC = "PUBLIC", '公开' + PRIVATE = "PRIVATE", "私有" + +class ModelParam(models.Model): + label = models.CharField(max_length=128, verbose_name="参数") + field = models.CharField(max_length=256, verbose_name="显示名称") + default_value = models.CharField(max_length=1000, verbose_name="默认值") + input_type = models.CharField(max_length=32, verbose_name="组件类型") + attrs = models.JSONField(verbose_name="属性") + required = models.BooleanField(verbose_name="必填") + +class Model(AppModelMixin): + """ + 模型数据 + """ + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + + name = models.CharField(max_length=128, verbose_name="名称") + + status = models.CharField(max_length=20, verbose_name='设置类型', choices=Status.choices, + default=Status.SUCCESS) + + model_type = models.CharField(max_length=128, verbose_name="模型类型") + + model_name = models.CharField(max_length=128, verbose_name="模型名称") + + user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="成员用户id") + + provider = models.CharField(max_length=128, verbose_name='供应商') + + credential = models.CharField(max_length=102400, verbose_name="模型认证信息") + + meta = models.JSONField(verbose_name="模型元数据,用于存储下载,或者错误信息", default=dict) + + permission_type = models.CharField(max_length=20, verbose_name='权限类型', choices=PermissionType.choices, + default=PermissionType.PRIVATE) + + model_params_form = models.JSONField(verbose_name="模型参数配置", default=list) + + + def is_permission(self, user_id): + if self.permission_type == PermissionType.PUBLIC or str(user_id) == str(self.user_id): + return True + return False + + class Meta: + db_table = "model" + unique_together = ['name', 'user_id'] diff --git a/src/MaxKB-1.7.2/apps/setting/models/system_management.py b/src/MaxKB-1.7.2/apps/setting/models/system_management.py new file mode 100644 index 0000000..8dea895 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models/system_management.py @@ -0,0 +1,32 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: system_management.py + @date:2024/3/19 13:47 + @desc: 邮箱管理 +""" + +from django.db import models + +from common.mixins.app_model_mixin import AppModelMixin + + +class SettingType(models.IntegerChoices): + """系统设置类型""" + EMAIL = 0, '邮箱' + + RSA = 1, "私钥秘钥" + + +class SystemSetting(AppModelMixin): + """ + 系统设置 + """ + type = models.IntegerField(primary_key=True, verbose_name='设置类型', choices=SettingType.choices, + default=SettingType.EMAIL) + + meta = models.JSONField(verbose_name="配置数据", default=dict) + + class Meta: + db_table = "system_setting" diff --git a/src/MaxKB-1.7.2/apps/setting/models/team_management.py b/src/MaxKB-1.7.2/apps/setting/models/team_management.py new file mode 100644 index 0000000..3e480d8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models/team_management.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: team_management.py + @date:2023/9/25 15:04 + @desc: +""" +import uuid + +from django.contrib.postgres.fields import ArrayField +from django.db import models + +from common.constants.permission_constants import Group, Operate +from common.mixins.app_model_mixin import AppModelMixin +from users.models import User + + +class AuthTargetType(models.TextChoices): + """授权目标""" + DATASET = Group.DATASET.value, '数据集' + APPLICATION = Group.APPLICATION.value, '应用' + + +class AuthOperate(models.TextChoices): + """授权权限""" + MANAGE = Operate.MANAGE.value, '管理' + + USE = Operate.USE.value, "使用" + + +class Team(AppModelMixin): + """ + 团队表 + """ + user = models.OneToOneField(User, primary_key=True, on_delete=models.DO_NOTHING, verbose_name="团队所有者") + + name = models.CharField(max_length=128, verbose_name="团队名称") + + class Meta: + db_table = "team" + + +class TeamMember(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + team = models.ForeignKey(Team, on_delete=models.DO_NOTHING, verbose_name="团队id") + user = models.ForeignKey(User, on_delete=models.DO_NOTHING, verbose_name="成员用户id") + + class Meta: + db_table = "team_member" + + +class TeamMemberPermission(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + """ + 团队成员权限 + """ + member = models.ForeignKey(TeamMember, on_delete=models.DO_NOTHING, verbose_name="团队成员") + + auth_target_type = models.CharField(verbose_name='授权目标', max_length=128, choices=AuthTargetType.choices, + default=AuthTargetType.DATASET) + + target = models.UUIDField(max_length=128, verbose_name="数据集/应用id") + + operate = ArrayField(verbose_name="权限操作列表", + base_field=models.CharField(max_length=256, + blank=True, + choices=AuthOperate.choices, + default=AuthOperate.USE), + ) + + class Meta: + db_table = "team_member_permission" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/__init__.py new file mode 100644 index 0000000..7f573ec --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/__init__.py @@ -0,0 +1,94 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/10/31 17:16 + @desc: +""" +import json +from typing import Dict + +from common.util.rsa_util import rsa_long_decrypt +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants + + +def get_model_(provider, model_type, model_name, credential, model_id, use_local=False, **kwargs): + """ + 获取模型实例 + @param provider: 供应商 + @param model_type: 模型类型 + @param model_name: 模型名称 + @param credential: 认证信息 + @param model_id: 模型id + @param use_local: 是否调用本地模型 只适用于本地供应商 + @return: 模型实例 + """ + model = get_provider(provider).get_model(model_type, model_name, + json.loads( + rsa_long_decrypt(credential)), + model_id=model_id, + use_local=use_local, + streaming=True, **kwargs) + return model + + +def get_model(model, **kwargs): + """ + 获取模型实例 + @param model: model 数据库Model实例对象 + @return: 模型实例 + """ + return get_model_(model.provider, model.model_type, model.model_name, model.credential, str(model.id), **kwargs) + + +def get_provider(provider): + """ + 获取供应商实例 + @param provider: 供应商字符串 + @return: 供应商实例 + """ + return ModelProvideConstants[provider].value + + +def get_model_list(provider, model_type): + """ + 获取模型列表 + @param provider: 供应商字符串 + @param model_type: 模型类型 + @return: 模型列表 + """ + return get_provider(provider).get_model_list(model_type) + + +def get_model_credential(provider, model_type, model_name): + """ + 获取模型认证实例 + @param provider: 供应商字符串 + @param model_type: 模型类型 + @param model_name: 模型名称 + @return: 认证实例对象 + """ + return get_provider(provider).get_model_credential(model_type, model_name) + + +def get_model_type_list(provider): + """ + 获取模型类型列表 + @param provider: 供应商字符串 + @return: 模型类型列表 + """ + return get_provider(provider).get_model_type_list() + + +def is_valid_credential(provider, model_type, model_name, model_credential: Dict[str, object], raise_exception=False): + """ + 校验模型认证参数 + @param provider: 供应商字符串 + @param model_type: 模型类型 + @param model_name: 模型名称 + @param model_credential: 模型认证数据 + @param raise_exception: 是否抛出错误 + @return: True|False + """ + return get_provider(provider).is_valid_credential(model_type, model_name, model_credential, raise_exception) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/base_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/base_model_provider.py new file mode 100644 index 0000000..c4722c9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/base_model_provider.py @@ -0,0 +1,261 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: base_model_provider.py + @date:2023/10/31 16:19 + @desc: +""" +from abc import ABC, abstractmethod +from enum import Enum +from functools import reduce +from typing import Dict, Iterator, Type, List + +from pydantic.v1 import BaseModel + +from common.exception.app_exception import AppApiException + + +class DownModelChunkStatus(Enum): + success = "success" + error = "error" + pulling = "pulling" + unknown = 'unknown' + + +class ValidCode(Enum): + valid_error = 500 + model_not_fount = 404 + + +class DownModelChunk: + def __init__(self, status: DownModelChunkStatus, digest: str, progress: int, details: str, index: int): + self.details = details + self.status = status + self.digest = digest + self.progress = progress + self.index = index + + def to_dict(self): + return { + "details": self.details, + "status": self.status.value, + "digest": self.digest, + "progress": self.progress, + "index": self.index + } + + +class IModelProvider(ABC): + @abstractmethod + def get_model_info_manage(self): + pass + + @abstractmethod + def get_model_provide_info(self): + pass + + def get_model_type_list(self): + return self.get_model_info_manage().get_model_type_list() + + def get_model_list(self, model_type): + if model_type is None: + raise AppApiException(500, '模型类型不能为空') + return self.get_model_info_manage().get_model_list_by_model_type(model_type) + + def get_model_credential(self, model_type, model_name): + model_info = self.get_model_info_manage().get_model_info(model_type, model_name) + return model_info.model_credential + + def is_valid_credential(self, model_type, model_name, model_credential: Dict[str, object], raise_exception=False): + model_info = self.get_model_info_manage().get_model_info(model_type, model_name) + return model_info.model_credential.is_valid(model_type, model_name, model_credential, self, + raise_exception=raise_exception) + + def get_model(self, model_type, model_name, model_credential: Dict[str, object], **model_kwargs) -> BaseModel: + model_info = self.get_model_info_manage().get_model_info(model_type, model_name) + return model_info.model_class.new_instance(model_type, model_name, model_credential, **model_kwargs) + + def get_dialogue_number(self): + return 3 + + def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]: + raise AppApiException(500, "当前平台不支持下载模型") + + +class MaxKBBaseModel(ABC): + @staticmethod + @abstractmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + pass + + @staticmethod + def is_cache_model(): + return True + + @staticmethod + def filter_optional_params(model_kwargs): + optional_params = {} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params[key] = value + return optional_params + + +class BaseModelCredential(ABC): + + @abstractmethod + def is_valid(self, model_type: str, model_name, model: Dict[str, object], provider, raise_exception=True): + pass + + @abstractmethod + def encryption_dict(self, model_info: Dict[str, object]): + """ + :param model_info: 模型数据 + :return: 加密后数据 + """ + pass + + def get_model_params_setting_form(self, model_name): + """ + 模型参数设置表单 + :return: + """ + pass + + @staticmethod + def encryption(message: str): + """ + 加密敏感字段数据 加密方式是 如果密码是 1234567890 那么给前端则是 123******890 + :param message: + :return: + """ + max_pre_len = 8 + max_post_len = 4 + message_len = len(message) + pre_len = int(message_len / 5 * 2) + post_len = int(message_len / 5 * 1) + pre_str = "".join([message[index] for index in + range(0, max_pre_len if pre_len > max_pre_len else 1 if pre_len <= 0 else int(pre_len))]) + end_str = "".join( + [message[index] for index in + range(message_len - (int(post_len) if pre_len < max_post_len else max_post_len), message_len)]) + content = "***************" + return pre_str + content + end_str + + +class ModelTypeConst(Enum): + LLM = {'code': 'LLM', 'message': '大语言模型'} + EMBEDDING = {'code': 'EMBEDDING', 'message': '向量模型'} + STT = {'code': 'STT', 'message': '语音识别'} + TTS = {'code': 'TTS', 'message': '语音合成'} + RERANKER = {'code': 'RERANKER', 'message': '重排模型'} + + +class ModelInfo: + def __init__(self, name: str, desc: str, model_type: ModelTypeConst, model_credential: BaseModelCredential, + model_class: Type[MaxKBBaseModel], + **keywords): + self.name = name + self.desc = desc + self.model_type = model_type.name + self.model_credential = model_credential + self.model_class = model_class + if keywords is not None: + for key in keywords.keys(): + self.__setattr__(key, keywords.get(key)) + + def get_name(self): + """ + 获取模型名称 + :return: 模型名称 + """ + return self.name + + def get_desc(self): + """ + 获取模型描述 + :return: 模型描述 + """ + return self.desc + + def get_model_type(self): + return self.model_type + + def get_model_class(self): + return self.model_class + + def to_dict(self): + return reduce(lambda x, y: {**x, **y}, + [{attr: self.__getattribute__(attr)} for attr in vars(self) if + not attr.startswith("__") and not attr == 'model_credential' and not attr == 'model_class'], {}) + + +class ModelInfoManage: + def __init__(self): + self.model_dict = {} + self.model_list = [] + self.default_model_list = [] + self.default_model_dict = {} + + def append_model_info(self, model_info: ModelInfo): + self.model_list.append(model_info) + model_type_dict = self.model_dict.get(model_info.model_type) + if model_type_dict is None: + self.model_dict[model_info.model_type] = {model_info.name: model_info} + else: + model_type_dict[model_info.name] = model_info + + def append_default_model_info(self, model_info: ModelInfo): + self.default_model_list.append(model_info) + self.default_model_dict[model_info.model_type] = model_info + + def get_model_list(self): + return [model.to_dict() for model in self.model_list] + + def get_model_list_by_model_type(self, model_type): + return [model.to_dict() for model in self.model_list if model.model_type == model_type] + + def get_model_type_list(self): + return [{'key': _type.value.get('message'), 'value': _type.value.get('code')} for _type in ModelTypeConst if + len([model for model in self.model_list if model.model_type == _type.name]) > 0] + + def get_model_info(self, model_type, model_name) -> ModelInfo: + model_info = self.model_dict.get(model_type, {}).get(model_name, self.default_model_dict.get(model_type)) + if model_info is None: + raise AppApiException(500, '模型不支持') + return model_info + + class builder: + def __init__(self): + self.modelInfoManage = ModelInfoManage() + + def append_model_info(self, model_info: ModelInfo): + self.modelInfoManage.append_model_info(model_info) + return self + + def append_model_info_list(self, model_info_list: List[ModelInfo]): + for model_info in model_info_list: + self.modelInfoManage.append_model_info(model_info) + return self + + def append_default_model_info(self, model_info: ModelInfo): + self.modelInfoManage.append_default_model_info(model_info) + return self + + def build(self): + return self.modelInfoManage + + +class ModelProvideInfo: + def __init__(self, provider: str, name: str, icon: str): + self.provider = provider + + self.name = name + + self.icon = icon + + def to_dict(self): + return reduce(lambda x, y: {**x, **y}, + [{attr: self.__getattribute__(attr)} for attr in vars(self) if + not attr.startswith("__")], {}) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/constants/model_provider_constants.py b/src/MaxKB-1.7.2/apps/setting/models_provider/constants/model_provider_constants.py new file mode 100644 index 0000000..c471cea --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/constants/model_provider_constants.py @@ -0,0 +1,49 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: model_provider_constants.py + @date:2023/11/2 14:55 + @desc: +""" +from enum import Enum + +from setting.models_provider.impl.aliyun_bai_lian_model_provider.aliyun_bai_lian_model_provider import \ + AliyunBaiLianModelProvider +from setting.models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider +from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider +from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider +from setting.models_provider.impl.gemini_model_provider.gemini_model_provider import GeminiModelProvider +from setting.models_provider.impl.kimi_model_provider.kimi_model_provider import KimiModelProvider +from setting.models_provider.impl.ollama_model_provider.ollama_model_provider import OllamaModelProvider +from setting.models_provider.impl.openai_model_provider.openai_model_provider import OpenAIModelProvider +from setting.models_provider.impl.qwen_model_provider.qwen_model_provider import QwenModelProvider +from setting.models_provider.impl.tencent_model_provider.tencent_model_provider import TencentModelProvider +from setting.models_provider.impl.vllm_model_provider.vllm_model_provider import VllmModelProvider +from setting.models_provider.impl.volcanic_engine_model_provider.volcanic_engine_model_provider import \ + VolcanicEngineModelProvider +from setting.models_provider.impl.wenxin_model_provider.wenxin_model_provider import WenxinModelProvider +from setting.models_provider.impl.xf_model_provider.xf_model_provider import XunFeiModelProvider +from setting.models_provider.impl.xinference_model_provider.xinference_model_provider import XinferenceModelProvider +from setting.models_provider.impl.zhipu_model_provider.zhipu_model_provider import ZhiPuModelProvider +from setting.models_provider.impl.local_model_provider.local_model_provider import LocalModelProvider + + +class ModelProvideConstants(Enum): + model_azure_provider = AzureModelProvider() + model_wenxin_provider = WenxinModelProvider() + model_ollama_provider = OllamaModelProvider() + model_openai_provider = OpenAIModelProvider() + model_kimi_provider = KimiModelProvider() + model_qwen_provider = QwenModelProvider() + model_zhipu_provider = ZhiPuModelProvider() + model_xf_provider = XunFeiModelProvider() + model_deepseek_provider = DeepSeekModelProvider() + model_gemini_provider = GeminiModelProvider() + model_volcanic_engine_provider = VolcanicEngineModelProvider() + model_tencent_provider = TencentModelProvider() + model_aws_bedrock_provider = BedrockModelProvider() + model_local_provider = LocalModelProvider() + model_xinference_provider = XinferenceModelProvider() + model_vllm_provider = VllmModelProvider() + aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/__init__.py new file mode 100644 index 0000000..3c10c55 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/9/9 17:42 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py new file mode 100644 index 0000000..f3fd75a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/aliyun_bai_lian_model_provider.py @@ -0,0 +1,60 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: aliyun_bai_lian_model_provider.py + @date:2024/9/9 17:43 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.embedding import \ + AliyunBaiLianEmbeddingCredential +from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.reranker import \ + AliyunBaiLianRerankerCredential +from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.stt import AliyunBaiLianSTTModelCredential +from setting.models_provider.impl.aliyun_bai_lian_model_provider.credential.tts import AliyunBaiLianTTSModelCredential +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.stt import AliyunBaiLianSpeechToText +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.tts import AliyunBaiLianTextToSpeech +from smartdoc.conf import PROJECT_DIR + +aliyun_bai_lian_model_credential = AliyunBaiLianRerankerCredential() +aliyun_bai_lian_tts_model_credential = AliyunBaiLianTTSModelCredential() +aliyun_bai_lian_stt_model_credential = AliyunBaiLianSTTModelCredential() +aliyun_bai_lian_embedding_model_credential = AliyunBaiLianEmbeddingCredential() + +model_info_list = [ModelInfo('gte-rerank', + '阿里巴巴通义实验室开发的GTE-Rerank文本排序系列模型,开发者可以通过LlamaIndex框架进行集成高质量文本检索、排序。', + ModelTypeConst.RERANKER, aliyun_bai_lian_model_credential, AliyunBaiLianReranker), + ModelInfo('paraformer-realtime-v2', + '中文(含粤语等各种方言)、英文、日语、韩语支持多个语种自由切换', + ModelTypeConst.STT, aliyun_bai_lian_stt_model_credential, AliyunBaiLianSpeechToText), + ModelInfo('cosyvoice-v1', + 'CosyVoice基于新一代生成式语音大模型,能根据上下文预测情绪、语调、韵律等,具有更好的拟人效果', + ModelTypeConst.TTS, aliyun_bai_lian_tts_model_credential, AliyunBaiLianTextToSpeech), + ModelInfo('text-embedding-v1', + '通用文本向量,是通义实验室基于LLM底座的多语言文本统一向量模型,面向全球多个主流语种,提供高水准的向量服务,帮助开发者将文本数据快速转换为高质量的向量数据。', + ModelTypeConst.EMBEDDING, aliyun_bai_lian_embedding_model_credential, + AliyunBaiLianEmbedding), + ] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + model_info_list[1]).append_default_model_info(model_info_list[2]).append_default_model_info( + model_info_list[3]).build() + + +class AliyunBaiLianModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='aliyun_bai_lian_model_provider', name='阿里云百炼', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'aliyun_bai_lian_model_provider', + 'icon', + 'aliyun_bai_lian_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py new file mode 100644 index 0000000..7884e51 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/embedding.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/16 17:01 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.embedding import AliyunBaiLianEmbedding + + +class AliyunBaiLianEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['dashscope_api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model: AliyunBaiLianEmbedding = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'dashscope_api_key': super().encryption(model.get('dashscope_api_key', ''))} + + dashscope_api_key = forms.PasswordInputField('API Key', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py new file mode 100644 index 0000000..d8d2f3c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/reranker.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py + @date:2024/9/9 17:51 + @desc: +""" +from typing import Dict + +from langchain_core.documents import Document + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.aliyun_bai_lian_model_provider.model.reranker import AliyunBaiLianReranker + + +class AliyunBaiLianRerankerCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + if not model_type == 'RERANKER': + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['dashscope_api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model: AliyunBaiLianReranker = provider.get_model(model_type, model_name, model_credential) + model.compress_documents([Document(page_content='你好')], '你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'dashscope_api_key': super().encryption(model.get('dashscope_api_key', ''))} + + dashscope_api_key = forms.PasswordInputField('API Key', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py new file mode 100644 index 0000000..5c9290b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/stt.py @@ -0,0 +1,42 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class AliyunBaiLianSTTModelCredential(BaseForm, BaseModelCredential): + api_key = forms.PasswordInputField("API Key", required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py new file mode 100644 index 0000000..640ba7a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/credential/tts.py @@ -0,0 +1,76 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class AliyunBaiLianTTSModelGeneralParams(BaseForm): + voice = forms.SingleSelect( + TooltipLabel('音色', '中文音色可支持中英文混合场景'), + required=True, default_value='longxiaochun', + text_field='value', + value_field='value', + option_list=[ + {'text': '龙小淳', 'value': 'longxiaochun'}, + {'text': '龙小夏', 'value': 'longxiaoxia'}, + {'text': '龙小诚', 'value': 'longxiaocheng'}, + {'text': '龙小白', 'value': 'longxiaobai'}, + {'text': '龙老铁', 'value': 'longlaotie'}, + {'text': '龙书', 'value': 'longshu'}, + {'text': '龙硕', 'value': 'longshuo'}, + {'text': '龙婧', 'value': 'longjing'}, + {'text': '龙妙', 'value': 'longmiao'}, + {'text': '龙悦', 'value': 'longyue'}, + {'text': '龙媛', 'value': 'longyuan'}, + {'text': '龙飞', 'value': 'longfei'}, + {'text': '龙杰力豆', 'value': 'longjielidou'}, + {'text': '龙彤', 'value': 'longtong'}, + {'text': '龙祥', 'value': 'longxiang'}, + {'text': 'Stella', 'value': 'loongstella'}, + {'text': 'Bella', 'value': 'loongbella'}, + ]) + speech_rate = forms.SliderField( + TooltipLabel('语速', '[0.5,2],默认为1,通常保留一位小数即可'), + required=True, default_value=1, + _min=0.5, + _max=2, + _step=0.1, + precision=1) + + +class AliyunBaiLianTTSModelCredential(BaseForm, BaseModelCredential): + api_key = forms.PasswordInputField("API Key", required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return AliyunBaiLianTTSModelGeneralParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py new file mode 100644 index 0000000..e209e77 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/embedding.py @@ -0,0 +1,54 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/16 16:34 + @desc: +""" +from typing import Dict, List + +from langchain_community.embeddings import DashScopeEmbeddings +from langchain_community.embeddings.dashscope import embed_with_retry + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class AliyunBaiLianEmbedding(MaxKBBaseModel, DashScopeEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return AliyunBaiLianEmbedding( + model=model_name, + dashscope_api_key=model_credential.get('dashscope_api_key') + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Call out to DashScope's embedding endpoint for embedding search docs. + + Args: + texts: The list of texts to embed. + chunk_size: The chunk size of embeddings. If None, will use the chunk size + specified by the class. + + Returns: + List of embeddings, one for each text. + """ + embeddings = embed_with_retry( + self, input=texts, text_type="document", model=self.model + ) + embedding_list = [item["embedding"] for item in embeddings] + return embedding_list + + def embed_query(self, text: str) -> List[float]: + """Call out to DashScope's embedding endpoint for embedding query text. + + Args: + text: The text to embed. + + Returns: + Embedding for the text. + """ + embedding = embed_with_retry( + self, input=[text], text_type="document", model=self.model + )[0]["embedding"] + return embedding diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/iat_mp3_16k.mp3 b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/iat_mp3_16k.mp3 new file mode 100644 index 0000000..75e744c Binary files /dev/null and b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/iat_mp3_16k.mp3 differ diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py new file mode 100644 index 0000000..5c9bea4 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/reranker.py @@ -0,0 +1,20 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py.py + @date:2024/9/2 16:42 + @desc: +""" +from typing import Dict + +from langchain_community.document_compressors import DashScopeRerank + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class AliyunBaiLianReranker(MaxKBBaseModel, DashScopeRerank): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return AliyunBaiLianReranker(model=model_name, dashscope_api_key=model_credential.get('dashscope_api_key'), + top_n=model_kwargs.get('top_n', 3)) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py new file mode 100644 index 0000000..89ebd50 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/stt.py @@ -0,0 +1,63 @@ +import os +import tempfile +from typing import Dict + +import dashscope +from dashscope.audio.asr import (Recognition) + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + + +class AliyunBaiLianSpeechToText(MaxKBBaseModel, BaseSpeechToText): + api_key: str + model: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.model = kwargs.get('model') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return AliyunBaiLianSpeechToText( + model=model_name, + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f: + self.speech_to_text(f) + + def speech_to_text(self, audio_file): + dashscope.api_key = self.api_key + recognition = Recognition(model=self.model, + format='mp3', + sample_rate=16000, + callback=None) + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + # 将上传的文件保存到临时文件中 + temp_file.write(audio_file.read()) + # 获取临时文件的路径 + temp_file_path = temp_file.name + + try: + # 识别临时文件 + result = recognition.call(temp_file_path) + text = '' + if result.status_code == 200: + for sentence in result.get_sentence(): + text += sentence['text'] + return text + else: + raise Exception('Error: ', result.message) + finally: + # 删除临时文件 + os.remove(temp_file_path) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py new file mode 100644 index 0000000..1dbee97 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aliyun_bai_lian_model_provider/model/tts.py @@ -0,0 +1,47 @@ +from typing import Dict + +import dashscope +from dashscope.audio.tts_v2 import * + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + + +class AliyunBaiLianTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'voice': 'longxiaochun', 'speech_rate': 1.0}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + + return AliyunBaiLianTextToSpeech( + model=model_name, + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + self.text_to_speech('你好') + + def text_to_speech(self, text): + dashscope.api_key = self.api_key + synthesizer = SpeechSynthesizer(model=self.model, **self.params) + audio = synthesizer.call(text) + if type(audio) == str: + print(audio) + raise Exception(audio) + return audio + + def is_cache_model(self): + return False diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py new file mode 100644 index 0000000..8cb7f45 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py new file mode 100644 index 0000000..ddb5afd --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py @@ -0,0 +1,153 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import os +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ( + IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage +) +from setting.models_provider.impl.aws_bedrock_model_provider.credential.embedding import BedrockEmbeddingCredential +from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential +from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel +from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel +from smartdoc.conf import PROJECT_DIR + + +def _create_model_info(model_name, description, model_type, credential_class, model_class): + return ModelInfo( + name=model_name, + desc=description, + model_type=model_type, + model_credential=credential_class(), + model_class=model_class + ) + + +def _get_aws_bedrock_icon_path(): + return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'aws_bedrock_model_provider', + 'icon', 'bedrock_icon_svg') + + +def _initialize_model_info(): + model_info_list = [ + _create_model_info( + 'anthropic.claude-v2:1', + 'Claude 2 的更新,采用双倍的上下文窗口,并在长文档和 RAG 上下文中提高可靠性、幻觉率和循证准确性。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-v2', + 'Anthropic 功能强大的模型,可处理各种任务,从复杂的对话和创意内容生成到详细的指令服从。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-3-haiku-20240307-v1:0', + 'Claude 3 Haiku 是 Anthropic 最快速、最紧凑的模型,具有近乎即时的响应能力。该模型可以快速回答简单的查询和请求。客户将能够构建模仿人类交互的无缝人工智能体验。 Claude 3 Haiku 可以处理图像和返回文本输出,并且提供 200K 上下文窗口。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-3-sonnet-20240229-v1:0', + 'Anthropic 推出的 Claude 3 Sonnet 模型在智能和速度之间取得理想的平衡,尤其是在处理企业工作负载方面。该模型提供最大的效用,同时价格低于竞争产品,并且其经过精心设计,是大规模部署人工智能的可靠选择。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-3-5-sonnet-20240620-v1:0', + 'Claude 3.5 Sonnet提高了智能的行业标准,在广泛的评估中超越了竞争对手的型号和Claude 3 Opus,具有我们中端型号的速度和成本效益。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'anthropic.claude-instant-v1', + '一种更快速、更实惠但仍然非常强大的模型,它可以处理一系列任务,包括随意对话、文本分析、摘要和文档问题回答。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'amazon.titan-text-premier-v1:0', + 'Titan Text Premier 是 Titan Text 系列中功能强大且先进的型号,旨在为各种企业应用程序提供卓越的性能。凭借其尖端功能,它提供了更高的准确性和出色的结果,使其成为寻求一流文本处理解决方案的组织的绝佳选择。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel + ), + _create_model_info( + 'amazon.titan-text-lite-v1', + 'Amazon Titan Text Lite 是一种轻量级的高效模型,非常适合英语任务的微调,包括摘要和文案写作等,在这种场景下,客户需要更小、更经济高效且高度可定制的模型', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'amazon.titan-text-express-v1', + 'Amazon Titan Text Express 的上下文长度长达 8000 个令牌,因而非常适合各种高级常规语言任务,例如开放式文本生成和对话式聊天,以及检索增强生成(RAG)中的支持。在发布时,该模型针对英语进行了优化,但也支持其他语言。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'mistral.mistral-7b-instruct-v0:2', + '7B 密集型转换器,可快速部署,易于定制。体积虽小,但功能强大,适用于各种用例。支持英语和代码,以及 32k 的上下文窗口。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'mistral.mistral-large-2402-v1:0', + '先进的 Mistral AI 大型语言模型,能够处理任何语言任务,包括复杂的多语言推理、文本理解、转换和代码生成。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'meta.llama3-70b-instruct-v1:0', + '非常适合内容创作、会话式人工智能、语言理解、研发和企业应用', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + _create_model_info( + 'meta.llama3-8b-instruct-v1:0', + '非常适合有限的计算能力和资源、边缘设备和更快的训练时间。', + ModelTypeConst.LLM, + BedrockLLMModelCredential, + BedrockModel), + ] + embedded_model_info_list = [ + _create_model_info( + 'amazon.titan-embed-text-v1', + 'Titan Embed Text 是 Amazon Titan Embed 系列中最大的嵌入模型,可以处理各种文本嵌入任务,如文本分类、文本相似度计算等。', + ModelTypeConst.EMBEDDING, + BedrockEmbeddingCredential, + BedrockEmbeddingModel + ), + ] + + model_info_manage = ModelInfoManage.builder() \ + .append_model_info_list(model_info_list) \ + .append_default_model_info(model_info_list[0]) \ + .append_model_info_list(embedded_model_info_list) \ + .append_default_model_info(embedded_model_info_list[0]) \ + .build() + + return model_info_manage + + +class BedrockModelProvider(IModelProvider): + def __init__(self): + self._model_info_manage = _initialize_model_info() + + def get_model_info_manage(self): + return self._model_info_manage + + def get_model_provide_info(self): + icon_path = _get_aws_bedrock_icon_path() + icon_data = get_file_content(icon_path) + return ModelProvideInfo( + provider='model_aws_bedrock_provider', + name='Amazon Bedrock', + icon=icon_data + ) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py new file mode 100644 index 0000000..520960d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py @@ -0,0 +1,64 @@ +import os +import re +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel + + +class BedrockEmbeddingCredential(BaseForm, BaseModelCredential): + + @staticmethod + def _update_aws_credentials(profile_name, access_key_id, secret_access_key): + credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") + os.makedirs(os.path.dirname(credentials_path), exist_ok=True) + + content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else '' + pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*' + content = re.sub(pattern, '', content, flags=re.DOTALL) + + if not re.search(rf'\[{profile_name}\]', content): + content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n" + + with open(credentials_path, 'w') as file: + file.write(content) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(mt.get('value') == model_type for mt in model_type_list): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return False + + required_keys = ['region_name', 'access_key_id', 'secret_access_key'] + if not all(key in model_credential for key in required_keys): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}') + return False + + try: + self._update_aws_credentials('aws-profile', model_credential['access_key_id'], + model_credential['secret_access_key']) + model_credential['credentials_profile_name'] = 'aws-profile' + model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential) + aa = model.embed_query('你好') + print(aa) + except AppApiException: + raise + except Exception as e: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + return False + + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))} + + region_name = forms.TextInputField('Region Name', required=True) + access_key_id = forms.TextInputField('Access Key ID', required=True) + secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py new file mode 100644 index 0000000..8c4d683 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/llm.py @@ -0,0 +1,84 @@ +import os +import re +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import ValidCode, BaseModelCredential + + +class BedrockLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class BedrockLLMModelCredential(BaseForm, BaseModelCredential): + + @staticmethod + def _update_aws_credentials(profile_name, access_key_id, secret_access_key): + credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") + os.makedirs(os.path.dirname(credentials_path), exist_ok=True) + + content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else '' + pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*' + content = re.sub(pattern, '', content, flags=re.DOTALL) + + if not re.search(rf'\[{profile_name}\]', content): + content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n" + + with open(credentials_path, 'w') as file: + file.write(content) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(mt.get('value') == model_type for mt in model_type_list): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return False + + required_keys = ['region_name', 'access_key_id', 'secret_access_key'] + if not all(key in model_credential for key in required_keys): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}') + return False + + try: + self._update_aws_credentials('aws-profile', model_credential['access_key_id'], + model_credential['secret_access_key']) + model_credential['credentials_profile_name'] = 'aws-profile' + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except AppApiException: + raise + except Exception as e: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + return False + + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))} + + region_name = forms.TextInputField('Region Name', required=True) + access_key_id = forms.TextInputField('Access Key ID', required=True) + secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) + + def get_model_params_setting_form(self, model_name): + return BedrockLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py new file mode 100644 index 0000000..d08f62c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py @@ -0,0 +1,56 @@ +from langchain_community.embeddings import BedrockEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from typing import Dict, List + + +class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings): + def __init__(self, model_id: str, region_name: str, credentials_profile_name: str, + **kwargs): + super().__init__(model_id=model_id, region_name=region_name, + credentials_profile_name=credentials_profile_name, **kwargs) + + @classmethod + def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], + **model_kwargs) -> 'BedrockModel': + return cls( + model_id=model_name, + region_name=model_credential['region_name'], + credentials_profile_name=model_credential['credentials_profile_name'], + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a Bedrock model. + + Args: + texts: The list of texts to embed + + Returns: + List of embeddings, one for each text. + """ + results = [] + for text in texts: + response = self._embedding_func(text) + + if self.normalize: + response = self._normalize_vector(response) + + results.append(response) + + return results + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a Bedrock model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + embedding = self._embedding_func(text) + + if self.normalize: + return self._normalize_vector(embedding) + + return embedding diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py new file mode 100644 index 0000000..950cd2b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/llm.py @@ -0,0 +1,51 @@ +from typing import List, Dict +from langchain_community.chat_models import BedrockChat +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +def get_max_tokens_keyword(model_name): + """ + 根据模型名称返回正确的 max_tokens 关键字。 + + :param model_name: 模型名称字符串 + :return: 对应的 max_tokens 关键字字符串 + """ + maxTokens = ["ai21.j2-ultra-v1", "ai21.j2-mid-v1"] + # max_tokens_to_sample = ["anthropic.claude-v2:1", "anthropic.claude-v2", "anthropic.claude-instant-v1"] + maxTokenCount = ["amazon.titan-text-lite-v1", "amazon.titan-text-express-v1"] + max_new_tokens = [ + "us.meta.llama3-2-1b-instruct-v1:0", "us.meta.llama3-2-3b-instruct-v1:0", "us.meta.llama3-2-11b-instruct-v1:0", + "us.meta.llama3-2-90b-instruct-v1:0"] + if model_name in maxTokens: + return 'maxTokens' + elif model_name in maxTokenCount: + return 'maxTokenCount' + elif model_name in max_new_tokens: + return 'max_new_tokens' + else: + return 'max_tokens' + + +class BedrockModel(MaxKBBaseModel, BedrockChat): + + @staticmethod + def is_cache_model(): + return False + + def __init__(self, model_id: str, region_name: str, credentials_profile_name: str, + streaming: bool = False, **kwargs): + super().__init__(model_id=model_id, region_name=region_name, + credentials_profile_name=credentials_profile_name, streaming=streaming, **kwargs) + + @classmethod + def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], + **model_kwargs) -> 'BedrockModel': + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + + return cls( + model_id=model_name, + region_name=model_credential['region_name'], + credentials_profile_name=model_credential['credentials_profile_name'], + streaming=model_kwargs.pop('streaming', True), + model_kwargs=optional_params + ) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/__init__.py new file mode 100644 index 0000000..53b7001 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/10/31 17:16 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py new file mode 100644 index 0000000..8b95dfe --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/azure_model_provider.py @@ -0,0 +1,36 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: azure_model_provider.py + @date:2023/10/31 16:19 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ + ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.azure_model_provider.credential.llm import AzureLLMModelCredential +from setting.models_provider.impl.azure_model_provider.model.azure_chat_model import AzureChatModel +from smartdoc.conf import PROJECT_DIR + +base_azure_llm_model_credential = AzureLLMModelCredential() + +default_model_info = ModelInfo('Azure OpenAI', '具体的基础模型由部署名决定', ModelTypeConst.LLM, + base_azure_llm_model_credential, AzureChatModel, api_version='2024-02-15-preview' + ) + +model_info_manage = ModelInfoManage.builder().append_default_model_info(default_model_info).append_model_info( + default_model_info).build() + + +class AzureModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_azure_provider', name='Azure OpenAI', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'azure_model_provider', 'icon', + 'azure_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py new file mode 100644 index 0000000..09e51dc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/credential/llm.py @@ -0,0 +1,75 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 17:08 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class AzureLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class AzureLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key', 'deployment_name', 'api_version']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, '校验失败,请检查参数是否正确') + else: + return False + + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_version = forms.TextInputField("API 版本 (api_version)", required=True) + + api_base = forms.TextInputField('API 域名 (azure_endpoint)', required=True) + + api_key = forms.PasswordInputField("API Key (api_key)", required=True) + + deployment_name = forms.TextInputField("部署名 (deployment_name)", required=True) + + def get_model_params_setting_form(self, model_name): + return AzureLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py new file mode 100644 index 0000000..0996c32 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/azure_model_provider/model/azure_chat_model.py @@ -0,0 +1,53 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: azure_chat_model.py + @date:2024/4/28 11:45 + @desc: +""" + +from typing import List, Dict, Optional, Any, Iterator, Type + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, get_buffer_string, BaseMessageChunk, AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk +from langchain_openai import AzureChatOpenAI +from langchain_openai.chat_models.base import _convert_delta_to_message_chunk + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class AzureChatModel(MaxKBBaseModel, AzureChatOpenAI): + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + + return AzureChatModel( + azure_endpoint=model_credential.get('api_base'), + openai_api_version=model_credential.get('api_version', '2024-02-15-preview'), + deployment_name=model_credential.get('deployment_name'), + openai_api_key=model_credential.get('api_key'), + openai_api_type="azure", + **optional_params, + streaming=True, + ) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/base_chat_open_ai.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/base_chat_open_ai.py new file mode 100644 index 0000000..c0594d9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/base_chat_open_ai.py @@ -0,0 +1,71 @@ +# coding=utf-8 + +from typing import List, Dict, Optional, Any, Iterator, Type, cast +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.outputs import ChatGenerationChunk, ChatGeneration +from langchain_core.runnables import RunnableConfig, ensure_config +from langchain_openai import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage + + +class BaseChatOpenAI(ChatOpenAI): + usage_metadata: dict = {} + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.usage_metadata + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + return self.usage_metadata.get('input_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + if self.usage_metadata is None or self.usage_metadata == {}: + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) + return self.get_last_generation_info().get('output_tokens', 0) + + def _stream( + self, *args: Any, stream_usage: Optional[bool] = None, **kwargs: Any + ) -> Iterator[ChatGenerationChunk]: + kwargs["stream"] = True + kwargs["stream_options"] = {"include_usage": True} + for chunk in super()._stream(*args, stream_usage=stream_usage, **kwargs): + if chunk.message.usage_metadata is not None: + self.usage_metadata = chunk.message.usage_metadata + yield chunk + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + config = ensure_config(config) + chat_result = cast( + ChatGeneration, + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + run_id=config.pop("run_id", None), + **kwargs, + ).generations[0][0], + ).message + self.usage_metadata = chat_result.response_metadata[ + 'token_usage'] if 'token_usage' in chat_result.response_metadata else chat_result.usage_metadata + return chat_result diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/base_stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/base_stt.py new file mode 100644 index 0000000..aae72a5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/base_stt.py @@ -0,0 +1,14 @@ +# coding=utf-8 +from abc import abstractmethod + +from pydantic import BaseModel + + +class BaseSpeechToText(BaseModel): + @abstractmethod + def check_auth(self): + pass + + @abstractmethod + def speech_to_text(self, audio_file): + pass diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/base_tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/base_tts.py new file mode 100644 index 0000000..6311f26 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/base_tts.py @@ -0,0 +1,14 @@ +# coding=utf-8 +from abc import abstractmethod + +from pydantic import BaseModel + + +class BaseTextToSpeech(BaseModel): + @abstractmethod + def check_auth(self): + pass + + @abstractmethod + def text_to_speech(self, text): + pass diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/__init__.py new file mode 100644 index 0000000..ee456da --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :__init__.py.py +@Author :Brian Yang +@Date :5/12/24 7:38 AM +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py new file mode 100644 index 0000000..9739b71 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/credential/llm.py @@ -0,0 +1,68 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 17:51 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class DeepSeekLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class DeepSeekLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return DeepSeekLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py new file mode 100644 index 0000000..f60f26f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/deepseek_model_provider.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :deepseek_model_provider.py +@Author :Brian Yang +@Date :5/12/24 7:40 AM +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.deepseek_model_provider.credential.llm import DeepSeekLLMModelCredential +from setting.models_provider.impl.deepseek_model_provider.model.llm import DeepSeekChatModel +from smartdoc.conf import PROJECT_DIR + +deepseek_llm_model_credential = DeepSeekLLMModelCredential() + +deepseek_chat = ModelInfo('deepseek-chat', '擅长通用对话任务,支持 32K 上下文', ModelTypeConst.LLM, + deepseek_llm_model_credential, DeepSeekChatModel + ) + +deepseek_coder = ModelInfo('deepseek-coder', '擅长处理编程任务,支持 16K 上下文', ModelTypeConst.LLM, + deepseek_llm_model_credential, + DeepSeekChatModel) + +model_info_manage = ModelInfoManage.builder().append_model_info(deepseek_chat).append_model_info( + deepseek_coder).append_default_model_info( + deepseek_coder).build() + + +class DeepSeekModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_deepseek_provider', name='DeepSeek', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'deepseek_model_provider', 'icon', + 'deepseek_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py new file mode 100644 index 0000000..ac8dff4 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/deepseek_model_provider/model/llm.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :llm.py +@Author :Brian Yang +@Date :5/12/24 7:44 AM +""" +from typing import List, Dict + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +class DeepSeekChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + + deepseek_chat_open_ai = DeepSeekChatModel( + model=model_name, + openai_api_base='https://api.deepseek.com', + openai_api_key=model_credential.get('api_key'), + **optional_params + ) + return deepseek_chat_open_ai + diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/__init__.py new file mode 100644 index 0000000..43fd3dd --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :__init__.py.py +@Author :Brian Yang +@Date :5/13/24 7:40 AM +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py new file mode 100644 index 0000000..4cacbe1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/credential/llm.py @@ -0,0 +1,69 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 17:57 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class GeminiLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class GeminiLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.invoke([HumanMessage(content='你好')]) + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return GeminiLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py new file mode 100644 index 0000000..b6dd442 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/gemini_model_provider.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :gemini_model_provider.py +@Author :Brian Yang +@Date :5/13/24 7:47 AM +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.gemini_model_provider.credential.llm import GeminiLLMModelCredential +from setting.models_provider.impl.gemini_model_provider.model.llm import GeminiChatModel +from smartdoc.conf import PROJECT_DIR + +gemini_llm_model_credential = GeminiLLMModelCredential() + +gemini_1_pro = ModelInfo('gemini-1.0-pro', '最新的Gemini 1.0 Pro模型,随Google更新而更新', + ModelTypeConst.LLM, + gemini_llm_model_credential, + GeminiChatModel) + +gemini_1_pro_vision = ModelInfo('gemini-1.0-pro-vision', '最新的Gemini 1.0 Pro Vision模型,随Google更新而更新', + ModelTypeConst.LLM, + gemini_llm_model_credential, + GeminiChatModel) + +model_info_manage = ModelInfoManage.builder().append_model_info(gemini_1_pro).append_model_info( + gemini_1_pro_vision).append_default_model_info(gemini_1_pro).build() + + +class GeminiModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_gemini_provider', name='Gemini', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'gemini_model_provider', 'icon', + 'gemini_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py new file mode 100644 index 0000000..68d5e11 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/gemini_model_provider/model/llm.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :llm.py +@Author :Brian Yang +@Date :5/13/24 7:40 AM +""" +from typing import List, Dict, Optional, Sequence, Union, Any, Iterator, cast + +from google.ai.generativelanguage_v1 import GenerateContentResponse +from google.generativeai.responder import ToolDict +from google.generativeai.types import FunctionDeclarationType, SafetySettingDict +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.outputs import ChatGenerationChunk +from langchain_google_genai import ChatGoogleGenerativeAI +from langchain_google_genai._function_utils import _ToolConfigDict +from langchain_google_genai.chat_models import _chat_with_retry, _response_to_result +from google.generativeai.types import Tool as GoogleTool +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class GeminiChatModel(MaxKBBaseModel, ChatGoogleGenerativeAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + + gemini_chat = GeminiChatModel( + model=model_name, + google_api_key=model_credential.get('api_key'), + **optional_params + ) + return gemini_chat + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.__dict__.get('_last_generation_info') + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return self.get_last_generation_info().get('input_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + return self.get_last_generation_info().get('output_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + *, + tools: Optional[Sequence[Union[ToolDict, GoogleTool]]] = None, + functions: Optional[Sequence[FunctionDeclarationType]] = None, + safety_settings: Optional[SafetySettingDict] = None, + tool_config: Optional[Union[Dict, _ToolConfigDict]] = None, + generation_config: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + request = self._prepare_request( + messages, + stop=stop, + tools=tools, + functions=functions, + safety_settings=safety_settings, + tool_config=tool_config, + generation_config=generation_config, + ) + response: GenerateContentResponse = _chat_with_retry( + request=request, + generation_method=self.client.stream_generate_content, + **kwargs, + metadata=self.default_metadata, + ) + for chunk in response: + _chat_result = _response_to_result(chunk, stream=True) + gen = cast(ChatGenerationChunk, _chat_result.generations[0]) + if gen.message: + token_usage = gen.message.usage_metadata + self.__dict__.setdefault('_last_generation_info', {}).update(token_usage) + if run_manager: + run_manager.on_llm_new_token(gen.text) + yield gen diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/__init__.py new file mode 100644 index 0000000..53b7001 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/10/31 17:16 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py new file mode 100644 index 0000000..a6d06a8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/credential/llm.py @@ -0,0 +1,69 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:06 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class KimiLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.3, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class KimiLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return KimiLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/kimi_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/kimi_model_provider.py new file mode 100644 index 0000000..1347df4 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/kimi_model_provider.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: kimi_model_provider.py + @date:2024/3/28 16:26 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ + ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.kimi_model_provider.credential.llm import KimiLLMModelCredential +from setting.models_provider.impl.kimi_model_provider.model.llm import KimiChatModel +from smartdoc.conf import PROJECT_DIR + +kimi_llm_model_credential = KimiLLMModelCredential() + +moonshot_v1_8k = ModelInfo('moonshot-v1-8k', '', ModelTypeConst.LLM, kimi_llm_model_credential, + KimiChatModel) +moonshot_v1_32k = ModelInfo('moonshot-v1-32k', '', ModelTypeConst.LLM, kimi_llm_model_credential, + KimiChatModel) +moonshot_v1_128k = ModelInfo('moonshot-v1-128k', '', ModelTypeConst.LLM, kimi_llm_model_credential, + KimiChatModel) + +model_info_manage = ModelInfoManage.builder().append_model_info(moonshot_v1_8k).append_model_info( + moonshot_v1_32k).append_default_model_info(moonshot_v1_128k).append_default_model_info(moonshot_v1_8k).build() + + +class KimiModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_dialogue_number(self): + return 3 + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_kimi_provider', name='Kimi', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'kimi_model_provider', 'icon', + 'kimi_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py new file mode 100644 index 0000000..c5f7b62 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/kimi_model_provider/model/llm.py @@ -0,0 +1,31 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2023/11/10 17:45 + @desc: +""" +from typing import List, Dict + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +class KimiChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + + kimi_chat_open_ai = KimiChatModel( + openai_api_base=model_credential['api_base'], + openai_api_key=model_credential['api_key'], + model_name=model_name, + **optional_params + ) + return kimi_chat_open_ai diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/__init__.py new file mode 100644 index 0000000..90a8d72 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: __init__.py + @date:2024/7/10 17:48 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py new file mode 100644 index 0000000..a631196 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/credential/embedding.py @@ -0,0 +1,45 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/11 11:06 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding + + +class LocalEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + if not model_type == 'EMBEDDING': + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['cache_folder']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_folder = forms.TextInputField('模型目录', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py new file mode 100644 index 0000000..0048fce --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/credential/reranker.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py + @date:2024/9/3 14:33 + @desc: +""" +from typing import Dict + +from langchain_core.documents import Document + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.reranker import LocalBaseReranker + + +class LocalRerankerCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + if not model_type == 'RERANKER': + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['cache_dir']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model: LocalBaseReranker = provider.get_model(model_type, model_name, model_credential) + model.compress_documents([Document(page_content='你好')], '你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return model + + cache_dir = forms.TextInputField('模型目录', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py new file mode 100644 index 0000000..2c92bbb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/local_model_provider.py @@ -0,0 +1,44 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: zhipu_model_provider.py + @date:2024/04/19 13:5 + @desc: +""" +import os +from typing import Dict + +from pydantic import BaseModel + +from common.exception.app_exception import AppApiException +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.local_model_provider.credential.embedding import LocalEmbeddingCredential +from setting.models_provider.impl.local_model_provider.credential.reranker import LocalRerankerCredential +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding +from setting.models_provider.impl.local_model_provider.model.reranker import LocalReranker +from smartdoc.conf import PROJECT_DIR + +embedding_text2vec_base_chinese = ModelInfo('shibing624/text2vec-base-chinese', '', ModelTypeConst.EMBEDDING, + LocalEmbeddingCredential(), LocalEmbedding) +bge_reranker_v2_m3 = ModelInfo('BAAI/bge-reranker-v2-m3', '', ModelTypeConst.RERANKER, + LocalRerankerCredential(), LocalReranker) + +model_info_manage = (ModelInfoManage.builder().append_model_info(embedding_text2vec_base_chinese) + .append_default_model_info(embedding_text2vec_base_chinese) + .append_model_info(bge_reranker_v2_m3) + .append_default_model_info(bge_reranker_v2_m3) + .build()) + + +class LocalModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_local_provider', name='本地模型', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'local_model_provider', 'icon', + 'local_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/model/embedding.py new file mode 100644 index 0000000..820b93e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/model/embedding.py @@ -0,0 +1,62 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/11 14:06 + @desc: +""" +from typing import Dict, List + +import requests +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel +from langchain_huggingface import HuggingFaceEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from smartdoc.const import CONFIG + + +class WebLocalEmbedding(MaxKBBaseModel, BaseModel, Embeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + pass + + model_id: str = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_id = kwargs.get('model_id', None) + + def embed_query(self, text: str) -> List[float]: + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/embed_query', + {'text': text}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + res = requests.post(f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/embed_documents', + {'texts': texts}) + result = res.json() + if result.get('code', 500) == 200: + return result.get('data') + raise Exception(result.get('message')) + + +class LocalEmbedding(MaxKBBaseModel, HuggingFaceEmbeddings): + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + if model_kwargs.get('use_local', True): + return LocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True} + ) + return WebLocalEmbedding(model_name=model_name, cache_folder=model_credential.get('cache_folder'), + model_kwargs={'device': model_credential.get('device')}, + encode_kwargs={'normalize_embeddings': True}, + **model_kwargs) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/model/reranker.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/model/reranker.py new file mode 100644 index 0000000..f5056b2 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/local_model_provider/model/reranker.py @@ -0,0 +1,101 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py.py + @date:2024/9/2 16:42 + @desc: +""" +from typing import Sequence, Optional, Dict, Any + +import requests +import torch +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from transformers import AutoModelForSequenceClassification, AutoTokenizer + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from smartdoc.const import CONFIG + + +class LocalReranker(MaxKBBaseModel): + def __init__(self, model_name, top_n=3, cache_dir=None): + super().__init__() + self.model_name = model_name + self.cache_dir = cache_dir + self.top_n = top_n + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + if model_kwargs.get('use_local', True): + return LocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'), + model_kwargs={'device': model_credential.get('device', 'cpu')} + + ) + return WebLocalBaseReranker(model_name=model_name, cache_dir=model_credential.get('cache_dir'), + model_kwargs={'device': model_credential.get('device')}, + **model_kwargs) + + +class WebLocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + pass + + model_id: str = None + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.model_id = kwargs.get('model_id', None) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + res = requests.post( + f'{CONFIG.get("LOCAL_MODEL_PROTOCOL")}://{bind}/api/model/{self.model_id}/compress_documents', + json={'documents': [{'page_content': document.page_content, 'metadata': document.metadata} for document in + documents], 'query': query}, headers={'Content-Type': 'application/json'}) + result = res.json() + if result.get('code', 500) == 200: + return [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document + in result.get('data')] + raise Exception(result.get('message')) + + +class LocalBaseReranker(MaxKBBaseModel, BaseDocumentCompressor): + client: Any = None + tokenizer: Any = None + model: Optional[str] = None + cache_dir: Optional[str] = None + model_kwargs = {} + + def __init__(self, model_name, cache_dir=None, **model_kwargs): + super().__init__() + self.model = model_name + self.cache_dir = cache_dir + self.model_kwargs = model_kwargs + self.client = AutoModelForSequenceClassification.from_pretrained(self.model, cache_dir=self.cache_dir) + self.tokenizer = AutoTokenizer.from_pretrained(self.model, cache_dir=self.cache_dir) + self.client = self.client.to(self.model_kwargs.get('device', 'cpu')) + self.client.eval() + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return LocalBaseReranker(model_name, cache_dir=model_credential.get('cache_dir'), **model_kwargs) + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + with torch.no_grad(): + inputs = self.tokenizer([[query, document.page_content] for document in documents], padding=True, + truncation=True, return_tensors='pt', max_length=512) + scores = [torch.sigmoid(s).float().item() for s in + self.client(**inputs, return_dict=True).logits.view(-1, ).float()] + result = [Document(page_content=documents[index].page_content, metadata={'relevance_score': scores[index]}) + for index + in range(len(documents))] + result.sort(key=lambda row: row.metadata.get('relevance_score'), reverse=True) + return result diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/__init__.py new file mode 100644 index 0000000..6da6cdb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/3/5 17:20 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py new file mode 100644 index 0000000..e0eeabe --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/credential/embedding.py @@ -0,0 +1,45 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 15:10 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding + + +class OllamaEmbeddingModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base')) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if + model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name] + if len(exist) == 0: + raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型") + model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return model_info + + def build_model(self, model_info: Dict[str, object]): + for key in ['model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + return self + + api_base = forms.TextInputField('API 域名', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py new file mode 100644 index 0000000..33f6d8c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/credential/llm.py @@ -0,0 +1,64 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:19 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OllamaLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.3, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class OllamaLLMModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base')) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = [model for model in (model_list.get('models') if model_list.get('models') is not None else []) if + model.get('model') == model_name or model.get('model').replace(":latest", "") == model_name] + if len(exist) == 0: + raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型") + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} + + def build_model(self, model_info: Dict[str, object]): + for key in ['api_key', 'model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + self.api_key = model_info.get('api_key') + return self + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return OllamaLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py new file mode 100644 index 0000000..d1a68eb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/model/embedding.py @@ -0,0 +1,48 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 15:02 + @desc: +""" +from typing import Dict, List + +from langchain_community.embeddings import OllamaEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class OllamaEmbedding(MaxKBBaseModel, OllamaEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return OllamaEmbedding( + model=model_name, + base_url=model_credential.get('api_base'), + ) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed documents using an Ollama deployed embedding model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + instruction_pairs = [f"{text}" for text in texts] + embeddings = self._embed(instruction_pairs) + return embeddings + + def embed_query(self, text: str) -> List[float]: + """Embed a query using a Ollama deployed embedding model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + instruction_pair = f"{text}" + embedding = self._embed([instruction_pair])[0] + return embedding diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py new file mode 100644 index 0000000..7c98f7e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/model/llm.py @@ -0,0 +1,49 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2024/3/6 11:48 + @desc: +""" +from typing import List, Dict +from urllib.parse import urlparse, ParseResult + +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +class OllamaChatModel(MaxKBBaseModel, ChatOpenAI): + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + api_base = model_credential.get('api_base', '') + base_url = get_base_url(api_base) + base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + + return OllamaChatModel(model=model_name, openai_api_base=base_url, + openai_api_key=model_credential.get('api_key'), + stream_usage=True, **optional_params) + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py new file mode 100644 index 0000000..a690830 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/ollama_model_provider/ollama_model_provider.py @@ -0,0 +1,226 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: ollama_model_provider.py + @date:2024/3/5 17:23 + @desc: +""" +import json +import os +from typing import Dict, Iterator +from urllib.parse import urlparse, ParseResult + +import requests +from langchain.chat_models.base import BaseChatModel + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + BaseModelCredential, DownModelChunk, DownModelChunkStatus, ValidCode, ModelInfoManage +from setting.models_provider.impl.ollama_model_provider.credential.embedding import OllamaEmbeddingModelCredential +from setting.models_provider.impl.ollama_model_provider.credential.llm import OllamaLLMModelCredential +from setting.models_provider.impl.ollama_model_provider.model.embedding import OllamaEmbedding +from setting.models_provider.impl.ollama_model_provider.model.llm import OllamaChatModel +from smartdoc.conf import PROJECT_DIR + +"" + +ollama_llm_model_credential = OllamaLLMModelCredential() +model_info_list = [ + ModelInfo( + 'llama2', + 'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 7B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'llama2:13b', + 'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 13B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'llama2:70b', + 'Llama 2 是一组经过预训练和微调的生成文本模型,其规模从 70 亿到 700 亿个不等。这是 70B 预训练模型的存储库。其他模型的链接可以在底部的索引中找到。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'llama2-chinese:13b', + '由于Llama2本身的中文对齐较弱,我们采用中文指令集,对meta-llama/Llama-2-13b-chat-hf进行LoRA微调,使其具备较强的中文对话能力。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'llama3:8b', + 'Meta Llama 3:迄今为止最有能力的公开产品LLM。80亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'llama3:70b', + 'Meta Llama 3:迄今为止最有能力的公开产品LLM。700亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen:0.5b', + 'qwen 1.5 0.5b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。5亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen:1.8b', + 'qwen 1.5 1.8b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。18亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen:4b', + 'qwen 1.5 4b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。40亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + + ModelInfo( + 'qwen:7b', + 'qwen 1.5 7b 相较于以往版本,模型与人类偏好的对齐程度以及多语1言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。70亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen:14b', + 'qwen 1.5 14b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。140亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen:32b', + 'qwen 1.5 32b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。320亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen:72b', + 'qwen 1.5 72b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。720亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen:110b', + 'qwen 1.5 110b 相较于以往版本,模型与人类偏好的对齐程度以及多语言处理能力上有显著增强。所有规模的模型都支持32768个tokens的上下文长度。1100亿参数。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2-72b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2-57b-a14b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2-7b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2.5-72b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2.5-32b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2.5-14b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2.5-7b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2.5-1.5b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2.5-0.5b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'qwen2.5-3b-instruct', + '', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), + ModelInfo( + 'phi3', + 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel), +] +ollama_embedding_model_credential = OllamaEmbeddingModelCredential() +embedding_model_info = [ + ModelInfo( + 'nomic-embed-text', + '一个具有大令牌上下文窗口的高性能开放嵌入模型。', + ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), +] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_model_info_list( + embedding_model_info).append_default_model_info( + ModelInfo( + 'phi3', + 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', + ModelTypeConst.LLM, ollama_llm_model_credential, OllamaChatModel)).append_default_model_info(ModelInfo( + 'nomic-embed-text', + '一个具有大令牌上下文窗口的高性能开放嵌入模型。', + ModelTypeConst.EMBEDDING, ollama_embedding_model_credential, OllamaEmbedding), ).build() + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +def convert_to_down_model_chunk(row_str: str, chunk_index: int): + row = json.loads(row_str) + status = DownModelChunkStatus.unknown + digest = "" + progress = 100 + if 'status' in row: + digest = row.get('status') + if row.get('status') == 'success': + status = DownModelChunkStatus.success + if row.get('status').__contains__("pulling"): + progress = 0 + status = DownModelChunkStatus.pulling + if 'total' in row and 'completed' in row: + progress = (row.get('completed') / row.get('total') * 100) + elif 'error' in row: + status = DownModelChunkStatus.error + digest = row.get('error') + return DownModelChunk(status=status, digest=digest, progress=progress, details=row_str, index=chunk_index) + + +def convert(response_stream) -> Iterator[DownModelChunk]: + temp = "" + index = 0 + for c in response_stream: + index += 1 + row_content = c.decode() + temp += row_content + if row_content.endswith('}') or row_content.endswith('\n'): + rows = [t for t in temp.split("\n") if len(t) > 0] + for row in rows: + yield convert_to_down_model_chunk(row, index) + temp = "" + + if len(temp) > 0: + rows = [t for t in temp.split("\n") if len(t) > 0] + for row in rows: + yield convert_to_down_model_chunk(row, index) + + +class OllamaModelProvider(IModelProvider): + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_ollama_provider', name='Ollama', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'ollama_model_provider', 'icon', + 'ollama_icon_svg'))) + + @staticmethod + def get_base_model_list(api_base): + base_url = get_base_url(api_base) + r = requests.request(method="GET", url=f"{base_url}/api/tags", timeout=5) + r.raise_for_status() + return r.json() + + def down_model(self, model_type: str, model_name, model_credential: Dict[str, object]) -> Iterator[DownModelChunk]: + api_base = model_credential.get('api_base', '') + base_url = get_base_url(api_base) + r = requests.request( + method="POST", + url=f"{base_url}/api/pull", + data=json.dumps({"name": model_name}).encode(), + stream=True, + ) + return convert(r) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/__init__.py new file mode 100644 index 0000000..2dc4ab1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/3/28 16:25 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py new file mode 100644 index 0000000..d49d22e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/embedding.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 16:45 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py new file mode 100644 index 0000000..755f955 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/llm.py @@ -0,0 +1,69 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:32 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAILLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class OpenAILLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return OpenAILLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py new file mode 100644 index 0000000..5950631 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/stt.py @@ -0,0 +1,42 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAISTTModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py new file mode 100644 index 0000000..96d0013 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/credential/tts.py @@ -0,0 +1,58 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + +class OpenAITTSModelGeneralParams(BaseForm): + # alloy, echo, fable, onyx, nova, shimmer + voice = forms.SingleSelect( + TooltipLabel('Voice', '尝试不同的声音(合金、回声、寓言、缟玛瑙、新星和闪光),找到一种适合您所需的音调和听众的声音。当前的语音针对英语进行了优化。'), + required=True, default_value='alloy', + text_field='value', + value_field='value', + option_list=[ + {'text': 'alloy', 'value': 'alloy'}, + {'text': 'echo', 'value': 'echo'}, + {'text': 'fable', 'value': 'fable'}, + {'text': 'onyx', 'value': 'onyx'}, + {'text': 'nova', 'value': 'nova'}, + {'text': 'shimmer', 'value': 'shimmer'}, + ]) + + +class OpenAITTSModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return OpenAITTSModelGeneralParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py new file mode 100644 index 0000000..5ac1f8e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/embedding.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 17:44 + @desc: +""" +from typing import Dict + +from langchain_community.embeddings import OpenAIEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class OpenAIEmbeddingModel(MaxKBBaseModel, OpenAIEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return OpenAIEmbeddingModel( + api_key=model_credential.get('api_key'), + model=model_name, + openai_api_base=model_credential.get('api_base'), + ) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/llm.py new file mode 100644 index 0000000..c5b5694 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/llm.py @@ -0,0 +1,55 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2024/4/18 15:28 + @desc: +""" +from typing import List, Dict, Optional, Any + +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.runnables import RunnableConfig +from langchain_openai.chat_models import ChatOpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class OpenAIChatModel(MaxKBBaseModel, ChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + azure_chat_open_ai = OpenAIChatModel( + model=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + **optional_params, + custom_get_token_ids=custom_get_token_ids + ) + return azure_chat_open_ai + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + try: + return super().get_num_tokens_from_messages(messages) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages]) + + def get_num_tokens(self, text: str) -> int: + try: + return super().get_num_tokens(text) + except Exception as e: + tokenizer = TokenizerManage.get_tokenizer() + return len(tokenizer.encode(text)) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/stt.py new file mode 100644 index 0000000..0b5f9a4 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/stt.py @@ -0,0 +1,59 @@ +import asyncio +import io +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class OpenAISpeechToText(MaxKBBaseModel, BaseSpeechToText): + api_base: str + api_key: str + model: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return OpenAISpeechToText( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + response_list = client.models.with_raw_response.list() + # print(response_list) + + def speech_to_text(self, audio_file): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + audio_data = audio_file.read() + buffer = io.BytesIO(audio_data) + buffer.name = "file.mp3" # this is the important line + res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + return res.text + diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/tts.py new file mode 100644 index 0000000..6e9aa2c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/model/tts.py @@ -0,0 +1,62 @@ +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class OpenAITextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + api_base: str + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'voice': 'alloy'}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return OpenAITextToSpeech( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + response_list = client.models.with_raw_response.list() + # print(response_list) + + def text_to_speech(self, text): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + with client.audio.speech.with_streaming_response.create( + model=self.model, + input=text, + **self.params + ) as response: + return response.read() + + def is_cache_model(self): + return False diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py new file mode 100644 index 0000000..f922138 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/openai_model_provider/openai_model_provider.py @@ -0,0 +1,106 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: openai_model_provider.py + @date:2024/3/28 16:26 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \ + ModelTypeConst, ModelInfoManage +from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential +from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.openai_model_provider.credential.stt import OpenAISTTModelCredential +from setting.models_provider.impl.openai_model_provider.credential.tts import OpenAITTSModelCredential +from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel +from setting.models_provider.impl.openai_model_provider.model.llm import OpenAIChatModel +from setting.models_provider.impl.openai_model_provider.model.stt import OpenAISpeechToText +from setting.models_provider.impl.openai_model_provider.model.tts import OpenAITextToSpeech +from smartdoc.conf import PROJECT_DIR + +openai_llm_model_credential = OpenAILLMModelCredential() +openai_stt_model_credential = OpenAISTTModelCredential() +openai_tts_model_credential = OpenAITTSModelCredential() +model_info_list = [ + ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, + openai_llm_model_credential, OpenAIChatModel + ), + ModelInfo('gpt-4', '最新的gpt-4,随OpenAI调整而更新', ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4o', '最新的GPT-4o,比gpt-4-turbo更便宜、更快,随OpenAI调整而更新', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4o-mini', '最新的gpt-4o-mini,比gpt-4o更便宜、更快,随OpenAI调整而更新', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-turbo', '最新的gpt-4-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, + openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-turbo-preview', '最新的gpt-4-turbo-preview,随OpenAI调整而更新', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-3.5-turbo-0125', + '2024年1月25日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM, + openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-3.5-turbo-1106', + '2023年11月6日的gpt-3.5-turbo快照,支持上下文长度16,385 tokens', ModelTypeConst.LLM, + openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-3.5-turbo-0613', + '[Legacy] 2023年6月13日的gpt-3.5-turbo快照,将于2024年6月13日弃用', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4o-2024-05-13', + '2024年5月13日的gpt-4o快照,支持上下文长度128,000 tokens', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-turbo-2024-04-09', + '2024年4月9日的gpt-4-turbo快照,支持上下文长度128,000 tokens', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-0125-preview', '2024年1月25日的gpt-4-turbo快照,支持上下文长度128,000 tokens', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('gpt-4-1106-preview', '2023年11月6日的gpt-4-turbo快照,支持上下文长度128,000 tokens', + ModelTypeConst.LLM, openai_llm_model_credential, + OpenAIChatModel), + ModelInfo('whisper-1', '', + ModelTypeConst.STT, openai_stt_model_credential, + OpenAISpeechToText), + ModelInfo('tts-1', '', + ModelTypeConst.TTS, openai_tts_model_credential, + OpenAITextToSpeech) +] +open_ai_embedding_credential = OpenAIEmbeddingCredential() +model_info_embedding_list = [ + ModelInfo('text-embedding-ada-002', '', + ModelTypeConst.EMBEDDING, open_ai_embedding_credential, + OpenAIEmbeddingModel), + ModelInfo('text-embedding-3-small', '', + ModelTypeConst.EMBEDDING, open_ai_embedding_credential, + OpenAIEmbeddingModel), + ModelInfo('text-embedding-3-large', '', + ModelTypeConst.EMBEDDING, open_ai_embedding_credential, + OpenAIEmbeddingModel) +] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo('gpt-3.5-turbo', '最新的gpt-3.5-turbo,随OpenAI调整而更新', ModelTypeConst.LLM, + openai_llm_model_credential, OpenAIChatModel + )).append_model_info_list(model_info_embedding_list).append_default_model_info( + model_info_embedding_list[0]).build() + + +class OpenAIModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_openai_provider', name='OpenAI', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'openai_model_provider', 'icon', + 'openai_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/__init__.py new file mode 100644 index 0000000..53b7001 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/10/31 17:16 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py new file mode 100644 index 0000000..b9bb45e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/credential/llm.py @@ -0,0 +1,67 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 18:41 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class QwenModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=1.0, + _min=0.1, + _max=1.9, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class OpenAILLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return QwenModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py new file mode 100644 index 0000000..1336cb0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/model/llm.py @@ -0,0 +1,110 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2024/4/28 11:44 + @desc: +""" +from typing import List, Dict, Optional, Iterator, Any, cast + +from langchain_community.chat_models import ChatTongyi +from langchain_community.llms.tongyi import generate_with_last_element_mark +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import BaseMessage, get_buffer_string +from langchain_core.outputs import ChatGenerationChunk, ChatGeneration +from langchain_core.runnables import RunnableConfig, ensure_config + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class QwenChatModel(MaxKBBaseModel, ChatTongyi): + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + chat_tong_yi = QwenChatModel( + model_name=model_name, + dashscope_api_key=model_credential.get('api_key'), + model_kwargs=optional_params, + ) + return chat_tong_yi + + usage_metadata: dict = {} + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.usage_metadata + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return self.usage_metadata.get('input_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + return self.usage_metadata.get('output_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + params: Dict[str, Any] = self._invocation_params( + messages=messages, stop=stop, stream=True, **kwargs + ) + + for stream_resp, is_last_chunk in generate_with_last_element_mark( + self.stream_completion_with_retry(**params) + ): + choice = stream_resp["output"]["choices"][0] + message = choice["message"] + if ( + choice["finish_reason"] == "stop" + and message["content"] == "" + ) or (choice["finish_reason"] == "length"): + token_usage = stream_resp["usage"] + self.usage_metadata = token_usage + if ( + choice["finish_reason"] == "null" + and message["content"] == "" + and "tool_calls" not in message + ): + continue + + chunk = ChatGenerationChunk( + **self._chat_generation_from_qwen_resp( + stream_resp, is_chunk=True, is_last_chunk=is_last_chunk + ) + ) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk + + def invoke( + self, + input: LanguageModelInput, + config: Optional[RunnableConfig] = None, + *, + stop: Optional[List[str]] = None, + **kwargs: Any, + ) -> BaseMessage: + config = ensure_config(config) + chat_result = cast( + ChatGeneration, + self.generate_prompt( + [self._convert_input(input)], + stop=stop, + callbacks=config.get("callbacks"), + tags=config.get("tags"), + metadata=config.get("metadata"), + run_name=config.get("run_name"), + run_id=config.pop("run_id", None), + **kwargs, + ).generations[0][0], + ).message + self.usage_metadata = chat_result.response_metadata['token_usage'] + return chat_result diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py new file mode 100644 index 0000000..dd0a924 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/qwen_model_provider/qwen_model_provider.py @@ -0,0 +1,39 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: qwen_model_provider.py + @date:2023/10/31 16:19 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.qwen_model_provider.credential.llm import OpenAILLMModelCredential + +from setting.models_provider.impl.qwen_model_provider.model.llm import QwenChatModel +from smartdoc.conf import PROJECT_DIR + +qwen_model_credential = OpenAILLMModelCredential() + +module_info_list = [ + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel), + ModelInfo('qwen-plus', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel), + ModelInfo('qwen-max', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel) +] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(module_info_list).append_default_model_info( + ModelInfo('qwen-turbo', '', ModelTypeConst.LLM, qwen_model_credential, QwenChatModel)).build() + + +class QwenModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_qwen_provider', name='通义千问', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'qwen_model_provider', 'icon', + 'qwen_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/__init__.py new file mode 100644 index 0000000..8cb7f45 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py new file mode 100644 index 0000000..a0b0064 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/credential/embedding.py @@ -0,0 +1,34 @@ +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class TencentEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True) -> bool: + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + self.valid_form(model_credential) + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: + encrypted_secret_key = super().encryption(model.get('SecretKey', '')) + return {**model, 'SecretKey': encrypted_secret_key} + + SecretId = forms.PasswordInputField('SecretId', required=True) + SecretKey = forms.PasswordInputField('SecretKey', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py new file mode 100644 index 0000000..20b1bf8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/credential/llm.py @@ -0,0 +1,60 @@ +# coding=utf-8 +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class TencentLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.5, + _min=0.1, + _max=2.0, + _step=0.01, + precision=2) + + +class TencentLLMModelCredential(BaseForm, BaseModelCredential): + REQUIRED_FIELDS = ['hunyuan_app_id', 'hunyuan_secret_id', 'hunyuan_secret_key'] + + @classmethod + def _validate_model_type(cls, model_type, provider, raise_exception=False): + if not any(mt['value'] == model_type for mt in provider.get_model_type_list()): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return False + return True + + @classmethod + def _validate_credential_fields(cls, model_credential, raise_exception=False): + missing_keys = [key for key in cls.REQUIRED_FIELDS if key not in model_credential] + if missing_keys: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{", ".join(missing_keys)} 字段为必填字段') + return False + return True + + def is_valid(self, model_type, model_name, model_credential, provider, raise_exception=False): + if not (self._validate_model_type(model_type, provider, raise_exception) and + self._validate_credential_fields(model_credential, raise_exception)): + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + return False + return True + + def encryption_dict(self, model): + return {**model, 'hunyuan_secret_key': super().encryption(model.get('hunyuan_secret_key', ''))} + + hunyuan_app_id = forms.TextInputField('APP ID', required=True) + hunyuan_secret_id = forms.PasswordInputField('SecretId', required=True) + hunyuan_secret_key = forms.PasswordInputField('SecretKey', required=True) + + def get_model_params_setting_form(self, model_name): + return TencentLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py new file mode 100644 index 0000000..659a5ac --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/model/embedding.py @@ -0,0 +1,41 @@ + +from typing import Dict, List + +from langchain_core.embeddings import Embeddings +from tencentcloud.common import credential +from tencentcloud.hunyuan.v20230901.hunyuan_client import HunyuanClient +from tencentcloud.hunyuan.v20230901.models import GetEmbeddingRequest + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class TencentEmbeddingModel(MaxKBBaseModel, Embeddings): + def embed_documents(self, texts: List[str]) -> List[List[float]]: + return [self.embed_query(text) for text in texts] + + def embed_query(self, text: str) -> List[float]: + request = GetEmbeddingRequest() + request.Input = text + res = self.client.GetEmbedding(request) + return res.Data[0].Embedding + + def __init__(self, secret_id: str, secret_key: str, model_name: str): + self.secret_id = secret_id + self.secret_key = secret_key + self.model_name = model_name + cred = credential.Credential( + secret_id, secret_key + ) + self.client = HunyuanClient(cred, "") + + @staticmethod + def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs): + return TencentEmbeddingModel( + secret_id=model_credential.get('SecretId'), + secret_key=model_credential.get('SecretKey'), + model_name=model_name, + ) + + def _generate_auth_token(self): + # Example method to generate an authentication token for the model API + return f"{self.secret_id}:{self.secret_key}" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py new file mode 100644 index 0000000..7e913f9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/model/hunyuan.py @@ -0,0 +1,280 @@ +import json +import logging +from typing import Any, Dict, Iterator, List, Mapping, Optional, Type + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.chat_models import ( + BaseChatModel, + generate_from_stream, +) +from langchain_core.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + BaseMessageChunk, + ChatMessage, + ChatMessageChunk, + HumanMessage, + HumanMessageChunk, SystemMessage, +) +from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import ( + convert_to_secret_str, + get_from_dict_or_env, + get_pydantic_field_names, + pre_init, +) + +logger = logging.getLogger(__name__) + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + message_dict: Dict[str, Any] + if isinstance(message, ChatMessage): + message_dict = {"Role": message.role, "Content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"Role": "user", "Content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"Role": "assistant", "Content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"Role": "system", "Content": message.content} + else: + raise TypeError(f"Got unknown type {message}") + + return message_dict + + +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: + role = _dict["Role"] + if role == "user": + return HumanMessage(content=_dict["Content"]) + elif role == "assistant": + return AIMessage(content=_dict.get("Content", "") or "") + else: + return ChatMessage(content=_dict["Content"], role=role) + + +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("Role") + content = _dict.get("Content") or "" + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) # type: ignore[arg-type] + else: + return default_class(content=content) # type: ignore[call-arg] + + +def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: + generations = [] + for choice in response["Choices"]: + message = _convert_dict_to_message(choice["Message"]) + generations.append(ChatGeneration(message=message)) + + token_usage = response["Usage"] + llm_output = {"token_usage": token_usage} + return ChatResult(generations=generations, llm_output=llm_output) + + +class ChatHunyuan(BaseChatModel): + """Tencent Hunyuan chat models API by Tencent. + + For more information, see https://cloud.tencent.com/document/product/1729 + """ + + @property + def lc_secrets(self) -> Dict[str, str]: + return { + "hunyuan_app_id": "HUNYUAN_APP_ID", + "hunyuan_secret_id": "HUNYUAN_SECRET_ID", + "hunyuan_secret_key": "HUNYUAN_SECRET_KEY", + } + + @property + def lc_serializable(self) -> bool: + return True + + hunyuan_app_id: Optional[int] = None + """Hunyuan App ID""" + hunyuan_secret_id: Optional[str] = None + """Hunyuan Secret ID""" + hunyuan_secret_key: Optional[SecretStr] = None + """Hunyuan Secret Key""" + streaming: bool = False + """Whether to stream the results or not.""" + request_timeout: int = 60 + """Timeout for requests to Hunyuan API. Default is 60 seconds.""" + temperature: float = 1.0 + """What sampling temperature to use.""" + top_p: float = 1.0 + """What probability mass to use.""" + model: str = "hunyuan-lite" + """What Model to use. + Optional model: + - hunyuan-lite、 + - hunyuan-standard + - hunyuan-standard-256K + - hunyuan-pro + - hunyuan-code + - hunyuan-role + - hunyuan-functioncall + - hunyuan-vision + """ + stream_moderation: bool = False + """Whether to review the results or not when streaming is true.""" + enable_enhancement: bool = True + """Whether to enhancement the results or not.""" + + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + """Holds any model parameters valid for API call not explicitly specified.""" + + class Config: + """Configuration for this pydantic object.""" + + allow_population_by_field_name = True + + @root_validator(pre=True) + def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Build extra kwargs from additional params that were passed in.""" + all_required_field_names = get_pydantic_field_names(cls) + extra = values.get("model_kwargs", {}) + for field_name in list(values): + if field_name in extra: + raise ValueError(f"Found {field_name} supplied twice.") + if field_name not in all_required_field_names: + logger.warning( + f"""WARNING! {field_name} is not default parameter. + {field_name} was transferred to model_kwargs. + Please confirm that {field_name} is what you intended.""" + ) + extra[field_name] = values.pop(field_name) + + invalid_model_kwargs = all_required_field_names.intersection(extra.keys()) + if invalid_model_kwargs: + raise ValueError( + f"Parameters {invalid_model_kwargs} should be specified explicitly. " + f"Instead they were passed in as part of `model_kwargs` parameter." + ) + + values["model_kwargs"] = extra + return values + + @pre_init + def validate_environment(cls, values: Dict) -> Dict: + values["hunyuan_app_id"] = get_from_dict_or_env( + values, + "hunyuan_app_id", + "HUNYUAN_APP_ID", + ) + values["hunyuan_secret_id"] = get_from_dict_or_env( + values, + "hunyuan_secret_id", + "HUNYUAN_SECRET_ID", + ) + values["hunyuan_secret_key"] = convert_to_secret_str( + get_from_dict_or_env( + values, + "hunyuan_secret_key", + "HUNYUAN_SECRET_KEY", + ) + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling Hunyuan API.""" + normal_params = { + "Temperature": self.temperature, + "TopP": self.top_p, + "Model": self.model, + "Stream": self.streaming, + "StreamModeration": self.stream_moderation, + "EnableEnhancement": self.enable_enhancement, + } + return {**normal_params, **self.model_kwargs} + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + if self.streaming: + stream_iter = self._stream( + messages=messages, stop=stop, run_manager=run_manager, **kwargs + ) + return generate_from_stream(stream_iter) + + res = self._chat(messages, **kwargs) + return _create_chat_result(json.loads(res.to_json_string())) + + usage_metadata: dict = {} + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + res = self._chat(messages, **kwargs) + + default_chunk_class = AIMessageChunk + for chunk in res: + chunk = chunk.get("data", "") + if len(chunk) == 0: + continue + response = json.loads(chunk) + if "error" in response: + raise ValueError(f"Error from Hunyuan api response: {response}") + + for choice in response["Choices"]: + chunk = _convert_delta_to_message_chunk( + choice["Delta"], default_chunk_class + ) + default_chunk_class = chunk.__class__ + # FinishReason === stop + if choice.get("FinishReason") == "stop": + self.usage_metadata = response.get("Usage", {}) + cg_chunk = ChatGenerationChunk(message=chunk) + if run_manager: + run_manager.on_llm_new_token(chunk.content, chunk=cg_chunk) + yield cg_chunk + + def _chat(self, messages: List[BaseMessage], **kwargs: Any) -> Any: + if self.hunyuan_secret_key is None: + raise ValueError("Hunyuan secret key is not set.") + + try: + from tencentcloud.common import credential + from tencentcloud.hunyuan.v20230901 import hunyuan_client, models + except ImportError: + raise ImportError( + "Could not import tencentcloud python package. " + "Please install it with `pip install tencentcloud-sdk-python`." + ) + + parameters = {**self._default_params, **kwargs} + cred = credential.Credential( + self.hunyuan_secret_id, str(self.hunyuan_secret_key.get_secret_value()) + ) + client = hunyuan_client.HunyuanClient(cred, "") + req = models.ChatCompletionsRequest() + params = { + "Messages": [_convert_message_to_dict(m) for m in messages], + **parameters, + } + req.from_json_string(json.dumps(params)) + resp = client.ChatCompletions(req) + return resp + + @property + def _llm_type(self) -> str: + return "hunyuan-chat" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py new file mode 100644 index 0000000..17023f3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/model/llm.py @@ -0,0 +1,45 @@ +# coding=utf-8 + +from typing import List, Dict, Optional, Any + +from langchain_core.messages import BaseMessage, get_buffer_string +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.tencent_model_provider.model.hunyuan import ChatHunyuan + + +class TencentModel(MaxKBBaseModel, ChatHunyuan): + @staticmethod + def is_cache_model(): + return False + + def __init__(self, model_name: str, credentials: Dict[str, str], streaming: bool = False, **kwargs): + hunyuan_app_id = credentials.get('hunyuan_app_id') + hunyuan_secret_id = credentials.get('hunyuan_secret_id') + hunyuan_secret_key = credentials.get('hunyuan_secret_key') + + optional_params = MaxKBBaseModel.filter_optional_params(kwargs) + + if not all([hunyuan_app_id, hunyuan_secret_id, hunyuan_secret_key]): + raise ValueError( + "All of 'hunyuan_app_id', 'hunyuan_secret_id', and 'hunyuan_secret_key' must be provided in credentials.") + + super().__init__(model=model_name, hunyuan_app_id=hunyuan_app_id, hunyuan_secret_id=hunyuan_secret_id, + hunyuan_secret_key=hunyuan_secret_key, streaming=streaming, + temperature=optional_params.get('temperature', 1.0) + ) + + @staticmethod + def new_instance(model_type: str, model_name: str, model_credential: Dict[str, object], + **model_kwargs) -> 'TencentModel': + streaming = model_kwargs.pop('streaming', False) + return TencentModel(model_name=model_name, credentials=model_credential, streaming=streaming, **model_kwargs) + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.usage_metadata + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return self.usage_metadata.get('PromptTokens', 0) + + def get_num_tokens(self, text: str) -> int: + return self.usage_metadata.get('CompletionTokens', 0) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py new file mode 100644 index 0000000..47841a0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/tencent_model_provider/tencent_model_provider.py @@ -0,0 +1,104 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- + +import os +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ( + IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage +) +from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential +from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential +from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel +from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel +from smartdoc.conf import PROJECT_DIR + + +def _create_model_info(model_name, description, model_type, credential_class, model_class): + return ModelInfo( + name=model_name, + desc=description, + model_type=model_type, + model_credential=credential_class(), + model_class=model_class + ) + + +def _get_tencent_icon_path(): + return os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'tencent_model_provider', + 'icon', 'tencent_icon_svg') + + +def _initialize_model_info(): + model_info_list = [_create_model_info( + 'hunyuan-pro', + '当前混元模型中效果最优版本,万亿级参数规模 MOE-32K 长文模型。在各种 benchmark 上达到绝对领先的水平,复杂指令和推理,具备复杂数学能力,支持 functioncall,在多语言翻译、金融法律医疗等领域应用重点优化', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel + ), + _create_model_info( + 'hunyuan-standard', + '采用更优的路由策略,同时缓解了负载均衡和专家趋同的问题。长文方面,大海捞针指标达到99.9%', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + _create_model_info( + 'hunyuan-lite', + '升级为 MOE 结构,上下文窗口为 256k ,在 NLP,代码,数学,行业等多项评测集上领先众多开源模型', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + _create_model_info( + 'hunyuan-role', + '混元最新版角色扮演模型,混元官方精调训练推出的角色扮演模型,基于混元模型结合角色扮演场景数据集进行增训,在角色扮演场景具有更好的基础效果', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + _create_model_info( + 'hunyuan-functioncall ', + '混元最新 MOE 架构 FunctionCall 模型,经过高质量的 FunctionCall 数据训练,上下文窗口达 32K,在多个维度的评测指标上处于领先。', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + _create_model_info( + 'hunyuan-code', + '混元最新代码生成模型,经过 200B 高质量代码数据增训基座模型,迭代半年高质量 SFT 数据训练,上下文长窗口长度增大到 8K,五大语言代码生成自动评测指标上位居前列;五大语言10项考量各方面综合代码任务人工高质量评测上,性能处于第一梯队', + ModelTypeConst.LLM, + TencentLLMModelCredential, + TencentModel), + ] + + tencent_embedding_model_info = _create_model_info( + 'hunyuan-embedding', + '腾讯混元 Embedding 接口,可以将文本转化为高质量的向量数据。向量维度为1024维。', + ModelTypeConst.EMBEDDING, + TencentEmbeddingCredential, + TencentEmbeddingModel + ) + + model_info_embedding_list = [tencent_embedding_model_info] + + model_info_manage = ModelInfoManage.builder() \ + .append_model_info_list(model_info_list) \ + .append_model_info_list(model_info_embedding_list) \ + .append_default_model_info(model_info_list[0]) \ + .build() + + return model_info_manage + + +class TencentModelProvider(IModelProvider): + def __init__(self): + self._model_info_manage = _initialize_model_info() + + def get_model_info_manage(self): + return self._model_info_manage + + def get_model_provide_info(self): + icon_path = _get_tencent_icon_path() + icon_data = get_file_content(icon_path) + return ModelProvideInfo( + provider='model_tencent_provider', + name='腾讯混元', + icon=icon_data + ) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/__init__.py new file mode 100644 index 0000000..9bad579 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py new file mode 100644 index 0000000..0e100d7 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/credential/llm.py @@ -0,0 +1,65 @@ +# coding=utf-8 + +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class VLLMModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base')) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = provider.get_model_info_by_name(model_list, model_name) + if len(exist) == 0: + raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型") + model = provider.get_model(model_type, model_name, model_credential) + try: + res = model.invoke([HumanMessage(content='你好')]) + print(res) + except Exception as e: + print(e) + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} + + def build_model(self, model_info: Dict[str, object]): + for key in ['api_key', 'model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + self.api_key = model_info.get('api_key') + return self + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return VLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py new file mode 100644 index 0000000..d03eb72 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/model/llm.py @@ -0,0 +1,34 @@ +# coding=utf-8 + +from typing import List, Dict +from urllib.parse import urlparse, ParseResult +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +class VllmChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + vllm_chat_open_ai = VllmChatModel( + model=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + **optional_params, + streaming=True, + stream_usage=True, + ) + return vllm_chat_open_ai diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py new file mode 100644 index 0000000..42ba361 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/vllm_model_provider/vllm_model_provider.py @@ -0,0 +1,59 @@ +# coding=utf-8 +import os +from urllib.parse import urlparse, ParseResult + +import requests + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.vllm_model_provider.credential.llm import VLLMModelCredential +from setting.models_provider.impl.vllm_model_provider.model.llm import VllmChatModel +from smartdoc.conf import PROJECT_DIR + +v_llm_model_credential = VLLMModelCredential() +model_info_list = [ + ModelInfo('facebook/opt-125m', 'Facebook的125M参数模型', ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), + ModelInfo('BAAI/Aquila-7B', 'BAAI的7B参数模型', ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), + ModelInfo('BAAI/AquilaChat-7B', 'BAAI的13B参数模型', ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel), + +] + +model_info_manage = (ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo( + 'facebook/opt-125m', + 'Facebook的125M参数模型', + ModelTypeConst.LLM, v_llm_model_credential, VllmChatModel)) + .build()) + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +class VllmModelProvider(IModelProvider): + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_vllm_provider', name='vLLM', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'vllm_model_provider', 'icon', + 'vllm_icon_svg'))) + + @staticmethod + def get_base_model_list(api_base): + base_url = get_base_url(api_base) + base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + r = requests.request(method="GET", url=f"{base_url}/models", timeout=5) + r.raise_for_status() + return r.json().get('data') + + @staticmethod + def get_model_info_by_name(model_list, model_name): + if model_list is None: + return [] + return [model for model in model_list if model.get('id') == model_name] diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py new file mode 100644 index 0000000..8cb7f45 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/__init__.py @@ -0,0 +1,2 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py new file mode 100644 index 0000000..d49d22e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/embedding.py @@ -0,0 +1,46 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/7/12 16:45 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class OpenAIEmbeddingCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py new file mode 100644 index 0000000..48c434b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/llm.py @@ -0,0 +1,70 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/11 17:57 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.3, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class VolcanicEngineLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['access_key_id', 'secret_access_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + res = model.invoke([HumanMessage(content='你好')]) + print(res) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'access_key_id': super().encryption(model.get('access_key_id', ''))} + + access_key_id = forms.PasswordInputField('Access Key ID', required=True) + secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) + + def get_model_params_setting_form(self, model_name): + return VolcanicEngineLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py new file mode 100644 index 0000000..d7607de --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/stt.py @@ -0,0 +1,45 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineSTTModelCredential(BaseForm, BaseModelCredential): + volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v2/asr') + volcanic_app_id = forms.TextInputField('App ID', required=True) + volcanic_token = forms.PasswordInputField('Access Token', required=True) + volcanic_cluster = forms.TextInputField('Cluster ID', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py new file mode 100644 index 0000000..b565b16 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/credential/tts.py @@ -0,0 +1,73 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class VolcanicEngineTTSModelGeneralParams(BaseForm): + voice_type = forms.SingleSelect( + TooltipLabel('音色', '中文音色可支持中英文混合场景'), + required=True, default_value='BV002_streaming', + text_field='value', + value_field='value', + option_list=[ + {'text': '灿灿 2.0', 'value': 'BV700_V2_streaming'}, + {'text': '炀炀', 'value': 'BV705_streaming'}, + {'text': '擎苍 2.0', 'value': 'BV701_V2_streaming'}, + {'text': '通用女声 2.0', 'value': 'BV001_V2_streaming'}, + {'text': '灿灿', 'value': 'BV700_streaming'}, + {'text': '超自然音色-梓梓2.0', 'value': 'BV406_V2_streaming'}, + {'text': '超自然音色-梓梓', 'value': 'BV406_streaming'}, + {'text': '超自然音色-燃燃2.0', 'value': 'BV407_V2_streaming'}, + {'text': '超自然音色-燃燃', 'value': 'BV407_streaming'}, + {'text': '通用女声', 'value': 'BV001_streaming'}, + {'text': '通用男声', 'value': 'BV002_streaming'}, + ]) + speed_ratio = forms.SliderField( + TooltipLabel('语速', '[0.2,3],默认为1,通常保留一位小数即可'), + required=True, default_value=1, + _min=0.2, + _max=3, + _step=0.1, + precision=1) + + +class VolcanicEngineTTSModelCredential(BaseForm, BaseModelCredential): + volcanic_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://openspeech.bytedance.com/api/v1/tts/ws_binary') + volcanic_app_id = forms.TextInputField('App ID', required=True) + volcanic_token = forms.PasswordInputField('Access Token', required=True) + volcanic_cluster = forms.TextInputField('Cluster ID', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['volcanic_api_url', 'volcanic_app_id', 'volcanic_token', 'volcanic_cluster']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'volcanic_token': super().encryption(model.get('volcanic_token', ''))} + + def get_model_params_setting_form(self, model_name): + return VolcanicEngineTTSModelGeneralParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py new file mode 100644 index 0000000..b7307a0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/embedding.py @@ -0,0 +1,15 @@ +from typing import Dict + +from langchain_community.embeddings import VolcanoEmbeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class VolcanicEngineEmbeddingModel(MaxKBBaseModel, VolcanoEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return VolcanicEngineEmbeddingModel( + api_key=model_credential.get('api_key'), + model=model_name, + openai_api_base=model_credential.get('api_base'), + ) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/iat_mp3_16k.mp3 b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/iat_mp3_16k.mp3 new file mode 100644 index 0000000..75e744c Binary files /dev/null and b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/iat_mp3_16k.mp3 differ diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py new file mode 100644 index 0000000..181ad29 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/llm.py @@ -0,0 +1,21 @@ +from typing import List, Dict + +from setting.models_provider.base_model_provider import MaxKBBaseModel + +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +class VolcanicEngineChatModel(MaxKBBaseModel, BaseChatOpenAI): + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return VolcanicEngineChatModel( + model=model_name, + openai_api_base=model_credential.get('api_base'), + openai_api_key=model_credential.get('api_key'), + **optional_params + ) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py new file mode 100644 index 0000000..4d27a64 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/stt.py @@ -0,0 +1,342 @@ +# coding=utf-8 + +""" +requires Python 3.6 or later + +pip install asyncio +pip install websockets +""" +import asyncio +import base64 +import gzip +import hmac +import json +import os +import uuid +import wave +from enum import Enum +from hashlib import sha256 +from io import BytesIO +from typing import Dict +from urllib.parse import urlparse +import ssl +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + +audio_format = "mp3" # wav 或者 mp3,根据实际音频格式设置 + +PROTOCOL_VERSION = 0b0001 +DEFAULT_HEADER_SIZE = 0b0001 + +PROTOCOL_VERSION_BITS = 4 +HEADER_BITS = 4 +MESSAGE_TYPE_BITS = 4 +MESSAGE_TYPE_SPECIFIC_FLAGS_BITS = 4 +MESSAGE_SERIALIZATION_BITS = 4 +MESSAGE_COMPRESSION_BITS = 4 +RESERVED_BITS = 8 + +# Message Type: +CLIENT_FULL_REQUEST = 0b0001 +CLIENT_AUDIO_ONLY_REQUEST = 0b0010 +SERVER_FULL_RESPONSE = 0b1001 +SERVER_ACK = 0b1011 +SERVER_ERROR_RESPONSE = 0b1111 + +# Message Type Specific Flags +NO_SEQUENCE = 0b0000 # no check sequence +POS_SEQUENCE = 0b0001 +NEG_SEQUENCE = 0b0010 +NEG_SEQUENCE_1 = 0b0011 + +# Message Serialization +NO_SERIALIZATION = 0b0000 +JSON = 0b0001 +THRIFT = 0b0011 +CUSTOM_TYPE = 0b1111 + +# Message Compression +NO_COMPRESSION = 0b0000 +GZIP = 0b0001 +CUSTOM_COMPRESSION = 0b1111 + +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE + + +def generate_header( + version=PROTOCOL_VERSION, + message_type=CLIENT_FULL_REQUEST, + message_type_specific_flags=NO_SEQUENCE, + serial_method=JSON, + compression_type=GZIP, + reserved_data=0x00, + extension_header=bytes() +): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + """ + header = bytearray() + header_size = int(len(extension_header) / 4) + 1 + header.append((version << 4) | header_size) + header.append((message_type << 4) | message_type_specific_flags) + header.append((serial_method << 4) | compression_type) + header.append(reserved_data) + header.extend(extension_header) + return header + + +def generate_full_default_header(): + return generate_header() + + +def generate_audio_default_header(): + return generate_header( + message_type=CLIENT_AUDIO_ONLY_REQUEST + ) + + +def generate_last_audio_default_header(): + return generate_header( + message_type=CLIENT_AUDIO_ONLY_REQUEST, + message_type_specific_flags=NEG_SEQUENCE + ) + + +def parse_response(res): + """ + protocol_version(4 bits), header_size(4 bits), + message_type(4 bits), message_type_specific_flags(4 bits) + serialization_method(4 bits) message_compression(4 bits) + reserved (8bits) 保留字段 + header_extensions 扩展头(大小等于 8 * 4 * (header_size - 1) ) + payload 类似与http 请求体 + """ + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + reserved = res[3] + header_extensions = res[4:header_size * 4] + payload = res[header_size * 4:] + result = {} + payload_msg = None + payload_size = 0 + if message_type == SERVER_FULL_RESPONSE: + payload_size = int.from_bytes(payload[:4], "big", signed=True) + payload_msg = payload[4:] + elif message_type == SERVER_ACK: + seq = int.from_bytes(payload[:4], "big", signed=True) + result['seq'] = seq + if len(payload) >= 8: + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + elif message_type == SERVER_ERROR_RESPONSE: + code = int.from_bytes(payload[:4], "big", signed=False) + result['code'] = code + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload_msg = payload[8:] + print(f"Error code: {code}, message: {payload_msg}") + if payload_msg is None: + return result + if message_compression == GZIP: + payload_msg = gzip.decompress(payload_msg) + if serialization_method == JSON: + payload_msg = json.loads(str(payload_msg, "utf-8")) + elif serialization_method != NO_SERIALIZATION: + payload_msg = str(payload_msg, "utf-8") + result['payload_msg'] = payload_msg + result['payload_size'] = payload_size + return result + + +def read_wav_info(data: bytes = None) -> (int, int, int, int, int): + with BytesIO(data) as _f: + wave_fp = wave.open(_f, 'rb') + nchannels, sampwidth, framerate, nframes = wave_fp.getparams()[:4] + wave_bytes = wave_fp.readframes(nframes) + return nchannels, sampwidth, framerate, nframes, len(wave_bytes) + + +class VolcanicEngineSpeechToText(MaxKBBaseModel, BaseSpeechToText): + workflow: str = "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate" + show_language: bool = False + show_utterances: bool = False + result_type: str = "full" + format: str = "mp3" + rate: int = 16000 + language: str = "zh-CN" + bits: int = 16 + channel: int = 1 + codec: str = "raw" + audio_type: int = 1 + secret: str = "access_secret" + auth_method: str = "token" + mp3_seg_size: int = 10000 + success_code: int = 1000 # success code, default is 1000 + seg_duration: int = 15000 + nbest: int = 1 + + volcanic_app_id: str + volcanic_cluster: str + volcanic_api_url: str + volcanic_token: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.volcanic_api_url = kwargs.get('volcanic_api_url') + self.volcanic_token = kwargs.get('volcanic_token') + self.volcanic_app_id = kwargs.get('volcanic_app_id') + self.volcanic_cluster = kwargs.get('volcanic_cluster') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return VolcanicEngineSpeechToText( + volcanic_api_url=model_credential.get('volcanic_api_url'), + volcanic_token=model_credential.get('volcanic_token'), + volcanic_app_id=model_credential.get('volcanic_app_id'), + volcanic_cluster=model_credential.get('volcanic_cluster'), + **optional_params + ) + + def construct_request(self, reqid): + req = { + 'app': { + 'appid': self.volcanic_app_id, + 'cluster': self.volcanic_cluster, + 'token': self.volcanic_token, + }, + 'user': { + 'uid': 'uid' + }, + 'request': { + 'reqid': reqid, + 'nbest': self.nbest, + 'workflow': self.workflow, + 'show_language': self.show_language, + 'show_utterances': self.show_utterances, + 'result_type': self.result_type, + "sequence": 1 + }, + 'audio': { + 'format': self.format, + 'rate': self.rate, + 'language': self.language, + 'bits': self.bits, + 'channel': self.channel, + 'codec': self.codec + } + } + return req + + @staticmethod + def slice_data(data: bytes, chunk_size: int) -> (list, bool): + """ + slice data + :param data: wav data + :param chunk_size: the segment size in one request + :return: segment data, last flag + """ + data_len = len(data) + offset = 0 + while offset + chunk_size < data_len: + yield data[offset: offset + chunk_size], False + offset += chunk_size + else: + yield data[offset: data_len], True + + def _real_processor(self, request_params: dict) -> dict: + pass + + def token_auth(self): + return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)} + + def signature_auth(self, data): + header_dicts = { + 'Custom': 'auth_custom', + } + + url_parse = urlparse(self.volcanic_api_url) + input_str = 'GET {} HTTP/1.1\n'.format(url_parse.path) + auth_headers = 'Custom' + for header in auth_headers.split(','): + input_str += '{}\n'.format(header_dicts[header]) + input_data = bytearray(input_str, 'utf-8') + input_data += data + mac = base64.urlsafe_b64encode( + hmac.new(self.secret.encode('utf-8'), input_data, digestmod=sha256).digest()) + header_dicts['Authorization'] = 'HMAC256; access_token="{}"; mac="{}"; h="{}"'.format(self.volcanic_token, + str(mac, 'utf-8'), + auth_headers) + return header_dicts + + async def segment_data_processor(self, wav_data: bytes, segment_size: int): + reqid = str(uuid.uuid4()) + # 构建 full client request,并序列化压缩 + request_params = self.construct_request(reqid) + payload_bytes = str.encode(json.dumps(request_params)) + payload_bytes = gzip.compress(payload_bytes) + full_client_request = bytearray(generate_full_default_header()) + full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + full_client_request.extend(payload_bytes) # payload + header = None + if self.auth_method == "token": + header = self.token_auth() + elif self.auth_method == "signature": + header = self.signature_auth(full_client_request) + async with websockets.connect(self.volcanic_api_url, extra_headers=header, max_size=1000000000, + ssl=ssl_context) as ws: + # 发送 full client request + await ws.send(full_client_request) + res = await ws.recv() + result = parse_response(res) + if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: + raise Exception(f"Error code: {result['payload_msg']['code']}, message: {result['payload_msg']['message']}") + for seq, (chunk, last) in enumerate(VolcanicEngineSpeechToText.slice_data(wav_data, segment_size), 1): + # if no compression, comment this line + payload_bytes = gzip.compress(chunk) + audio_only_request = bytearray(generate_audio_default_header()) + if last: + audio_only_request = bytearray(generate_last_audio_default_header()) + audio_only_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + audio_only_request.extend(payload_bytes) # payload + # 发送 audio-only client request + await ws.send(audio_only_request) + res = await ws.recv() + result = parse_response(res) + if 'payload_msg' in result and result['payload_msg']['code'] != self.success_code: + return result + return result['payload_msg']['result'][0]['text'] + + def check_auth(self): + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f: + self.speech_to_text(f) + + def speech_to_text(self, file): + data = file.read() + audio_data = bytes(data) + if self.format == "mp3": + segment_size = self.mp3_seg_size + return asyncio.run(self.segment_data_processor(audio_data, segment_size)) + if self.format != "wav": + raise Exception("format should in wav or mp3") + nchannels, sampwidth, framerate, nframes, wav_len = read_wav_info( + audio_data) + size_per_sec = nchannels * sampwidth * framerate + segment_size = int(size_per_sec * self.seg_duration / 1000) + return asyncio.run(self.segment_data_processor(audio_data, segment_size)) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py new file mode 100644 index 0000000..ec39f22 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/model/tts.py @@ -0,0 +1,178 @@ +# coding=utf-8 + +''' +requires Python 3.6 or later + +pip install asyncio +pip install websockets + +''' + +import asyncio +import copy +import gzip +import json +import re +import uuid +from typing import Dict +import ssl +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + +MESSAGE_TYPES = {11: "audio-only server response", 12: "frontend server response", 15: "error message from server"} +MESSAGE_TYPE_SPECIFIC_FLAGS = {0: "no sequence number", 1: "sequence number > 0", + 2: "last message from server (seq < 0)", 3: "sequence number < 0"} +MESSAGE_SERIALIZATION_METHODS = {0: "no serialization", 1: "JSON", 15: "custom type"} +MESSAGE_COMPRESSIONS = {0: "no compression", 1: "gzip", 15: "custom compression method"} + +# version: b0001 (4 bits) +# header size: b0001 (4 bits) +# message type: b0001 (Full client request) (4bits) +# message type specific flags: b0000 (none) (4bits) +# message serialization method: b0001 (JSON) (4 bits) +# message compression: b0001 (gzip) (4bits) +# reserved data: 0x00 (1 byte) +default_header = bytearray(b'\x11\x10\x11\x00') + +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE + + +class VolcanicEngineTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + volcanic_app_id: str + volcanic_cluster: str + volcanic_api_url: str + volcanic_token: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.volcanic_api_url = kwargs.get('volcanic_api_url') + self.volcanic_token = kwargs.get('volcanic_token') + self.volcanic_app_id = kwargs.get('volcanic_app_id') + self.volcanic_cluster = kwargs.get('volcanic_cluster') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'voice_type': 'BV002_streaming', 'speed_ratio': 1.0}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return VolcanicEngineTextToSpeech( + volcanic_api_url=model_credential.get('volcanic_api_url'), + volcanic_token=model_credential.get('volcanic_token'), + volcanic_app_id=model_credential.get('volcanic_app_id'), + volcanic_cluster=model_credential.get('volcanic_cluster'), + **optional_params + ) + + def check_auth(self): + self.text_to_speech('你好') + + def text_to_speech(self, text): + request_json = { + "app": { + "appid": self.volcanic_app_id, + "token": "access_token", + "cluster": self.volcanic_cluster + }, + "user": { + "uid": "uid" + }, + "audio": { + "encoding": "mp3", + "volume_ratio": 1.0, + "pitch_ratio": 1.0, + } | self.params, + "request": { + "reqid": str(uuid.uuid4()), + "text": '', + "text_type": "plain", + "operation": "xxx" + } + } + + return asyncio.run(self.submit(request_json, text)) + + def is_cache_model(self): + return False + + def token_auth(self): + return {'Authorization': 'Bearer; {}'.format(self.volcanic_token)} + + async def submit(self, request_json, text): + submit_request_json = copy.deepcopy(request_json) + submit_request_json["request"]["operation"] = "submit" + header = {"Authorization": f"Bearer; {self.volcanic_token}"} + result = b'' + async with websockets.connect(self.volcanic_api_url, extra_headers=header, ping_interval=None, + ssl=ssl_context) as ws: + lines = text.split('\n') + for line in lines: + if self.is_table_format_chars_only(line): + continue + submit_request_json["request"]["reqid"] = str(uuid.uuid4()) + submit_request_json["request"]["text"] = line + payload_bytes = str.encode(json.dumps(submit_request_json)) + payload_bytes = gzip.compress(payload_bytes) # if no compression, comment this line + full_client_request = bytearray(default_header) + full_client_request.extend((len(payload_bytes)).to_bytes(4, 'big')) # payload size(4 bytes) + full_client_request.extend(payload_bytes) # payload + await ws.send(full_client_request) + result += await self.parse_response(ws) + return result + + @staticmethod + def is_table_format_chars_only(s): + # 检查是否仅包含 "|", "-", 和空格字符 + return bool(s) and re.fullmatch(r'[|\-\s]+', s) + + @staticmethod + async def parse_response(ws): + result = b'' + while True: + res = await ws.recv() + protocol_version = res[0] >> 4 + header_size = res[0] & 0x0f + message_type = res[1] >> 4 + message_type_specific_flags = res[1] & 0x0f + serialization_method = res[2] >> 4 + message_compression = res[2] & 0x0f + reserved = res[3] + header_extensions = res[4:header_size * 4] + payload = res[header_size * 4:] + if header_size != 1: + # print(f" Header extensions: {header_extensions}") + pass + if message_type == 0xb: # audio-only server response + if message_type_specific_flags == 0: # no sequence number as ACK + continue + else: + sequence_number = int.from_bytes(payload[:4], "big", signed=True) + payload_size = int.from_bytes(payload[4:8], "big", signed=False) + payload = payload[8:] + result += payload + if sequence_number < 0: + break + else: + continue + elif message_type == 0xf: + code = int.from_bytes(payload[:4], "big", signed=False) + msg_size = int.from_bytes(payload[4:8], "big", signed=False) + error_msg = payload[8:] + if message_compression == 1: + error_msg = gzip.decompress(error_msg) + error_msg = str(error_msg, "utf-8") + raise Exception(f"Error code: {code}, message: {error_msg}") + elif message_type == 0xc: + msg_size = int.from_bytes(payload[:4], "big", signed=False) + payload = payload[4:] + if message_compression == 1: + payload = gzip.decompress(payload) + else: + break + return result diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py new file mode 100644 index 0000000..1a0e17d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/volcanic_engine_model_provider/volcanic_engine_model_provider.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python +# -*- coding: UTF-8 -*- +""" +@Project :MaxKB +@File :gemini_model_provider.py +@Author :Brian Yang +@Date :5/13/24 7:47 AM +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.openai_model_provider.credential.embedding import OpenAIEmbeddingCredential +from setting.models_provider.impl.openai_model_provider.credential.llm import OpenAILLMModelCredential +from setting.models_provider.impl.openai_model_provider.model.embedding import OpenAIEmbeddingModel +from setting.models_provider.impl.volcanic_engine_model_provider.credential.tts import VolcanicEngineTTSModelCredential +from setting.models_provider.impl.volcanic_engine_model_provider.model.llm import VolcanicEngineChatModel +from setting.models_provider.impl.volcanic_engine_model_provider.credential.stt import VolcanicEngineSTTModelCredential +from setting.models_provider.impl.volcanic_engine_model_provider.model.stt import VolcanicEngineSpeechToText +from setting.models_provider.impl.volcanic_engine_model_provider.model.tts import VolcanicEngineTextToSpeech + +from smartdoc.conf import PROJECT_DIR + +volcanic_engine_llm_model_credential = OpenAILLMModelCredential() +volcanic_engine_stt_model_credential = VolcanicEngineSTTModelCredential() +volcanic_engine_tts_model_credential = VolcanicEngineTTSModelCredential() + +model_info_list = [ + ModelInfo('ep-xxxxxxxxxx-yyyy', + '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', + ModelTypeConst.LLM, + volcanic_engine_llm_model_credential, VolcanicEngineChatModel + ), + ModelInfo('asr', + '', + ModelTypeConst.STT, + volcanic_engine_stt_model_credential, VolcanicEngineSpeechToText + ), + ModelInfo('tts', + '', + ModelTypeConst.TTS, + volcanic_engine_tts_model_credential, VolcanicEngineTextToSpeech + ), +] + +open_ai_embedding_credential = OpenAIEmbeddingCredential() +model_info_embedding_list = [ + ModelInfo('ep-xxxxxxxxxx-yyyy', + '用户前往火山方舟的模型推理页面创建推理接入点,这里需要输入ep-xxxxxxxxxx-yyyy进行调用', + ModelTypeConst.EMBEDDING, open_ai_embedding_credential, + OpenAIEmbeddingModel)] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + model_info_list[0]).build() + + +class VolcanicEngineModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_volcanic_engine_provider', name='火山引擎', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'volcanic_engine_model_provider', + 'icon', + 'volcanic_engine_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/__init__.py new file mode 100644 index 0000000..53b7001 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/10/31 17:16 + @desc: +""" diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/credential/embedding.py new file mode 100644 index 0000000..25af4d5 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/credential/embedding.py @@ -0,0 +1,42 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/17 15:40 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class QianfanEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + self.valid_form(model_credential) + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'qianfan_sk': super().encryption(model.get('qianfan_sk', ''))} + + qianfan_ak = forms.PasswordInputField('API Key', required=True) + + qianfan_sk = forms.PasswordInputField("Secret Key", required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py new file mode 100644 index 0000000..342cb2e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/credential/llm.py @@ -0,0 +1,75 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/12 10:19 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class WenxinLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.95, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=2, + _max=100000, + _step=1, + precision=0) + + +class WenxinLLMModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + model = provider.get_model(model_type, model_name, model_credential) + model_info = [model.lower() for model in model.client.models()] + if not model_info.__contains__(model_name.lower()): + raise AppApiException(ValidCode.valid_error.value, f'{model_name} 模型不支持') + for key in ['api_key', 'secret_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model.invoke( + [HumanMessage(content='你好')]) + except Exception as e: + raise e + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'secret_key': super().encryption(model_info.get('secret_key', ''))} + + def build_model(self, model_info: Dict[str, object]): + for key in ['api_key', 'secret_key', 'model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + self.api_key = model_info.get('api_key') + self.secret_key = model_info.get('secret_key') + return self + + api_key = forms.PasswordInputField('API Key', required=True) + + secret_key = forms.PasswordInputField("Secret Key", required=True) + + def get_model_params_setting_form(self, model_name): + return WenxinLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/model/embedding.py new file mode 100644 index 0000000..d46ac51 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/model/embedding.py @@ -0,0 +1,23 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/17 16:48 + @desc: +""" +from typing import Dict + +from langchain_community.embeddings import QianfanEmbeddingsEndpoint + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class QianfanEmbeddings(MaxKBBaseModel, QianfanEmbeddingsEndpoint): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return QianfanEmbeddings( + model=model_name, + qianfan_ak=model_credential.get('qianfan_ak'), + qianfan_sk=model_credential.get('qianfan_sk'), + ) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py new file mode 100644 index 0000000..e9b69d7 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/model/llm.py @@ -0,0 +1,76 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2023/11/10 17:45 + @desc: +""" +import uuid +from typing import List, Dict, Optional, Any, Iterator + +from langchain_community.chat_models.baidu_qianfan_endpoint import _convert_dict_to_message, QianfanChatEndpoint +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.outputs import ChatGenerationChunk +from setting.models_provider.base_model_provider import MaxKBBaseModel +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage, +) + + +class QianfanChatModel(MaxKBBaseModel, QianfanChatEndpoint): + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return QianfanChatModel(model=model_name, + qianfan_ak=model_credential.get('api_key'), + qianfan_sk=model_credential.get('secret_key'), + streaming=model_kwargs.get('streaming', False), + init_kwargs=optional_params) + + usage_metadata: dict = {} + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.usage_metadata + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return self.usage_metadata.get('prompt_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + return self.usage_metadata.get('completion_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + kwargs = {**self.init_kwargs, **kwargs} + params = self._convert_prompt_msg_params(messages, **kwargs) + params["stop"] = stop + params["stream"] = True + for res in self.client.do(**params): + if res: + msg = _convert_dict_to_message(res) + additional_kwargs = msg.additional_kwargs.get("function_call", {}) + if msg.content == "" or res.get("body").get("is_end"): + token_usage = res.get("body").get("usage") + self.usage_metadata = token_usage + chunk = ChatGenerationChunk( + text=res["result"], + message=AIMessageChunk( # type: ignore[call-arg] + content=msg.content, + role="assistant", + additional_kwargs=additional_kwargs, + ), + generation_info=msg.additional_kwargs, + ) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py new file mode 100644 index 0000000..7944d70 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/wenxin_model_provider/wenxin_model_provider.py @@ -0,0 +1,67 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: wenxin_model_provider.py + @date:2023/10/31 16:19 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.wenxin_model_provider.credential.embedding import QianfanEmbeddingCredential +from setting.models_provider.impl.wenxin_model_provider.credential.llm import WenxinLLMModelCredential +from setting.models_provider.impl.wenxin_model_provider.model.embedding import QianfanEmbeddings +from setting.models_provider.impl.wenxin_model_provider.model.llm import QianfanChatModel +from smartdoc.conf import PROJECT_DIR + +win_xin_llm_model_credential = WenxinLLMModelCredential() +qianfan_embedding_credential = QianfanEmbeddingCredential() +model_info_list = [ModelInfo('ERNIE-Bot-4', + 'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('ERNIE-Bot', + 'ERNIE-Bot是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('ERNIE-Bot-turbo', + 'ERNIE-Bot-turbo是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力,响应速度更快。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('BLOOMZ-7B', + 'BLOOMZ-7B是业内知名的大语言模型,由BigScience研发并开源,能够以46种语言和13种编程语言输出文本。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('Llama-2-7b-chat', + 'Llama-2-7b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-7b-chat是高性能原生开源版本,适用于对话场景。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('Llama-2-13b-chat', + 'Llama-2-13b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-13b-chat是性能与效果均衡的原生开源版本,适用于对话场景。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('Llama-2-70b-chat', + 'Llama-2-70b-chat由Meta AI研发并开源,在编码、推理及知识应用等场景表现优秀,Llama-2-70b-chat是高精度效果的原生开源版本。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel), + ModelInfo('Qianfan-Chinese-Llama-2-7B', + '千帆团队在Llama-2-7b基础上的中文增强版本,在CMMLU、C-EVAL等中文知识库上表现优异。', + ModelTypeConst.LLM, win_xin_llm_model_credential, QianfanChatModel) + ] +embedding_model_info = ModelInfo('Embedding-V1', + 'Embedding-V1是一个基于百度文心大模型技术的文本表示模型,可以将文本转化为用数值表示的向量形式,用于文本检索、信息推荐、知识挖掘等场景。 Embedding-V1提供了Embeddings接口,可以根据输入内容生成对应的向量表示。您可以通过调用该接口,将文本输入到模型中,获取到对应的向量表示,从而进行后续的文本处理和分析。', + ModelTypeConst.EMBEDDING, qianfan_embedding_credential, QianfanEmbeddings) +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo('ERNIE-Bot-4', + 'ERNIE-Bot-4是百度自行研发的大语言模型,覆盖海量中文数据,具有更强的对话问答、内容创作生成等能力。', + ModelTypeConst.LLM, + win_xin_llm_model_credential, + QianfanChatModel)).append_model_info(embedding_model_info).append_default_model_info( + embedding_model_info).build() + + +class WenxinModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_wenxin_provider', name='千帆大模型', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'wenxin_model_provider', 'icon', + 'azure_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/__init__.py new file mode 100644 index 0000000..c743b4e --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/__init__.py @@ -0,0 +1,8 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/04/19 15:55 + @desc: +""" \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py new file mode 100644 index 0000000..63214bd --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/embedding.py @@ -0,0 +1,43 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/17 15:40 + @desc: +""" +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XFEmbeddingCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + self.valid_form(model_credential) + try: + model = provider.get_model(model_type, model_name, model_credential) + model.embed_query('你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + base_url = forms.TextInputField('API 域名', required=True, default_value="https://emb-cn-huabei-1.xf-yun.com/") + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py new file mode 100644 index 0000000..8ec12e3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/llm.py @@ -0,0 +1,90 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/12 10:29 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XunFeiLLMModelGeneralParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.5, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=4096, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class XunFeiLLMModelProParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.5, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=4096, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class XunFeiLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + spark_api_url = forms.TextInputField('API 域名', required=True) + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + def get_model_params_setting_form(self, model_name): + if model_name == 'general' or model_name == 'pro-128k': + return XunFeiLLMModelGeneralParams() + return XunFeiLLMModelProParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py new file mode 100644 index 0000000..bf051c1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/stt.py @@ -0,0 +1,46 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XunFeiSTTModelCredential(BaseForm, BaseModelCredential): + spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://iat-api.xfyun.cn/v2/iat') + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + + def get_model_params_setting_form(self, model_name): + pass diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py new file mode 100644 index 0000000..ec9478a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/credential/tts.py @@ -0,0 +1,67 @@ +# coding=utf-8 + +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XunFeiTTSModelGeneralParams(BaseForm): + vcn = forms.SingleSelect( + TooltipLabel('发音人', '发音人,可选值:请到控制台添加试用或购买发音人,添加后即显示发音人参数值'), + required=True, default_value='xiaoyan', + text_field='value', + value_field='value', + option_list=[ + {'text': '讯飞小燕', 'value': 'xiaoyan'}, + {'text': '讯飞许久', 'value': 'aisjiuxu'}, + {'text': '讯飞小萍', 'value': 'aisxping'}, + {'text': '讯飞小婧', 'value': 'aisjinger'}, + {'text': '讯飞许小宝', 'value': 'aisbabyxu'}, + ]) + speed = forms.SliderField( + TooltipLabel('语速', '语速,可选值:[0-100],默认为50'), + required=True, default_value=50, + _min=1, + _max=100, + _step=5, + precision=1) + + +class XunFeiTTSModelCredential(BaseForm, BaseModelCredential): + spark_api_url = forms.TextInputField('API 域名', required=True, default_value='wss://tts-api.xfyun.cn/v2/tts') + spark_app_id = forms.TextInputField('APP ID', required=True) + spark_api_key = forms.PasswordInputField("API Key", required=True) + spark_api_secret = forms.PasswordInputField('API Secret', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['spark_api_url', 'spark_app_id', 'spark_api_key', 'spark_api_secret']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'spark_api_secret': super().encryption(model.get('spark_api_secret', ''))} + + def get_model_params_setting_form(self, model_name): + return XunFeiTTSModelGeneralParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/embedding.py new file mode 100644 index 0000000..78cc04c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/embedding.py @@ -0,0 +1,49 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: embedding.py + @date:2024/10/17 15:29 + @desc: +""" + +import base64 +import json +from typing import Dict, Optional + +import numpy as np +from langchain_community.embeddings import SparkLLMTextEmbeddings +from numpy import ndarray + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class XFEmbedding(MaxKBBaseModel, SparkLLMTextEmbeddings): + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return XFEmbedding( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret') + ) + + @staticmethod + def _parser_message( + message: str, + ) -> Optional[ndarray]: + data = json.loads(message) + code = data["header"]["code"] + if code != 0: + # 这里是讯飞的QPS限制会报错,所以不建议用讯飞的向量模型 + raise Exception(f"Request error: {code}, {data}") + else: + text_base = data["payload"]["feature"]["text"] + text_data = base64.b64decode(text_base) + dt = np.dtype(np.float32) + dt = dt.newbyteorder("<") + text = np.frombuffer(text_data, dtype=dt) + if len(text) > 2560: + array = text[:2560] + else: + array = text + return array diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/iat_mp3_16k.mp3 b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/iat_mp3_16k.mp3 new file mode 100644 index 0000000..75e744c Binary files /dev/null and b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/iat_mp3_16k.mp3 differ diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/llm.py new file mode 100644 index 0000000..6380f75 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/llm.py @@ -0,0 +1,78 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2024/04/19 15:55 + @desc: +""" +from typing import List, Optional, Any, Iterator, Dict + +from langchain_community.chat_models.sparkllm import \ + ChatSparkLLM, convert_message_to_dict, _convert_delta_to_message_chunk +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.messages import BaseMessage, AIMessageChunk +from langchain_core.outputs import ChatGenerationChunk + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class XFChatSparkLLM(MaxKBBaseModel, ChatSparkLLM): + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return XFChatSparkLLM( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + spark_llm_domain=model_name, + streaming=model_kwargs.get('streaming', False), + **optional_params + ) + + usage_metadata: dict = {} + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.usage_metadata + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return self.usage_metadata.get('prompt_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + return self.usage_metadata.get('completion_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + default_chunk_class = AIMessageChunk + + self.client.arun( + [convert_message_to_dict(m) for m in messages], + self.spark_user_id, + self.model_kwargs, + True, + ) + for content in self.client.subscribe(timeout=self.request_timeout): + if "data" in content: + delta = content["data"] + chunk = _convert_delta_to_message_chunk(delta, default_chunk_class) + cg_chunk = ChatGenerationChunk(message=chunk) + elif "usage" in content: + generation_info = content["usage"] + self.usage_metadata = generation_info + continue + else: + continue + if cg_chunk is not None: + if run_manager: + run_manager.on_llm_new_token(str(cg_chunk.message.content), chunk=cg_chunk) + yield cg_chunk diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/stt.py new file mode 100644 index 0000000..f400473 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/stt.py @@ -0,0 +1,169 @@ +# -*- coding:utf-8 -*- +# +# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +import asyncio +import base64 +import datetime +import hashlib +import hmac +import json +import logging +import os +from datetime import datetime +from typing import Dict +from urllib.parse import urlencode, urlparse +import ssl +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + +STATUS_FIRST_FRAME = 0 # 第一帧的标识 +STATUS_CONTINUE_FRAME = 1 # 中间帧标识 +STATUS_LAST_FRAME = 2 # 最后一帧的标识 + +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE + +max_kb = logging.getLogger("max_kb") + +class XFSparkSpeechToText(MaxKBBaseModel, BaseSpeechToText): + spark_app_id: str + spark_api_key: str + spark_api_secret: str + spark_api_url: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.spark_api_url = kwargs.get('spark_api_url') + self.spark_app_id = kwargs.get('spark_app_id') + self.spark_api_key = kwargs.get('spark_api_key') + self.spark_api_secret = kwargs.get('spark_api_secret') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return XFSparkSpeechToText( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + **optional_params + ) + + # 生成url + def create_url(self): + url = self.spark_api_url + host = urlparse(url).hostname + # 生成RFC1123格式的时间戳 + gmt_format = '%a, %d %b %Y %H:%M:%S GMT' + date = datetime.utcnow().strftime(gmt_format) + + # 拼接字符串 + signature_origin = "host: " + host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/iat " + "HTTP/1.1" + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": host + } + # 拼接鉴权参数,生成url + url = url + '?' + urlencode(v) + # print("date: ",date) + # print("v: ",v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + # print('websocket url :', url) + return url + + def check_auth(self): + cwd = os.path.dirname(os.path.abspath(__file__)) + with open(f'{cwd}/iat_mp3_16k.mp3', 'rb') as f: + self.speech_to_text(f) + + def speech_to_text(self, file): + async def handle(): + async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: + # 发送 full client request + await self.send(ws, file) + return await self.handle_message(ws) + + return asyncio.run(handle()) + + @staticmethod + async def handle_message(ws): + res = await ws.recv() + message = json.loads(res) + code = message["code"] + sid = message["sid"] + if code != 0: + errMsg = message["message"] + raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}") + else: + data = message["data"]["result"]["ws"] + result = "" + for i in data: + for w in i["cw"]: + result += w["w"] + # print("sid:%s call success!,data is:%s" % (sid, json.dumps(data, ensure_ascii=False))) + return result + + # 收到websocket连接建立的处理 + async def send(self, ws, file): + frameSize = 8000 # 每一帧的音频大小 + status = STATUS_FIRST_FRAME # 音频的状态信息,标识音频是第一帧,还是中间帧、最后一帧 + + while True: + buf = file.read(frameSize) + # 文件结束 + if not buf: + status = STATUS_LAST_FRAME + # 第一帧处理 + # 发送第一帧音频,带business 参数 + # appid 必须带上,只需第一帧发送 + if status == STATUS_FIRST_FRAME: + d = { + "common": {"app_id": self.spark_app_id}, + "business": { + "domain": "iat", + "language": "zh_cn", + "accent": "mandarin", + "vinfo": 1, + "vad_eos": 10000 + }, + "data": { + "status": 0, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "lame"} + } + d = json.dumps(d) + await ws.send(d) + status = STATUS_CONTINUE_FRAME + # 中间帧处理 + elif status == STATUS_CONTINUE_FRAME: + d = {"data": {"status": 1, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "lame"}} + await ws.send(json.dumps(d)) + # 最后一帧处理 + elif status == STATUS_LAST_FRAME: + d = {"data": {"status": 2, "format": "audio/L16;rate=16000", + "audio": str(base64.b64encode(buf), 'utf-8'), + "encoding": "lame"}} + await ws.send(json.dumps(d)) + break diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/tts.py new file mode 100644 index 0000000..3a575ed --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/model/tts.py @@ -0,0 +1,146 @@ +# -*- coding:utf-8 -*- +# +# author: iflytek +# +# 错误码链接:https://www.xfyun.cn/document/error-code (code返回错误码时必看) +# # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # +import asyncio +import base64 +import datetime +import hashlib +import hmac +import json +import logging +import os +from datetime import datetime +from typing import Dict +from urllib.parse import urlencode, urlparse +import ssl +import websockets + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + +max_kb = logging.getLogger("max_kb") + +STATUS_FIRST_FRAME = 0 # 第一帧的标识 +STATUS_CONTINUE_FRAME = 1 # 中间帧标识 +STATUS_LAST_FRAME = 2 # 最后一帧的标识 + +ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) +ssl_context.check_hostname = False +ssl_context.verify_mode = ssl.CERT_NONE + + +class XFSparkTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + spark_app_id: str + spark_api_key: str + spark_api_secret: str + spark_api_url: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.spark_api_url = kwargs.get('spark_api_url') + self.spark_app_id = kwargs.get('spark_app_id') + self.spark_api_key = kwargs.get('spark_api_key') + self.spark_api_secret = kwargs.get('spark_api_secret') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'vcn': 'xiaoyan', 'speed': 50}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return XFSparkTextToSpeech( + spark_app_id=model_credential.get('spark_app_id'), + spark_api_key=model_credential.get('spark_api_key'), + spark_api_secret=model_credential.get('spark_api_secret'), + spark_api_url=model_credential.get('spark_api_url'), + **optional_params + ) + + # 生成url + def create_url(self): + url = self.spark_api_url + host = urlparse(url).hostname + # 生成RFC1123格式的时间戳 + gmt_format = '%a, %d %b %Y %H:%M:%S GMT' + date = datetime.utcnow().strftime(gmt_format) + + # 拼接字符串 + signature_origin = "host: " + host + "\n" + signature_origin += "date: " + date + "\n" + signature_origin += "GET " + "/v2/tts " + "HTTP/1.1" + # 进行hmac-sha256进行加密 + signature_sha = hmac.new(self.spark_api_secret.encode('utf-8'), signature_origin.encode('utf-8'), + digestmod=hashlib.sha256).digest() + signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8') + + authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % ( + self.spark_api_key, "hmac-sha256", "host date request-line", signature_sha) + authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') + # 将请求的鉴权参数组合为字典 + v = { + "authorization": authorization, + "date": date, + "host": host + } + # 拼接鉴权参数,生成url + url = url + '?' + urlencode(v) + # print("date: ",date) + # print("v: ",v) + # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 + # print('websocket url :', url) + return url + + def check_auth(self): + self.text_to_speech("你好") + + def text_to_speech(self, text): + + # 使用小语种须使用以下方式,此处的unicode指的是 utf16小端的编码方式,即"UTF-16LE"” + # self.Data = {"status": 2, "text": str(base64.b64encode(self.Text.encode('utf-16')), "UTF8")} + async def handle(): + async with websockets.connect(self.create_url(), max_size=1000000000, ssl=ssl_context) as ws: + # 发送 full client request + await self.send(ws, text) + return await self.handle_message(ws) + + return asyncio.run(handle()) + + def is_cache_model(self): + return False + + @staticmethod + async def handle_message(ws): + audio_bytes: bytes = b'' + while True: + res = await ws.recv() + message = json.loads(res) + # print(message) + code = message["code"] + sid = message["sid"] + + if code != 0: + errMsg = message["message"] + raise Exception(f"sid: {sid} call error: {errMsg} code is: {code}") + else: + audio = message["data"]["audio"] + audio = base64.b64decode(audio) + audio_bytes += audio + # 退出 + if message["data"]["status"] == 2: + break + return audio_bytes + + async def send(self, ws, text): + business = {"aue": "lame", "sfl": 1, "auf": "audio/L16;rate=16000", "tte": "utf8"} + d = { + "common": {"app_id": self.spark_app_id}, + "business": business | self.params, + "data": {"status": 2, "text": str(base64.b64encode(text.encode('utf-8')), "UTF8")}, + } + d = json.dumps(d) + await ws.send(d) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py new file mode 100644 index 0000000..04fd2d4 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xf_model_provider/xf_model_provider.py @@ -0,0 +1,52 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: xf_model_provider.py + @date:2024/04/19 14:47 + @desc: +""" +import os +import ssl + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.xf_model_provider.credential.embedding import XFEmbeddingCredential +from setting.models_provider.impl.xf_model_provider.credential.llm import XunFeiLLMModelCredential +from setting.models_provider.impl.xf_model_provider.credential.stt import XunFeiSTTModelCredential +from setting.models_provider.impl.xf_model_provider.credential.tts import XunFeiTTSModelCredential +from setting.models_provider.impl.xf_model_provider.model.embedding import XFEmbedding +from setting.models_provider.impl.xf_model_provider.model.llm import XFChatSparkLLM +from setting.models_provider.impl.xf_model_provider.model.stt import XFSparkSpeechToText +from setting.models_provider.impl.xf_model_provider.model.tts import XFSparkTextToSpeech +from smartdoc.conf import PROJECT_DIR + +ssl._create_default_https_context = ssl.create_default_context() + +qwen_model_credential = XunFeiLLMModelCredential() +stt_model_credential = XunFeiSTTModelCredential() +tts_model_credential = XunFeiTTSModelCredential() +embedding_model_credential = XFEmbeddingCredential() +model_info_list = [ + ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv3', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('generalv2', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM), + ModelInfo('iat', '中英文识别', ModelTypeConst.STT, stt_model_credential, XFSparkSpeechToText), + ModelInfo('tts', '', ModelTypeConst.TTS, tts_model_credential, XFSparkTextToSpeech), + ModelInfo('embedding', '', ModelTypeConst.EMBEDDING, embedding_model_credential, XFEmbedding) +] + +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo('generalv3.5', '', ModelTypeConst.LLM, qwen_model_credential, XFChatSparkLLM)).build() + + +class XunFeiModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_xf_provider', name='讯飞星火', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xf_model_provider', 'icon', + 'xf_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/__init__.py new file mode 100644 index 0000000..9bad579 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/__init__.py @@ -0,0 +1 @@ +# coding=utf-8 diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py new file mode 100644 index 0000000..7cddb4f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py @@ -0,0 +1,40 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.local_model_provider.model.embedding import LocalEmbedding + + +class XinferenceEmbeddingModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key'), + 'embedding') + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = provider.get_model_info_by_name(model_list, model_name) + model: LocalEmbedding = provider.get_model(model_type, model_name, model_credential) + if len(exist) == 0: + model.start_down_model_thread() + raise AppApiException(ValidCode.model_not_fount, "模型不存在,请先下载模型") + model.embed_query('你好') + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return model_info + + def build_model(self, model_info: Dict[str, object]): + for key in ['model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + return self + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py new file mode 100644 index 0000000..dc01c79 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/llm.py @@ -0,0 +1,61 @@ +# coding=utf-8 + +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XinferenceLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.7, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=800, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class XinferenceLLMModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + try: + model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key'), model_type) + except Exception as e: + raise AppApiException(ValidCode.valid_error.value, "API 域名无效") + exist = provider.get_model_info_by_name(model_list, model_name) + if len(exist) == 0: + raise AppApiException(ValidCode.valid_error.value, "模型不存在,请先下载模型") + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return {**model_info, 'api_key': super().encryption(model_info.get('api_key', ''))} + + def build_model(self, model_info: Dict[str, object]): + for key in ['api_key', 'model']: + if key not in model_info: + raise AppApiException(500, f'{key} 字段为必填字段') + self.api_key = model_info.get('api_key') + return self + + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return XinferenceLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py new file mode 100644 index 0000000..87f2797 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/reranker.py @@ -0,0 +1,47 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py + @date:2024/9/10 9:46 + @desc: +""" +from typing import Dict + +from langchain_core.documents import Document + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XInferenceRerankerModelCredential(BaseForm, BaseModelCredential): + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=True): + if not model_type == 'RERANKER': + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['server_url']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.compress_documents([Document(page_content='你好')], '你好') + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model_info: Dict[str, object]): + return model_info + + server_url = forms.TextInputField('API 域名', required=True) + + api_key = forms.PasswordInputField('API Key', required=False) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py new file mode 100644 index 0000000..7d19fea --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/stt.py @@ -0,0 +1,42 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XInferenceSTTModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + pass diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py new file mode 100644 index 0000000..0bf3daa --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/credential/tts.py @@ -0,0 +1,60 @@ +# coding=utf-8 +from typing import Dict + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class XInferenceTTSModelGeneralParams(BaseForm): + # ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女'] + voice = forms.SingleSelect( + TooltipLabel('音色', ''), + required=True, default_value='中文女', + text_field='value', + value_field='value', + option_list=[ + {'text': '中文女', 'value': '中文女'}, + {'text': '中文男', 'value': '中文男'}, + {'text': '日语男', 'value': '日语男'}, + {'text': '粤语女', 'value': '粤语女'}, + {'text': '英文女', 'value': '英文女'}, + {'text': '英文男', 'value': '英文男'}, + {'text': '韩语女', 'value': '韩语女'}, + ]) + + +class XInferenceTTSModelCredential(BaseForm, BaseModelCredential): + api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + + for key in ['api_base', 'api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.check_auth() + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + def get_model_params_setting_form(self, model_name): + return XInferenceTTSModelGeneralParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py new file mode 100644 index 0000000..935f4d2 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py @@ -0,0 +1,92 @@ +# coding=utf-8 +import threading +from typing import Dict, Optional, List, Any + +from langchain_community.embeddings import XinferenceEmbeddings +from langchain_core.embeddings import Embeddings + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class XinferenceEmbedding(MaxKBBaseModel, Embeddings): + client: Any + server_url: Optional[str] + """URL of the xinference server""" + model_uid: Optional[str] + """UID of the launched model""" + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return XinferenceEmbedding( + model_uid=model_name, + server_url=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + ) + + def down_model(self): + self.client.launch_model(model_name=self.model_uid, model_type="embedding") + + def start_down_model_thread(self): + thread = threading.Thread(target=self.down_model) + thread.daemon = True + thread.start() + + def __init__( + self, server_url: Optional[str] = None, model_uid: Optional[str] = None, + api_key: Optional[str] = None + ): + try: + from xinference.client import RESTfulClient + except ImportError: + try: + from xinference_client import RESTfulClient + except ImportError as e: + raise ImportError( + "Could not import RESTfulClient from xinference. Please install it" + " with `pip install xinference` or `pip install xinference_client`." + ) from e + + if server_url is None: + raise ValueError("Please provide server URL") + + if model_uid is None: + raise ValueError("Please provide the model UID") + + self.server_url = server_url + + self.model_uid = model_uid + + self.api_key = api_key + + self.client = RESTfulClient(server_url, api_key) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of documents using Xinference. + Args: + texts: The list of texts to embed. + Returns: + List of embeddings, one for each text. + """ + + model = self.client.get_model(self.model_uid) + + embeddings = [ + model.create_embedding(text)["data"][0]["embedding"] for text in texts + ] + return [list(map(float, e)) for e in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Embed a query of documents using Xinference. + Args: + text: The text to embed. + Returns: + Embeddings for the text. + """ + + model = self.client.get_model(self.model_uid) + + embedding_res = model.create_embedding(text) + + embedding = embedding_res["data"][0]["embedding"] + + return list(map(float, embedding)) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py new file mode 100644 index 0000000..16996b9 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/llm.py @@ -0,0 +1,39 @@ +# coding=utf-8 + +from typing import Dict, Optional, List, Any, Iterator +from urllib.parse import urlparse, ParseResult + +from langchain_core.language_models import LanguageModelInput +from langchain_core.messages import BaseMessageChunk +from langchain_core.runnables import RunnableConfig + +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_chat_open_ai import BaseChatOpenAI + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +class XinferenceChatModel(MaxKBBaseModel, BaseChatOpenAI): + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + api_base = model_credential.get('api_base', '') + base_url = get_base_url(api_base) + base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + return XinferenceChatModel( + model=model_name, + openai_api_base=base_url, + openai_api_key=model_credential.get('api_key'), + **optional_params + ) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py new file mode 100644 index 0000000..ed2db0f --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/reranker.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: reranker.py + @date:2024/9/10 9:45 + @desc: +""" +from typing import Sequence, Optional, Any, Dict + +from langchain_core.callbacks import Callbacks +from langchain_core.documents import BaseDocumentCompressor, Document +from xinference_client.client.restful.restful_client import RESTfulRerankModelHandle + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class XInferenceReranker(MaxKBBaseModel, BaseDocumentCompressor): + client: Any + server_url: Optional[str] + """URL of the xinference server""" + model_uid: Optional[str] + """UID of the launched model""" + api_key: Optional[str] + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + return XInferenceReranker(server_url=model_credential.get('server_url'), model_uid=model_name, + api_key=model_credential.get('api_key'), top_n=model_kwargs.get('top_n', 3)) + + top_n: Optional[int] = 3 + + def __init__( + self, server_url: Optional[str] = None, model_uid: Optional[str] = None, top_n=3, + api_key: Optional[str] = None + ): + try: + from xinference.client import RESTfulClient + except ImportError: + try: + from xinference_client import RESTfulClient + except ImportError as e: + raise ImportError( + "Could not import RESTfulClient from xinference. Please install it" + " with `pip install xinference` or `pip install xinference_client`." + ) from e + + super().__init__() + + if server_url is None: + raise ValueError("Please provide server URL") + + if model_uid is None: + raise ValueError("Please provide the model UID") + + self.server_url = server_url + + self.model_uid = model_uid + + self.api_key = api_key + + self.client = RESTfulClient(server_url, api_key) + + self.top_n = top_n + + def compress_documents(self, documents: Sequence[Document], query: str, callbacks: Optional[Callbacks] = None) -> \ + Sequence[Document]: + if documents is None or len(documents) == 0: + return [] + model: RESTfulRerankModelHandle = self.client.get_model(self.model_uid) + res = model.rerank([document.page_content for document in documents], query, self.top_n, return_documents=True) + return [Document(page_content=d.get('document', {}).get('text'), + metadata={'relevance_score': d.get('relevance_score')}) for d in res.get('results', [])] diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/stt.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/stt.py new file mode 100644 index 0000000..5e21ca6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/stt.py @@ -0,0 +1,59 @@ +import asyncio +import io +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_stt import BaseSpeechToText + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class XInferenceSpeechToText(MaxKBBaseModel, BaseSpeechToText): + api_base: str + api_key: str + model: str + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {} + if 'max_tokens' in model_kwargs and model_kwargs['max_tokens'] is not None: + optional_params['max_tokens'] = model_kwargs['max_tokens'] + if 'temperature' in model_kwargs and model_kwargs['temperature'] is not None: + optional_params['temperature'] = model_kwargs['temperature'] + return XInferenceSpeechToText( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + response_list = client.models.with_raw_response.list() + # print(response_list) + + def speech_to_text(self, audio_file): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + audio_data = audio_file.read() + buffer = io.BytesIO(audio_data) + buffer.name = "file.mp3" # this is the important line + res = client.audio.transcriptions.create(model=self.model, language="zh", file=buffer) + return res.text + diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py new file mode 100644 index 0000000..1420612 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/model/tts.py @@ -0,0 +1,64 @@ +from typing import Dict + +from openai import OpenAI + +from common.config.tokenizer_manage_config import TokenizerManage +from setting.models_provider.base_model_provider import MaxKBBaseModel +from setting.models_provider.impl.base_tts import BaseTextToSpeech + + +def custom_get_token_ids(text: str): + tokenizer = TokenizerManage.get_tokenizer() + return tokenizer.encode(text) + + +class XInferenceTextToSpeech(MaxKBBaseModel, BaseTextToSpeech): + api_base: str + api_key: str + model: str + params: dict + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.api_key = kwargs.get('api_key') + self.api_base = kwargs.get('api_base') + self.model = kwargs.get('model') + self.params = kwargs.get('params') + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = {'params': {'voice': '中文女'}} + for key, value in model_kwargs.items(): + if key not in ['model_id', 'use_local', 'streaming']: + optional_params['params'][key] = value + return XInferenceTextToSpeech( + model=model_name, + api_base=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), + **optional_params, + ) + + def check_auth(self): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + response_list = client.models.with_raw_response.list() + # print(response_list) + + def text_to_speech(self, text): + client = OpenAI( + base_url=self.api_base, + api_key=self.api_key + ) + # ['中文女', '中文男', '日语男', '粤语女', '英文女', '英文男', '韩语女'] + + with client.audio.speech.with_streaming_response.create( + model=self.model, + input=text, + **self.params + ) as response: + return response.read() + + def is_cache_model(self): + return False \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py new file mode 100644 index 0000000..0da07f6 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/xinference_model_provider/xinference_model_provider.py @@ -0,0 +1,417 @@ +# coding=utf-8 +import os +from urllib.parse import urlparse, ParseResult + +import requests + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, \ + ModelInfoManage +from setting.models_provider.impl.xinference_model_provider.credential.embedding import \ + XinferenceEmbeddingModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.llm import XinferenceLLMModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.reranker import XInferenceRerankerModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.stt import XInferenceSTTModelCredential +from setting.models_provider.impl.xinference_model_provider.credential.tts import XInferenceTTSModelCredential +from setting.models_provider.impl.xinference_model_provider.model.embedding import XinferenceEmbedding +from setting.models_provider.impl.xinference_model_provider.model.llm import XinferenceChatModel +from setting.models_provider.impl.xinference_model_provider.model.reranker import XInferenceReranker +from setting.models_provider.impl.xinference_model_provider.model.stt import XInferenceSpeechToText +from setting.models_provider.impl.xinference_model_provider.model.tts import XInferenceTextToSpeech +from smartdoc.conf import PROJECT_DIR + +xinference_llm_model_credential = XinferenceLLMModelCredential() +xinference_stt_model_credential = XInferenceSTTModelCredential() +xinference_tts_model_credential = XInferenceTTSModelCredential() + +model_info_list = [ + ModelInfo( + 'code-llama', + 'Code Llama 是一个专门用于代码生成的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'code-llama-instruct', + 'Code Llama Instruct 是 Code Llama 的指令微调版本,专为执行特定任务而设计。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'code-llama-python', + 'Code Llama Python 是一个专门用于 Python 代码生成的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'codeqwen1.5', + 'CodeQwen 1.5 是一个用于代码生成的语言模型,具有较高的性能。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'codeqwen1.5-chat', + 'CodeQwen 1.5 Chat 是一个聊天模型版本的 CodeQwen 1.5。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek', + 'Deepseek 是一个大规模语言模型,具有 130 亿参数。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek-chat', + 'Deepseek Chat 是一个聊天模型版本的 Deepseek。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek-coder', + 'Deepseek Coder 是一个专为代码生成设计的模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek-coder-instruct', + 'Deepseek Coder Instruct 是 Deepseek Coder 的指令微调版本,专为执行特定任务而设计。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'deepseek-vl-chat', + 'Deepseek VL Chat 是 Deepseek 的视觉语言聊天模型版本,能够处理图像和文本输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gpt-3.5-turbo', + 'GPT-3.5 Turbo 是一个高效能的通用语言模型,适用于多种应用场景。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gpt-4', + 'GPT-4 是一个强大的多模态模型,不仅支持文本输入,还支持图像输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gpt-4-vision-preview', + 'GPT-4 Vision Preview 是 GPT-4 的视觉预览版本,支持图像输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'gpt4all', + 'GPT4All 是一个开源的多模态模型,支持文本和图像输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'llama2', + 'Llama2 是一个具有 700 亿参数的大规模语言模型,支持多种语言。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'llama2-chat', + 'Llama2 Chat 是一个聊天模型版本的 Llama2,支持多种语言。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'llama2-chat-32k', + 'Llama2 Chat 32K 是一个聊天模型版本的 Llama2,支持长达 32K 令牌的上下文。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen', + 'Qwen 是一个大规模语言模型,具有 130 亿参数。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-chat', + 'Qwen Chat 是一个聊天模型版本的 Qwen。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-chat-32k', + 'Qwen Chat 32K 是一个聊天模型版本的 Qwen,支持长达 32K 令牌的上下文。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-code', + 'Qwen Code 是一个专门用于代码生成的语言模型。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-code-chat', + 'Qwen Code Chat 是一个聊天模型版本的 Qwen Code。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-vl', + 'Qwen VL 是 Qwen 的视觉语言模型版本,能够处理图像和文本输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen-vl-chat', + 'Qwen VL Chat 是 Qwen VL 的聊天模型版本,能够处理图像和文本输入。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2-instruct', + 'Qwen2 Instruct 是 Qwen2 的指令微调版本,专为执行特定任务而设计。', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2-72b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2-57b-a14b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2-7b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2.5-72b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2.5-32b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2.5-14b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2.5-7b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2.5-1.5b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2.5-0.5b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'qwen2.5-3b-instruct', + '', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), + ModelInfo( + 'minicpm-llama3-v-2_5', + 'MiniCPM-Llama3-V 2.5是MiniCPM-V系列中的最新型号,该模型基于SigLip-400M和Llama3-8B-Instruct构建,共有8B个参数', + ModelTypeConst.LLM, + xinference_llm_model_credential, + XinferenceChatModel + ), +] + +voice_model_info = [ + ModelInfo( + 'CosyVoice-300M-SFT', + 'CosyVoice-300M-SFT是一个小型的语音合成模型。', + ModelTypeConst.TTS, + xinference_tts_model_credential, + XInferenceTextToSpeech + ), + ModelInfo( + 'Belle-whisper-large-v3-zh', + 'Belle Whisper Large V3 是一个中文大型语音识别模型。', + ModelTypeConst.STT, + xinference_stt_model_credential, + XInferenceSpeechToText + ), +] + +xinference_embedding_model_credential = XinferenceEmbeddingModelCredential() + +# 生成embedding_model_info列表 +embedding_model_info = [ + ModelInfo('bce-embedding-base_v1', 'BCE 嵌入模型的基础版本。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-base-en', 'BGE 英语基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-base-en-v1.5', 'BGE 英语基础版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-base-zh', 'BGE 中文基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-base-zh-v1.5', 'BGE 中文基础版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-en', 'BGE 英语大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-en-v1.5', 'BGE 英语大型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-zh', 'BGE 中文大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-zh-noinstruct', 'BGE 中文大型版本的嵌入模型,无指令调整。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-large-zh-v1.5', 'BGE 中文大型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-m3', 'BGE M3 版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('bge-small-en-v1.5', 'BGE 英语小型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-small-zh', 'BGE 中文小型版本的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('bge-small-zh-v1.5', 'BGE 中文小型版本 1.5 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('e5-large-v2', 'E5 大型版本 2 的嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('gte-base', 'GTE 基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('gte-large', 'GTE 大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('jina-embeddings-v2-base-en', 'Jina 嵌入模型的英语基础版本 2。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('jina-embeddings-v2-base-zh', 'Jina 嵌入模型的中文基础版本 2。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('jina-embeddings-v2-small-en', 'Jina 嵌入模型的英语小型版本 2。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('m3e-base', 'M3E 基础版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('m3e-large', 'M3E 大型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('m3e-small', 'M3E 小型版本的嵌入模型。', ModelTypeConst.EMBEDDING, xinference_embedding_model_credential, + XinferenceEmbedding), + ModelInfo('multilingual-e5-large', '多语言大型版本的 E5 嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-base-chinese', 'Text2Vec 的中文基础版本嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-base-chinese-paraphrase', 'Text2Vec 的中文基础版本的同义句嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-base-chinese-sentence', 'Text2Vec 的中文基础版本的句子嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-base-multilingual', 'Text2Vec 的多语言基础版本嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), + ModelInfo('text2vec-large-chinese', 'Text2Vec 的中文大型版本嵌入模型。', ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding), +] +rerank_list = [ModelInfo('bce-reranker-base_v1', + '发布新的重新排名器,建立在强大的 M3 和LLM (GEMMA 和 MiniCPM,实际上没那么大)骨干上,支持多语言处理和更大的输入,大幅提高 BEIR、C-MTEB/Retrieval 的排名性能、MIRACL、LlamaIndex 评估', + ModelTypeConst.RERANKER, XInferenceRerankerModelCredential(), XInferenceReranker)] +model_info_manage = (ModelInfoManage.builder() + .append_model_info_list(model_info_list) + .append_model_info_list(voice_model_info) + .append_default_model_info(voice_model_info[0]) + .append_default_model_info(voice_model_info[1]) + .append_default_model_info(ModelInfo('phi3', + 'Phi-3 Mini是Microsoft的3.8B参数,轻量级,最先进的开放模型。', + ModelTypeConst.LLM, xinference_llm_model_credential, + XinferenceChatModel)) + .append_model_info_list(embedding_model_info) + .append_default_model_info(ModelInfo('', + '', + ModelTypeConst.EMBEDDING, + xinference_embedding_model_credential, XinferenceEmbedding)) + .append_model_info_list(rerank_list) + .append_default_model_info(rerank_list[0]) + .build()) + + +def get_base_url(url: str): + parse = urlparse(url) + result_url = ParseResult(scheme=parse.scheme, netloc=parse.netloc, path=parse.path, params='', + query='', + fragment='').geturl() + return result_url[:-1] if result_url.endswith("/") else result_url + + +class XinferenceModelProvider(IModelProvider): + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_xinference_provider', name='Xorbits Inference', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'xinference_model_provider', 'icon', + 'xinference_icon_svg'))) + + @staticmethod + def get_base_model_list(api_base, api_key, model_type): + base_url = get_base_url(api_base) + base_url = base_url if base_url.endswith('/v1') else (base_url + '/v1') + headers = {} + if api_key: + headers['Authorization'] = f"Bearer {api_key}" + r = requests.request(method="GET", url=f"{base_url}/models", headers=headers, timeout=5) + r.raise_for_status() + model_list = r.json().get('data') + return [model for model in model_list if model.get('model_type') == model_type] + + @staticmethod + def get_model_info_by_name(model_list, model_name): + if model_list is None: + return [] + return [model for model in model_list if model.get('model_name') == model_name or model.get('id') == model_name] diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/__init__.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py new file mode 100644 index 0000000..48c1194 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/credential/llm.py @@ -0,0 +1,67 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: llm.py + @date:2024/7/12 10:46 + @desc: +""" +from typing import Dict + +from langchain_core.messages import HumanMessage + +from common import forms +from common.exception.app_exception import AppApiException +from common.forms import BaseForm, TooltipLabel +from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode + + +class ZhiPuLLMModelParams(BaseForm): + temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'), + required=True, default_value=0.95, + _min=0.1, + _max=1.0, + _step=0.01, + precision=2) + + max_tokens = forms.SliderField( + TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'), + required=True, default_value=1024, + _min=1, + _max=100000, + _step=1, + precision=0) + + +class ZhiPuLLMModelCredential(BaseForm, BaseModelCredential): + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): + model_type_list = provider.get_model_type_list() + if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + for key in ['api_key']: + if key not in model_credential: + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') + else: + return False + try: + model = provider.get_model(model_type, model_name, model_credential) + model.invoke([HumanMessage(content='你好')]) + except Exception as e: + if isinstance(e, AppApiException): + raise e + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') + else: + return False + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'api_key': super().encryption(model.get('api_key', ''))} + + api_key = forms.PasswordInputField('API Key', required=True) + + def get_model_params_setting_form(self, model_name): + return ZhiPuLLMModelParams() diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py new file mode 100644 index 0000000..0369932 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/model/llm.py @@ -0,0 +1,107 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: llm.py + @date:2024/4/28 11:42 + @desc: +""" + +import json +from collections.abc import Iterator +from typing import Any, Dict, List, Optional + +from langchain_community.chat_models import ChatZhipuAI +from langchain_community.chat_models.zhipuai import _truncate_params, _get_jwt_token, connect_sse, \ + _convert_delta_to_message_chunk +from langchain_core.callbacks import ( + CallbackManagerForLLMRun, +) +from langchain_core.messages import ( + AIMessageChunk, + BaseMessage +) +from langchain_core.outputs import ChatGenerationChunk + +from setting.models_provider.base_model_provider import MaxKBBaseModel + + +class ZhipuChatModel(MaxKBBaseModel, ChatZhipuAI): + optional_params: dict + + @staticmethod + def is_cache_model(): + return False + + @staticmethod + def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): + optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs) + zhipuai_chat = ZhipuChatModel( + api_key=model_credential.get('api_key'), + model=model_name, + streaming=model_kwargs.get('streaming', False), + optional_params=optional_params, + **optional_params, + ) + return zhipuai_chat + + usage_metadata: dict = {} + + def get_last_generation_info(self) -> Optional[Dict[str, Any]]: + return self.usage_metadata + + def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int: + return self.usage_metadata.get('prompt_tokens', 0) + + def get_num_tokens(self, text: str) -> int: + return self.usage_metadata.get('completion_tokens', 0) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """Stream the chat response in chunks.""" + if self.zhipuai_api_key is None: + raise ValueError("Did not find zhipuai_api_key.") + if self.zhipuai_api_base is None: + raise ValueError("Did not find zhipu_api_base.") + message_dicts, params = self._create_message_dicts(messages, stop) + payload = {**params, **kwargs, **self.optional_params, "messages": message_dicts, "stream": True} + _truncate_params(payload) + headers = { + "Authorization": _get_jwt_token(self.zhipuai_api_key), + "Accept": "application/json", + } + + default_chunk_class = AIMessageChunk + import httpx + + with httpx.Client(headers=headers, timeout=60) as client: + with connect_sse( + client, "POST", self.zhipuai_api_base, json=payload + ) as event_source: + for sse in event_source.iter_sse(): + chunk = json.loads(sse.data) + if len(chunk["choices"]) == 0: + continue + choice = chunk["choices"][0] + generation_info = {} + if "usage" in chunk: + generation_info = chunk["usage"] + self.usage_metadata = generation_info + chunk = _convert_delta_to_message_chunk( + choice["delta"], default_chunk_class + ) + finish_reason = choice.get("finish_reason", None) + + chunk = ChatGenerationChunk( + message=chunk, generation_info=generation_info + ) + yield chunk + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + if finish_reason is not None: + break diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py new file mode 100644 index 0000000..ab19b15 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/impl/zhipu_model_provider/zhipu_model_provider.py @@ -0,0 +1,36 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: zhipu_model_provider.py + @date:2024/04/19 13:5 + @desc: +""" +import os + +from common.util.file_util import get_file_content +from setting.models_provider.base_model_provider import ModelProvideInfo, ModelTypeConst, ModelInfo, IModelProvider, \ + ModelInfoManage +from setting.models_provider.impl.zhipu_model_provider.credential.llm import ZhiPuLLMModelCredential +from setting.models_provider.impl.zhipu_model_provider.model.llm import ZhipuChatModel +from smartdoc.conf import PROJECT_DIR + +qwen_model_credential = ZhiPuLLMModelCredential() +model_info_list = [ + ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel), + ModelInfo('glm-4v', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel), + ModelInfo('glm-3-turbo', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel) +] +model_info_manage = ModelInfoManage.builder().append_model_info_list(model_info_list).append_default_model_info( + ModelInfo('glm-4', '', ModelTypeConst.LLM, qwen_model_credential, ZhipuChatModel)).build() + + +class ZhiPuModelProvider(IModelProvider): + + def get_model_info_manage(self): + return model_info_manage + + def get_model_provide_info(self): + return ModelProvideInfo(provider='model_zhipu_provider', name='智谱AI', icon=get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'zhipu_model_provider', 'icon', + 'zhipuai_icon_svg'))) diff --git a/src/MaxKB-1.7.2/apps/setting/models_provider/tools.py b/src/MaxKB-1.7.2/apps/setting/models_provider/tools.py new file mode 100644 index 0000000..6606043 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/models_provider/tools.py @@ -0,0 +1,33 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: tools.py + @date:2024/7/22 11:18 + @desc: +""" +from django.db.models import QuerySet + +from common.config.embedding_config import ModelManage +from setting.models import Model +from setting.models_provider import get_model + + +def get_model_by_id(_id, user_id): + model = QuerySet(Model).filter(id=_id).first() + if model is None: + raise Exception("模型不存在") + if model.permission_type == 'PRIVATE' and str(model.user_id) != str(user_id): + raise Exception(f"无权限使用此模型:{model.name}") + return model + + +def get_model_instance_by_model_user_id(model_id, user_id, **kwargs): + """ + 获取模型实例,根据模型相关数据 + @param model_id: 模型id + @param user_id: 用户id + @return: 模型实例 + """ + model = get_model_by_id(model_id, user_id) + return ModelManage.get_model(model_id, lambda _id: get_model(model, **kwargs)) diff --git a/src/MaxKB-1.7.2/apps/setting/serializers/model_apply_serializers.py b/src/MaxKB-1.7.2/apps/setting/serializers/model_apply_serializers.py new file mode 100644 index 0000000..fd41869 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/serializers/model_apply_serializers.py @@ -0,0 +1,73 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: model_apply_serializers.py + @date:2024/8/20 20:39 + @desc: +""" +from django.db.models import QuerySet +from langchain_core.documents import Document +from rest_framework import serializers + +from common.config.embedding_config import ModelManage +from common.util.field_message import ErrMessage +from setting.models import Model +from setting.models_provider import get_model + + +def get_embedding_model(model_id): + model = QuerySet(Model).filter(id=model_id).first() + embedding_model = ModelManage.get_model(model_id, + lambda _id: get_model(model, use_local=True)) + return embedding_model + + +class EmbedDocuments(serializers.Serializer): + texts = serializers.ListField(required=True, child=serializers.CharField(required=True, + error_messages=ErrMessage.char( + "向量文本")), + error_messages=ErrMessage.list("向量文本列表")) + + +class EmbedQuery(serializers.Serializer): + text = serializers.CharField(required=True, error_messages=ErrMessage.char("向量文本")) + + +class CompressDocument(serializers.Serializer): + page_content = serializers.CharField(required=True, error_messages=ErrMessage.char("文本")) + metadata = serializers.DictField(required=False, error_messages=ErrMessage.dict("元数据")) + + +class CompressDocuments(serializers.Serializer): + documents = CompressDocument(required=True, many=True) + query = serializers.CharField(required=True, error_messages=ErrMessage.char("查询query")) + + +class ModelApplySerializers(serializers.Serializer): + model_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) + + def embed_documents(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + EmbedDocuments(data=instance).is_valid(raise_exception=True) + + model = get_embedding_model(self.data.get('model_id')) + return model.embed_documents(instance.getlist('texts')) + + def embed_query(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + EmbedQuery(data=instance).is_valid(raise_exception=True) + + model = get_embedding_model(self.data.get('model_id')) + return model.embed_query(instance.get('text')) + + def compress_documents(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + CompressDocuments(data=instance).is_valid(raise_exception=True) + model = get_embedding_model(self.data.get('model_id')) + return [{'page_content': d.page_content, 'metadata': d.metadata} for d in model.compress_documents( + [Document(page_content=document.get('page_content'), metadata=document.get('metadata')) for document in + instance.get('documents')], instance.get('query'))] diff --git a/src/MaxKB-1.7.2/apps/setting/serializers/provider_serializers.py b/src/MaxKB-1.7.2/apps/setting/serializers/provider_serializers.py new file mode 100644 index 0000000..e76e67d --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/serializers/provider_serializers.py @@ -0,0 +1,392 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: provider_serializers.py + @date:2023/11/2 14:01 + @desc: +""" +import json +import re +import threading +import time +import uuid +from typing import Dict + +from django.core import validators +from django.db.models import QuerySet, Q +from rest_framework import serializers + +from application.models import Application +from common.config.embedding_config import ModelManage +from common.exception.app_exception import AppApiException +from common.util.field_message import ErrMessage +from common.util.rsa_util import rsa_long_decrypt, rsa_long_encrypt +from dataset.models import DataSet +from setting.models.model_management import Model, Status, PermissionType +from setting.models_provider import get_model, get_model_credential +from setting.models_provider.base_model_provider import ValidCode, DownModelChunkStatus +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants + + +class ModelPullManage: + + @staticmethod + def pull(model: Model, credential: Dict): + try: + response = ModelProvideConstants[model.provider].value.down_model(model.model_type, model.model_name, + credential) + down_model_chunk = {} + timestamp = time.time() + for chunk in response: + down_model_chunk[chunk.digest] = chunk.to_dict() + if time.time() - timestamp > 5: + model_new = QuerySet(Model).filter(id=model.id).first() + if model_new.status == Status.PAUSE_DOWNLOAD: + return + QuerySet(Model).filter(id=model.id).update( + meta={"down_model_chunk": list(down_model_chunk.values())}) + timestamp = time.time() + status = Status.ERROR + message = "" + down_model_chunk_list = list(down_model_chunk.values()) + for chunk in down_model_chunk_list: + if chunk.get('status') == DownModelChunkStatus.success.value: + status = Status.SUCCESS + if chunk.get('status') == DownModelChunkStatus.error.value: + message = chunk.get("digest") + QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": [], "message": message}, + status=status) + except Exception as e: + QuerySet(Model).filter(id=model.id).update(meta={"down_model_chunk": [], "message": str(e)}, + status=Status.ERROR) + + +class ModelSerializer(serializers.Serializer): + class Query(serializers.Serializer): + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + name = serializers.CharField(required=False, max_length=64, + error_messages=ErrMessage.char("模型名称")) + + model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) + + model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("基础模型")) + + provider = serializers.CharField(required=False, error_messages=ErrMessage.char("供应商")) + + permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限")) + + create_user = serializers.CharField(required=False, error_messages=ErrMessage.char("创建者")) + + + def list(self, with_valid): + if with_valid: + self.is_valid(raise_exception=True) + user_id = self.data.get('user_id') + name = self.data.get('name') + create_user = self.data.get('create_user') + if create_user is not None: + # 当前用户能查看自己的模型,包括公开和私有的 + if create_user == user_id: + model_query_set = QuerySet(Model).filter(Q(user_id=create_user)) + # 当前用户能查看其他人的模型,只能查看公开的 + else: + model_query_set = QuerySet(Model).filter((Q(user_id=self.data.get('create_user')) & Q(permission_type='PUBLIC'))) + else: + model_query_set = QuerySet(Model).filter((Q(user_id=user_id) | Q(permission_type='PUBLIC'))) + query_params = {} + if name is not None: + query_params['name__contains'] = name + if self.data.get('model_type') is not None: + query_params['model_type'] = self.data.get('model_type') + if self.data.get('model_name') is not None: + query_params['model_name'] = self.data.get('model_name') + if self.data.get('provider') is not None: + query_params['provider'] = self.data.get('provider') + if self.data.get('permission_type') is not None: + query_params['permission_type'] = self.data.get('permission_type') + + + return [ + {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, + 'model_name': model.model_name, 'status': model.status, 'meta': model.meta, + 'permission_type': model.permission_type, 'user_id': model.user_id, 'username': model.user.username} for model in + model_query_set.filter(**query_params).order_by("-create_time")] + + class Edit(serializers.Serializer): + user_id = serializers.CharField(required=False, error_messages=ErrMessage.uuid("用户id")) + + name = serializers.CharField(required=False, max_length=64, + error_messages=ErrMessage.char("模型名称")) + + model_type = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) + + permission_type = serializers.CharField(required=False, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + + model_name = serializers.CharField(required=False, error_messages=ErrMessage.char("模型类型")) + + credential = serializers.DictField(required=False, error_messages=ErrMessage.dict("认证信息")) + + def is_valid(self, model=None, raise_exception=False): + super().is_valid(raise_exception=True) + filter_params = {'user_id': self.data.get('user_id')} + if 'name' in self.data and self.data.get('name') is not None: + filter_params['name'] = self.data.get('name') + if QuerySet(Model).exclude(id=model.id).filter(**filter_params).exists(): + raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在') + + ModelSerializer.model_to_dict(model) + + provider = model.provider + model_type = self.data.get('model_type') + model_name = self.data.get( + 'model_name') + credential = self.data.get('credential') + provider_handler = ModelProvideConstants[provider].value + model_credential = ModelProvideConstants[provider].value.get_model_credential(model_type, + model_name) + source_model_credential = json.loads(rsa_long_decrypt(model.credential)) + source_encryption_model_credential = model_credential.encryption_dict(source_model_credential) + if credential is not None: + for k in source_encryption_model_credential.keys(): + if credential[k] == source_encryption_model_credential[k]: + credential[k] = source_model_credential[k] + return credential, model_credential, provider_handler + + class Create(serializers.Serializer): + user_id = serializers.CharField(required=True, error_messages=ErrMessage.uuid("用户id")) + + name = serializers.CharField(required=True, max_length=64, error_messages=ErrMessage.char("模型名称")) + + provider = serializers.CharField(required=True, error_messages=ErrMessage.char("供应商")) + + model_type = serializers.CharField(required=True, error_messages=ErrMessage.char("模型类型")) + + permission_type = serializers.CharField(required=True, error_messages=ErrMessage.char("权限"), validators=[ + validators.RegexValidator(regex=re.compile("^PUBLIC|PRIVATE$"), + message="权限只支持PUBLIC|PRIVATE", code=500) + ]) + + model_name = serializers.CharField(required=True, error_messages=ErrMessage.char("基础模型")) + + credential = serializers.DictField(required=True, error_messages=ErrMessage.dict("认证信息")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if QuerySet(Model).filter(user_id=self.data.get('user_id'), + name=self.data.get('name')).exists(): + raise AppApiException(500, f'模型名称【{self.data.get("name")}】已存在') + ModelProvideConstants[self.data.get('provider')].value.is_valid_credential(self.data.get('model_type'), + self.data.get('model_name'), + self.data.get('credential'), + raise_exception=True + ) + + def insert(self, user_id, with_valid=False): + status = Status.SUCCESS + if with_valid: + try: + self.is_valid(raise_exception=True) + except AppApiException as e: + if e.code == ValidCode.model_not_fount: + status = Status.DOWNLOAD + else: + raise e + credential = self.data.get('credential') + name = self.data.get('name') + provider = self.data.get('provider') + model_type = self.data.get('model_type') + model_name = self.data.get('model_name') + permission_type = self.data.get('permission_type') + model_credential_str = json.dumps(credential) + model = Model(id=uuid.uuid1(), status=status, user_id=user_id, name=name, + credential=rsa_long_encrypt(model_credential_str), + provider=provider, model_type=model_type, model_name=model_name, + permission_type=permission_type) + model.save() + if status == Status.DOWNLOAD: + thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential)) + thread.start() + return ModelSerializer.Operate(data={'id': model.id, 'user_id': user_id}).one(with_valid=True) + + @staticmethod + def model_to_dict(model: Model): + credential = json.loads(rsa_long_decrypt(model.credential)) + return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, + 'model_name': model.model_name, + 'status': model.status, + 'meta': model.meta, + 'credential': ModelProvideConstants[model.provider].value.get_model_credential(model.model_type, + model.model_name).encryption_dict( + credential), + 'permission_type': model.permission_type} + + class ModelParams(serializers.Serializer): + id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + model = QuerySet(Model).filter(id=self.data.get("id")).first() + if model is None: + raise AppApiException(500, '模型不存在') + if model.permission_type == PermissionType.PRIVATE and self.data.get('user_id') != str(model.user_id): + raise AppApiException(500, '没有权限访问到此模型') + + def get_model_params(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + model_id = self.data.get('id') + model = QuerySet(Model).filter(id=model_id).first() + credential = get_model_credential(model.provider, model.model_type, model.model_name) + # 已经保存过的模型参数表单 + if model.model_params_form is not None and len(model.model_params_form) > 0: + return model.model_params_form + # 没有保存过的LLM类型的 + if credential.get_model_params_setting_form(model.model_name) is not None: + return credential.get_model_params_setting_form(model.model_name).to_form_list() + # 其他的 + return model.model_params_form + + class ModelParamsForm(serializers.Serializer): + id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + model = QuerySet(Model).filter(id=self.data.get("id")).first() + if model is None: + raise AppApiException(500, '模型不存在') + if model.permission_type == PermissionType.PRIVATE and self.data.get('user_id') != str(model.user_id): + raise AppApiException(500, '没有权限访问到此模型') + + def save_model_params_form(self, model_params_form, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + if model_params_form is None: + model_params_form = [] + model_id = self.data.get('id') + model = QuerySet(Model).filter(id=model_id).first() + model.model_params_form = model_params_form + model.save() + return True + + class Operate(serializers.Serializer): + id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("模型id")) + + user_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("用户id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + model = QuerySet(Model).filter(id=self.data.get("id"), user_id=self.data.get("user_id")).first() + if model is None: + raise AppApiException(500, '模型不存在') + + def one(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id')) + return ModelSerializer.model_to_dict(model) + + def one_meta(self, with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + model = QuerySet(Model).get(id=self.data.get('id'), user_id=self.data.get('user_id')) + return {'id': str(model.id), 'provider': model.provider, 'name': model.name, 'model_type': model.model_type, + 'model_name': model.model_name, + 'status': model.status, + 'meta': model.meta + } + + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + model_id = self.data.get('id') + model = Model.objects.filter(id=model_id).first() + if not model: + # 模型不存在,直接返回或抛出异常 + raise AppApiException(500, "模型不存在") + if model.model_type == 'LLM': + application_count = Application.objects.filter(model_id=model_id).count() + if application_count > 0: + raise AppApiException(500, f"该模型关联了{application_count} 个应用,无法删除该模型。") + elif model.model_type == 'EMBEDDING': + dataset_count = DataSet.objects.filter(embedding_mode_id=model_id).count() + if dataset_count > 0: + raise AppApiException(500, f"该模型关联了{dataset_count} 个知识库,无法删除该模型。") + elif model.model_type == 'TTS': + dataset_count = Application.objects.filter(tts_model_id=model_id).count() + if dataset_count > 0: + raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。") + elif model.model_type == 'STT': + dataset_count = Application.objects.filter(stt_model_id=model_id).count() + if dataset_count > 0: + raise AppApiException(500, f"该模型关联了{dataset_count} 个应用,无法删除该模型。") + model.delete() + return True + + def pause_download(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + QuerySet(Model).filter(id=self.data.get('id')).update(status=Status.PAUSE_DOWNLOAD) + return True + + def edit(self, instance: Dict, user_id: str, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + model = QuerySet(Model).filter(id=self.data.get('id')).first() + + if model is None: + raise AppApiException(500, '不存在的id') + else: + credential, model_credential, provider_handler = ModelSerializer.Edit( + data={**instance, 'user_id': user_id}).is_valid( + model=model) + try: + model.status = Status.SUCCESS + # 校验模型认证数据 + provider_handler.is_valid_credential(model.model_type, + instance.get("model_name"), + credential, + raise_exception=True) + + except AppApiException as e: + if e.code == ValidCode.model_not_fount: + model.status = Status.DOWNLOAD + else: + raise e + update_keys = ['credential', 'name', 'model_type', 'model_name', 'permission_type'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + if update_key == 'credential': + model_credential_str = json.dumps(credential) + model.__setattr__(update_key, rsa_long_encrypt(model_credential_str)) + else: + model.__setattr__(update_key, instance.get(update_key)) + # 修改模型时候删除缓存 + ModelManage.delete_key(str(model.id)) + model.save() + if model.status == Status.DOWNLOAD: + thread = threading.Thread(target=ModelPullManage.pull, args=(model, credential)) + thread.start() + return self.one(with_valid=False) + + +class ProviderSerializer(serializers.Serializer): + provider = serializers.CharField(required=True, error_messages=ErrMessage.char("供应商")) + + method = serializers.CharField(required=True, error_messages=ErrMessage.char("执行函数名称")) + + def exec(self, exec_params: Dict[str, object], with_valid=False): + if with_valid: + self.is_valid(raise_exception=True) + + provider = self.data.get('provider') + method = self.data.get('method') + return getattr(ModelProvideConstants[provider].value, method)(exec_params) diff --git a/src/MaxKB-1.7.2/apps/setting/serializers/system_setting.py b/src/MaxKB-1.7.2/apps/setting/serializers/system_setting.py new file mode 100644 index 0000000..a66b158 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/serializers/system_setting.py @@ -0,0 +1,67 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: system_setting.py + @date:2024/3/19 16:29 + @desc: +""" +from django.core.mail.backends.smtp import EmailBackend +from django.db.models import QuerySet +from rest_framework import serializers + +from common.exception.app_exception import AppApiException +from common.util.field_message import ErrMessage +from setting.models.system_management import SystemSetting, SettingType + + +class SystemSettingSerializer(serializers.Serializer): + class EmailSerializer(serializers.Serializer): + @staticmethod + def one(): + system_setting = QuerySet(SystemSetting).filter(type=SettingType.EMAIL.value).first() + if system_setting is None: + return {} + return system_setting.meta + + class Create(serializers.Serializer): + email_host = serializers.CharField(required=True, error_messages=ErrMessage.char("SMTP 主机")) + email_port = serializers.IntegerField(required=True, error_messages=ErrMessage.char("SMTP 端口")) + email_host_user = serializers.CharField(required=True, error_messages=ErrMessage.char("发件人邮箱")) + email_host_password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码")) + email_use_tls = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否开启TLS")) + email_use_ssl = serializers.BooleanField(required=True, error_messages=ErrMessage.char("是否开启SSL")) + from_email = serializers.EmailField(required=True, error_messages=ErrMessage.char("发送人邮箱")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + try: + EmailBackend(self.data.get("email_host"), + self.data.get("email_port"), + self.data.get("email_host_user"), + self.data.get("email_host_password"), + self.data.get("email_use_tls"), + False, + self.data.get("email_use_ssl") + ).open() + except Exception as e: + raise AppApiException(1004, "邮箱校验失败") + + def update_or_save(self): + self.is_valid(raise_exception=True) + system_setting = QuerySet(SystemSetting).filter(type=SettingType.EMAIL.value).first() + if system_setting is None: + system_setting = SystemSetting(type=SettingType.EMAIL.value) + system_setting.meta = self.to_email_meta() + system_setting.save() + return system_setting.meta + + def to_email_meta(self): + return {'email_host': self.data.get('email_host'), + 'email_port': self.data.get('email_port'), + 'email_host_user': self.data.get('email_host_user'), + 'email_host_password': self.data.get('email_host_password'), + 'email_use_tls': self.data.get('email_use_tls'), + 'email_use_ssl': self.data.get('email_use_ssl'), + 'from_email': self.data.get('from_email') + } diff --git a/src/MaxKB-1.7.2/apps/setting/serializers/team_serializers.py b/src/MaxKB-1.7.2/apps/setting/serializers/team_serializers.py new file mode 100644 index 0000000..46266bb --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/serializers/team_serializers.py @@ -0,0 +1,320 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: team_serializers.py + @date:2023/9/5 16:32 + @desc: +""" +import itertools +import json +import os +import uuid +from typing import Dict, List + +from django.core import cache +from django.db import transaction +from django.db.models import QuerySet, Q +from drf_yasg import openapi +from rest_framework import serializers + +from common.constants.permission_constants import Operate +from common.db.sql_execute import select_list +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from common.response.result import get_api_response +from common.util.field_message import ErrMessage +from common.util.file_util import get_file_content +from setting.models import TeamMember, TeamMemberPermission, Team +from smartdoc.conf import PROJECT_DIR +from users.models.user import User +from users.serializers.user_serializers import UserSerializer + +user_cache = cache.caches['user_cache'] + + +def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'username', 'email', 'role', 'is_active', 'team_id', 'member_id'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'role': openapi.Schema(type=openapi.TYPE_STRING, title="角色", description="角色"), + 'is_active': openapi.Schema(type=openapi.TYPE_STRING, title="是否可用", description="是否可用"), + 'team_id': openapi.Schema(type=openapi.TYPE_STRING, title="团队id", description="团队id"), + 'member_id': openapi.Schema(type=openapi.TYPE_STRING, title="成员id", description="成员id"), + } + ) + + +class TeamMemberPermissionOperate(ApiMixin, serializers.Serializer): + USE = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("使用")) + MANAGE = serializers.BooleanField(required=True, error_messages=ErrMessage.boolean("管理")) + + def get_request_body_api(self): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="类型", + description="操作权限USE,MANAGE权限", + properties={ + 'USE': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title="使用权限", + description="使用权限 True|False"), + 'MANAGE': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title="管理权限", + description="管理权限 True|False") + } + ) + + +class UpdateTeamMemberItemPermissionSerializer(ApiMixin, serializers.Serializer): + target_id = serializers.CharField(required=True, error_messages=ErrMessage.char("目标id")) + type = serializers.CharField(required=True, error_messages=ErrMessage.char("目标类型")) + operate = TeamMemberPermissionOperate(required=True, many=False) + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'type', 'operate'], + properties={ + 'target_id': openapi.Schema(type=openapi.TYPE_STRING, title="知识库/应用id", + description="知识库或者应用的id"), + 'type': openapi.Schema(type=openapi.TYPE_STRING, + title="类型", + description="DATASET|APPLICATION", + ), + 'operate': TeamMemberPermissionOperate().get_request_body_api() + } + ) + + +class UpdateTeamMemberPermissionSerializer(ApiMixin, serializers.Serializer): + team_member_permission_list = UpdateTeamMemberItemPermissionSerializer(required=True, many=True) + + def is_valid(self, *, user_id=None): + super().is_valid(raise_exception=True) + permission_list = self.data.get("team_member_permission_list") + illegal_target_id_list = select_list( + get_file_content( + os.path.join(PROJECT_DIR, "apps", "setting", 'sql', 'check_member_permission_target_exists.sql')), + [json.dumps(permission_list), user_id, user_id]) + if illegal_target_id_list is not None and len(illegal_target_id_list) > 0: + raise AppApiException(500, '不存在的 应用|知识库id[' + str(illegal_target_id_list) + ']') + + def update_or_save(self, member_id: str): + team_member_permission_list = self.data.get("team_member_permission_list") + # 获取数据库已有权限 从而判断是否是插入还是更新 + team_member_permission_exist_list = QuerySet(TeamMemberPermission).filter( + member_id=member_id) + update_list = [] + save_list = [] + for item in team_member_permission_list: + exist_list = list( + filter(lambda use: str(use.target) == item.get('target_id'), team_member_permission_exist_list)) + if len(exist_list) > 0: + exist_list[0].operate = list( + filter(lambda key: item.get('operate').get(key), + item.get('operate').keys())) + update_list.append(exist_list[0]) + else: + save_list.append(TeamMemberPermission(target=item.get('target_id'), auth_target_type=item.get('type'), + operate=list( + filter(lambda key: item.get('operate').get(key), + item.get('operate').keys())), + member_id=member_id)) + # 批量更新 + QuerySet(TeamMemberPermission).bulk_update(update_list, ['operate']) + # 批量插入 + QuerySet(TeamMemberPermission).bulk_create(save_list) + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id'], + properties={ + 'team_member_permission_list': + openapi.Schema(type=openapi.TYPE_ARRAY, title="权限数据", + description="权限数据", + items=UpdateTeamMemberItemPermissionSerializer().get_request_body_api() + ), + } + ) + + +class TeamMemberSerializer(ApiMixin, serializers.Serializer): + team_id = serializers.UUIDField(required=True, error_messages=ErrMessage.uuid("团队id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + + @staticmethod + def get_bach_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_ARRAY, + title="用户id列表", + description="用户id列表", + items=openapi.Schema(type=openapi.TYPE_STRING) + ) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['username_or_email'], + properties={ + 'username_or_email': openapi.Schema(type=openapi.TYPE_STRING, title="用户名或者邮箱", + description="用户名或者邮箱"), + + } + ) + + @transaction.atomic + def batch_add_member(self, user_id_list: List[str], with_valid=True): + """ + 批量添加成员 + :param user_id_list: 用户id列表 + :param with_valid: 是否校验 + :return: 成员列表 + """ + if with_valid: + self.is_valid(raise_exception=True) + use_user_id_list = [str(u.id) for u in QuerySet(User).filter(id__in=user_id_list)] + + team_member_user_id_list = [str(team_member.user_id) for team_member in + QuerySet(TeamMember).filter(team_id=self.data.get('team_id'))] + team_id = self.data.get("team_id") + create_team_member_list = [ + self.to_member_model(add_user_id, team_member_user_id_list, use_user_id_list, team_id) for add_user_id in + user_id_list] + QuerySet(TeamMember).bulk_create(create_team_member_list) if len(create_team_member_list) > 0 else None + return TeamMemberSerializer( + data={'team_id': self.data.get("team_id")}).list_member() + + def to_member_model(self, add_user_id, team_member_user_id_list, use_user_id_list, user_id): + if use_user_id_list.__contains__(add_user_id): + if team_member_user_id_list.__contains__(add_user_id) or user_id == add_user_id: + raise AppApiException(500, "团队中已存在当前成员,不要重复添加") + else: + return TeamMember(team_id=self.data.get("team_id"), user_id=add_user_id) + else: + raise AppApiException(500, "不存在的用户") + + def add_member(self, username_or_email: str, with_valid=True): + """ + 添加一个成员 + :param with_valid: 是否校驗參數 + :param username_or_email: 添加成员的邮箱或者用户名 + :return: 成员列表 + """ + if with_valid: + self.is_valid(raise_exception=True) + if username_or_email is None: + raise AppApiException(500, "用户名或者邮箱必填") + user = QuerySet(User).filter( + Q(username=username_or_email) | Q(email=username_or_email)).first() + if user is None: + raise AppApiException(500, "不存在的用户") + if QuerySet(TeamMember).filter(Q(team_id=self.data.get('team_id')) & Q(user=user)).exists() or self.data.get( + "team_id") == str(user.id): + raise AppApiException(500, "团队中已存在当前成员,不要重复添加") + TeamMember(team_id=self.data.get("team_id"), user=user).save() + return self.list_member(with_valid=False) + + def list_member(self, with_valid=True): + """ + 获取 团队中的成员列表 + :return: 成员列表 + """ + if with_valid: + self.is_valid(raise_exception=True) + # 普通成員列表 + member_list = list(map(lambda t: {"id": t.id, 'email': t.user.email, 'username': t.user.username, + 'team_id': self.data.get("team_id"), 'user_id': t.user.id, + 'type': 'member'}, + QuerySet(TeamMember).filter(team_id=self.data.get("team_id")))) + # 管理員成員 + manage_member = QuerySet(User).get(id=self.data.get('team_id')) + return [{'id': 'root', 'email': manage_member.email, 'username': manage_member.username, + 'team_id': self.data.get("team_id"), 'user_id': manage_member.id, 'type': 'manage' + }, *member_list] + + def get_response_body_api(self): + return get_api_response(openapi.Schema( + type=openapi.TYPE_ARRAY, title="成员列表", description="成员列表", + items=UserSerializer().get_response_body_api() + )) + + class Operate(ApiMixin, serializers.Serializer): + # 团队 成员id + member_id = serializers.CharField(required=True, error_messages=ErrMessage.char("成员id")) + # 团队id + team_id = serializers.CharField(required=True, error_messages=ErrMessage.char("团队id")) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + if self.data.get('member_id') != 'root' and not QuerySet(TeamMember).filter( + team_id=self.data.get('team_id'), + id=self.data.get('member_id')).exists(): + raise AppApiException(500, "不存在的成员,请先添加成员") + + return True + + def list_member_permission(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + team_id = self.data.get('team_id') + member_id = self.data.get("member_id") + # 查询当前团队成员所有的知识库和应用的权限 注意 operate为null是为设置权限 默认值都是false + member_permission_list = select_list( + get_file_content(os.path.join(PROJECT_DIR, "apps", "setting", 'sql', 'get_member_permission.sql')), + [team_id, team_id, (member_id if member_id != 'root' else uuid.uuid1())]) + + # 如果是管理员 则拥有所有权限 默认赋值 + if member_id == 'root': + member_permission_list = list( + map(lambda row: {**row, 'operate': {Operate.USE.value: True, Operate.MANAGE.value: True}}, + member_permission_list)) + # 分为 APPLICATION DATASET俩组 + groups = itertools.groupby( + sorted(list(map(lambda m: {**m, 'member_id': member_id, + 'operate': dict( + map(lambda key: (key, True if m.get('operate') is not None and m.get( + 'operate').__contains__(key) else False), + [Operate.USE.value, Operate.MANAGE.value]))}, + member_permission_list)), key=lambda x: x.get('type')), + key=lambda x: x.get('type')) + return dict([(key, list(group)) for key, group in groups]) + + def edit(self, member_permission: Dict): + self.is_valid(raise_exception=True) + member_id = self.data.get("member_id") + if member_id == 'root': + raise AppApiException(500, "管理员权限不允许修改") + s = UpdateTeamMemberPermissionSerializer(data=member_permission) + s.is_valid(user_id=self.data.get("team_id")) + s.update_or_save(member_id) + return self.list_member_permission(with_valid=False) + + def delete(self): + """ + 移除成员 + :return: + """ + self.is_valid(raise_exception=True) + member_id = self.data.get("member_id") + if member_id == 'root': + raise AppApiException(500, "无法移除团队管理员") + # 删除成员权限 + QuerySet(TeamMemberPermission).filter(member_id=member_id).delete() + # 删除成员 + QuerySet(TeamMember).filter(id=member_id).delete() + return True + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='member_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='团队成员id')] diff --git a/src/MaxKB-1.7.2/apps/setting/serializers/valid_serializers.py b/src/MaxKB-1.7.2/apps/setting/serializers/valid_serializers.py new file mode 100644 index 0000000..ee73152 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/serializers/valid_serializers.py @@ -0,0 +1,51 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: valid_serializers.py + @date:2024/7/8 18:00 + @desc: +""" +import re + +from django.core import validators +from django.db.models import QuerySet +from rest_framework import serializers + +from application.models import Application +from common.exception.app_exception import AppApiException +from common.models.db_model_manage import DBModelManage +from common.util.field_message import ErrMessage +from dataset.models import DataSet +from users.models import User + +model_message_dict = { + 'dataset': {'model': DataSet, 'count': 50, + 'message': '社区版最多支持 50 个知识库,如需拥有更多知识库,请联系我们(https://fit2cloud.com/)。'}, + 'application': {'model': Application, 'count': 5, + 'message': '社区版最多支持 5 个应用,如需拥有更多应用,请联系我们(https://fit2cloud.com/)。'}, + 'user': {'model': User, 'count': 2, + 'message': '社区版最多支持 2 个用户,如需拥有更多用户,请联系我们(https://fit2cloud.com/)。'} +} + + +class ValidSerializer(serializers.Serializer): + valid_type = serializers.CharField(required=True, error_messages=ErrMessage.char("类型"), validators=[ + validators.RegexValidator(regex=re.compile("^application|dataset|user$"), + message="类型只支持:application|dataset|user", code=500) + ]) + valid_count = serializers.IntegerField(required=True, error_messages=ErrMessage.integer("校验数量")) + + def valid(self, is_valid=True): + if is_valid: + self.is_valid(raise_exception=True) + model_value = model_message_dict.get(self.data.get('valid_type')) + xpack_cache = DBModelManage.get_model('xpack_cache') + is_license_valid = xpack_cache.get('XPACK_LICENSE_IS_VALID', False) if xpack_cache is not None else False + if not is_license_valid: + if self.data.get('valid_count') != model_value.get('count'): + raise AppApiException(400, model_value.get('message')) + if QuerySet( + model_value.get('model')).count() >= model_value.get('count'): + raise AppApiException(400, model_value.get('message')) + return True diff --git a/src/MaxKB-1.7.2/apps/setting/sql/check_member_permission_target_exists.sql b/src/MaxKB-1.7.2/apps/setting/sql/check_member_permission_target_exists.sql new file mode 100644 index 0000000..13c1aaa --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/sql/check_member_permission_target_exists.sql @@ -0,0 +1,32 @@ +SELECT + static_temp."target_id"::text +FROM + (SELECT * FROM json_to_recordset( + %s + ) AS x(target_id uuid,type text)) static_temp + LEFT JOIN ( + SELECT + "id", + 'DATASET' AS "type", + user_id, + ARRAY [ 'MANAGE', + 'USE', + 'DELETE' ] AS "operate" + FROM + dataset + WHERE + "user_id" = %s UNION + SELECT + "id", + 'APPLICATION' AS "type", + user_id, + ARRAY [ 'MANAGE', + 'USE', + 'DELETE' ] AS "operate" + FROM + application + WHERE + "user_id" = %s + ) "app_and_dataset_temp" + ON "app_and_dataset_temp"."id" = static_temp."target_id" and app_and_dataset_temp."type"=static_temp."type" + WHERE app_and_dataset_temp.id is NULL ; \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/setting/sql/get_member_permission.sql b/src/MaxKB-1.7.2/apps/setting/sql/get_member_permission.sql new file mode 100644 index 0000000..f6b2d95 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/sql/get_member_permission.sql @@ -0,0 +1,26 @@ +SELECT + app_or_dataset.*, + team_member_permission.member_id, + team_member_permission.operate +FROM + ( + SELECT + "id", + "name", + 'DATASET' AS "type", + user_id + FROM + dataset + WHERE + "user_id" = %s UNION + SELECT + "id", + "name", + 'APPLICATION' AS "type", + user_id + FROM + application + WHERE + "user_id" = %s + ) app_or_dataset + LEFT JOIN ( SELECT * FROM team_member_permission WHERE member_id = %s ) team_member_permission ON team_member_permission.target = app_or_dataset."id" \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/setting/sql/get_user_permission.sql b/src/MaxKB-1.7.2/apps/setting/sql/get_user_permission.sql new file mode 100644 index 0000000..c50e5ea --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/sql/get_user_permission.sql @@ -0,0 +1,30 @@ +SELECT + "id", + 'DATASET' AS "type", + user_id, + ARRAY [ 'MANAGE', + 'USE','DELETE' ] AS "operate" +FROM + dataset +WHERE + "user_id" = %s UNION +SELECT + "id", + 'APPLICATION' AS "type", + user_id, + ARRAY [ 'MANAGE', + 'USE','DELETE' ] AS "operate" +FROM + application +WHERE + "user_id" = %s UNION +SELECT + team_member_permission.target AS "id", + team_member_permission.auth_target_type AS "type", + team_member.user_id AS user_id, + team_member_permission.operate AS "operate" +FROM + team_member team_member + LEFT JOIN team_member_permission team_member_permission ON team_member.ID = team_member_permission.member_id +WHERE + team_member.user_id = %s AND team_member_permission.target IS NOT NULL \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/setting/swagger_api/provide_api.py b/src/MaxKB-1.7.2/apps/setting/swagger_api/provide_api.py new file mode 100644 index 0000000..7544fdf --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/swagger_api/provide_api.py @@ -0,0 +1,188 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: provide_api.py + @date:2023/11/2 14:25 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ModelQueryApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='模型名称'), + openapi.Parameter(name='model_type', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='模型类型'), + openapi.Parameter(name='model_name', in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='基础模型名称'), + openapi.Parameter(name='provider', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='供应名称') + ] + + +class ModelEditApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="调用函数所需要的参数", + description="调用函数所需要的参数", + required=['provide', 'model_info'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, + title="模型名称", + description="模型名称"), + 'model_type': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'model_name': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'credential': openapi.Schema(type=openapi.TYPE_OBJECT, + title="模型证书信息", + description="模型证书信息") + } + ) + + +class ModelCreateApi(ApiMixin): + + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="调用函数所需要的参数", + description="调用函数所需要的参数", + required=['provide', 'model_info'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, + title="模型名称", + description="模型名称"), + 'provider': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'permission_type': openapi.Schema(type=openapi.TYPE_STRING, title="权限", + description="PUBLIC|PRIVATE"), + 'model_type': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'model_name': openapi.Schema(type=openapi.TYPE_STRING, + title="供应商", + description="供应商"), + 'credential': openapi.Schema(type=openapi.TYPE_OBJECT, + title="模型证书信息", + description="模型证书信息"), + + } + ) + + +class ProvideApi(ApiMixin): + class ModelTypeList(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='provider', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='供应名称'), + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['key', 'value'], + properties={ + 'key': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型描述", + description="模型类型描述", default="大语言模型"), + 'value': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型值", + description="模型类型值", default="LLM"), + + } + ) + + class ModelList(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='provider', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='供应名称'), + openapi.Parameter(name='model_type', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='模型类型'), + ] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['name', 'desc', 'model_type'], + properties={ + 'name': openapi.Schema(type=openapi.TYPE_STRING, title="模型名称", + description="模型名称", default="模型名称"), + 'desc': openapi.Schema(type=openapi.TYPE_STRING, title="模型描述", + description="模型描述", default="xxx模型"), + 'model_type': openapi.Schema(type=openapi.TYPE_STRING, title="模型类型值", + description="模型类型值", default="LLM"), + + } + ) + + class ModelForm(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='provider', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='供应名称'), + openapi.Parameter(name='model_type', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='模型类型'), + openapi.Parameter(name='model_name', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='模型名称'), + ] + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='provider', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='供应商'), + openapi.Parameter(name='method', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='需要执行的函数'), + ] + + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="调用函数所需要的参数", + description="调用函数所需要的参数", + ) diff --git a/src/MaxKB-1.7.2/apps/setting/swagger_api/system_setting.py b/src/MaxKB-1.7.2/apps/setting/swagger_api/system_setting.py new file mode 100644 index 0000000..1246ff2 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/swagger_api/system_setting.py @@ -0,0 +1,77 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: system_setting.py + @date:2024/3/19 16:05 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class SystemSettingEmailApi(ApiMixin): + @staticmethod + def get_request_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="邮箱相关参数", + description="邮箱相关参数", + required=['email_host', 'email_port', 'email_host_user', 'email_host_password', + 'email_use_tls', 'email_use_ssl', 'from_email'], + properties={ + 'email_host': openapi.Schema(type=openapi.TYPE_STRING, + title="SMTP 主机", + description="SMTP 主机"), + 'email_port': openapi.Schema(type=openapi.TYPE_NUMBER, + title="SMTP 端口", + description="SMTP 端口"), + 'email_host_user': openapi.Schema(type=openapi.TYPE_STRING, + title="发件人邮箱", + description="发件人邮箱"), + 'email_host_password': openapi.Schema(type=openapi.TYPE_STRING, + title="密码", + description="密码"), + 'email_use_tls': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title="是否开启TLS", + description="是否开启TLS"), + 'email_use_ssl': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title="是否开启SSL", + description="是否开启SSL"), + 'from_email': openapi.Schema(type=openapi.TYPE_STRING, + title="发送人邮箱", + description="发送人邮箱") + } + ) + + @staticmethod + def get_response_body_api(): + return openapi.Schema(type=openapi.TYPE_OBJECT, + title="邮箱相关参数", + description="邮箱相关参数", + required=['email_host', 'email_port', 'email_host_user', 'email_host_password', + 'email_use_tls', 'email_use_ssl', 'from_email'], + properties={ + 'email_host': openapi.Schema(type=openapi.TYPE_STRING, + title="SMTP 主机", + description="SMTP 主机"), + 'email_port': openapi.Schema(type=openapi.TYPE_NUMBER, + title="SMTP 端口", + description="SMTP 端口"), + 'email_host_user': openapi.Schema(type=openapi.TYPE_STRING, + title="发件人邮箱", + description="发件人邮箱"), + 'email_host_password': openapi.Schema(type=openapi.TYPE_STRING, + title="密码", + description="密码"), + 'email_use_tls': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title="是否开启TLS", + description="是否开启TLS"), + 'email_use_ssl': openapi.Schema(type=openapi.TYPE_BOOLEAN, + title="是否开启SSL", + description="是否开启SSL"), + 'from_email': openapi.Schema(type=openapi.TYPE_STRING, + title="发送人邮箱", + description="发送人邮箱") + } + ) diff --git a/src/MaxKB-1.7.2/apps/setting/swagger_api/valid_api.py b/src/MaxKB-1.7.2/apps/setting/swagger_api/valid_api.py new file mode 100644 index 0000000..4fad9e8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/swagger_api/valid_api.py @@ -0,0 +1,27 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: valid_api.py + @date:2024/7/8 17:52 + @desc: +""" +from drf_yasg import openapi + +from common.mixins.api_mixin import ApiMixin + + +class ValidApi(ApiMixin): + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='valid_type', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='校验类型:application|dataset|user'), + openapi.Parameter(name='valid_count', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='校验数量') + ] diff --git a/src/MaxKB-1.7.2/apps/setting/tests.py b/src/MaxKB-1.7.2/apps/setting/tests.py new file mode 100644 index 0000000..7ce503c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/tests.py @@ -0,0 +1,3 @@ +from django.test import TestCase + +# Create your tests here. diff --git a/src/MaxKB-1.7.2/apps/setting/urls.py b/src/MaxKB-1.7.2/apps/setting/urls.py new file mode 100644 index 0000000..42e8059 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/urls.py @@ -0,0 +1,37 @@ +import os + +from django.urls import path + +from . import views + +app_name = "team" +urlpatterns = [ + path('team/member', views.TeamMember.as_view(), name="team"), + path('team/member/_batch', views.TeamMember.Batch.as_view()), + path('team/member/', views.TeamMember.Operate.as_view(), name='member'), + path('provider//', views.Provide.Exec.as_view(), name='provide_exec'), + path('provider', views.Provide.as_view(), name='provide'), + path('provider/model_type_list', views.Provide.ModelTypeList.as_view(), name="provider/model_type_list"), + path('provider/model_list', views.Provide.ModelList.as_view(), + name="provider/model_name_list"), + path('provider/model_form', views.Provide.ModelForm.as_view(), + name="provider/model_form"), + path('model', views.Model.as_view(), name='model'), + path('model//model_params_form', views.Model.ModelParamsForm.as_view(), + name='model/model_params_form'), + path('model/', views.Model.Operate.as_view(), name='model/operate'), + path('model//pause_download', views.Model.PauseDownload.as_view(), name='model/operate'), + path('model//meta', views.Model.ModelMeta.as_view(), name='model/operate/meta'), + path('email_setting', views.SystemSetting.Email.as_view(), name='email_setting'), + path('valid//', views.Valid.as_view()) + +] +if os.environ.get('SERVER_NAME', 'web') == 'local_model': + urlpatterns += [ + path('model//embed_documents', views.ModelApply.EmbedDocuments.as_view(), + name='model/embed_documents'), + path('model//embed_query', views.ModelApply.EmbedQuery.as_view(), + name='model/embed_query'), + path('model//compress_documents', views.ModelApply.CompressDocuments.as_view(), + name='model/embed_query'), + ] diff --git a/src/MaxKB-1.7.2/apps/setting/views/Team.py b/src/MaxKB-1.7.2/apps/setting/views/Team.py new file mode 100644 index 0000000..71710e3 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/views/Team.py @@ -0,0 +1,90 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: Team.py + @date:2023/9/25 17:13 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import PermissionConstants +from common.response import result +from setting.serializers.team_serializers import TeamMemberSerializer, get_response_body_api, \ + UpdateTeamMemberPermissionSerializer + + +class TeamMember(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取团队成员列表", + operation_id="获取团员成员列表", + responses=result.get_api_response(get_response_body_api()), + tags=["团队"]) + @has_permissions(PermissionConstants.TEAM_READ) + def get(self, request: Request): + return result.success(TeamMemberSerializer(data={'team_id': str(request.user.id)}).list_member()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="添加成员", + operation_id="添加成员", + request_body=TeamMemberSerializer().get_request_body_api(), + tags=["团队"]) + @has_permissions(PermissionConstants.TEAM_CREATE) + def post(self, request: Request): + team = TeamMemberSerializer(data={'team_id': str(request.user.id)}) + return result.success((team.add_member(**request.data))) + + class Batch(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="批量添加成员", + operation_id="批量添加成员", + request_body=TeamMemberSerializer.get_bach_request_body_api(), + tags=["团队"]) + @has_permissions(PermissionConstants.TEAM_CREATE) + def post(self, request: Request): + return result.success( + TeamMemberSerializer(data={'team_id': request.user.id}).batch_add_member(request.data)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取团队成员权限", + operation_id="获取团队成员权限", + manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(), + tags=["团队"]) + @has_permissions(PermissionConstants.TEAM_READ) + def get(self, request: Request, member_id: str): + return result.success(TeamMemberSerializer.Operate( + data={'member_id': member_id, 'team_id': str(request.user.id)}).list_member_permission()) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改团队成员权限", + operation_id="修改团队成员权限", + request_body=UpdateTeamMemberPermissionSerializer().get_request_body_api(), + manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(), + tags=["团队"] + ) + @has_permissions(PermissionConstants.TEAM_EDIT) + def put(self, request: Request, member_id: str): + return result.success(TeamMemberSerializer.Operate( + data={'member_id': member_id, 'team_id': str(request.user.id)}).edit(request.data)) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="移除成员", + operation_id="移除成员", + manual_parameters=TeamMemberSerializer.Operate.get_request_params_api(), + tags=["团队"] + ) + @has_permissions(PermissionConstants.TEAM_DELETE) + def delete(self, request: Request, member_id: str): + return result.success(TeamMemberSerializer.Operate( + data={'member_id': member_id, 'team_id': str(request.user.id)}).delete()) diff --git a/src/MaxKB-1.7.2/apps/setting/views/__init__.py b/src/MaxKB-1.7.2/apps/setting/views/__init__.py new file mode 100644 index 0000000..4fe5056 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/views/__init__.py @@ -0,0 +1,13 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: __init__.py.py + @date:2023/9/25 17:12 + @desc: +""" +from .Team import * +from .model import * +from .system_setting import * +from .valid import * +from .model_apply import * diff --git a/src/MaxKB-1.7.2/apps/setting/views/model.py b/src/MaxKB-1.7.2/apps/setting/views/model.py new file mode 100644 index 0000000..b5abf91 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/views/model.py @@ -0,0 +1,224 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: model.py + @date:2023/11/2 13:55 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import PermissionConstants +from common.response import result +from common.util.common import query_params_to_single_dict +from setting.models_provider.constants.model_provider_constants import ModelProvideConstants +from setting.serializers.provider_serializers import ProviderSerializer, ModelSerializer +from setting.swagger_api.provide_api import ProvideApi, ModelCreateApi, ModelQueryApi, ModelEditApi + + +class Model(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="创建模型", + operation_id="创建模型", + request_body=ModelCreateApi.get_request_body_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_CREATE) + def post(self, request: Request): + return result.success( + ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id, + with_valid=True)) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="下载模型,只试用与Ollama平台", + operation_id="下载模型,只试用与Ollama平台", + request_body=ModelCreateApi.get_request_body_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_CREATE) + def put(self, request: Request): + return result.success( + ModelSerializer.Create(data={**request.data, 'user_id': str(request.user.id)}).insert(request.user.id, + with_valid=True)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型列表", + operation_id="获取模型列表", + manual_parameters=ModelQueryApi.get_request_params_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + return result.success( + ModelSerializer.Query( + data={**query_params_to_single_dict(request.query_params), 'user_id': request.user.id}).list( + with_valid=True)) + + class ModelMeta(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="查询模型meta信息,该接口不携带认证信息", + operation_id="查询模型meta信息,该接口不携带认证信息", + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one_meta(with_valid=True)) + + class PauseDownload(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="暂停模型下载", + operation_id="暂停模型下载", + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_CREATE) + def put(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).pause_download()) + + class ModelParamsForm(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型参数表单", + operation_id="获取模型参数表单", + manual_parameters=ProvideApi.ModelForm.get_request_params_api(), + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request, model_id: str): + return result.success( + ModelSerializer.ModelParams(data={'id': model_id, 'user_id': request.user.id}).get_model_params()) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="保存模型参数表单", + operation_id="保存模型参数表单", + manual_parameters=ProvideApi.ModelForm.get_request_params_api(), + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def put(self, request: Request, model_id: str): + return result.success( + ModelSerializer.ModelParamsForm(data={'id': model_id, 'user_id': request.user.id}) + .save_model_params_form(request.data)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改模型", + operation_id="修改模型", + request_body=ModelEditApi.get_request_body_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_CREATE) + def put(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).edit(request.data, + str(request.user.id))) + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除模型", + operation_id="删除模型", + responses=result.get_default_response() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_DELETE) + def delete(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).delete()) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="查询模型详细信息", + operation_id="查询模型详细信息", + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request, model_id: str): + return result.success( + ModelSerializer.Operate(data={'id': model_id, 'user_id': request.user.id}).one(with_valid=True)) + + +class Provide(APIView): + authentication_classes = [TokenAuth] + + class Exec(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="调用供应商函数,获取表单数据", + operation_id="调用供应商函数,获取表单数据", + manual_parameters=ProvideApi.get_request_params_api(), + request_body=ProvideApi.get_request_body_api() + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def post(self, request: Request, provider: str, method: str): + return result.success( + ProviderSerializer(data={'provider': provider, 'method': method}).exec(request.data, with_valid=True)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型供应商数据", + operation_id="获取模型供应商列表" + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + model_type = request.query_params.get('model_type') + if model_type: + providers = [] + for key in ModelProvideConstants.__members__: + if len([item for item in ModelProvideConstants[key].value.get_model_type_list() if + item['value'] == model_type]) > 0: + providers.append(ModelProvideConstants[key].value.get_model_provide_info().to_dict()) + return result.success(providers) + return result.success( + [ModelProvideConstants[key].value.get_model_provide_info().to_dict() for key in + ModelProvideConstants.__members__]) + + class ModelTypeList(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型类型列表", + operation_id="获取模型类型类型列表", + manual_parameters=ProvideApi.ModelTypeList.get_request_params_api(), + responses=result.get_api_array_response(ProvideApi.ModelTypeList.get_response_body_api()) + , tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + provider = request.query_params.get('provider') + return result.success(ModelProvideConstants[provider].value.get_model_type_list()) + + class ModelList(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型列表", + operation_id="获取模型创建表单", + manual_parameters=ProvideApi.ModelList.get_request_params_api(), + responses=result.get_api_array_response(ProvideApi.ModelList.get_response_body_api()) + , tags=["模型"] + ) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + provider = request.query_params.get('provider') + model_type = request.query_params.get('model_type') + + return result.success( + ModelProvideConstants[provider].value.get_model_list( + model_type)) + + class ModelForm(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取模型创建表单", + operation_id="获取模型创建表单", + manual_parameters=ProvideApi.ModelForm.get_request_params_api(), + tags=["模型"]) + @has_permissions(PermissionConstants.MODEL_READ) + def get(self, request: Request): + provider = request.query_params.get('provider') + model_type = request.query_params.get('model_type') + model_name = request.query_params.get('model_name') + return result.success( + ModelProvideConstants[provider].value.get_model_credential(model_type, model_name).to_form_list()) diff --git a/src/MaxKB-1.7.2/apps/setting/views/model_apply.py b/src/MaxKB-1.7.2/apps/setting/views/model_apply.py new file mode 100644 index 0000000..6bd0b54 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/views/model_apply.py @@ -0,0 +1,48 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: model_apply.py + @date:2024/8/20 20:38 + @desc: +""" +from urllib.request import Request + +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.views import APIView + +from common.response import result +from setting.serializers.model_apply_serializers import ModelApplySerializers + + +class ModelApply(APIView): + class EmbedDocuments(APIView): + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="向量化文档", + operation_id="向量化文档", + responses=result.get_default_response(), + tags=["模型"]) + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).embed_documents(request.data)) + + class EmbedQuery(APIView): + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="向量化文档", + operation_id="向量化文档", + responses=result.get_default_response(), + tags=["模型"]) + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).embed_query(request.data)) + + class CompressDocuments(APIView): + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="重排序文档", + operation_id="重排序文档", + responses=result.get_default_response(), + tags=["模型"]) + def post(self, request: Request, model_id): + return result.success( + ModelApplySerializers(data={'model_id': model_id}).compress_documents(request.data)) diff --git a/src/MaxKB-1.7.2/apps/setting/views/system_setting.py b/src/MaxKB-1.7.2/apps/setting/views/system_setting.py new file mode 100644 index 0000000..e08a470 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/views/system_setting.py @@ -0,0 +1,57 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: system_setting.py + @date:2024/3/19 16:01 + @desc: +""" + +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.views import APIView + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import RoleConstants +from common.response import result +from setting.serializers.system_setting import SystemSettingSerializer +from setting.swagger_api.system_setting import SystemSettingEmailApi + + +class SystemSetting(APIView): + class Email(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="创建或者修改邮箱设置", + operation_id="创建或者修改邮箱设置", + request_body=SystemSettingEmailApi.get_request_body_api(), tags=["邮箱设置"], + responses=result.get_api_response(SystemSettingEmailApi.get_response_body_api())) + @has_permissions(RoleConstants.ADMIN) + def put(self, request: Request): + return result.success( + SystemSettingSerializer.EmailSerializer.Create( + data=request.data).update_or_save()) + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="测试邮箱设置", + operation_id="测试邮箱设置", + request_body=SystemSettingEmailApi.get_request_body_api(), + responses=result.get_default_response(), + tags=["邮箱设置"]) + @has_permissions(RoleConstants.ADMIN) + def post(self, request: Request): + return result.success( + SystemSettingSerializer.EmailSerializer.Create( + data=request.data).is_valid()) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取邮箱设置", + operation_id="获取邮箱设置", + responses=result.get_api_response(SystemSettingEmailApi.get_response_body_api()), + tags=["邮箱设置"]) + @has_permissions(RoleConstants.ADMIN) + def get(self, request: Request): + return result.success( + SystemSettingSerializer.EmailSerializer.one()) diff --git a/src/MaxKB-1.7.2/apps/setting/views/valid.py b/src/MaxKB-1.7.2/apps/setting/views/valid.py new file mode 100644 index 0000000..f88c589 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/setting/views/valid.py @@ -0,0 +1,32 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: valid.py + @date:2024/7/8 17:50 + @desc: +""" +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.request import Request +from rest_framework.views import APIView + +from common.auth import TokenAuth, has_permissions +from common.constants.permission_constants import RoleConstants +from common.response import result +from setting.serializers.valid_serializers import ValidSerializer +from setting.swagger_api.valid_api import ValidApi + + +class Valid(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取校验结果", + operation_id="获取校验结果", + manual_parameters=ValidApi.get_request_params_api(), + responses=result.get_default_response() + , tags=["校验"]) + @has_permissions(RoleConstants.ADMIN, RoleConstants.USER) + def get(self, request: Request, valid_type: str, valid_count: int): + return result.success(ValidSerializer(data={'valid_type': valid_type, 'valid_count': valid_count}).valid()) diff --git a/src/MaxKB-1.7.2/apps/smartdoc/__init__.py b/src/MaxKB-1.7.2/apps/smartdoc/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/smartdoc/asgi.py b/src/MaxKB-1.7.2/apps/smartdoc/asgi.py new file mode 100644 index 0000000..e68e6ce --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/asgi.py @@ -0,0 +1,16 @@ +""" +ASGI config for apps project. + +It exposes the ASGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/4.2/howto/deployment/asgi/ +""" + +import os + +from django.core.asgi import get_asgi_application + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings') + +application = get_asgi_application() diff --git a/src/MaxKB-1.7.2/apps/smartdoc/conf.py b/src/MaxKB-1.7.2/apps/smartdoc/conf.py new file mode 100644 index 0000000..0349739 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/conf.py @@ -0,0 +1,225 @@ +# !/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +""" +配置分类: +1. Django使用的配置文件,写到settings中 +2. 程序需要, 用户不需要更改的写到settings中 +3. 程序需要, 用户需要更改的写到本config中 +""" +import errno +import logging +import os +import re +from importlib import import_module +from urllib.parse import urljoin, urlparse + +import yaml + +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +PROJECT_DIR = os.path.dirname(BASE_DIR) +logger = logging.getLogger('smartdoc.conf') + + +def import_string(dotted_path): + try: + module_path, class_name = dotted_path.rsplit('.', 1) + except ValueError as err: + raise ImportError("%s doesn't look like a module path" % dotted_path) from err + + module = import_module(module_path) + + try: + return getattr(module, class_name) + except AttributeError as err: + raise ImportError( + 'Module "%s" does not define a "%s" attribute/class' % + (module_path, class_name)) from err + + +def is_absolute_uri(uri): + """ 判断一个uri是否是绝对地址 """ + if not isinstance(uri, str): + return False + + result = re.match(r'^http[s]?://.*', uri) + if result is None: + return False + + return True + + +def build_absolute_uri(base, uri): + """ 构建绝对uri地址 """ + if uri is None: + return base + + if isinstance(uri, int): + uri = str(uri) + + if not isinstance(uri, str): + return base + + if is_absolute_uri(uri): + return uri + + parsed_base = urlparse(base) + url = "{}://{}".format(parsed_base.scheme, parsed_base.netloc) + path = '{}/{}/'.format(parsed_base.path.strip('/'), uri.strip('/')) + return urljoin(url, path) + + +class DoesNotExist(Exception): + pass + + +class Config(dict): + defaults = { + # 数据库相关配置 + "DB_HOST": "127.0.0.1", + "DB_PORT": 5432, + "DB_USER": "root", + "DB_PASSWORD": "Password123@postgres", + "DB_ENGINE": "django.db.backends.postgresql_psycopg2", + # 向量模型 + "EMBEDDING_MODEL_NAME": "shibing624/text2vec-base-chinese", + "EMBEDDING_DEVICE": "cpu", + "EMBEDDING_MODEL_PATH": os.path.join(PROJECT_DIR, 'models'), + # 向量库配置 + "VECTOR_STORE_NAME": 'pg_vector', + "DEBUG": False, + 'SANDBOX': False, + 'LOCAL_MODEL_HOST': '127.0.0.1', + 'LOCAL_MODEL_PORT': '11636', + 'LOCAL_MODEL_PROTOCOL': "http" + + } + + def get_debug(self) -> bool: + return self.get('DEBUG') if 'DEBUG' in self else True + + def get_time_zone(self) -> str: + return self.get('TIME_ZONE') if 'TIME_ZONE' in self else 'Asia/Shanghai' + + def get_db_setting(self) -> dict: + return { + "NAME": self.get('DB_NAME'), + "HOST": self.get('DB_HOST'), + "PORT": self.get('DB_PORT'), + "USER": self.get('DB_USER'), + "PASSWORD": self.get('DB_PASSWORD'), + "ENGINE": self.get('DB_ENGINE') + } + + def __init__(self, *args): + super().__init__(*args) + + def __repr__(self): + return '<%s %s>' % (self.__class__.__name__, dict.__repr__(self)) + + def __getitem__(self, item): + return self.get(item) + + def __getattr__(self, item): + return self.get(item) + + +class ConfigManager: + config_class = Config + + def __init__(self, root_path=None): + self.root_path = root_path + self.config = self.config_class() + for key in self.config_class.defaults: + self.config[key] = self.config_class.defaults[key] + + def from_mapping(self, *mapping, **kwargs): + """Updates the config like :meth:`update` ignoring items with non-upper + keys. + + .. versionadded:: 0.11 + """ + mappings = [] + if len(mapping) == 1: + if hasattr(mapping[0], 'items'): + mappings.append(mapping[0].items()) + else: + mappings.append(mapping[0]) + elif len(mapping) > 1: + raise TypeError( + 'expected at most 1 positional argument, got %d' % len(mapping) + ) + mappings.append(kwargs.items()) + for mapping in mappings: + for (key, value) in mapping: + if key.isupper(): + self.config[key] = value + return True + + def from_yaml(self, filename, silent=False): + if self.root_path: + filename = os.path.join(self.root_path, filename) + try: + with open(filename, 'rt', encoding='utf8') as f: + obj = yaml.safe_load(f) + except IOError as e: + if silent and e.errno in (errno.ENOENT, errno.EISDIR): + return False + e.strerror = 'Unable to load configuration file (%s)' % e.strerror + raise + if obj: + return self.from_mapping(obj) + return True + + def load_from_yml(self): + for i in ['config_example.yml', 'config.yaml', 'config.yml']: + if not os.path.isfile(os.path.join(self.root_path, i)): + continue + loaded = self.from_yaml(i) + if loaded: + return True + msg = f""" + + Error: No config file found. + + You can run `cp config_example.yml {self.root_path}/config.yml`, and edit it. + + """ + raise ImportError(msg) + + def load_from_env(self): + keys = os.environ.keys() + config = {key.replace('MAXKB_', ''): os.environ.get(key) for key in keys if key.startswith('MAXKB_')} + if len(config.keys()) <= 1: + msg = f""" + + Error: No config env found. + + Please set environment variables + MAXKB_CONFIG_TYPE: 配置文件读取方式 FILE: 使用配置文件配置 ENV: 使用ENV配置 + MAXKB_DB_NAME: 数据库名称 + MAXKB_DB_HOST: 数据库主机 + MAXKB_DB_PORT: 数据库端口 + MAXKB_DB_USER: 数据库用户名 + MAXKB_DB_PASSWORD: 数据库密码 + MAXKB_EMBEDDING_MODEL_PATH: 向量模型目录 + MAXKB_EMBEDDING_MODEL_NAME: 向量模型名称 + """ + raise ImportError(msg) + self.from_mapping(config) + return True + + @classmethod + def load_user_config(cls, root_path=None, config_class=None): + config_class = config_class or Config + cls.config_class = config_class + if not root_path: + root_path = PROJECT_DIR + manager = cls(root_path=root_path) + config_type = os.environ.get('MAXKB_CONFIG_TYPE') + if config_type is None or config_type != 'ENV': + manager.load_from_yml() + else: + manager.load_from_env() + config = manager.config + return config diff --git a/src/MaxKB-1.7.2/apps/smartdoc/const.py b/src/MaxKB-1.7.2/apps/smartdoc/const.py new file mode 100644 index 0000000..9b1159a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/const.py @@ -0,0 +1,12 @@ +# -*- coding: utf-8 -*- +# +import os + +from .conf import ConfigManager + +__all__ = ['BASE_DIR', 'PROJECT_DIR', 'VERSION', 'CONFIG'] + +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +PROJECT_DIR = os.path.dirname(BASE_DIR) +VERSION = '1.0.0' +CONFIG = ConfigManager.load_user_config(root_path=os.path.abspath('/opt/maxkb/conf')) diff --git a/src/MaxKB-1.7.2/apps/smartdoc/settings/__init__.py b/src/MaxKB-1.7.2/apps/smartdoc/settings/__init__.py new file mode 100644 index 0000000..4e7ea78 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/settings/__init__.py @@ -0,0 +1,12 @@ +# coding=utf-8 +""" + @project: smart-doc + @Author:虎 + @file: __init__.py + @date:2023/9/14 15:45 + @desc: +""" +from .base import * +from .logging import * +from .auth import * +from .lib import * diff --git a/src/MaxKB-1.7.2/apps/smartdoc/settings/auth.py b/src/MaxKB-1.7.2/apps/smartdoc/settings/auth.py new file mode 100644 index 0000000..077f98b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/settings/auth.py @@ -0,0 +1,19 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: auth.py + @date:2024/7/9 18:47 + @desc: +""" +USER_TOKEN_AUTH = 'common.auth.handle.impl.user_token.UserToken' + +PUBLIC_ACCESS_TOKEN_AUTH = 'common.auth.handle.impl.public_access_token.PublicAccessToken' + +APPLICATION_KEY_AUTH = 'common.auth.handle.impl.application_key.ApplicationKey' + +AUTH_HANDLES = [ + USER_TOKEN_AUTH, + PUBLIC_ACCESS_TOKEN_AUTH, + APPLICATION_KEY_AUTH +] diff --git a/src/MaxKB-1.7.2/apps/smartdoc/settings/base.py b/src/MaxKB-1.7.2/apps/smartdoc/settings/base.py new file mode 100644 index 0000000..785a3fe --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/settings/base.py @@ -0,0 +1,191 @@ +import datetime +import mimetypes +import os +from pathlib import Path + +from PIL import Image + +from ..const import CONFIG, PROJECT_DIR + +mimetypes.add_type("text/css", ".css", True) +mimetypes.add_type("text/javascript", ".js", True) +# Build paths inside the project like this: BASE_DIR / 'subdir'. +BASE_DIR = Path(__file__).resolve().parent.parent +Image.MAX_IMAGE_PIXELS = 20000000000 +# Quick-start development settings - unsuitable for production +# See https://docs.djangoproject.com/en/4.2/howto/deployment/checklist/ + +# SECURITY WARNING: keep the secret key used in production secret! +SECRET_KEY = 'django-insecure-g1u*$)1ddn20_3orw^f+g4(i(2dacj^awe*2vh-$icgqwfnbq(' +# SECURITY WARNING: don't run with debug turned on in production! +DEBUG = CONFIG.get_debug() + +ALLOWED_HOSTS = ['*'] + +DATABASES = { + 'default': CONFIG.get_db_setting() +} + +SECURE_PROXY_SSL_HEADER = ('HTTP_X_FORWARDED_PROTO', 'https') + +# Application definition + +INSTALLED_APPS = [ + 'users.apps.UsersConfig', + 'setting', + 'dataset', + 'application', + 'embedding', + 'django.contrib.contenttypes', + 'django.contrib.messages', + 'django.contrib.staticfiles', + 'rest_framework', + "drf_yasg", # swagger 接口 + 'django_filters', # 条件过滤 + 'django_apscheduler', + 'common', + 'function_lib', + 'django_celery_beat' + +] + +MIDDLEWARE = [ + 'django.middleware.security.SecurityMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.middleware.common.CommonMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + 'common.middleware.static_headers_middleware.StaticHeadersMiddleware', + 'common.middleware.cross_domain_middleware.CrossDomainMiddleware' + +] + +JWT_AUTH = { + 'JWT_EXPIRATION_DELTA': datetime.timedelta(seconds=60 * 60 * 2) # <-- 设置token有效时间 +} + +APPS_DIR = os.path.join(PROJECT_DIR, 'apps') +ROOT_URLCONF = 'smartdoc.urls' +# FORCE_SCRIPT_NAME +TEMPLATES = [ + { + 'BACKEND': 'django.template.backends.django.DjangoTemplates', + 'DIRS': ['apps/static/ui'], + 'APP_DIRS': True, + 'OPTIONS': { + 'context_processors': [ + 'django.template.context_processors.debug', + 'django.template.context_processors.request', + 'django.contrib.auth.context_processors.auth', + 'django.contrib.messages.context_processors.messages', + ], + }, + }, +] + +SWAGGER_SETTINGS = { + 'DEFAULT_AUTO_SCHEMA_CLASS': 'common.config.swagger_conf.CustomSwaggerAutoSchema', + "DEFAULT_MODEL_RENDERING": "example", + 'USE_SESSION_AUTH': False, + 'SECURITY_DEFINITIONS': { + 'Bearer': { + 'type': 'apiKey', + 'name': 'AUTHORIZATION', + 'in': 'header', + } + } +} + +# 缓存配置 +CACHES = { + "default": { + 'BACKEND': 'django.core.cache.backends.locmem.LocMemCache', + 'LOCATION': 'unique-snowflake', + 'TIMEOUT': 60 * 30, + 'OPTIONS': { + 'MAX_ENTRIES': 150, + 'CULL_FREQUENCY': 5, + } + }, + 'default_file': { + 'BACKEND': 'common.cache.file_cache.FileCache', + 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "default_file_cache") # 文件夹路径 + }, + 'chat_cache': { + 'BACKEND': 'common.cache.file_cache.FileCache', + 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "chat_cache") # 文件夹路径 + }, + # 存储用户信息 + 'user_cache': { + 'BACKEND': 'common.cache.file_cache.FileCache', + 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "user_cache") # 文件夹路径 + }, + # 存储用户Token + "token_cache": { + 'BACKEND': 'common.cache.file_cache.FileCache', + 'LOCATION': os.path.join(PROJECT_DIR, 'data', 'cache', "token_cache") # 文件夹路径 + } +} + +REST_FRAMEWORK = { + 'EXCEPTION_HANDLER': 'common.handle.handle_exception.handle_exception', + 'DEFAULT_AUTHENTICATION_CLASSES': ['common.auth.authenticate.AnonymousAuthentication'] + +} +STATICFILES_DIRS = [(os.path.join(PROJECT_DIR, 'ui', 'dist'))] + +STATIC_ROOT = os.path.join(BASE_DIR.parent, 'static') + +WSGI_APPLICATION = 'smartdoc.wsgi.application' + +# 邮件配置 +EMAIL_ADDRESS = CONFIG.get('EMAIL_ADDRESS') +EMAIL_BACKEND = 'django.core.mail.backends.smtp.EmailBackend' +EMAIL_USE_TLS = CONFIG.get('EMAIL_USE_TLS') # 是否使用TLS安全传输协议(用于在两个通信应用程序之间提供保密性和数据完整性。) +EMAIL_USE_SSL = CONFIG.get('EMAIL_USE_SSL') # 是否使用SSL加密,qq企业邮箱要求使用 +EMAIL_HOST = CONFIG.get('EMAIL_HOST') # 发送邮件的邮箱 的 SMTP服务器,这里用了163邮箱 +EMAIL_PORT = CONFIG.get('EMAIL_PORT') # 发件箱的SMTP服务器端口 +EMAIL_HOST_USER = CONFIG.get('EMAIL_HOST_USER') # 发送邮件的邮箱地址 +EMAIL_HOST_PASSWORD = CONFIG.get('EMAIL_HOST_PASSWORD') # 发送邮件的邮箱密码(这里使用的是授权码) + +# Database +# https://docs.djangoproject.com/en/4.2/ref/settings/#databases + + +# Password validation +# https://docs.djangoproject.com/en/4.2/ref/settings/#auth-password-validators + +AUTH_PASSWORD_VALIDATORS = [ + { + 'NAME': 'django.contrib.auth.password_validation.UserAttributeSimilarityValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.MinimumLengthValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.CommonPasswordValidator', + }, + { + 'NAME': 'django.contrib.auth.password_validation.NumericPasswordValidator', + }, +] + +# Internationalization +# https://docs.djangoproject.com/en/4.2/topics/i18n/ + +LANGUAGE_CODE = 'en-us' + +TIME_ZONE = CONFIG.get_time_zone() + +USE_I18N = True + +USE_TZ = False + +# Static files (CSS, JavaScript, Images) +# https://docs.djangoproject.com/en/4.2/howto/static-files/ + +STATIC_URL = 'static/' + +# Default primary key field type +# https://docs.djangoproject.com/en/4.2/ref/settings/#default-auto-field + +DEFAULT_AUTO_FIELD = 'django.db.models.BigAutoField' diff --git a/src/MaxKB-1.7.2/apps/smartdoc/settings/lib.py b/src/MaxKB-1.7.2/apps/smartdoc/settings/lib.py new file mode 100644 index 0000000..e7b6d39 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/settings/lib.py @@ -0,0 +1,40 @@ +# coding=utf-8 +""" + @project: MaxKB + @Author:虎 + @file: lib.py + @date:2024/8/16 17:12 + @desc: +""" +import os + +from smartdoc.const import CONFIG, PROJECT_DIR + +# celery相关配置 +celery_data_dir = os.path.join(PROJECT_DIR, 'data', 'celery_task') +if not os.path.exists(celery_data_dir) or not os.path.isdir(celery_data_dir): + os.makedirs(celery_data_dir) +broker_path = os.path.join(celery_data_dir, "celery_db.sqlite3") +backend_path = os.path.join(celery_data_dir, "celery_results.sqlite3") +# 使用sql_lite 当做broker 和 响应接收 +CELERY_BROKER_URL = f'sqla+sqlite:///{broker_path}' +CELERY_result_backend = f'db+sqlite:///{backend_path}' +CELERY_timezone = CONFIG.TIME_ZONE +CELERY_ENABLE_UTC = False +CELERY_task_serializer = 'pickle' +CELERY_result_serializer = 'pickle' +CELERY_accept_content = ['json', 'pickle'] +CELERY_RESULT_EXPIRES = 600 +CELERY_WORKER_TASK_LOG_FORMAT = '%(asctime).19s %(message)s' +CELERY_WORKER_LOG_FORMAT = '%(asctime).19s %(message)s' +CELERY_TASK_EAGER_PROPAGATES = True +CELERY_WORKER_REDIRECT_STDOUTS = True +CELERY_WORKER_REDIRECT_STDOUTS_LEVEL = "INFO" +CELERY_TASK_SOFT_TIME_LIMIT = 3600 +CELERY_WORKER_CANCEL_LONG_RUNNING_TASKS_ON_CONNECTION_LOSS = True +CELERY_ONCE = { + 'backend': 'celery_once.backends.File', + 'settings': {'location': os.path.join(celery_data_dir, "celery_once")} +} +CELERY_BROKER_CONNECTION_RETRY_ON_STARTUP = True +CELERY_LOG_DIR = os.path.join(PROJECT_DIR, 'logs', 'celery') diff --git a/src/MaxKB-1.7.2/apps/smartdoc/settings/logging.py b/src/MaxKB-1.7.2/apps/smartdoc/settings/logging.py new file mode 100644 index 0000000..9c3df8c --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/settings/logging.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# +import os + +from ..const import PROJECT_DIR, CONFIG + +LOG_DIR = os.path.join(PROJECT_DIR, 'data', 'logs') +MAX_KB_LOG_FILE = os.path.join(LOG_DIR, 'max_kb.log') +DRF_EXCEPTION_LOG_FILE = os.path.join(LOG_DIR, 'drf_exception.log') +UNEXPECTED_EXCEPTION_LOG_FILE = os.path.join(LOG_DIR, 'unexpected_exception.log') +LOG_LEVEL = "DEBUG" + +LOGGING = { + 'version': 1, + 'disable_existing_loggers': False, + 'formatters': { + 'verbose': { + 'format': '%(levelname)s %(asctime)s %(module)s %(process)d %(thread)d %(message)s' + }, + 'main': { + 'datefmt': '%Y-%m-%d %H:%M:%S', + 'format': '%(asctime)s [%(module)s %(levelname)s] %(message)s', + }, + 'exception': { + 'datefmt': '%Y-%m-%d %H:%M:%S', + 'format': '\n%(asctime)s [%(levelname)s] %(message)s', + }, + 'simple': { + 'format': '%(levelname)s %(message)s' + }, + 'syslog': { + 'format': 'jumpserver: %(message)s' + }, + 'msg': { + 'format': '%(message)s' + } + }, + 'handlers': { + 'null': { + 'level': 'DEBUG', + 'class': 'logging.NullHandler', + }, + 'console': { + 'level': 'DEBUG', + 'class': 'logging.StreamHandler', + 'formatter': 'main' + }, + 'file': { + 'encoding': 'utf8', + 'level': 'DEBUG', + 'class': 'logging.handlers.RotatingFileHandler', + 'maxBytes': 1024 * 1024 * 100, + 'backupCount': 7, + 'formatter': 'main', + 'filename': MAX_KB_LOG_FILE, + }, + 'drf_exception': { + 'encoding': 'utf8', + 'level': 'DEBUG', + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'exception', + 'maxBytes': 1024 * 1024 * 100, + 'backupCount': 7, + 'filename': DRF_EXCEPTION_LOG_FILE, + }, + 'unexpected_exception': { + 'encoding': 'utf8', + 'level': 'DEBUG', + 'class': 'logging.handlers.RotatingFileHandler', + 'formatter': 'exception', + 'maxBytes': 1024 * 1024 * 100, + 'backupCount': 7, + 'filename': UNEXPECTED_EXCEPTION_LOG_FILE, + }, + 'syslog': { + 'level': 'INFO', + 'class': 'logging.NullHandler', + 'formatter': 'syslog' + }, + }, + 'loggers': { + 'django': { + 'handlers': ['null'], + 'propagate': False, + 'level': LOG_LEVEL, + }, + 'django.request': { + 'handlers': ['console', 'file', 'syslog'], + 'level': LOG_LEVEL, + 'propagate': False, + }, + 'sqlalchemy': { + 'handlers': ['console', 'file', 'syslog'], + 'level': "ERROR", + 'propagate': False, + }, + 'django.db.backends': { + 'handlers': ['console', 'file', 'syslog'], + 'propagate': False, + 'level': LOG_LEVEL, + }, + 'django.server': { + 'handlers': ['console', 'file', 'syslog'], + 'level': LOG_LEVEL, + 'propagate': False, + }, + 'max_kb_error': { + 'handlers': ['console', 'unexpected_exception'], + 'level': LOG_LEVEL, + 'propagate': False, + }, + 'max_kb': { + 'handlers': ['console', 'file'], + 'level': LOG_LEVEL, + 'propagate': False, + }, + 'common.event': { + 'handlers': ['console', 'file'], + 'level': "DEBUG", + 'propagate': False, + }, + } +} + +SYSLOG_ENABLE = CONFIG.SYSLOG_ENABLE + +if not os.path.isdir(LOG_DIR): + os.makedirs(LOG_DIR, mode=0o755) diff --git a/src/MaxKB-1.7.2/apps/smartdoc/urls.py b/src/MaxKB-1.7.2/apps/smartdoc/urls.py new file mode 100644 index 0000000..b243809 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/urls.py @@ -0,0 +1,74 @@ +""" +URL configuration for apps project. + +The `urlpatterns` list routes URLs to views. For more information please see: + https://docs.djangoproject.com/en/4.2/topics/http/urls/ +Examples: +Function views + 1. Add an import: from my_app import views + 2. Add a URL to urlpatterns: path('', views.home, name='home') +Class-based views + 1. Add an import: from other_app.views import Home + 2. Add a URL to urlpatterns: path('', Home.as_view(), name='home') +Including another URLconf + 1. Import the include() function_lib: from django.urls import include, path + 2. Add a URL to urlpatterns: path('blog/', include('blog.urls')) +""" +import os + +from django.http import HttpResponse +from django.urls import path, re_path, include +from django.views import static +from rest_framework import status + +from application.urls import urlpatterns as application_urlpatterns +from common.cache_data.static_resource_cache import get_index_html +from common.constants.cache_code_constants import CacheCodeConstants +from common.init.init_doc import init_doc +from common.response.result import Result +from common.util.cache_util import get_cache +from smartdoc import settings +from smartdoc.conf import PROJECT_DIR + +urlpatterns = [ + path("api/", include("users.urls")), + path("api/", include("dataset.urls")), + path("api/", include("setting.urls")), + path("api/", include("application.urls")), + path("api/", include("function_lib.urls")) +] + + +def pro(): + # 暴露静态主要是swagger资源 + urlpatterns.append( + re_path(r'^static/(?P.*)$', static.serve, {'document_root': settings.STATIC_ROOT}, name='static'), + ) + # 暴露ui静态资源 + urlpatterns.append( + re_path(r'^ui/(?P.*)$', static.serve, {'document_root': os.path.join(settings.STATIC_ROOT, "ui")}, + name='ui'), + ) + + +if not settings.DEBUG: + pro() + + +def page_not_found(request, exception): + """ + 页面不存在处理 + """ + if request.path.startswith("/api/"): + return Result(response_status=status.HTTP_404_NOT_FOUND, code=404, message="找不到接口") + index_path = os.path.join(PROJECT_DIR, 'apps', "static", 'ui', 'index.html') + if not os.path.exists(index_path): + return HttpResponse("页面不存在", status=404) + content = get_index_html(index_path) + if request.path.startswith('/ui/chat/'): + return HttpResponse(content, status=200) + return HttpResponse(content, status=200, headers={'X-Frame-Options': 'DENY'}) + + +handler404 = page_not_found +init_doc(urlpatterns, application_urlpatterns) diff --git a/src/MaxKB-1.7.2/apps/smartdoc/wsgi.py b/src/MaxKB-1.7.2/apps/smartdoc/wsgi.py new file mode 100644 index 0000000..6c7c681 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/smartdoc/wsgi.py @@ -0,0 +1,28 @@ +""" +WSGI config for apps project. + +It exposes the WSGI callable as a module-level variable named ``application``. + +For more information on this file, see +https://docs.djangoproject.com/en/4.2/howto/deployment/wsgi/ +""" + +import os + +from django.core.wsgi import get_wsgi_application + +os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'smartdoc.settings') + +application = get_wsgi_application() + + +def post_handler(): + from common import event + from common import job + from common.models.db_model_manage import DBModelManage + event.run() + job.run() + DBModelManage.init() + + +post_handler() diff --git a/src/MaxKB-1.7.2/apps/users/__init__.py b/src/MaxKB-1.7.2/apps/users/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/users/apps.py b/src/MaxKB-1.7.2/apps/users/apps.py new file mode 100644 index 0000000..8e08561 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/apps.py @@ -0,0 +1,9 @@ +from django.apps import AppConfig + + +class UsersConfig(AppConfig): + default_auto_field = 'django.db.models.BigAutoField' + name = 'users' + + def ready(self): + from ops.celery import signal_handler diff --git a/src/MaxKB-1.7.2/apps/users/migrations/0001_initial.py b/src/MaxKB-1.7.2/apps/users/migrations/0001_initial.py new file mode 100644 index 0000000..9565efa --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/migrations/0001_initial.py @@ -0,0 +1,44 @@ +# Generated by Django 4.1.10 on 2024-03-18 16:02 + +from django.db import migrations, models +import uuid + +from common.constants.permission_constants import RoleConstants +from users.models import password_encrypt + + +def insert_default_data(apps, schema_editor): + UserModel = apps.get_model('users', 'User') + UserModel.objects.create(id='f0dd8f71-e4ee-11ee-8c84-a8a1595801ab', email='', username='admin', + nick_name="系统管理员", + password=password_encrypt('MaxKB@123..'), + role=RoleConstants.ADMIN.name, + is_active=True) + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + ] + + operations = [ + migrations.CreateModel( + name='User', + fields=[ + ('id', models.UUIDField(default=uuid.uuid1, editable=False, primary_key=True, serialize=False, + verbose_name='主键id')), + ('email', models.EmailField(max_length=254, unique=True, verbose_name='邮箱')), + ('phone', models.CharField(default='', max_length=20, verbose_name='电话')), + ('nick_name', models.CharField(default='', max_length=150, verbose_name='昵称')), + ('username', models.CharField(max_length=150, unique=True, verbose_name='用户名')), + ('password', models.CharField(max_length=150, verbose_name='密码')), + ('role', models.CharField(max_length=150, verbose_name='角色')), + ('is_active', models.BooleanField(default=True)), + ], + options={ + 'db_table': 'user', + }, + ), + migrations.RunPython(insert_default_data) + ] diff --git a/src/MaxKB-1.7.2/apps/users/migrations/0002_user_create_time_user_update_time.py b/src/MaxKB-1.7.2/apps/users/migrations/0002_user_create_time_user_update_time.py new file mode 100644 index 0000000..68baae0 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/migrations/0002_user_create_time_user_update_time.py @@ -0,0 +1,23 @@ +# Generated by Django 4.1.13 on 2024-03-20 12:27 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('users', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='user', + name='create_time', + field=models.DateTimeField(auto_now_add=True, null=True, verbose_name='创建时间'), + ), + migrations.AddField( + model_name='user', + name='update_time', + field=models.DateTimeField(auto_now=True, null=True, verbose_name='修改时间'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/users/migrations/0003_user_source.py b/src/MaxKB-1.7.2/apps/users/migrations/0003_user_source.py new file mode 100644 index 0000000..7292cc1 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/migrations/0003_user_source.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.13 on 2024-07-11 19:16 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('users', '0002_user_create_time_user_update_time'), + ] + + operations = [ + migrations.AddField( + model_name='user', + name='source', + field=models.CharField(default='LOCAL', max_length=10, verbose_name='来源'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/users/migrations/0004_alter_user_email.py b/src/MaxKB-1.7.2/apps/users/migrations/0004_alter_user_email.py new file mode 100644 index 0000000..c77416b --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/migrations/0004_alter_user_email.py @@ -0,0 +1,18 @@ +# Generated by Django 4.2.13 on 2024-07-16 17:03 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('users', '0003_user_source'), + ] + + operations = [ + migrations.AlterField( + model_name='user', + name='email', + field=models.EmailField(blank=True, max_length=254, null=True, unique=True, verbose_name='邮箱'), + ), + ] diff --git a/src/MaxKB-1.7.2/apps/users/migrations/__init__.py b/src/MaxKB-1.7.2/apps/users/migrations/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/users/models/__init__.py b/src/MaxKB-1.7.2/apps/users/models/__init__.py new file mode 100644 index 0000000..da7106a --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/models/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: __init__.py + @date:2023/9/4 10:08 + @desc: +""" +from .user import * \ No newline at end of file diff --git a/src/MaxKB-1.7.2/apps/users/models/user.py b/src/MaxKB-1.7.2/apps/users/models/user.py new file mode 100644 index 0000000..b16f073 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/models/user.py @@ -0,0 +1,85 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: users.py + @date:2023/9/4 10:09 + @desc: +""" +import hashlib +import os +import uuid + +from django.db import models + +from common.constants.permission_constants import Permission, Group, Operate +from common.db.sql_execute import select_list +from common.mixins.app_model_mixin import AppModelMixin +from common.util.file_util import get_file_content +from smartdoc.conf import PROJECT_DIR + +__all__ = ["User", "password_encrypt", 'get_user_dynamics_permission'] + + +def password_encrypt(raw_password): + """ + 密码 md5加密 + :param raw_password: 密码 + :return: 加密后密码 + """ + md5 = hashlib.md5() # 2,实例化md5() 方法 + md5.update(raw_password.encode()) # 3,对字符串的字节类型加密 + result = md5.hexdigest() # 4,加密 + return result + + +def to_dynamics_permission(group_type: str, operate: list[str], dynamic_tag: str): + """ + 转换为权限对象 + :param group_type: 分组类型 + :param operate: 操作 + :param dynamic_tag: 标记 + :return: 权限列表 + """ + return [Permission(group=Group[group_type], operate=Operate[o], dynamic_tag=dynamic_tag) + for o in operate] + + +def get_user_dynamics_permission(user_id: str): + """ + 获取 应用和数据集权限 + :param user_id: 用户id + :return: 用户 应用和数据集权限 + """ + member_permission_list = select_list( + get_file_content(os.path.join(PROJECT_DIR, "apps", "setting", 'sql', 'get_user_permission.sql')), + [user_id, user_id, user_id]) + result = [] + for member_permission in member_permission_list: + result += to_dynamics_permission(member_permission.get('type'), member_permission.get('operate'), + str(member_permission.get('id'))) + return result + + +class User(AppModelMixin): + id = models.UUIDField(primary_key=True, max_length=128, default=uuid.uuid1, editable=False, verbose_name="主键id") + email = models.EmailField(unique=True, null=True, blank=True, verbose_name="邮箱") + phone = models.CharField(max_length=20, verbose_name="电话", default="") + nick_name = models.CharField(max_length=150, verbose_name="昵称", default="") + username = models.CharField(max_length=150, unique=True, verbose_name="用户名") + password = models.CharField(max_length=150, verbose_name="密码") + role = models.CharField(max_length=150, verbose_name="角色") + source = models.CharField(max_length=10, verbose_name="来源", default="LOCAL") + is_active = models.BooleanField(default=True) + create_time = models.DateTimeField(verbose_name="创建时间", auto_now_add=True, null=True) + update_time = models.DateTimeField(verbose_name="修改时间", auto_now=True, null=True) + + USERNAME_FIELD = 'username' + REQUIRED_FIELDS = [] + + class Meta: + db_table = "user" + + def set_password(self, raw_password): + self.password = password_encrypt(raw_password) + self._password = raw_password diff --git a/src/MaxKB-1.7.2/apps/users/serializers/user_serializers.py b/src/MaxKB-1.7.2/apps/users/serializers/user_serializers.py new file mode 100644 index 0000000..e8d6963 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/serializers/user_serializers.py @@ -0,0 +1,785 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: team_serializers.py + @date:2023/9/5 16:32 + @desc: +""" +import datetime +import os +import random +import re +import uuid + +from django.conf import settings +from django.core import validators, signing, cache +from django.core.mail import send_mail +from django.core.mail.backends.smtp import EmailBackend +from django.db import transaction +from django.db.models import Q, QuerySet +from drf_yasg import openapi +from rest_framework import serializers + +from application.models import Application +from common.constants.authentication_type import AuthenticationType +from common.constants.exception_code_constants import ExceptionCodeConstants +from common.constants.permission_constants import RoleConstants, get_permission_list_by_role +from common.db.search import page_search +from common.exception.app_exception import AppApiException +from common.mixins.api_mixin import ApiMixin +from common.models.db_model_manage import DBModelManage +from common.response.result import get_api_response +from common.util.common import valid_license +from common.util.field_message import ErrMessage +from common.util.lock import lock +from dataset.models import DataSet, Document, Paragraph, Problem, ProblemParagraphMapping +from embedding.task import delete_embedding_by_dataset_id_list +from setting.models import Team, SystemSetting, SettingType, Model, TeamMember, TeamMemberPermission +from smartdoc.conf import PROJECT_DIR +from users.models.user import User, password_encrypt, get_user_dynamics_permission + +user_cache = cache.caches['user_cache'] + + +class SystemSerializer(ApiMixin, serializers.Serializer): + @staticmethod + def get_profile(): + version = os.environ.get('MAXKB_VERSION') + xpack_cache = DBModelManage.get_model('xpack_cache') + return {'version': version, 'IS_XPACK': hasattr(settings, 'IS_XPACK'), + 'XPACK_LICENSE_IS_VALID': False if xpack_cache is None else xpack_cache.get('XPACK_LICENSE_IS_VALID', + False)} + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=[], + properties={ + 'version': openapi.Schema(type=openapi.TYPE_STRING, title="系统版本号", description="系统版本号"), + } + ) + + +class LoginSerializer(ApiMixin, serializers.Serializer): + username = serializers.CharField(required=True, + error_messages=ErrMessage.char("用户名")) + + password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码")) + + def is_valid(self, *, raise_exception=False): + """ + 校验参数 + :param raise_exception: 是否抛出异常 只能是True + :return: 用户信息 + """ + super().is_valid(raise_exception=True) + username = self.data.get("username") + password = password_encrypt(self.data.get("password")) + user = QuerySet(User).filter(Q(username=username, + password=password) | Q(email=username, + password=password)).first() + if user is None: + raise ExceptionCodeConstants.INCORRECT_USERNAME_AND_PASSWORD.value.to_app_api_exception() + if not user.is_active: + raise AppApiException(1005, "用户已被禁用,请联系管理员!") + return user + + def get_user_token(self): + """ + 获取用户Token + :return: 用户Token(认证信息) + """ + user = self.is_valid() + token = signing.dumps({'username': user.username, 'id': str(user.id), 'email': user.email, + 'type': AuthenticationType.USER.value}) + return token + + class Meta: + model = User + fields = '__all__' + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['username', 'password'], + properties={ + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'password': openapi.Schema(type=openapi.TYPE_STRING, title="密码", description="密码") + } + ) + + def get_response_body_api(self): + return get_api_response(openapi.Schema( + type=openapi.TYPE_STRING, + title="token", + default="xxxx", + description="认证token" + )) + + +class RegisterSerializer(ApiMixin, serializers.Serializer): + """ + 注册请求对象 + """ + email = serializers.EmailField( + required=True, + error_messages=ErrMessage.char("邮箱"), + validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message, + code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)]) + + username = serializers.CharField(required=True, + error_messages=ErrMessage.char("用户名"), + max_length=20, + min_length=6, + validators=[ + validators.RegexValidator(regex=re.compile("^.{6,20}$"), + message="用户名字符数为 6-20 个字符") + ]) + password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码"), + validators=[validators.RegexValidator(regex=re.compile( + "^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~.()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)" + "(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~.()-+=]{6,20}$") + , message="密码长度6-20个字符,必须字母、数字、特殊字符组合")]) + + re_password = serializers.CharField(required=True, + error_messages=ErrMessage.char("确认密码"), + validators=[validators.RegexValidator(regex=re.compile( + "^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~.()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)" + "(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~.()-+=]{6,20}$") + , message="确认密码长度6-20个字符,必须字母、数字、特殊字符组合")]) + + code = serializers.CharField(required=True, error_messages=ErrMessage.char("验证码")) + + class Meta: + model = User + fields = '__all__' + + @lock(lock_key=lambda this, raise_exception: ( + this.initial_data.get("email") + ":register" + + )) + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if self.data.get('password') != self.data.get('re_password'): + raise ExceptionCodeConstants.PASSWORD_NOT_EQ_RE_PASSWORD.value.to_app_api_exception() + username = self.data.get("username") + email = self.data.get("email") + code = self.data.get("code") + code_cache_key = email + ":register" + cache_code = user_cache.get(code_cache_key) + if code != cache_code: + raise ExceptionCodeConstants.CODE_ERROR.value.to_app_api_exception() + u = QuerySet(User).filter(Q(username=username) | Q(email=email)).first() + if u is not None: + if u.email == email: + raise ExceptionCodeConstants.EMAIL_IS_EXIST.value.to_app_api_exception() + if u.username == username: + raise ExceptionCodeConstants.USERNAME_IS_EXIST.value.to_app_api_exception() + + return True + + @valid_license(model=User, count=2, + message='社区版最多支持 2 个用户,如需拥有更多用户,请联系我们(https://fit2cloud.com/)。') + @transaction.atomic + def save(self, **kwargs): + m = User( + **{'id': uuid.uuid1(), 'email': self.data.get("email"), 'username': self.data.get("username"), + 'role': RoleConstants.USER.name}) + m.set_password(self.data.get("password")) + # 插入用户 + m.save() + # 初始化用户团队 + Team(**{'user': m, 'name': m.username + '的团队'}).save() + email = self.data.get("email") + code_cache_key = email + ":register" + # 删除验证码缓存 + user_cache.delete(code_cache_key) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['username', 'email', 'password', 're_password', 'code'], + properties={ + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'password': openapi.Schema(type=openapi.TYPE_STRING, title="密码", description="密码"), + 're_password': openapi.Schema(type=openapi.TYPE_STRING, title="确认密码", description="确认密码"), + 'code': openapi.Schema(type=openapi.TYPE_STRING, title="验证码", description="验证码") + } + ) + + +class CheckCodeSerializer(ApiMixin, serializers.Serializer): + """ + 校验验证码 + """ + email = serializers.EmailField( + required=True, + error_messages=ErrMessage.char("邮箱"), + validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message, + code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)]) + code = serializers.CharField(required=True, error_messages=ErrMessage.char("验证码")) + + type = serializers.CharField(required=True, + error_messages=ErrMessage.char("类型"), + validators=[ + validators.RegexValidator(regex=re.compile("^register|reset_password$"), + message="类型只支持register|reset_password", code=500) + ]) + + def is_valid(self, *, raise_exception=False): + super().is_valid() + value = user_cache.get(self.data.get("email") + ":" + self.data.get("type")) + if value is None or value != self.data.get("code"): + raise ExceptionCodeConstants.CODE_ERROR.value.to_app_api_exception() + return True + + class Meta: + model = User + fields = '__all__' + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['email', 'code', 'type'], + properties={ + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'code': openapi.Schema(type=openapi.TYPE_STRING, title="验证码", description="验证码"), + 'type': openapi.Schema(type=openapi.TYPE_STRING, title="类型", description="register|reset_password") + } + ) + + def get_response_body_api(self): + return get_api_response(openapi.Schema( + type=openapi.TYPE_BOOLEAN, + title="是否成功", + default=True, + description="错误提示")) + + +class RePasswordSerializer(ApiMixin, serializers.Serializer): + email = serializers.EmailField( + required=True, + error_messages=ErrMessage.char("邮箱"), + validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message, + code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)]) + + code = serializers.CharField(required=True, error_messages=ErrMessage.char("验证码")) + + password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码"), + validators=[validators.RegexValidator(regex=re.compile( + "^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~.()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)" + "(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~.()-+=]{6,20}$") + , message="确认密码长度6-20个字符,必须字母、数字、特殊字符组合")]) + + re_password = serializers.CharField(required=True, error_messages=ErrMessage.char("确认密码"), + validators=[validators.RegexValidator(regex=re.compile( + "^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~.()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)" + "(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~.()-+=]{6,20}$") + , message="确认密码长度6-20个字符,必须字母、数字、特殊字符组合")] + ) + + class Meta: + model = User + fields = '__all__' + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + email = self.data.get("email") + cache_code = user_cache.get(email + ':reset_password') + if self.data.get('password') != self.data.get('re_password'): + raise AppApiException(ExceptionCodeConstants.PASSWORD_NOT_EQ_RE_PASSWORD.value.code, + ExceptionCodeConstants.PASSWORD_NOT_EQ_RE_PASSWORD.value.message) + if cache_code != self.data.get('code'): + raise AppApiException(ExceptionCodeConstants.CODE_ERROR.value.code, + ExceptionCodeConstants.CODE_ERROR.value.message) + return True + + def reset_password(self): + """ + 修改密码 + :return: 是否成功 + """ + if self.is_valid(): + email = self.data.get("email") + QuerySet(User).filter(email=email).update( + password=password_encrypt(self.data.get('password'))) + code_cache_key = email + ":reset_password" + # 删除验证码缓存 + user_cache.delete(code_cache_key) + return True + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['email', 'code', "password", 're_password'], + properties={ + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'code': openapi.Schema(type=openapi.TYPE_STRING, title="验证码", description="验证码"), + 'password': openapi.Schema(type=openapi.TYPE_STRING, title="密码", description="密码"), + 're_password': openapi.Schema(type=openapi.TYPE_STRING, title="确认密码", description="确认密码") + } + ) + + +class SendEmailSerializer(ApiMixin, serializers.Serializer): + email = serializers.EmailField( + required=True + , error_messages=ErrMessage.char("邮箱"), + validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message, + code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)]) + + type = serializers.CharField(required=True, error_messages=ErrMessage.char("类型"), validators=[ + validators.RegexValidator(regex=re.compile("^register|reset_password$"), + message="类型只支持register|reset_password", code=500) + ]) + + class Meta: + model = User + fields = '__all__' + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=raise_exception) + user_exists = QuerySet(User).filter(email=self.data.get('email')).exists() + if not user_exists and self.data.get('type') == 'reset_password': + raise ExceptionCodeConstants.EMAIL_IS_NOT_EXIST.value.to_app_api_exception() + elif user_exists and self.data.get('type') == 'register': + raise ExceptionCodeConstants.EMAIL_IS_EXIST.value.to_app_api_exception() + code_cache_key = self.data.get('email') + ":" + self.data.get("type") + code_cache_key_lock = code_cache_key + "_lock" + ttl = user_cache.ttl(code_cache_key_lock) + if ttl is not None: + raise AppApiException(500, f"{ttl.total_seconds()}秒内请勿重复发送邮件") + return True + + def send(self): + """ + 发送邮件 + :return: 是否发送成功 + :exception 发送失败异常 + """ + email = self.data.get("email") + state = self.data.get("type") + # 生成随机验证码 + code = "".join(list(map(lambda i: random.choice(['1', '2', '3', '4', '5', '6', '7', '8', '9', '0' + ]), range(6)))) + # 获取邮件模板 + file = open(os.path.join(PROJECT_DIR, "apps", "common", 'template', 'email_template.html'), "r", + encoding='utf-8') + content = file.read() + file.close() + code_cache_key = email + ":" + state + code_cache_key_lock = code_cache_key + "_lock" + # 设置缓存 + user_cache.set(code_cache_key_lock, code, timeout=datetime.timedelta(minutes=1)) + system_setting = QuerySet(SystemSetting).filter(type=SettingType.EMAIL.value).first() + if system_setting is None: + user_cache.delete(code_cache_key_lock) + raise AppApiException(1004, "邮箱服务未设置,请联系管理员到【邮箱设置】中设置邮箱服务。") + try: + connection = EmailBackend(system_setting.meta.get("email_host"), + system_setting.meta.get('email_port'), + system_setting.meta.get('email_host_user'), + system_setting.meta.get('email_host_password'), + system_setting.meta.get('email_use_tls'), + False, + system_setting.meta.get('email_use_ssl') + ) + # 发送邮件 + send_mail(f'【智能知识库问答系统-{"用户注册" if state == "register" else "修改密码"}】', + '', + html_message=f'{content.replace("${code}", code)}', + from_email=system_setting.meta.get('from_email'), + recipient_list=[email], fail_silently=False, connection=connection) + except Exception as e: + user_cache.delete(code_cache_key_lock) + raise AppApiException(500, f"{str(e)}邮件发送失败") + user_cache.set(code_cache_key, code, timeout=datetime.timedelta(minutes=30)) + return True + + def get_request_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['email', 'type'], + properties={ + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'type': openapi.Schema(type=openapi.TYPE_STRING, title="类型", description="register|reset_password") + } + ) + + def get_response_body_api(self): + return get_api_response(openapi.Schema(type=openapi.TYPE_STRING, default=True)) + + +class UserProfile(ApiMixin): + + @staticmethod + def get_user_profile(user: User): + """ + 获取用户详情 + :param user: 用户对象 + :return: + """ + permission_list = get_user_dynamics_permission(str(user.id)) + permission_list += [p.value for p in get_permission_list_by_role(RoleConstants[user.role])] + return {'id': user.id, 'username': user.username, 'email': user.email, 'role': user.role, + 'permissions': [str(p) for p in permission_list], + 'is_edit_password': user.password == 'd880e722c47a34d8e9fce789fc62389d' if user.role == 'ADMIN' else False} + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'username', 'email', 'role', 'is_active'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'role': openapi.Schema(type=openapi.TYPE_STRING, title="角色", description="角色"), + 'is_active': openapi.Schema(type=openapi.TYPE_STRING, title="是否可用", description="是否可用"), + "permissions": openapi.Schema(type=openapi.TYPE_ARRAY, title="权限列表", description="权限列表", + items=openapi.Schema(type=openapi.TYPE_STRING)) + } + ) + + +class UserSerializer(ApiMixin, serializers.ModelSerializer): + class Meta: + model = User + fields = ["email", "id", + "username", ] + + def get_response_body_api(self): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'username', 'email', 'role', 'is_active'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'role': openapi.Schema(type=openapi.TYPE_STRING, title="角色", description="角色"), + 'is_active': openapi.Schema(type=openapi.TYPE_STRING, title="是否可用", description="是否可用") + } + ) + + class Query(ApiMixin, serializers.Serializer): + email_or_username = serializers.CharField(required=True) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='email_or_username', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=True, + description='邮箱或者用户名')] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['username', 'email', 'id'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title='用户主键id', description="用户主键id"), + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址") + } + ) + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + email_or_username = self.data.get('email_or_username') + return [{'id': user_model.id, 'username': user_model.username, 'email': user_model.email} for user_model in + QuerySet(User).filter(Q(username=email_or_username) | Q(email=email_or_username))] + + +class UserInstanceSerializer(ApiMixin, serializers.ModelSerializer): + class Meta: + model = User + fields = ['id', 'username', 'email', 'phone', 'is_active', 'role', 'nick_name', 'create_time', 'update_time', + 'source'] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['id', 'username', 'email', 'phone', 'is_active', 'role', 'nick_name', 'create_time', + 'update_time'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title="用户id", description="用户id"), + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'phone': openapi.Schema(type=openapi.TYPE_STRING, title="手机号", description="手机号"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否激活", description="是否激活"), + 'role': openapi.Schema(type=openapi.TYPE_STRING, title="角色", description="角色"), + 'source': openapi.Schema(type=openapi.TYPE_STRING, title="来源", description="来源"), + 'nick_name': openapi.Schema(type=openapi.TYPE_STRING, title="姓名", description="姓名"), + 'create_time': openapi.Schema(type=openapi.TYPE_STRING, title="创建时间", description="修改时间"), + 'update_time': openapi.Schema(type=openapi.TYPE_STRING, title="修改时间", description="修改时间") + } + ) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='user_id', + in_=openapi.IN_PATH, + type=openapi.TYPE_STRING, + required=True, + description='用户名id') + + ] + + +class UserManageSerializer(serializers.Serializer): + class Query(ApiMixin, serializers.Serializer): + email_or_username = serializers.CharField(required=False, allow_null=True, + error_messages=ErrMessage.char("邮箱或者用户名")) + + @staticmethod + def get_request_params_api(): + return [openapi.Parameter(name='email_or_username', + in_=openapi.IN_QUERY, + type=openapi.TYPE_STRING, + required=False, + description='邮箱或者用户名')] + + @staticmethod + def get_response_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['username', 'email', 'id'], + properties={ + 'id': openapi.Schema(type=openapi.TYPE_STRING, title='用户主键id', description="用户主键id"), + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址") + } + ) + + def get_query_set(self): + email_or_username = self.data.get('email_or_username') + query_set = QuerySet(User) + if email_or_username is not None: + query_set = query_set.filter( + Q(username__contains=email_or_username) | Q(email__contains=email_or_username)) + query_set = query_set.order_by("-create_time") + return query_set + + def list(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return [{'id': user_model.id, 'username': user_model.username, 'email': user_model.email} for user_model in + self.get_query_set()] + + def page(self, current_page: int, page_size: int, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + return page_search(current_page, page_size, + self.get_query_set(), + post_records_handler=lambda u: UserInstanceSerializer(u).data) + + class UserInstance(ApiMixin, serializers.Serializer): + email = serializers.EmailField( + required=True, + error_messages=ErrMessage.char("邮箱"), + validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message, + code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)]) + + username = serializers.CharField(required=True, + error_messages=ErrMessage.char("用户名"), + max_length=20, + min_length=6, + validators=[ + validators.RegexValidator(regex=re.compile("^.{6,20}$"), + message="用户名字符数为 6-20 个字符") + ]) + password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码"), + validators=[validators.RegexValidator(regex=re.compile( + "^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~.()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)" + "(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~.()-+=]{6,20}$") + , message="密码长度6-20个字符,必须字母、数字、特殊字符组合")]) + + nick_name = serializers.CharField(required=False, error_messages=ErrMessage.char("姓名"), max_length=64, + allow_null=True, allow_blank=True) + phone = serializers.CharField(required=False, error_messages=ErrMessage.char("手机号"), max_length=20, + allow_null=True, allow_blank=True) + + def is_valid(self, *, raise_exception=True): + super().is_valid(raise_exception=True) + username = self.data.get('username') + email = self.data.get('email') + u = QuerySet(User).filter(Q(username=username) | Q(email=email)).first() + if u is not None: + if u.email == email: + raise ExceptionCodeConstants.EMAIL_IS_EXIST.value.to_app_api_exception() + if u.username == username: + raise ExceptionCodeConstants.USERNAME_IS_EXIST.value.to_app_api_exception() + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['username', 'email', 'password'], + properties={ + 'username': openapi.Schema(type=openapi.TYPE_STRING, title="用户名", description="用户名"), + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱地址"), + 'password': openapi.Schema(type=openapi.TYPE_STRING, title="密码", description="密码"), + 'phone': openapi.Schema(type=openapi.TYPE_STRING, title="手机号", description="手机号"), + 'nick_name': openapi.Schema(type=openapi.TYPE_STRING, title="姓名", description="姓名") + } + ) + + class UserEditInstance(ApiMixin, serializers.Serializer): + email = serializers.EmailField( + required=False, + error_messages=ErrMessage.char("邮箱"), + validators=[validators.EmailValidator(message=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.message, + code=ExceptionCodeConstants.EMAIL_FORMAT_ERROR.value.code)]) + + nick_name = serializers.CharField(required=False, error_messages=ErrMessage.char("姓名"), max_length=64, + allow_null=True, allow_blank=True) + phone = serializers.CharField(required=False, error_messages=ErrMessage.char("手机号"), max_length=20, + allow_null=True, allow_blank=True) + is_active = serializers.BooleanField(required=False, error_messages=ErrMessage.char("是否可用")) + + def is_valid(self, *, user_id=None, raise_exception=False): + super().is_valid(raise_exception=True) + if self.data.get('email') is not None and QuerySet(User).filter(email=self.data.get('email')).exclude( + id=user_id).exists(): + raise AppApiException(1004, "邮箱已经被使用") + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + properties={ + 'email': openapi.Schema(type=openapi.TYPE_STRING, title="邮箱", description="邮箱"), + 'nick_name': openapi.Schema(type=openapi.TYPE_STRING, title="姓名", description="姓名"), + 'phone': openapi.Schema(type=openapi.TYPE_STRING, title="手机号", description="手机号"), + 'is_active': openapi.Schema(type=openapi.TYPE_BOOLEAN, title="是否可用", description="是否可用"), + } + ) + + class RePasswordInstance(ApiMixin, serializers.Serializer): + password = serializers.CharField(required=True, error_messages=ErrMessage.char("密码"), + validators=[validators.RegexValidator(regex=re.compile( + "^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~.()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)" + "(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~.()-+=]{6,20}$") + , message="密码长度6-20个字符,必须字母、数字、特殊字符组合")]) + re_password = serializers.CharField(required=True, error_messages=ErrMessage.char("确认密码"), + validators=[validators.RegexValidator(regex=re.compile( + "^(?![a-zA-Z]+$)(?![A-Z0-9]+$)(?![A-Z_!@#$%^&*`~.()-+=]+$)(?![a-z0-9]+$)(?![a-z_!@#$%^&*`~()-+=]+$)" + "(?![0-9_!@#$%^&*`~()-+=]+$)[a-zA-Z0-9_!@#$%^&*`~.()-+=]{6,20}$") + , message="确认密码长度6-20个字符,必须字母、数字、特殊字符组合")] + ) + + @staticmethod + def get_request_body_api(): + return openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['password', 're_password'], + properties={ + 'password': openapi.Schema(type=openapi.TYPE_STRING, title="密码", description="密码"), + 're_password': openapi.Schema(type=openapi.TYPE_STRING, title="确认密码", + description="确认密码"), + } + ) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if self.data.get('password') != self.data.get('re_password'): + raise ExceptionCodeConstants.PASSWORD_NOT_EQ_RE_PASSWORD.value.to_app_api_exception() + + @valid_license(model=User, count=2, + message='社区版最多支持 2 个用户,如需拥有更多用户,请联系我们(https://fit2cloud.com/)。') + @transaction.atomic + def save(self, instance, with_valid=True): + if with_valid: + UserManageSerializer.UserInstance(data=instance).is_valid(raise_exception=True) + + user = User(id=uuid.uuid1(), email=instance.get('email'), + phone="" if instance.get('phone') is None else instance.get('phone'), + nick_name="" if instance.get('nick_name') is None else instance.get('nick_name') + , username=instance.get('username'), password=password_encrypt(instance.get('password')), + role=RoleConstants.USER.name, source="LOCAL", + is_active=True) + user.save() + # 初始化用户团队 + Team(**{'user': user, 'name': user.username + '的团队'}).save() + return UserInstanceSerializer(user).data + + class Operate(serializers.Serializer): + id = serializers.UUIDField(required=True, error_messages=ErrMessage.char("用户id")) + + def is_valid(self, *, raise_exception=False): + super().is_valid(raise_exception=True) + if not QuerySet(User).filter(id=self.data.get('id')).exists(): + raise AppApiException(1004, "用户不存在") + + @transaction.atomic + def delete(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + user = QuerySet(User).filter(id=self.data.get('id')).first() + if user.role == RoleConstants.ADMIN.name: + raise AppApiException(1004, "无法删除管理员") + user_id = self.data.get('id') + + team_member_list = QuerySet(TeamMember).filter(Q(user_id=user_id) | Q(team_id=user_id)) + # 删除团队成员权限 + QuerySet(TeamMemberPermission).filter( + member_id__in=[team_member.id for team_member in team_member_list]).delete() + # 删除团队成员 + team_member_list.delete() + # 删除应用相关 因为应用相关都是级联删除所以不需要手动删除 + QuerySet(Application).filter(user_id=self.data.get('id')).delete() + # 删除数据集相关 + dataset_list = QuerySet(DataSet).filter(user_id=self.data.get('id')) + dataset_id_list = [str(dataset.id) for dataset in dataset_list] + QuerySet(Document).filter(dataset_id__in=dataset_id_list).delete() + QuerySet(Paragraph).filter(dataset_id__in=dataset_id_list).delete() + QuerySet(ProblemParagraphMapping).filter(dataset_id__in=dataset_id_list).delete() + QuerySet(Problem).filter(dataset_id__in=dataset_id_list).delete() + delete_embedding_by_dataset_id_list(dataset_id_list) + dataset_list.delete() + # 删除团队 + QuerySet(Team).filter(user_id=self.data.get('id')).delete() + # 删除模型 + QuerySet(Model).filter(user_id=self.data.get('id')).delete() + # 删除用户 + QuerySet(User).filter(id=self.data.get('id')).delete() + return True + + def edit(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + UserManageSerializer.UserEditInstance(data=instance).is_valid(user_id=self.data.get('id'), + raise_exception=True) + + user = QuerySet(User).filter(id=self.data.get('id')).first() + if user.role == RoleConstants.ADMIN.name and 'is_active' in instance and instance.get( + 'is_active') is not None: + raise AppApiException(1004, "不能修改管理员状态") + update_keys = ['email', 'nick_name', 'phone', 'is_active'] + for update_key in update_keys: + if update_key in instance and instance.get(update_key) is not None: + user.__setattr__(update_key, instance.get(update_key)) + user.save() + return UserInstanceSerializer(user).data + + def one(self, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + user = QuerySet(User).filter(id=self.data.get('id')).first() + return UserInstanceSerializer(user).data + + def re_password(self, instance, with_valid=True): + if with_valid: + self.is_valid(raise_exception=True) + UserManageSerializer.RePasswordInstance(data=instance).is_valid(raise_exception=True) + user = QuerySet(User).filter(id=self.data.get('id')).first() + user.password = password_encrypt(instance.get('password')) + user.save() + return True diff --git a/src/MaxKB-1.7.2/apps/users/task/__init__.py b/src/MaxKB-1.7.2/apps/users/task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/MaxKB-1.7.2/apps/users/urls.py b/src/MaxKB-1.7.2/apps/users/urls.py new file mode 100644 index 0000000..55388d8 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/urls.py @@ -0,0 +1,24 @@ +from django.urls import path + +from . import views + +app_name = "user" +urlpatterns = [ + path('profile', views.Profile.as_view()), + path('user', views.User.as_view(), name="profile"), + path('user/list', views.User.Query.as_view()), + path('user/login', views.Login.as_view(), name='login'), + path('user/logout', views.Logout.as_view(), name='logout'), + # path('user/register', views.Register.as_view(), name="register"), + path("user/send_email", views.SendEmail.as_view(), name='send_email'), + path("user/check_code", views.CheckCode.as_view(), name='check_code'), + path("user/re_password", views.RePasswordView.as_view(), name='re_password'), + path("user/current/send_email", views.SendEmailToCurrentUserView.as_view(), name="send_email_current"), + path("user/current/reset_password", views.ResetCurrentUserPasswordView.as_view(), name="reset_password_current"), + path("user_manage", views.UserManage.as_view(), name="user_manage"), + path("user_manage/", views.UserManage.Operate.as_view(), name="user_manage_operate"), + path("user_manage//re_password", views.UserManage.RePassword.as_view(), + name="user_manage_re_password"), + path("user_manage//", views.UserManage.Page.as_view(), + name="user_manage_re_password"), +] diff --git a/src/MaxKB-1.7.2/apps/users/views/__init__.py b/src/MaxKB-1.7.2/apps/users/views/__init__.py new file mode 100644 index 0000000..ee3becc --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/views/__init__.py @@ -0,0 +1,9 @@ +# coding=utf-8 +""" + @project: smart-doc + @Author:虎 + @file: __init__.py.py + @date:2023/9/14 19:01 + @desc: +""" +from .user import * diff --git a/src/MaxKB-1.7.2/apps/users/views/user.py b/src/MaxKB-1.7.2/apps/users/views/user.py new file mode 100644 index 0000000..e691ff4 --- /dev/null +++ b/src/MaxKB-1.7.2/apps/users/views/user.py @@ -0,0 +1,303 @@ +# coding=utf-8 +""" + @project: qabot + @Author:虎 + @file: user.py + @date:2023/9/4 10:57 + @desc: +""" +from django.core import cache +from drf_yasg import openapi +from drf_yasg.utils import swagger_auto_schema +from rest_framework.decorators import action +from rest_framework.decorators import permission_classes +from rest_framework.permissions import AllowAny +from rest_framework.views import APIView +from rest_framework.views import Request + +from common.auth.authenticate import TokenAuth +from common.auth.authentication import has_permissions +from common.constants.permission_constants import PermissionConstants, CompareConstants, ViewPermission, RoleConstants +from common.response import result +from smartdoc.settings import JWT_AUTH +from users.serializers.user_serializers import RegisterSerializer, LoginSerializer, CheckCodeSerializer, \ + RePasswordSerializer, \ + SendEmailSerializer, UserProfile, UserSerializer, UserManageSerializer, UserInstanceSerializer, SystemSerializer + +user_cache = cache.caches['user_cache'] +token_cache = cache.caches['token_cache'] + + +class Profile(APIView): + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取MaxKB相关信息", + operation_id="获取MaxKB相关信息", + responses=result.get_api_response(SystemSerializer.get_response_body_api()), + tags=['系统参数']) + def get(self, request: Request): + return result.success(SystemSerializer.get_profile()) + + +class User(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取当前用户信息", + operation_id="获取当前用户信息", + responses=result.get_api_response(UserProfile.get_response_body_api()), + tags=['用户']) + @has_permissions(PermissionConstants.USER_READ) + def get(self, request: Request): + return result.success(UserProfile.get_user_profile(request.user)) + + class Query(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取用户列表", + operation_id="获取用户列表", + manual_parameters=UserSerializer.Query.get_request_params_api(), + responses=result.get_api_array_response(UserSerializer.Query.get_response_body_api()), + tags=['用户']) + @has_permissions(PermissionConstants.USER_READ) + def get(self, request: Request): + return result.success( + UserSerializer.Query(data={'email_or_username': request.query_params.get('email_or_username')}).list()) + + +class ResetCurrentUserPasswordView(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="修改当前用户密码", + operation_id="修改当前用户密码", + request_body=openapi.Schema( + type=openapi.TYPE_OBJECT, + required=['email', 'code', "password", 're_password'], + properties={ + 'code': openapi.Schema(type=openapi.TYPE_STRING, title="验证码", description="验证码"), + 'password': openapi.Schema(type=openapi.TYPE_STRING, title="密码", description="密码"), + 're_password': openapi.Schema(type=openapi.TYPE_STRING, title="密码", + description="密码") + } + ), + responses=RePasswordSerializer().get_response_body_api(), + tags=['用户']) + def post(self, request: Request): + data = {'email': request.user.email} + data.update(request.data) + serializer_obj = RePasswordSerializer(data=data) + if serializer_obj.reset_password(): + token_cache.delete(request.META.get('HTTP_AUTHORIZATION')) + return result.success(True) + return result.error("修改密码失败") + + +class SendEmailToCurrentUserView(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @permission_classes((AllowAny,)) + @swagger_auto_schema(operation_summary="发送邮件到当前用户", + operation_id="发送邮件到当前用户", + responses=SendEmailSerializer().get_response_body_api(), + tags=['用户']) + def post(self, request: Request): + serializer_obj = SendEmailSerializer(data={'email': request.user.email, 'type': "reset_password"}) + if serializer_obj.is_valid(raise_exception=True): + return result.success(serializer_obj.send()) + + +class Logout(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @permission_classes((AllowAny,)) + @swagger_auto_schema(operation_summary="登出", + operation_id="登出", + responses=SendEmailSerializer().get_response_body_api(), + tags=['用户']) + def post(self, request: Request): + token_cache.delete(request.META.get('HTTP_AUTHORIZATION')) + return result.success(True) + + +class Login(APIView): + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="登录", + operation_id="登录", + request_body=LoginSerializer().get_request_body_api(), + responses=LoginSerializer().get_response_body_api(), + security=[], + tags=['用户']) + def post(self, request: Request): + login_request = LoginSerializer(data=request.data) + # 校验请求参数 + user = login_request.is_valid(raise_exception=True) + token = login_request.get_user_token() + token_cache.set(token, user, timeout=JWT_AUTH['JWT_EXPIRATION_DELTA']) + return result.success(token) + + +class Register(APIView): + + @action(methods=['POST'], detail=False) + @permission_classes((AllowAny,)) + @swagger_auto_schema(operation_summary="用户注册", + operation_id="用户注册", + request_body=RegisterSerializer().get_request_body_api(), + responses=RegisterSerializer().get_response_body_api(), + security=[], + tags=['用户']) + def post(self, request: Request): + serializer_obj = RegisterSerializer(data=request.data) + if serializer_obj.is_valid(raise_exception=True): + serializer_obj.save() + return result.success("注册成功") + + +class RePasswordView(APIView): + + @action(methods=['POST'], detail=False) + @permission_classes((AllowAny,)) + @swagger_auto_schema(operation_summary="修改密码", + operation_id="修改密码", + request_body=RePasswordSerializer().get_request_body_api(), + responses=RePasswordSerializer().get_response_body_api(), + security=[], + tags=['用户']) + def post(self, request: Request): + serializer_obj = RePasswordSerializer(data=request.data) + return result.success(serializer_obj.reset_password()) + + +class CheckCode(APIView): + + @action(methods=['POST'], detail=False) + @permission_classes((AllowAny,)) + @swagger_auto_schema(operation_summary="校验验证码是否正确", + operation_id="校验验证码是否正确", + request_body=CheckCodeSerializer().get_request_body_api(), + responses=CheckCodeSerializer().get_response_body_api(), + security=[], + tags=['用户']) + def post(self, request: Request): + return result.success(CheckCodeSerializer(data=request.data).is_valid(raise_exception=True)) + + +class SendEmail(APIView): + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="发送邮件", + operation_id="发送邮件", + request_body=SendEmailSerializer().get_request_body_api(), + responses=SendEmailSerializer().get_response_body_api(), + security=[], + tags=['用户']) + def post(self, request: Request): + serializer_obj = SendEmailSerializer(data=request.data) + if serializer_obj.is_valid(raise_exception=True): + return result.success(serializer_obj.send()) + + +class UserManage(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['POST'], detail=False) + @swagger_auto_schema(operation_summary="添加用户", + operation_id="添加用户", + request_body=UserManageSerializer.UserInstance.get_request_body_api(), + responses=result.get_api_response(UserInstanceSerializer.get_response_body_api()), + tags=["用户管理"] + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN], + [PermissionConstants.USER_READ], + compare=CompareConstants.AND)) + def post(self, request: Request): + return result.success(UserManageSerializer().save(request.data)) + + class Page(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取用户分页列表", + operation_id="获取用户分页列表", + tags=["用户管理"], + manual_parameters=UserManageSerializer.Query.get_request_params_api(), + responses=result.get_page_api_response(UserInstanceSerializer.get_response_body_api()), + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN], + [PermissionConstants.USER_READ], + compare=CompareConstants.AND)) + def get(self, request: Request, current_page, page_size): + d = UserManageSerializer.Query( + data={'email_or_username': request.query_params.get('email_or_username', None), + 'user_id': str(request.user.id)}) + return result.success(d.page(current_page, page_size)) + + class RePassword(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改密码", + operation_id="修改密码", + manual_parameters=UserInstanceSerializer.get_request_params_api(), + request_body=UserManageSerializer.RePasswordInstance.get_request_body_api(), + responses=result.get_default_response(), + tags=["用户管理"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN], + [PermissionConstants.USER_READ], + compare=CompareConstants.AND)) + def put(self, request: Request, user_id): + return result.success( + UserManageSerializer.Operate(data={'id': user_id}).re_password(request.data, with_valid=True)) + + class Operate(APIView): + authentication_classes = [TokenAuth] + + @action(methods=['DELETE'], detail=False) + @swagger_auto_schema(operation_summary="删除用户", + operation_id="删除用户", + manual_parameters=UserInstanceSerializer.get_request_params_api(), + responses=result.get_default_response(), + tags=["用户管理"]) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN], + [PermissionConstants.USER_READ], + compare=CompareConstants.AND)) + def delete(self, request: Request, user_id): + return result.success(UserManageSerializer.Operate(data={'id': user_id}).delete(with_valid=True)) + + @action(methods=['GET'], detail=False) + @swagger_auto_schema(operation_summary="获取用户信息", + operation_id="获取用户信息", + manual_parameters=UserInstanceSerializer.get_request_params_api(), + responses=result.get_api_response(UserInstanceSerializer.get_response_body_api()), + tags=["用户管理"] + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN], + [PermissionConstants.USER_READ], + compare=CompareConstants.AND)) + def get(self, request: Request, user_id): + return result.success(UserManageSerializer.Operate(data={'id': user_id}).one(with_valid=True)) + + @action(methods=['PUT'], detail=False) + @swagger_auto_schema(operation_summary="修改用户信息", + operation_id="修改用户信息", + manual_parameters=UserInstanceSerializer.get_request_params_api(), + request_body=UserManageSerializer.UserEditInstance.get_request_body_api(), + responses=result.get_api_response(UserInstanceSerializer.get_response_body_api()), + tags=["用户管理"] + ) + @has_permissions(ViewPermission( + [RoleConstants.ADMIN], + [PermissionConstants.USER_READ], + compare=CompareConstants.AND)) + def put(self, request: Request, user_id): + return result.success( + UserManageSerializer.Operate(data={'id': user_id}).edit(request.data, with_valid=True)) diff --git a/src/MaxKB-1.7.2/config_example.yml b/src/MaxKB-1.7.2/config_example.yml new file mode 100644 index 0000000..e262de1 --- /dev/null +++ b/src/MaxKB-1.7.2/config_example.yml @@ -0,0 +1,11 @@ +# 数据库链接信息 +DB_NAME: maxkb +DB_HOST: localhost +DB_PORT: 5432 +DB_USER: root +DB_PASSWORD: xxxxxxx +DB_ENGINE: django.db.backends.postgresql_psycopg2 + +DEBUG: false + +TIME_ZONE: Asia/Shanghai diff --git a/src/MaxKB-1.7.2/installer/Dockerfile b/src/MaxKB-1.7.2/installer/Dockerfile new file mode 100644 index 0000000..c990f0a --- /dev/null +++ b/src/MaxKB-1.7.2/installer/Dockerfile @@ -0,0 +1,73 @@ +FROM ghcr.io/1panel-dev/maxkb-vector-model:v1.0.1 AS vector-model +FROM node:18-alpine3.18 AS web-build +COPY ui ui +RUN cd ui && \ + npm install && \ + npm run build && \ + rm -rf ./node_modules +FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.8 AS stage-build + +ARG DEPENDENCIES=" \ + python3-pip" + +RUN apt-get update && \ + apt-get install -y --no-install-recommends $DEPENDENCIES && \ + apt-get clean all && \ + rm -rf /var/lib/apt/lists/* + +COPY . /opt/maxkb/app +RUN mkdir -p /opt/maxkb/app /opt/maxkb/model /opt/maxkb/conf && \ + rm -rf /opt/maxkb/app/ui + +COPY --from=web-build ui /opt/maxkb/app/ui +WORKDIR /opt/maxkb/app +RUN python3 -m venv /opt/py3 && \ + pip install poetry --break-system-packages && \ + poetry config virtualenvs.create false && \ + . /opt/py3/bin/activate && \ + if [ "$(uname -m)" = "x86_64" ]; then sed -i 's/^torch.*/torch = {version = "^2.2.1+cpu", source = "pytorch"}/g' pyproject.toml; fi && \ + poetry install + +FROM ghcr.io/1panel-dev/maxkb-python-pg:python3.11-pg15.8 +ARG DOCKER_IMAGE_TAG=dev \ + BUILD_AT \ + GITHUB_COMMIT + +ENV MAXKB_VERSION="${DOCKER_IMAGE_TAG} (build at ${BUILD_AT}, commit: ${GITHUB_COMMIT})" \ + MAXKB_CONFIG_TYPE=ENV \ + MAXKB_DB_NAME=maxkb \ + MAXKB_DB_HOST=127.0.0.1 \ + MAXKB_DB_PORT=5432 \ + MAXKB_DB_USER=root \ + MAXKB_DB_PASSWORD=Password123@postgres \ + MAXKB_EMBEDDING_MODEL_NAME=/opt/maxkb/model/embedding/shibing624_text2vec-base-chinese \ + MAXKB_EMBEDDING_MODEL_PATH=/opt/maxkb/model/embedding \ + MAXKB_SANDBOX=true \ + LANG=en_US.UTF-8 \ + PATH=/opt/py3/bin:$PATH \ + POSTGRES_USER=root \ + POSTGRES_PASSWORD=Password123@postgres \ + PIP_TARGET=/opt/maxkb/app/sandbox/python-packages \ + PYTHONPATH=/opt/maxkb/app/sandbox/python-packages \ + PYTHONUNBUFFERED=1 + +WORKDIR /opt/maxkb/app +COPY --from=stage-build /opt/maxkb /opt/maxkb +COPY --from=stage-build /opt/py3 /opt/py3 +COPY --from=vector-model /opt/maxkb/app/model /opt/maxkb/model + +RUN chmod 755 /opt/maxkb/app/installer/run-maxkb.sh && \ + cp -r /opt/maxkb/model/base/hub /opt/maxkb/model/tokenizer && \ + cp -f /opt/maxkb/app/installer/run-maxkb.sh /usr/bin/run-maxkb.sh && \ + cp -f /opt/maxkb/app/installer/init.sql /docker-entrypoint-initdb.d && \ + mkdir -p /opt/maxkb/app/sandbox/python-packages && \ + find /opt/maxkb/app -mindepth 1 -not -name 'sandbox' -exec chmod 700 {} + && \ + chmod 755 /tmp && \ + useradd --no-create-home --home /opt/maxkb/app/sandbox --shell /bin/bash sandbox && \ + chown sandbox:sandbox /opt/maxkb/app/sandbox + + +EXPOSE 8080 + +ENTRYPOINT ["bash", "-c"] +CMD [ "/usr/bin/run-maxkb.sh" ] diff --git a/src/MaxKB-1.7.2/installer/Dockerfile-python-pg b/src/MaxKB-1.7.2/installer/Dockerfile-python-pg new file mode 100644 index 0000000..0cd1ebb --- /dev/null +++ b/src/MaxKB-1.7.2/installer/Dockerfile-python-pg @@ -0,0 +1,18 @@ +FROM python:3.11-slim-bullseye AS python-stage +FROM postgres:15.8-bullseye + +ARG DEPENDENCIES=" \ + libexpat1-dev \ + libffi-dev \ + curl \ + ca-certificates \ + vim \ + postgresql-15-pgvector" + +RUN ln -sf /usr/share/zoneinfo/Asia/Shanghai /etc/localtime && \ + echo "Asia/Shanghai" > /etc/timezone && \ + apt-get update && apt-get install -y --no-install-recommends $DEPENDENCIES && \ + apt-get clean all && \ + rm -rf /var/lib/apt/lists/* + +COPY --from=python-stage /usr/local /usr/local \ No newline at end of file diff --git a/src/MaxKB-1.7.2/installer/Dockerfile-vector-model b/src/MaxKB-1.7.2/installer/Dockerfile-vector-model new file mode 100644 index 0000000..a732661 --- /dev/null +++ b/src/MaxKB-1.7.2/installer/Dockerfile-vector-model @@ -0,0 +1,10 @@ +FROM python:3.11-slim-bookworm AS vector-model + +COPY installer/install_model.py install_model.py +RUN pip3 install --upgrade pip setuptools && \ + pip install pycrawlers && \ + pip install transformers && \ + python3 install_model.py + +FROM scratch +COPY --from=vector-model model /opt/maxkb/app/model \ No newline at end of file diff --git a/src/MaxKB-1.7.2/installer/config.yaml b/src/MaxKB-1.7.2/installer/config.yaml new file mode 100644 index 0000000..c9f45db --- /dev/null +++ b/src/MaxKB-1.7.2/installer/config.yaml @@ -0,0 +1,20 @@ +# 邮箱配置 +EMAIL_ADDRESS: ${EMAIL_ADDRESS} +EMAIL_USE_TLS: ${EMAIL_USE_TLS} +EMAIL_USE_SSL: ${EMAIL_USE_SSL} +EMAIL_HOST: ${EMAIL_HOST} +EMAIL_PORT: ${EMAIL_PORT} +EMAIL_HOST_USER: ${EMAIL_HOST_USER} +EMAIL_HOST_PASSWORD: ${EMAIL_HOST_PASSWORD} + +# 数据库链接信息 +DB_NAME: maxkb +DB_HOST: 127.0.0.1 +DB_PORT: 5432 +DB_USER: root +DB_PASSWORD: Password123@postgres +DB_ENGINE: django.db.backends.postgresql_psycopg2 +EMBEDDING_MODEL_PATH: /opt/maxkb/model/embedding +EMBEDDING_MODEL_NAME: /opt/maxkb/model/embedding/shibing624_text2vec-base-chinese + +DEBUG: false \ No newline at end of file diff --git a/src/MaxKB-1.7.2/installer/init.sql b/src/MaxKB-1.7.2/installer/init.sql new file mode 100644 index 0000000..dfc30f9 --- /dev/null +++ b/src/MaxKB-1.7.2/installer/init.sql @@ -0,0 +1,5 @@ +CREATE DATABASE "maxkb"; + +\c "maxkb"; + +CREATE EXTENSION "vector"; \ No newline at end of file diff --git a/src/MaxKB-1.7.2/installer/install_model.py b/src/MaxKB-1.7.2/installer/install_model.py new file mode 100644 index 0000000..fb46461 --- /dev/null +++ b/src/MaxKB-1.7.2/installer/install_model.py @@ -0,0 +1,69 @@ +# coding=utf-8 +""" + @project: maxkb + @Author:虎 + @file: install_model.py + @date:2023/12/18 14:02 + @desc: +""" +import json +import os.path +from pycrawlers import huggingface +from transformers import GPT2TokenizerFast +hg = huggingface() +prefix_dir = "./model" +model_config = [ + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), + 'pretrained_model_name_or_path': 'gpt2' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), + 'pretrained_model_name_or_path': 'gpt2-medium' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), + 'pretrained_model_name_or_path': 'gpt2-large' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), + 'pretrained_model_name_or_path': 'gpt2-xl' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'cache_dir': os.path.join(prefix_dir, 'base/hub'), + 'pretrained_model_name_or_path': 'distilgpt2' + }, + 'download_function': GPT2TokenizerFast.from_pretrained + }, + { + 'download_params': { + 'urls': ["https://huggingface.co/shibing624/text2vec-base-chinese/tree/main"], + 'file_save_paths': [os.path.join(prefix_dir, 'embedding',"shibing624_text2vec-base-chinese")] + }, + 'download_function': hg.get_batch_data + } + +] + + +def install(): + for model in model_config: + print(json.dumps(model.get('download_params'))) + model.get('download_function')(**model.get('download_params')) + + +if __name__ == '__main__': + install() diff --git a/src/MaxKB-1.7.2/installer/run-maxkb.sh b/src/MaxKB-1.7.2/installer/run-maxkb.sh new file mode 100644 index 0000000..43374df --- /dev/null +++ b/src/MaxKB-1.7.2/installer/run-maxkb.sh @@ -0,0 +1,10 @@ +#!/bin/bash +rm -f /opt/maxkb/app/tmp/*.pid +# Start postgresql +docker-entrypoint.sh postgres & +sleep 10 +# Wait postgresql +until pg_isready --host=127.0.0.1; do sleep 1 && echo "waiting for postgres"; done + +# Start MaxKB +python /opt/maxkb/app/main.py start \ No newline at end of file diff --git a/src/MaxKB-1.7.2/installer/start-maxkb.sh b/src/MaxKB-1.7.2/installer/start-maxkb.sh new file mode 100644 index 0000000..4e88eff --- /dev/null +++ b/src/MaxKB-1.7.2/installer/start-maxkb.sh @@ -0,0 +1,3 @@ +#!/bin/bash +rm -f /opt/maxkb/app/tmp/*.pid +python /opt/maxkb/app/main.py start \ No newline at end of file diff --git a/src/MaxKB-1.7.2/main.py b/src/MaxKB-1.7.2/main.py new file mode 100644 index 0000000..a8bd74a --- /dev/null +++ b/src/MaxKB-1.7.2/main.py @@ -0,0 +1,122 @@ +import argparse +import logging +import os +import sys +import time + +import django +from django.core import management + +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) +APP_DIR = os.path.join(BASE_DIR, 'apps') + +os.chdir(BASE_DIR) +sys.path.insert(0, APP_DIR) +os.environ.setdefault("DJANGO_SETTINGS_MODULE", "smartdoc.settings") +django.setup() + + +def collect_static(): + """ + 收集静态文件到指定目录 + 本项目主要是将前端vue/dist的前端项目放到静态目录下面 + :return: + """ + logging.info("Collect static files") + try: + management.call_command('collectstatic', '--no-input', '-c', verbosity=0, interactive=False) + logging.info("Collect static files done") + except: + pass + + +def perform_db_migrate(): + """ + 初始化数据库表 + """ + logging.info("Check database structure change ...") + logging.info("Migrate model change to database ...") + try: + management.call_command('migrate') + except Exception as e: + logging.error('Perform migrate failed, exit', exc_info=True) + sys.exit(11) + + +def start_services(): + services = args.services if isinstance(args.services, list) else [args.services] + start_args = [] + if args.daemon: + start_args.append('--daemon') + if args.force: + start_args.append('--force') + if args.worker: + start_args.extend(['--worker', str(args.worker)]) + else: + worker = os.environ.get('CORE_WORKER') + if isinstance(worker, str) and worker.isdigit(): + start_args.extend(['--worker', worker]) + + try: + management.call_command(action, *services, *start_args) + except KeyboardInterrupt: + logging.info('Cancel ...') + time.sleep(2) + except Exception as exc: + logging.error("Start service error {}: {}".format(services, exc)) + time.sleep(2) + + +def dev(): + services = args.services if isinstance(args.services, list) else args.services + if services.__contains__('web'): + management.call_command('runserver', "0.0.0.0:8080") + elif services.__contains__('celery'): + management.call_command('celery', 'celery') + elif services.__contains__('local_model'): + os.environ.setdefault('SERVER_NAME', 'local_model') + from smartdoc.const import CONFIG + bind = f'{CONFIG.get("LOCAL_MODEL_HOST")}:{CONFIG.get("LOCAL_MODEL_PORT")}' + management.call_command('runserver', bind) + + +if __name__ == '__main__': + os.environ['HF_HOME'] = '/opt/maxkb/model/base' + parser = argparse.ArgumentParser( + description=""" + qabot service control tools; + + Example: \r\n + + %(prog)s start all -d; + """ + ) + parser.add_argument( + 'action', type=str, + choices=("start", "dev", "upgrade_db", "collect_static"), + help="Action to run" + ) + args, e = parser.parse_known_args() + parser.add_argument( + "services", type=str, default='all' if args.action == 'start' else 'web', nargs="*", + choices=("all", "web", "task") if args.action == 'start' else ("web", "celery", 'local_model'), + help="The service to start", + ) + + parser.add_argument('-d', '--daemon', nargs="?", const=True) + parser.add_argument('-w', '--worker', type=int, nargs="?") + parser.add_argument('-f', '--force', nargs="?", const=True) + args = parser.parse_args() + action = args.action + if action == "upgrade_db": + perform_db_migrate() + elif action == "collect_static": + collect_static() + elif action == 'dev': + collect_static() + perform_db_migrate() + dev() + else: + collect_static() + perform_db_migrate() + start_services() diff --git a/src/MaxKB-1.7.2/package-lock.json b/src/MaxKB-1.7.2/package-lock.json new file mode 100644 index 0000000..d70a5c3 --- /dev/null +++ b/src/MaxKB-1.7.2/package-lock.json @@ -0,0 +1,6 @@ +{ + "name": "MaxKB", + "lockfileVersion": 3, + "requires": true, + "packages": {} +} diff --git a/src/MaxKB-1.7.2/pyproject.toml b/src/MaxKB-1.7.2/pyproject.toml new file mode 100644 index 0000000..24bbcb0 --- /dev/null +++ b/src/MaxKB-1.7.2/pyproject.toml @@ -0,0 +1,64 @@ +[tool.poetry] +name = "maxkb" +version = "0.1.0" +description = "智能知识库问答系统" +authors = ["shaohuzhang1 "] +readme = "README.md" + +[tool.poetry.dependencies] +python = ">=3.11,<3.12" +django = "4.2.15" +djangorestframework = "^3.15.2" +drf-yasg = "1.21.7" +django-filter = "23.2" +langchain = "0.2.16" +langchain_community = "0.2.17" +langchain-huggingface = "^0.0.3" +psycopg2-binary = "2.9.7" +jieba = "^0.42.1" +diskcache = "^5.6.3" +pillow = "^10.2.0" +filetype = "^1.2.0" +torch = "2.2.1" +sentence-transformers = "^2.2.2" +openai = "^1.13.3" +tiktoken = "^0.7.0" +qianfan = "^0.3.6.1" +pycryptodome = "^3.19.0" +beautifulsoup4 = "^4.12.2" +html2text = "^2024.2.26" +langchain-openai = "^0.1.8" +django-ipware = "^6.0.4" +django-apscheduler = "^0.6.2" +pymupdf = "1.24.9" +pypdf = "4.3.1" +rapidocr-onnxruntime = "1.3.24" +python-docx = "^1.1.0" +xlwt = "^1.3.0" +dashscope = "^1.17.0" +zhipuai = "^2.0.1" +httpx = "^0.27.0" +httpx-sse = "^0.4.0" +websockets = "^13.0" +langchain-google-genai = "^1.0.3" +openpyxl = "^3.1.2" +xlrd = "^2.0.1" +gunicorn = "^22.0.0" +python-daemon = "3.0.1" +boto3 = "^1.34.160" +tencentcloud-sdk-python = "^3.0.1209" +xinference-client = "^0.14.1.post1" +psutil = "^6.0.0" +celery = { extras = ["sqlalchemy"], version = "^5.4.0" } +django-celery-beat = "^2.6.0" +celery-once = "^3.0.1" +anthropic = "^0.34.2" +pylint = "3.1.0" +[build-system] +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" + +[[tool.poetry.source]] +name = "pytorch" +url = "https://download.pytorch.org/whl/cpu" +priority = "explicit" \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/.eslintrc.cjs b/src/MaxKB-1.7.2/ui/.eslintrc.cjs new file mode 100644 index 0000000..d6c3088 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/.eslintrc.cjs @@ -0,0 +1,21 @@ +/* eslint-env node */ +require('@rushstack/eslint-patch/modern-module-resolution') + +module.exports = { + root: true, + 'extends': [ + 'plugin:vue/vue3-essential', + 'eslint:recommended', + '@vue/eslint-config-typescript', + '@vue/eslint-config-prettier/skip-formatting' + ], + parserOptions: { + ecmaVersion: 'latest' + }, + rules: { + // 添加组件命名忽略规则 + "vue/multi-word-component-names": ["error",{ + "ignores": ["index","main"]//需要忽略的组件名 + }] + } +} diff --git a/src/MaxKB-1.7.2/ui/.gitignore b/src/MaxKB-1.7.2/ui/.gitignore new file mode 100644 index 0000000..38adffa --- /dev/null +++ b/src/MaxKB-1.7.2/ui/.gitignore @@ -0,0 +1,28 @@ +# Logs +logs +*.log +npm-debug.log* +yarn-debug.log* +yarn-error.log* +pnpm-debug.log* +lerna-debug.log* + +node_modules +.DS_Store +dist +dist-ssr +coverage +*.local + +/cypress/videos/ +/cypress/screenshots/ + +# Editor directories and files +.vscode/* +!.vscode/extensions.json +.idea +*.suo +*.ntvs* +*.njsproj +*.sln +*.sw? diff --git a/src/MaxKB-1.7.2/ui/.prettierrc.json b/src/MaxKB-1.7.2/ui/.prettierrc.json new file mode 100644 index 0000000..66e2335 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/.prettierrc.json @@ -0,0 +1,8 @@ +{ + "$schema": "https://json.schemastore.org/prettierrc", + "semi": false, + "tabWidth": 2, + "singleQuote": true, + "printWidth": 100, + "trailingComma": "none" +} \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/README.md b/src/MaxKB-1.7.2/ui/README.md new file mode 100644 index 0000000..12c6c8c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/README.md @@ -0,0 +1,52 @@ +# web + +This template should help get you started developing with Vue 3 in Vite. + +## Recommended IDE Setup + +[VSCode](https://code.visualstudio.com/) + [Volar](https://marketplace.visualstudio.com/items?itemName=Vue.volar) (and disable Vetur) + [TypeScript Vue Plugin (Volar)](https://marketplace.visualstudio.com/items?itemName=Vue.vscode-typescript-vue-plugin). + +## Type Support for `.vue` Imports in TS + +TypeScript cannot handle type information for `.vue` imports by default, so we replace the `tsc` CLI with `vue-tsc` for type checking. In editors, we need [TypeScript Vue Plugin (Volar)](https://marketplace.visualstudio.com/items?itemName=Vue.vscode-typescript-vue-plugin) to make the TypeScript language service aware of `.vue` types. + +If the standalone TypeScript plugin doesn't feel fast enough to you, Volar has also implemented a [Take Over Mode](https://github.com/johnsoncodehk/volar/discussions/471#discussioncomment-1361669) that is more performant. You can enable it by the following steps: + +1. Disable the built-in TypeScript Extension + 1) Run `Extensions: Show Built-in Extensions` from VSCode's command palette + 2) Find `TypeScript and JavaScript Language Features`, right click and select `Disable (Workspace)` +2. Reload the VSCode window by running `Developer: Reload Window` from the command palette. + +## Customize configuration + +See [Vite Configuration Reference](https://vitejs.dev/config/). + +## Project Setup + +```sh +npm install +``` + +### Compile and Hot-Reload for Development + +```sh +npm run dev +``` + +### Type-Check, Compile and Minify for Production + +```sh +npm run build +``` + +### Run Unit Tests with [Vitest](https://vitest.dev/) + +```sh +npm run test:unit +``` + +### Lint with [ESLint](https://eslint.org/) + +```sh +npm run lint +``` diff --git a/src/MaxKB-1.7.2/ui/env.d.ts b/src/MaxKB-1.7.2/ui/env.d.ts new file mode 100644 index 0000000..52f5452 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/env.d.ts @@ -0,0 +1,14 @@ +/// +declare module 'element-plus/dist/locale/zh-cn.mjs' +declare module 'markdown-it-task-lists' +declare module 'markdown-it-abbr' +declare module 'markdown-it-anchor' +declare module 'markdown-it-footnote' +declare module 'markdown-it-sub' +declare module 'markdown-it-sup' +declare module 'markdown-it-toc-done-right' +declare module 'katex' +interface ImportMeta { + readonly env: ImportMetaEnv +} +declare type Recordable = Record; diff --git a/src/MaxKB-1.7.2/ui/index.html b/src/MaxKB-1.7.2/ui/index.html new file mode 100644 index 0000000..09bec9a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/index.html @@ -0,0 +1,18 @@ + + + + + + + + %VITE_APP_TITLE% + + +
+ + + diff --git a/src/MaxKB-1.7.2/ui/package.json b/src/MaxKB-1.7.2/ui/package.json new file mode 100644 index 0000000..e6067f6 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/package.json @@ -0,0 +1,75 @@ +{ + "name": "web", + "version": "v1.0.0", + "private": true, + "scripts": { + "dev": "vite", + "build": "run-p type-check build-only", + "preview": "vite preview", + "test:unit": "vitest", + "build-only": "vite build", + "type-check": "vue-tsc --noEmit -p tsconfig.vitest.json --composite false", + "lint": "eslint . --ext .vue,.js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix --ignore-path .gitignore", + "format": "prettier --write src/" + }, + "dependencies": { + "@codemirror/theme-one-dark": "^6.1.2", + "@ctrl/tinycolor": "^4.1.0", + "@logicflow/core": "^1.2.27", + "@logicflow/extension": "^1.2.27", + "@vueuse/core": "^10.9.0", + "@wecom/jssdk": "^2.1.0", + "axios": "^0.28.0", + "codemirror": "^6.0.1", + "cropperjs": "^1.6.2", + "echarts": "^5.5.0", + "element-plus": "^2.5.6", + "file-saver": "^2.0.5", + "highlight.js": "^11.9.0", + "install": "^0.13.0", + "katex": "^0.16.10", + "lodash": "^4.17.21", + "marked": "^12.0.2", + "md-editor-v3": "^4.16.7", + "medium-zoom": "^1.1.0", + "mermaid": "^10.9.0", + "mitt": "^3.0.0", + "moment": "^2.30.1", + "npm": "^10.2.4", + "nprogress": "^0.2.0", + "pinia": "^2.1.6", + "pinyin-pro": "^3.18.2", + "recorder-core": "^1.3.24040900", + "screenfull": "^6.0.2", + "use-element-plus-theme": "^0.0.5", + "vue": "^3.3.4", + "vue-clipboard3": "^2.0.0", + "vue-codemirror": "^6.1.1", + "vue-i18n": "^9.13.1", + "vue-router": "^4.2.4" + }, + "devDependencies": { + "@rushstack/eslint-patch": "^1.3.2", + "@tsconfig/node18": "^18.2.0", + "@types/file-saver": "^2.0.7", + "@types/jsdom": "^21.1.1", + "@types/node": "^18.17.5", + "@types/nprogress": "^0.2.0", + "@vitejs/plugin-vue": "^4.3.1", + "@vue/eslint-config-prettier": "^8.0.0", + "@vue/eslint-config-typescript": "^11.0.3", + "@vue/test-utils": "^2.4.1", + "@vue/tsconfig": "^0.4.0", + "eslint": "^8.46.0", + "eslint-plugin-vue": "^9.16.1", + "jsdom": "^22.1.0", + "npm-run-all": "^4.1.5", + "prettier": "^3.0.0", + "sass": "1.66.1", + "typescript": "~5.1.6", + "unplugin-vue-define-options": "^1.3.18", + "vite": "^4.4.9", + "vitest": "^0.34.2", + "vue-tsc": "^1.8.8" + } +} diff --git a/src/MaxKB-1.7.2/ui/public/MaxKB.gif b/src/MaxKB-1.7.2/ui/public/MaxKB.gif new file mode 100644 index 0000000..f18b93b Binary files /dev/null and b/src/MaxKB-1.7.2/ui/public/MaxKB.gif differ diff --git a/src/MaxKB-1.7.2/ui/public/favicon.ico b/src/MaxKB-1.7.2/ui/public/favicon.ico new file mode 100644 index 0000000..7d9781e Binary files /dev/null and b/src/MaxKB-1.7.2/ui/public/favicon.ico differ diff --git a/src/MaxKB-1.7.2/ui/src/App.vue b/src/MaxKB-1.7.2/ui/src/App.vue new file mode 100644 index 0000000..8664306 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/App.vue @@ -0,0 +1,7 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/api/application-overview.ts b/src/MaxKB-1.7.2/ui/src/api/application-overview.ts new file mode 100644 index 0000000..0513a0d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/application-overview.ts @@ -0,0 +1,90 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' + +import { type Ref } from 'vue' + +const prefix = '/application' + +/** + * API_KEY列表 + * @param 参数 application_id + */ +const getAPIKey: (application_id: string, loading?: Ref) => Promise> = ( + application_id, + loading +) => { + return get(`${prefix}/${application_id}/api_key`, undefined, loading) +} + +/** + * 新增API_KEY + * @param 参数 application_id + */ +const postAPIKey: (application_id: string, loading?: Ref) => Promise> = ( + application_id, + loading +) => { + return post(`${prefix}/${application_id}/api_key`, {}, undefined, loading) +} + +/** + * 删除API_KEY + * @param 参数 application_id api_key_id + */ +const delAPIKey: ( + application_id: String, + api_key_id: String, + loading?: Ref +) => Promise> = (application_id, api_key_id, loading) => { + return del(`${prefix}/${application_id}/api_key/${api_key_id}`, undefined, undefined, loading) +} + +/** + * 修改API_KEY + * @param 参数 application_id,api_key_id + * data { + * is_active: boolean + * } + */ +const putAPIKey: ( + application_id: string, + api_key_id: String, + data: any, + loading?: Ref +) => Promise> = (application_id, api_key_id, data, loading) => { + return put(`${prefix}/${application_id}/api_key/${api_key_id}`, data, undefined, loading) +} + +/** + * 统计 + * @param 参数 application_id, data + */ +const getStatistics: ( + application_id: string, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return get(`${prefix}/${application_id}/statistics/chat_record_aggregate_trend`, data, loading) +} + +/** + * 修改应用icon + * @param 参数 application_id + * data: file + */ +const putAppIcon: ( + application_id: string, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return put(`${prefix}/${application_id}/edit_icon`, data, undefined, loading) +} + +export default { + getAPIKey, + postAPIKey, + delAPIKey, + putAPIKey, + getStatistics, + putAppIcon +} diff --git a/src/MaxKB-1.7.2/ui/src/api/application-xpack.ts b/src/MaxKB-1.7.2/ui/src/api/application-xpack.ts new file mode 100644 index 0000000..25e973f --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/application-xpack.ts @@ -0,0 +1,41 @@ +import { Result } from '@/request/Result' +import { get, put } from '@/request/index' +import { type Ref } from 'vue' + +const prefix = '/application' + +/** + * 替换社区版-获取AccessToken + * @param 参数 application_id + */ +const getAccessToken: (application_id: string, loading?: Ref) => Promise> = ( + application_id, + loading +) => { + return get(`${prefix}/${application_id}/setting`, undefined, loading) +} + +/** + * 替换社区版-修改AccessToken + * @param 参数 application_id + * data { + * "show_source": boolean, + * "show_history": boolean, + * "draggable": boolean, + * "show_guide": boolean, + * "avatar": file, + * "float_icon": file, + * } + */ +const putAccessToken: ( + application_id: string, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return put(`${prefix}/${application_id}/setting`, data, undefined, loading) +} + +export default { + getAccessToken, + putAccessToken +} diff --git a/src/MaxKB-1.7.2/ui/src/api/application.ts b/src/MaxKB-1.7.2/ui/src/api/application.ts new file mode 100644 index 0000000..ac81e41 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/application.ts @@ -0,0 +1,496 @@ +import { Result } from '@/request/Result' +import { get, post, postStream, del, put, request, download } from '@/request/index' +import type { pageRequest } from '@/api/type/common' +import type { ApplicationFormType } from '@/api/type/application' +import { type Ref } from 'vue' +import type { FormField } from '@/components/dynamics-form/type' + +const prefix = '/application' + +/** + * 获取全部应用 + * @param 参数 + */ +const getAllAppilcation: () => Promise> = () => { + return get(`${prefix}`) +} + +/** + * 获取分页应用 + * page { + "current_page": "string", + "page_size": "string", + } + * param { + "name": "string", + } + */ +const getApplication: ( + page: pageRequest, + param: any, + loading?: Ref +) => Promise> = (page, param, loading) => { + return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading) +} + +/** + * 创建应用 + * @param 参数 + */ +const postApplication: ( + data: ApplicationFormType, + loading?: Ref +) => Promise> = (data, loading) => { + return post(`${prefix}`, data, undefined, loading) +} + +/** + * 修改应用 + * @param 参数 + */ +const putApplication: ( + application_id: String, + data: ApplicationFormType, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return put(`${prefix}/${application_id}`, data, undefined, loading) +} + +/** + * 删除应用 + * @param 参数 application_id + */ +const delApplication: ( + application_id: String, + loading?: Ref +) => Promise> = (application_id, loading) => { + return del(`${prefix}/${application_id}`, undefined, {}, loading) +} + +/** + * 应用详情 + * @param 参数 application_id + */ +const getApplicationDetail: ( + application_id: string, + loading?: Ref +) => Promise> = (application_id, loading) => { + return get(`${prefix}/${application_id}`, undefined, loading) +} + +/** + * 获得当前应用可使用的知识库 + * @param 参数 application_id + */ +const getApplicationDataset: ( + application_id: string, + loading?: Ref +) => Promise> = (application_id, loading) => { + return get(`${prefix}/${application_id}/list_dataset`, undefined, loading) +} + +/** + * 获取AccessToken + * @param 参数 application_id + */ +const getAccessToken: (application_id: string, loading?: Ref) => Promise> = ( + application_id, + loading +) => { + return get(`${prefix}/${application_id}/access_token`, undefined, loading) +} + +/** + * 修改AccessToken + * @param 参数 application_id + * data { + * "is_active": true + * } + */ +const putAccessToken: ( + application_id: string, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return put(`${prefix}/${application_id}/access_token`, data, undefined, loading) +} + +/** + * 应用认证 + * @param 参数 + { + "access_token": "string" + } + */ +const postAppAuthentication: ( + access_token: string, + loading?: Ref, + authentication_value?: any +) => Promise = (access_token, loading, authentication_value) => { + return post( + `${prefix}/authentication`, + { access_token: access_token, authentication_value }, + undefined, + loading + ) +} + +/** + * 对话获取应用相关信息 + * @param 参数 + { + "access_token": "string" + } + */ +const getAppProfile: (loading?: Ref) => Promise = (loading) => { + return get(`${prefix}/profile`, undefined, loading) +} + +/** + * 获得临时回话Id + * @param 参数 + + } + */ +const postChatOpen: (data: ApplicationFormType) => Promise> = (data) => { + return post(`${prefix}/chat/open`, data) +} + +/** + * 获得工作流临时回话Id + * @param 参数 + + } + */ +const postWorkflowChatOpen: (data: ApplicationFormType) => Promise> = (data) => { + return post(`${prefix}/chat_workflow/open`, data) +} + +/** + * 正式回话Id + * @param 参数 + * { + "model_id": "string", + "multiple_rounds_dialogue": true, + "dataset_id_list": [ + "string" + ] + } + */ +const getChatOpen: (application_id: String) => Promise> = (application_id) => { + return get(`${prefix}/${application_id}/chat/open`) +} +/** + * 对话 + * @param 参数 + * chat_id: string + * data + */ +const postChatMessage: (chat_id: string, data: any) => Promise = (chat_id, data) => { + return postStream(`/api${prefix}/chat_message/${chat_id}`, data) +} + +/** + * 点赞、点踩 + * @param 参数 + * application_id : string; chat_id : string; chat_record_id : string + * { + "vote_status": "string", // -1 0 1 + } + */ +const putChatVote: ( + application_id: string, + chat_id: string, + chat_record_id: string, + vote_status: string, + loading?: Ref +) => Promise = (application_id, chat_id, chat_record_id, vote_status, loading) => { + return put( + `${prefix}/${application_id}/chat/${chat_id}/chat_record/${chat_record_id}/vote`, + { + vote_status + }, + undefined, + loading + ) +} + +/** + * 命中测试列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationHitTest: ( + application_id: string, + data: any, + loading?: Ref +) => Promise>> = (application_id, data, loading) => { + return get(`${prefix}/${application_id}/hit_test`, data, loading) +} + +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, loading) +} + +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationRerankerModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, { model_type: 'RERANKER' }, loading) +} + +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationSTTModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, { model_type: 'STT' }, loading) +} + +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getApplicationTTSModel: ( + application_id: string, + loading?: Ref +) => Promise>> = (application_id, loading) => { + return get(`${prefix}/${application_id}/model`, { model_type: 'TTS' }, loading) +} + +/** + * 发布应用 + * @param 参数 + */ +const putPublishApplication: ( + application_id: String, + data: ApplicationFormType, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return put(`${prefix}/${application_id}/publish`, data, undefined, loading) +} +/** + * 获取应用所属的函数库列表 + * @param application_id 应用id + * @param loading + * @returns + */ +const listFunctionLib: (application_id: String, loading?: Ref) => Promise> = ( + application_id, + loading +) => { + return get(`${prefix}/${application_id}/function_lib`, undefined, loading) +} +/** + * 获取应用所属的函数库 + * @param application_id + * @param function_lib_id + * @param loading + * @returns + */ +const getFunctionLib: ( + application_id: String, + function_lib_id: String, + loading?: Ref +) => Promise> = (application_id, function_lib_id, loading) => { + return get(`${prefix}/${application_id}/function_lib/${function_lib_id}`, undefined, loading) +} +/** + * 获取模型参数表单 + * @param application_id 应用id + * @param model_id 模型id + * @param loading + * @returns + */ +const getModelParamsForm: ( + application_id: String, + model_id: String, + loading?: Ref +) => Promise>> = (application_id, model_id, loading) => { + return get(`${prefix}/${application_id}/model_params_form/${model_id}`, undefined, loading) +} + +/** + * 语音转文本 + */ +const postSpeechToText: ( + application_id: String, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return post(`${prefix}/${application_id}/speech_to_text`, data, undefined, loading) +} + +/** + * 文本转语音 + */ +const postTextToSpeech: ( + application_id: String, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return download(`${prefix}/${application_id}/text_to_speech`, 'post', data, undefined, loading) +} + +/** + * 播放测试文本 + */ +const playDemoText: ( + application_id: String, + data: any, + loading?: Ref +) => Promise> = (application_id, data, loading) => { + return download(`${prefix}/${application_id}/play_demo_text`, 'post', data, undefined, loading) +} +/** + * 获取平台状态 + */ +const getPlatformStatus: (application_id: string) => Promise> = (application_id) => { + return get(`/platform/${application_id}/status`) +} +/** + * 获取平台配置 + */ +const getPlatformConfig: (application_id: string, type: string) => Promise> = ( + application_id, + type +) => { + return get(`/platform/${application_id}/${type}`) +} +/** + * 更新平台配置 + */ +const updatePlatformConfig: ( + application_id: string, + type: string, + data: any +) => Promise> = (application_id, type, data) => { + return post(`/platform/${application_id}/${type}`, data) +} +/** + * 更新平台状态 + */ +const updatePlatformStatus: (application_id: string, data: any) => Promise> = ( + application_id, + data +) => { + return post(`/platform/${application_id}/status`, data) +} +/** + * 验证密码 + */ +const validatePassword: ( + application_id: string, + password: string, + loading?: Ref +) => Promise> = (application_id, password, loading) => { + return get(`/application/${application_id}/auth/${password}`, undefined, loading) +} + +/** + * workflow历史版本 + */ +const getWorkFlowVersion: ( + application_id: string, + loading?: Ref +) => Promise> = (application_id, loading) => { + return get(`/application/${application_id}/work_flow_version`, undefined, loading) +} + +/** + * workflow历史版本详情 + */ +const getWorkFlowVersionDetail: ( + application_id: string, + application_version_id: string, + loading?: Ref +) => Promise> = (application_id, application_version_id, loading) => { + return get( + `/application/${application_id}/work_flow_version/${application_version_id}`, + undefined, + loading + ) +} +/** + * 修改workflow历史版本 + */ +const putWorkFlowVersion: ( + application_id: string, + application_version_id: string, + data: any, + loading?: Ref +) => Promise> = (application_id, application_version_id, data, loading) => { + return put( + `/application/${application_id}/work_flow_version/${application_version_id}`, + data, + undefined, + loading + ) +} + +export default { + getAllAppilcation, + getApplication, + postApplication, + putApplication, + postChatOpen, + getChatOpen, + postChatMessage, + delApplication, + getApplicationDetail, + getApplicationDataset, + getAccessToken, + putAccessToken, + postAppAuthentication, + getAppProfile, + putChatVote, + getApplicationHitTest, + getApplicationModel, + putPublishApplication, + postWorkflowChatOpen, + listFunctionLib, + getFunctionLib, + getModelParamsForm, + getApplicationRerankerModel, + getApplicationSTTModel, + getApplicationTTSModel, + postSpeechToText, + postTextToSpeech, + getPlatformStatus, + getPlatformConfig, + updatePlatformConfig, + updatePlatformStatus, + validatePassword, + getWorkFlowVersion, + getWorkFlowVersionDetail, + putWorkFlowVersion, + playDemoText +} diff --git a/src/MaxKB-1.7.2/ui/src/api/auth-setting.ts b/src/MaxKB-1.7.2/ui/src/api/auth-setting.ts new file mode 100644 index 0000000..e1d239b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/auth-setting.ts @@ -0,0 +1,39 @@ +import {Result} from '@/request/Result' +import {get, post, del, put} from '@/request/index' +import type {pageRequest} from '@/api/type/common' +import {type Ref} from 'vue' + +const prefix = '/auth' +/** + * 获取认证设置 + */ +const getAuthSetting: (auth_type: string, loading?: Ref) => Promise> = (auth_type, loading) => { + return get(`${prefix}/${auth_type}/detail`, undefined, loading) +} + +/** + * 邮箱测试 + */ +const postAuthSetting: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}/connection`, data, undefined, loading) +} + +/** + * 修改邮箱设置 + */ +const putAuthSetting: (auth_type: string, data: any, loading?: Ref) => Promise> = ( + auth_type, + data, + loading +) => { + return put(`${prefix}/${auth_type}/info`, data, undefined, loading) +} + +export default { + getAuthSetting, + postAuthSetting, + putAuthSetting +} diff --git a/src/MaxKB-1.7.2/ui/src/api/dataset.ts b/src/MaxKB-1.7.2/ui/src/api/dataset.ts new file mode 100644 index 0000000..c50e645 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/dataset.ts @@ -0,0 +1,236 @@ +import { Result } from '@/request/Result' +import { get, post, del, put, exportExcel } from '@/request/index' +import type { datasetData } from '@/api/type/dataset' +import type { pageRequest } from '@/api/type/common' +import type { ApplicationFormType } from '@/api/type/application' +import { type Ref } from 'vue' +const prefix = '/dataset' + +/** + * 获取分页知识库 + * @param 参数 + * page { + "current_page": "string", + "page_size": "string", + } + * param { + "name": "string", + } + */ +const getDataset: ( + page: pageRequest, + param: any, + loading?: Ref +) => Promise> = (page, param, loading) => { + return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading) +} + +/** + * 获取全部知识库 + * @param 参数 + */ +const getAllDataset: (loading?: Ref) => Promise> = (loading) => { + return get(`${prefix}`, undefined, loading) +} + +/** + * 删除知识库 + * @param 参数 dataset_id + */ +const delDataset: (dataset_id: String, loading?: Ref) => Promise> = ( + dataset_id, + loading +) => { + return del(`${prefix}/${dataset_id}`, undefined, {}, loading) +} + +/** + * 创建知识库 + * @param 参数 + * { + "name": "string", + "desc": "string", + "documents": [ + { + "name": "string", + "paragraphs": [ + { + "content": "string", + "title": "string", + "problem_list": [ + { + "id": "string", + "content": "string" + } + ] + } + ] + } + ] +} + */ +const postDataset: (data: datasetData, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}`, data, undefined, loading, 1000 * 60 * 5) +} + +/** + * 创建Web知识库 + * @param 参数 + * { + "name": "string", + "desc": "string", + "source_url": "string", + "selector": "string", +} + */ +const postWebDataset: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}/web`, data, undefined, loading) +} + +/** + * 创建QA知识库 + * @param 参数 formData + * { + "file": "file", + "name": "string", + "desc": "string", + } + */ +const postQADataset: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}/qa`, data, undefined, loading) +} + +/** + * 知识库详情 + * @param 参数 dataset_id + */ +const getDatasetDetail: (dataset_id: string, loading?: Ref) => Promise> = ( + dataset_id, + loading +) => { + return get(`${prefix}/${dataset_id}`, undefined, loading) +} + +/** + * 修改知识库信息 + * @param 参数 + * dataset_id + * { + "name": "string", + "desc": true + } + */ +const putDataset: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put(`${prefix}/${dataset_id}`, data, undefined, loading) +} +/** + * 获取知识库 可关联的应用列表 + * @param dataset_id + * @param loading + * @returns + */ +const listUsableApplication: ( + dataset_id: string, + loading?: Ref +) => Promise>> = (dataset_id, loading) => { + return get(`${prefix}/${dataset_id}/application`, {}, loading) +} + +/** + * 命中测试列表 + * @param dataset_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getDatasetHitTest: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise>> = (dataset_id, data, loading) => { + return get(`${prefix}/${dataset_id}/hit_test`, data, loading) +} + +/** + * 同步知识库 + * @param 参数 dataset_id + * @query 参数 sync_type // 同步类型->replace:替换同步,complete:完整同步 + */ +const putSyncWebDataset: ( + dataset_id: string, + sync_type: string, + loading?: Ref +) => Promise> = (dataset_id, sync_type, loading) => { + return put(`${prefix}/${dataset_id}/sync_web`, undefined, { sync_type }, loading) +} + +/** + * 重新向量化知识库 + * @param 参数 dataset_id + */ +const putReEmbeddingDataset: ( + dataset_id: string, + loading?: Ref +) => Promise> = (dataset_id, loading) => { + return put(`${prefix}/${dataset_id}/re_embedding`, undefined, undefined, loading) +} + +/** + * 导出知识库 + * @param dataset_name 知识库名称 + * @param dataset_id 知识库id + * @returns + */ +const exportDataset: ( + dataset_name: string, + dataset_id: string, + loading?: Ref +) => Promise = (dataset_name, dataset_id, loading) => { + return exportExcel(dataset_name + '.xlsx', `dataset/${dataset_id}/export`, undefined, loading) +} + + +/** + * 获取当前用户可使用的模型列表 + * @param application_id + * @param loading + * @query { query_text: string, top_number: number, similarity: number } + * @returns + */ +const getDatasetModel: ( + dataset_id: string, + loading?: Ref +) => Promise>> = (dataset_id, loading) => { + return get(`${prefix}/${dataset_id}/model`, loading) +} + + +export default { + getDataset, + getAllDataset, + delDataset, + postDataset, + getDatasetDetail, + putDataset, + listUsableApplication, + getDatasetHitTest, + postWebDataset, + putSyncWebDataset, + putReEmbeddingDataset, + postQADataset, + exportDataset, + getDatasetModel +} diff --git a/src/MaxKB-1.7.2/ui/src/api/document.ts b/src/MaxKB-1.7.2/ui/src/api/document.ts new file mode 100644 index 0000000..28954d0 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/document.ts @@ -0,0 +1,356 @@ +import { Result } from '@/request/Result' +import { get, post, del, put, exportExcel } from '@/request/index' +import type { Ref } from 'vue' +import type { KeyValue } from '@/api/type/common' +import type { pageRequest } from '@/api/type/common' + +const prefix = '/dataset' + +/** + * 分段预览(上传文档) + * @param 参数 file:file,limit:number,patterns:array,with_filter:boolean + */ +const postSplitDocument: (data: any) => Promise> = (data) => { + return post(`${prefix}/document/split`, data, undefined, undefined, 1000 * 60 * 60) +} + +/** + * 分段标识列表 + * @param loading 加载器 + * @returns 分段标识列表 + */ +const listSplitPattern: ( + loading?: Ref +) => Promise>>> = (loading) => { + return get(`${prefix}/document/split_pattern`, {}, loading) +} + +/** + * 文档分页列表 + * @param 参数 dataset_id, + * page { + "current_page": "string", + "page_size": "string", + } + * param { + "name": "string", + } + */ + +const getDocument: ( + dataset_id: string, + page: pageRequest, + param: any, + loading?: Ref +) => Promise> = (dataset_id, page, param, loading) => { + return get( + `${prefix}/${dataset_id}/document/${page.current_page}/${page.page_size}`, + param, + loading + ) +} + +const getAllDocument: (dataset_id: string, loading?: Ref) => Promise> = ( + dataset_id, + loading +) => { + return get(`${prefix}/${dataset_id}/document`, undefined, loading) +} + +/** + * 创建批量文档 + * @param 参数 + * { + "name": "string", + "paragraphs": [ + { + "content": "string", + "title": "string", + "problem_list": [ + { + "id": "string", + "content": "string" + } + ] + } + ] + } + */ +const postDocument: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/document/_bach`, data, {}, loading, 1000 * 60 * 5) +} + +/** + * 修改文档 + * @param 参数 + * dataset_id, document_id, + * { + "name": "string", + "is_active": true, + "meta": {} + } + */ +const putDocument: ( + dataset_id: string, + document_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, data: any, loading) => { + return put(`${prefix}/${dataset_id}/document/${document_id}`, data, undefined, loading) +} + +/** + * 删除文档 + * @param 参数 dataset_id, document_id, + */ +const delDocument: ( + dataset_id: string, + document_id: string, + loading?: Ref +) => Promise> = (dataset_id, document_id, loading) => { + return del(`${prefix}/${dataset_id}/document/${document_id}`, loading) +} +/** + * 批量删除文档 + * @param 参数 dataset_id, + */ +const delMulDocument: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return del(`${prefix}/${dataset_id}/document/_bach`, undefined, { id_list: data }, loading) +} + +const batchRefresh: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put( + `${prefix}/${dataset_id}/document/batch_refresh`, + { id_list: data }, + undefined, + loading + ) +} +/** + * 文档详情 + * @param 参数 dataset_id + */ +const getDocumentDetail: (dataset_id: string, document_id: string) => Promise> = ( + dataset_id, + document_id +) => { + return get(`${prefix}/${dataset_id}/document/${document_id}`) +} + +/** + * 刷新文档向量库 + * @param 参数 + * dataset_id, document_id, + */ +const putDocumentRefresh: ( + dataset_id: string, + document_id: string, + loading?: Ref +) => Promise> = (dataset_id, document_id, loading) => { + return put( + `${prefix}/${dataset_id}/document/${document_id}/refresh`, + undefined, + undefined, + loading + ) +} + +/** + * 同步web站点类型 + * @param 参数 + * dataset_id, document_id, + */ +const putDocumentSync: ( + dataset_id: string, + document_id: string, + loading?: Ref +) => Promise> = (dataset_id, document_id, loading) => { + return put(`${prefix}/${dataset_id}/document/${document_id}/sync`, undefined, undefined, loading) +} + +/** + * 批量同步文档 + * @param 参数 dataset_id, + */ +const delMulSyncDocument: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put(`${prefix}/${dataset_id}/document/_bach`, { id_list: data }, undefined, loading) +} + +/** + * 创建Web站点文档 + * @param 参数 + * { + "source_url_list": [ + "string" + ], + "selector": "string" + } + } + */ +const postWebDocument: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/document/web`, data, undefined, loading) +} + +/** + * 导入QA文档 + * @param 参数 + * file + } + */ +const postQADocument: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/document/qa`, data, undefined, loading) +} + +/** + * 导入表格 + * @param 参数 + * file + */ +const postTableDocument: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/document/table`, data, undefined, loading) +} + +/** + * 批量迁移文档 + * @param 参数 dataset_id,target_dataset_id, + */ +const putMigrateMulDocument: ( + dataset_id: string, + target_dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, target_dataset_id, data, loading) => { + return put( + `${prefix}/${dataset_id}/document/migrate/${target_dataset_id}`, + data, + undefined, + loading + ) +} + +/** + * 批量修改命中方式 + * @param dataset_id 知识库id + * @param data {id_list:[],hit_handling_method:'directly_return|optimization'} + * @param loading + * @returns + */ +const batchEditHitHandling: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put(`${prefix}/${dataset_id}/document/batch_hit_handling`, data, undefined, loading) +} + +/** + * 获得QA模版 + * @param 参数 fileName,type, + */ +const exportQATemplate: (fileName: string, type: string, loading?: Ref) => void = ( + fileName, + type, + loading +) => { + return exportExcel(fileName, `${prefix}/document/template/export`, { type }, loading) +} + +/** + * 获得table模版 + * @param 参数 fileName,type, + */ +const exportTableTemplate: (fileName: string, type: string, loading?: Ref) => void = ( + fileName, + type, + loading +) => { + return exportExcel(fileName, `${prefix}/document/table_template/export`, { type }, loading) +} + +/** + * 导出文档 + * @param document_name 文档名称 + * @param dataset_id 数据集id + * @param document_id 文档id + * @param loading 加载器 + * @returns + */ +const exportDocument: ( + document_name: string, + dataset_id: string, + document_id: string, + loading?: Ref +) => Promise = (document_name, dataset_id, document_id, loading) => { + return exportExcel( + document_name + '.xlsx', + `${prefix}/${dataset_id}/document/${document_id}/export`, + {}, + loading + ) +} + +const batchGenerateRelated: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return put( + `${prefix}/${dataset_id}/document/batch_generate_related`, + data, + undefined, + loading + ) +} + +export default { + postSplitDocument, + getDocument, + getAllDocument, + postDocument, + putDocument, + delDocument, + delMulDocument, + getDocumentDetail, + listSplitPattern, + putDocumentRefresh, + putDocumentSync, + delMulSyncDocument, + postWebDocument, + putMigrateMulDocument, + batchEditHitHandling, + exportQATemplate, + exportTableTemplate, + postQADocument, + postTableDocument, + exportDocument, + batchRefresh, + batchGenerateRelated +} diff --git a/src/MaxKB-1.7.2/ui/src/api/email-setting.ts b/src/MaxKB-1.7.2/ui/src/api/email-setting.ts new file mode 100644 index 0000000..9fc8bf5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/email-setting.ts @@ -0,0 +1,38 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import type { pageRequest } from '@/api/type/common' +import { type Ref } from 'vue' + +const prefix = '/email_setting' +/** + * 获取邮箱设置 + */ +const getEmailSetting: (loading?: Ref) => Promise> = (loading) => { + return get(`${prefix}`, undefined, loading) +} + +/** + * 邮箱测试 + */ +const postTestEmail: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}`, data, undefined, loading) +} + +/** + * 修改邮箱设置 + */ +const putEmailSetting: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return put(`${prefix}`, data, undefined, loading) +} + +export default { + getEmailSetting, + postTestEmail, + putEmailSetting +} diff --git a/src/MaxKB-1.7.2/ui/src/api/function-lib.ts b/src/MaxKB-1.7.2/ui/src/api/function-lib.ts new file mode 100644 index 0000000..c6a8200 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/function-lib.ts @@ -0,0 +1,111 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import type { pageRequest } from '@/api/type/common' +import type { functionLibData } from '@/api/type/function-lib' +import { type Ref } from 'vue' + +const prefix = '/function_lib' + +/** + * 获取函数列表 + * param { + "name": "string", + } + */ +const getAllFunctionLib: (param?: any, loading?: Ref) => Promise> = ( + param, + loading +) => { + return get(`${prefix}`, param || {}, loading) +} + +/** + * 获取分页函数列表 + * page { + "current_page": "string", + "page_size": "string", + } + * param { + "name": "string", + } + */ +const getFunctionLib: ( + page: pageRequest, + param: any, + loading?: Ref +) => Promise> = (page, param, loading) => { + return get(`${prefix}/${page.current_page}/${page.page_size}`, param, loading) +} + +/** + * 创建函数 + * @param 参数 + */ +const postFunctionLib: (data: functionLibData, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}`, data, undefined, loading) +} + +/** + * 修改函数 + * @param 参数 + + */ +const putFunctionLib: ( + function_lib_id: string, + data: functionLibData, + loading?: Ref +) => Promise> = (function_lib_id, data, loading) => { + return put(`${prefix}/${function_lib_id}`, data, undefined, loading) +} + +/** + * 调试函数 + * @param 参数 + + */ +const postFunctionLibDebug: (data: any, loading?: Ref) => Promise> = ( + data: any, + loading +) => { + return post(`${prefix}/debug`, data, undefined, loading) +} + +/** + * 删除函数 + * @param 参数 function_lib_id + */ +const delFunctionLib: ( + function_lib_id: String, + loading?: Ref +) => Promise> = (function_lib_id, loading) => { + return del(`${prefix}/${function_lib_id}`, undefined, {}, loading) +} +/** + * 获取函数详情 + * @param function_lib_id 函数id + * @param loading 加载器 + * @returns 函数详情 + */ +const getFunctionLibById: ( + function_lib_id: String, + loading?: Ref +) => Promise> = (function_lib_id, loading) => { + return get(`${prefix}/${function_lib_id}`, undefined, loading) +} +const pylint: (code: string, loading?: Ref) => Promise> = (code, loading) => { + return post(`${prefix}/pylint`, { code }, {}, loading) +} + +export default { + getFunctionLib, + postFunctionLib, + putFunctionLib, + postFunctionLibDebug, + getAllFunctionLib, + delFunctionLib, + getFunctionLibById, + pylint +} diff --git a/src/MaxKB-1.7.2/ui/src/api/image.ts b/src/MaxKB-1.7.2/ui/src/api/image.ts new file mode 100644 index 0000000..425e8c6 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/image.ts @@ -0,0 +1,15 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' + +const prefix = '/image' +/** + * 上传图片 + * @param 参数 file:file + */ +const postImage: (data: any) => Promise> = (data) => { + return post(`${prefix}`, data) +} + +export default { + postImage +} diff --git a/src/MaxKB-1.7.2/ui/src/api/license.ts b/src/MaxKB-1.7.2/ui/src/api/license.ts new file mode 100644 index 0000000..16e5acd --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/license.ts @@ -0,0 +1,24 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import { type Ref } from 'vue' + +const prefix = '/license' + +/** + * 获得license信息 + */ +const getLicense: (loading?: Ref) => Promise> = (loading) => { + return get(`${prefix}/profile`, undefined, loading) +} +/** + * 更新license信息 + * @param 参数 license_file:file + */ +const putLicense: (data: any, loading?: Ref) => Promise> = (data, loading) => { + return put(`${prefix}/profile`, data, undefined, loading) +} + +export default { + getLicense, + putLicense +} diff --git a/src/MaxKB-1.7.2/ui/src/api/log.ts b/src/MaxKB-1.7.2/ui/src/api/log.ts new file mode 100644 index 0000000..6dbca46 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/log.ts @@ -0,0 +1,213 @@ +import { Result } from '@/request/Result' +import { get, del, put, exportExcel, exportExcelPost } from '@/request/index' +import type { pageRequest } from '@/api/type/common' +import { type Ref } from 'vue' + +const prefix = '/application' +/** + * 对话日志 + * @param 参数 + * application_id, history_day + * page { + "current_page": "string", + "page_size": "string", + } +* param { + "history_day": "string", + "search": "string", + } + */ +const getChatLog: ( + application_id: String, + page: pageRequest, + param: any, + loading?: Ref +) => Promise> = (application_id, page, param, loading) => { + return get( + `${prefix}/${application_id}/chat/${page.current_page}/${page.page_size}`, + param, + loading + ) +} + +const exportChatLog: ( + application_id: string, + application_name: string, + param: any, + data: any, + loading?: Ref +) => void = (application_id, application_name, param, data, loading) => { + exportExcelPost(application_name, `${prefix}/${application_id}/chat/export`, param, data, loading) +} + +/** + * 删除日志 + * @param 参数 application_id, chat_id, + */ +const delChatLog: ( + application_id: string, + chat_id: string, + loading?: Ref +) => Promise> = (application_id, chat_id, loading) => { + return del(`${prefix}/${application_id}/chat/${chat_id}`, undefined, {}, loading) +} + +/** + * 日志记录 + * @param 参数 + * application_id, chart_id + * page { + "current_page": "string", + "page_size": "string", + } + */ +const getChatRecordLog: ( + application_id: String, + chart_id: String, + page: pageRequest, + loading?: Ref, + order_asc?: boolean +) => Promise> = (application_id, chart_id, page, loading, order_asc) => { + return get( + `${prefix}/${application_id}/chat/${chart_id}/chat_record/${page.current_page}/${page.page_size}`, + { order_asc: order_asc !== undefined ? order_asc : true }, + loading + ) +} + +/** + * 修改日志内容 + * @param 参数 + * application_id, chart_id, chart_record_id, dataset_id, document_id + * data { + "title": "string", + "content": "string", + } + */ +const putChatRecordLog: ( + application_id: String, + chart_id: String, + chart_record_id: String, + dataset_id: String, + document_id: String, + data: any, + loading?: Ref +) => Promise> = ( + application_id, + chart_id, + chart_record_id, + dataset_id, + document_id, + data, + loading +) => { + return put( + `${prefix}/${application_id}/chat/${chart_id}/chat_record/${chart_record_id}/dataset/${dataset_id}/document_id/${document_id}/improve`, + data, + undefined, + loading + ) +} + +/** + * 获取标注段落列表信息 + * @param 参数 + * application_id, chart_id, chart_record_id + */ +const getMarkRecord: ( + application_id: String, + chart_id: String, + chart_record_id: String, + loading?: Ref +) => Promise> = (application_id, chart_id, chart_record_id, loading) => { + return get( + `${prefix}/${application_id}/chat/${chart_id}/chat_record/${chart_record_id}/improve`, + undefined, + loading + ) +} + +/** + * 删除标注 + * @param 参数 + * application_id, chart_id, chart_record_id, dataset_id, document_id,paragraph_id + */ +const delMarkRecord: ( + application_id: String, + chart_id: String, + chart_record_id: String, + dataset_id: String, + document_id: String, + paragraph_id: String, + loading?: Ref +) => Promise> = ( + application_id, + chart_id, + chart_record_id, + dataset_id, + document_id, + paragraph_id, + loading +) => { + return del( + `${prefix}/${application_id}/chat/${chart_id}/chat_record/${chart_record_id}/dataset/${dataset_id}/document_id/${document_id}/improve/${paragraph_id}`, + undefined, + {}, + loading + ) +} + +/** + * 获取对话记录详情 + * @param 参数 + * application_id, chart_id, chart_record_id + */ +const getRecordDetail: ( + application_id: String, + chart_id: String, + chart_record_id: String, + loading?: Ref +) => Promise> = (application_id, chart_id, chart_record_id, loading) => { + return get( + `${prefix}/${application_id}/chat/${chart_id}/chat_record/${chart_record_id}`, + undefined, + loading + ) +} + +const getChatLogClient: ( + application_id: String, + page: pageRequest, + loading?: Ref +) => Promise> = (application_id, page, loading) => { + return get( + `${prefix}/${application_id}/chat/client/${page.current_page}/${page.page_size}`, + null, + loading + ) +} + +/** + * 客户端删除日志 + * @param 参数 application_id, chat_id, + */ +const delChatClientLog: ( + application_id: string, + chat_id: string, + loading?: Ref +) => Promise> = (application_id, chat_id, loading) => { + return del(`${prefix}/${application_id}/chat/client/${chat_id}`, undefined, {}, loading) +} + +export default { + getChatLog, + delChatLog, + getChatRecordLog, + putChatRecordLog, + getMarkRecord, + getRecordDetail, + delMarkRecord, + exportChatLog, + getChatLogClient, + delChatClientLog +} diff --git a/src/MaxKB-1.7.2/ui/src/api/model.ts b/src/MaxKB-1.7.2/ui/src/api/model.ts new file mode 100644 index 0000000..6519f1b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/model.ts @@ -0,0 +1,199 @@ +import { request } from './../request/index' +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import { type Ref } from 'vue' +import type { + modelRequest, + Provider, + ListModelRequest, + Model, + BaseModel, + CreateModelRequest, + EditModelRequest +} from '@/api/type/model' +import type { FormField } from '@/components/dynamics-form/type' +import type { KeyValue } from './type/common' +const prefix = '/model' +const prefix_provider = '/provider' + +/** + * 获得模型列表 + * @params 参数 name, model_type, model_name + */ +const getModel: ( + request?: ListModelRequest, + loading?: Ref +) => Promise>> = (data, loading) => { + return get(`${prefix}`, data, loading) +} + +/** + * 获得供应商列表 + */ +const getProvider: (loading?: Ref) => Promise>> = (loading) => { + return get(`${prefix_provider}`, {}, loading) +} + +/** + * 获得供应商列表 + */ +const getProviderByModelType: (model_type: string, loading?: Ref) => Promise>> = (model_type, loading) => { + return get(`${prefix_provider}`, {model_type}, loading) +} + +/** + * 获取模型创建表单 + * @param provider + * @param model_type + * @param model_name + * @param loading + * @returns + */ +const getModelCreateForm: ( + provider: string, + model_type: string, + model_name: string, + loading?: Ref +) => Promise>> = (provider, model_type, model_name, loading) => { + return get(`${prefix_provider}/model_form`, { provider, model_type, model_name }, loading) +} + +/** + * 获取模型参数表单 + * @param model_id 模型id + * @param loading + * @returns + */ +const getModelParamsForm: ( + model_id: string, + loading?: Ref +) => Promise>> = (model_id, loading) => { + return get(`model/${model_id}/model_params_form`, {}, loading) +} +/** + * 获取模型类型列表 + * @param provider 供应商 + * @param loading 加载器 + * @returns 模型类型列表 + */ +const listModelType: ( + provider: string, + loading?: Ref +) => Promise>>> = (provider, loading?: Ref) => { + return get(`${prefix_provider}/model_type_list`, { provider }, loading) +} + +/** + * 获取基础模型列表 + * @param provider + * @param model_type + * @param loading + * @returns + */ +const listBaseModel: ( + provider: string, + model_type: string, + loading?: Ref +) => Promise>> = (provider, model_type, loading) => { + return get(`${prefix_provider}/model_list`, { provider, model_type }, loading) +} + +/** + * 创建模型 + * @param request 请求对象 + * @param loading 加载器 + * @returns + */ +const createModel: ( + request: CreateModelRequest, + loading?: Ref +) => Promise> = (request, loading) => { + return post(`${prefix}`, request, {}, loading) +} + +/** + * 修改模型 + * @param request 請求對象 + * @param loading 加載器 + * @returns + */ +const updateModel: ( + model_id: string, + request: EditModelRequest, + loading?: Ref +) => Promise> = (model_id, request, loading) => { + return put(`${prefix}/${model_id}`, request, {}, loading) +} + +/** + * 修改模型参数配置 + * @param request 請求對象 + * @param loading 加載器 + * @returns + */ +const updateModelParamsForm: ( + model_id: string, + request: any[], + loading?: Ref +) => Promise> = (model_id, request, loading) => { + return put(`${prefix}/${model_id}/model_params_form`, request, {}, loading) +} + +/** + * 获取模型详情根据模型id 包括认证信息 + * @param model_id 模型id + * @param loading 加载器 + * @returns + */ +const getModelById: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading +) => { + return get(`${prefix}/${model_id}`, {}, loading) +} +/** + * 获取模型信息不包括认证信息根据模型id + * @param model_id 模型id + * @param loading 加载器 + * @returns + */ +const getModelMetaById: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading +) => { + return get(`${prefix}/${model_id}/meta`, {}, loading) +} +/** + * 暂停下载 + * @param model_id 模型id + * @param loading 加载器 + * @returns + */ +const pauseDownload: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading +) => { + return put(`${prefix}/${model_id}/pause_download`, undefined, {}, loading) +} +const deleteModel: (model_id: string, loading?: Ref) => Promise> = ( + model_id, + loading +) => { + return del(`${prefix}/${model_id}`, undefined, {}, loading) +} +export default { + getModel, + getProvider, + getModelCreateForm, + listModelType, + listBaseModel, + createModel, + updateModel, + deleteModel, + getModelById, + getModelMetaById, + pauseDownload, + getModelParamsForm, + updateModelParamsForm, + getProviderByModelType +} diff --git a/src/MaxKB-1.7.2/ui/src/api/paragraph.ts b/src/MaxKB-1.7.2/ui/src/api/paragraph.ts new file mode 100644 index 0000000..4a7d29b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/paragraph.ts @@ -0,0 +1,256 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import type { pageRequest } from '@/api/type/common' +import type { Ref } from 'vue' +const prefix = '/dataset' + +/** + * 段落列表 + * @param 参数 dataset_id document_id + * page { + "current_page": "string", + "page_size": "string", + } + * param { + "title": "string", + "content": "string", + } + */ +const getParagraph: ( + dataset_id: string, + document_id: string, + page: pageRequest, + param: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, page, param, loading) => { + return get( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/${page.current_page}/${page.page_size}`, + param, + loading + ) +} + +/** + * 删除段落 + * @param 参数 dataset_id, document_id, paragraph_id + */ +const delParagraph: ( + dataset_id: string, + document_id: string, + paragraph_id: string, + loading?: Ref +) => Promise> = (dataset_id, document_id, paragraph_id, loading) => { + return del( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}`, + undefined, + {}, + loading + ) +} + +/** + * 批量删除段落 + * @param 参数 dataset_id, document_id + */ +const delMulParagraph: ( + dataset_id: string, + document_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, data, loading) => { + return del( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/_batch`, + undefined, + { id_list: data }, + loading + ) +} + +/** + * 创建段落 + * @param 参数 + * dataset_id, document_id + * { + "content": "string", + "title": "string", + "is_active": true, + "problem_list": [ + { + "id": "string", + "content": "string" + } + ] + } + */ +const postParagraph: ( + dataset_id: string, + document_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, data, loading) => { + return post(`${prefix}/${dataset_id}/document/${document_id}/paragraph`, data, undefined, loading) +} + +/** + * 修改段落 + * @param 参数 + * dataset_id, document_id, paragraph_id + * { + "content": "string", + "title": "string", + "is_active": true, + "problem_list": [ + { + "id": "string", + "content": "string" + } + ] + } + */ +const putParagraph: ( + dataset_id: string, + document_id: string, + paragraph_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, paragraph_id, data, loading) => { + return put( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}`, + data, + undefined, + loading + ) +} + +/** + * 批量迁移段落 + * @param 参数 dataset_id,target_dataset_id, + */ +const putMigrateMulParagraph: ( + dataset_id: string, + document_id: string, + target_dataset_id: string, + target_document_id: string, + data: any, + loading?: Ref +) => Promise> = ( + dataset_id, + document_id, + target_dataset_id, + target_document_id, + data, + loading +) => { + return put( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/migrate/dataset/${target_dataset_id}/document/${target_document_id}`, + data, + undefined, + loading + ) +} + +/** + * 问题列表 + * @param 参数 dataset_id,document_id,paragraph_id + */ +const getProblem: ( + dataset_id: string, + document_id: string, + paragraph_id: string +) => Promise> = (dataset_id, document_id, paragraph_id: string) => { + return get(`${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem`) +} + +/** + * 创建问题 + * @param 参数 + * dataset_id, document_id, paragraph_id + * { + "id": "string", + content": "string" + } + */ +const postProblem: ( + dataset_id: string, + document_id: string, + paragraph_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, paragraph_id, data: any, loading) => { + return post( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem`, + data, + {}, + loading + ) +} +/** + * + * @param dataset_id 数据集id + * @param document_id 文档id + * @param paragraph_id 段落id + * @param problem_id 问题id + * @param loading 加载器 + * @returns + */ +const associationProblem: ( + dataset_id: string, + document_id: string, + paragraph_id: string, + problem_id: string, + loading?: Ref +) => Promise> = (dataset_id, document_id, paragraph_id, problem_id, loading) => { + return put( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}/association`, + {}, + {}, + loading + ) +} +/** + * 解除关联问题 + * @param 参数 dataset_id, document_id, paragraph_id,problem_id + */ +const disassociationProblem: ( + dataset_id: string, + document_id: string, + paragraph_id: string, + problem_id: string, + loading?: Ref +) => Promise> = (dataset_id, document_id, paragraph_id, problem_id, loading) => { + return put( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/${paragraph_id}/problem/${problem_id}/un_association`, + {}, + {}, + loading + ) +} + +const batchGenerateRelated: ( + dataset_id: string, + document_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, document_id, data, loading) => { + return put( + `${prefix}/${dataset_id}/document/${document_id}/paragraph/batch_generate_related`, + data, + undefined, + loading + ) +} + + +export default { + getParagraph, + delParagraph, + putParagraph, + postParagraph, + getProblem, + postProblem, + disassociationProblem, + associationProblem, + delMulParagraph, + putMigrateMulParagraph, + batchGenerateRelated +} diff --git a/src/MaxKB-1.7.2/ui/src/api/platform-source.ts b/src/MaxKB-1.7.2/ui/src/api/platform-source.ts new file mode 100644 index 0000000..defcc84 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/platform-source.ts @@ -0,0 +1,28 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import type { pageRequest } from '@/api/type/common' +import { type Ref } from 'vue' + +const prefix = '/platform' +const getPlatformInfo: (loading?: Ref) => Promise> = (loading) => { + return get(`${prefix}/source`, undefined, loading) +} + +const updateConfig: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}/source`, data, undefined, loading) +} + +const validateConnection: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return put(`${prefix}/source`, data, undefined, loading) +} +export default { + getPlatformInfo, + updateConfig, + validateConnection +} diff --git a/src/MaxKB-1.7.2/ui/src/api/problem.ts b/src/MaxKB-1.7.2/ui/src/api/problem.ts new file mode 100644 index 0000000..4625d6d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/problem.ts @@ -0,0 +1,124 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import type { Ref } from 'vue' +import type { KeyValue } from '@/api/type/common' +import type { pageRequest } from '@/api/type/common' +const prefix = '/dataset' + +/** + * 文档分页列表 + * @param 参数 dataset_id, + * page { + "current_page": "string", + "page_size": "string", + } +* query { + "content": "string", + } + */ + +const getProblems: ( + dataset_id: string, + page: pageRequest, + param: any, + loading?: Ref +) => Promise> = (dataset_id, page, param, loading) => { + return get( + `${prefix}/${dataset_id}/problem/${page.current_page}/${page.page_size}`, + param, + loading + ) +} + +/** + * 创建问题 + * @param 参数 dataset_id + * data: array[string] + */ +const postProblems: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/problem`, data, undefined, loading) +} + +/** + * 删除问题 + * @param 参数 dataset_id, problem_id, + */ +const delProblems: ( + dataset_id: string, + problem_id: string, + loading?: Ref +) => Promise> = (dataset_id, problem_id, loading) => { + return del(`${prefix}/${dataset_id}/problem/${problem_id}`, loading) +} + +/** + * 批量删除问题 + * @param 参数 dataset_id, + */ +const delMulProblem: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return del(`${prefix}/${dataset_id}/problem/_batch`, undefined, data, loading) +} + +/** + * 修改问题 + * @param 参数 + * dataset_id, problem_id, + * { + "content": "string", + } + */ +const putProblems: ( + dataset_id: string, + problem_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, problem_id, data: any, loading) => { + return put(`${prefix}/${dataset_id}/problem/${problem_id}`, data, undefined, loading) +} + +/** + * 问题详情 + * @param 参数 + * dataset_id, problem_id, + */ +const getDetailProblems: ( + dataset_id: string, + problem_id: string, + loading?: Ref +) => Promise> = (dataset_id, problem_id, loading) => { + return get(`${prefix}/${dataset_id}/problem/${problem_id}/paragraph`, undefined, loading) +} + +/** + * 批量关联段落 + * @param 参数 dataset_id, + * { + "problem_id_list": "Array", + "paragraph_list": "Array", + } + */ +const postMulAssociationProblem: ( + dataset_id: string, + data: any, + loading?: Ref +) => Promise> = (dataset_id, data, loading) => { + return post(`${prefix}/${dataset_id}/problem/_batch`, data, undefined, loading) +} + +export default { + getProblems, + postProblems, + delProblems, + putProblems, + getDetailProblems, + delMulProblem, + postMulAssociationProblem +} diff --git a/src/MaxKB-1.7.2/ui/src/api/provider.ts b/src/MaxKB-1.7.2/ui/src/api/provider.ts new file mode 100644 index 0000000..2099658 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/provider.ts @@ -0,0 +1,12 @@ +import { Result } from '@/request/Result' +import { get, post } from '@/request/index' +import type { Ref } from 'vue' +const trigger: ( + provider: string, + method: string, + request_body: any, + loading?: Ref +) => Promise | string>> = (provider, method, request_body, loading) => { + return post(`provider/${provider}/${method}`, {}, request_body, loading) +} +export default { trigger, get } diff --git a/src/MaxKB-1.7.2/ui/src/api/system-api-key.ts b/src/MaxKB-1.7.2/ui/src/api/system-api-key.ts new file mode 100644 index 0000000..9d66bc7 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/system-api-key.ts @@ -0,0 +1,58 @@ +import {Result} from '@/request/Result' +import {get, post, del, put} from '@/request/index' + +import {type Ref} from 'vue' + +const prefix = '/system/api_key' + +/** + * API_KEY列表 + */ +const getAPIKey: (loading?: Ref) => Promise> = () => { + return get(`${prefix}/`) +} + +/** + * 新增API_KEY + */ +const postAPIKey: (loading?: Ref) => Promise> = ( + loading +) => { + return post(`${prefix}/`, {}, undefined, loading) +} + +/** + * 删除API_KEY + * @param 参数 application_id api_key_id + */ +const delAPIKey: ( + api_key_id: String, + loading?: Ref +) => Promise> = (api_key_id, loading) => { + return del(`${prefix}/${api_key_id}/`, undefined, undefined, loading) +} + +/** + * 修改API_KEY + * data { + * is_active: boolean + * } + * @param api_key_id + * @param data + * @param loading + */ +const putAPIKey: ( + api_key_id: String, + data: any, + loading?: Ref +) => Promise> = (api_key_id, data, loading) => { + return put(`${prefix}/${api_key_id}/`, data, undefined, loading) +} + + +export default { + getAPIKey, + postAPIKey, + delAPIKey, + putAPIKey +} diff --git a/src/MaxKB-1.7.2/ui/src/api/team.ts b/src/MaxKB-1.7.2/ui/src/api/team.ts new file mode 100644 index 0000000..82e8f98 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/team.ts @@ -0,0 +1,67 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import type { TeamMember } from '@/api/type/team' + +const prefix = '/team/member' + +/** + * 获取团队成员列表 + */ +const getTeamMember: () => Promise> = () => { + return get(`${prefix}`) +} + +/** + * 添加成员 + * @param 参数 [] + */ +const postCreatTeamMember: (data: Array) => Promise> = (data) => { + return post(`${prefix}/_batch`, data) +} + +/** + * 删除成员 + * @param 参数 member_id + */ +const delTeamMember: (member_id: String) => Promise> = (member_id) => { + return del(`${prefix}/${member_id}`) +} + +/** + * 获取成员权限 + * @param 参数 member_id + */ +const getMemberPermissions: (member_id: String) => Promise> = (member_id) => { + return get(`${prefix}/${member_id}`) +} + +/** + * 获取成员权限 + * @param 参数 member_id + * @param 参数 { + "team_member_permission_list": [ + { + "target_id": "string", + "type": "string", + "operate": { + "USE": true, + "MANAGE": true + } + } + ] + } + */ +const putMemberPermissions: (member_id: String, body: any) => Promise> = ( + member_id, + body +) => { + return put(`${prefix}/${member_id}`, body) +} + +export default { + getTeamMember, + postCreatTeamMember, + delTeamMember, + getMemberPermissions, + putMemberPermissions +} diff --git a/src/MaxKB-1.7.2/ui/src/api/theme.ts b/src/MaxKB-1.7.2/ui/src/api/theme.ts new file mode 100644 index 0000000..6e696e5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/theme.ts @@ -0,0 +1,35 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import type { Ref } from 'vue' +const prefix = '/display' + +/** + * 查看外观设置 + */ +const getThemeInfo: (loading?: Ref) => Promise> = (loading) => { + return get(`${prefix}/info`, undefined, loading) +} + +/** + * 更新外观设置 + * @param 参数 + * * formData { + * theme + * icon + * loginLogo + * loginImage + * title + * slogan + * } + */ +const postThemeInfo: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}/update`, data, undefined, loading) +} + +export default { + getThemeInfo, + postThemeInfo +} diff --git a/src/MaxKB-1.7.2/ui/src/api/type/application.ts b/src/MaxKB-1.7.2/ui/src/api/type/application.ts new file mode 100644 index 0000000..00ad179 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/type/application.ts @@ -0,0 +1,184 @@ +import { type Dict } from '@/api/type/common' +import { type Ref } from 'vue' +interface ApplicationFormType { + name?: string + desc?: string + model_id?: string + dialogue_number?: number + prologue?: string + dataset_id_list?: string[] + dataset_setting?: any + model_setting?: any + problem_optimization?: boolean + problem_optimization_prompt?: string + icon?: string | undefined + type?: string + work_flow?: any + model_params_setting?: any + tts_model_params_setting?: any + stt_model_id?: string + tts_model_id?: string + stt_model_enable?: boolean + tts_model_enable?: boolean + tts_type?: string +} +interface chatType { + id: string + problem_text: string + answer_text: string + buffer: Array + /** + * 是否写入结束 + */ + write_ed?: boolean + /** + * 是否暂停 + */ + is_stop?: boolean + record_id: string + vote_status: string + status?: number +} + +export class ChatRecordManage { + id?: any + ms: number + chat: chatType + is_close?: boolean + write_ed?: boolean + is_stop?: boolean + loading?: Ref + constructor(chat: chatType, ms?: number, loading?: Ref) { + this.ms = ms ? ms : 10 + this.chat = chat + this.loading = loading + this.is_stop = false + this.is_close = false + this.write_ed = false + } + write() { + this.chat.is_stop = false + this.is_stop = false + if (this.loading) { + this.loading.value = true + } + this.id = setInterval(() => { + if (this.chat.buffer.length > 20) { + this.chat.answer_text = + this.chat.answer_text + this.chat.buffer.splice(0, this.chat.buffer.length - 20).join('') + } else if (this.is_close) { + this.chat.answer_text = this.chat.answer_text + this.chat.buffer.splice(0).join('') + this.chat.write_ed = true + this.write_ed = true + if (this.loading) { + this.loading.value = false + } + if (this.id) { + clearInterval(this.id) + } + } else { + const s = this.chat.buffer.shift() + if (s !== undefined) { + this.chat.answer_text = this.chat.answer_text + s + } + } + }, this.ms) + } + stop() { + clearInterval(this.id) + this.is_stop = true + this.chat.is_stop = true + if (this.loading) { + this.loading.value = false + } + } + close() { + this.is_close = true + } + append(answer_text_block: string) { + for (let index = 0; index < answer_text_block.length; index++) { + this.chat.buffer.push(answer_text_block[index]) + } + } +} + +export class ChatManagement { + static chatMessageContainer: Dict = {} + + static addChatRecord(chat: chatType, ms: number, loading?: Ref) { + this.chatMessageContainer[chat.id] = new ChatRecordManage(chat, ms, loading) + } + static append(chatRecordId: string, content: string) { + const chatRecord = this.chatMessageContainer[chatRecordId] + if (chatRecord) { + chatRecord.append(content) + } + } + static updateStatus(chatRecordId: string, code: number) { + const chatRecord = this.chatMessageContainer[chatRecordId] + if (chatRecord) { + chatRecord.chat.status = code + } + } + /** + * 持续从缓存区 写出数据 + * @param chatRecordId 对话记录id + */ + static write(chatRecordId: string) { + const chatRecord = this.chatMessageContainer[chatRecordId] + if (chatRecord) { + chatRecord.write() + } + } + /** + * 等待所有数据输出完毕后 才会关闭流 + * @param chatRecordId 对话记录id + * @returns boolean + */ + static close(chatRecordId: string) { + const chatRecord = this.chatMessageContainer[chatRecordId] + if (chatRecord) { + chatRecord.close() + } + } + /** + * 停止输出 立即关闭定时任务输出 + * @param chatRecordId 对话记录id + * @returns boolean + */ + static stop(chatRecordId: string) { + const chatRecord = this.chatMessageContainer[chatRecordId] + if (chatRecord) { + chatRecord.stop() + } + } + /** + * 判断是否输出完成 + * @param chatRecordId 对话记录id + * @returns boolean + */ + static isClose(chatRecordId: string) { + const chatRecord = this.chatMessageContainer[chatRecordId] + return chatRecord ? chatRecord.is_close && chatRecord.write_ed : false + } + /** + * 判断是否停止输出 + * @param chatRecordId 对话记录id + * @returns + */ + static isStop(chatRecordId: string) { + const chatRecord = this.chatMessageContainer[chatRecordId] + return chatRecord ? chatRecord.is_stop : false + } + /** + * 清除无用数据 也就是被close掉的和stop的数据 + */ + static clean() { + for (const key in Object.keys(this.chatMessageContainer)) { + if (this.chatMessageContainer[key].is_close) { + delete this.chatMessageContainer[key] + } + } + } +} +export type { ApplicationFormType, chatType } diff --git a/src/MaxKB-1.7.2/ui/src/api/type/common.ts b/src/MaxKB-1.7.2/ui/src/api/type/common.ts new file mode 100644 index 0000000..1912b77 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/type/common.ts @@ -0,0 +1,14 @@ +interface KeyValue { + key: K + value: V +} +interface Dict { + [propName: string]: V +} + +interface pageRequest { + current_page: number + page_size: number +} + +export type { KeyValue, Dict, pageRequest } diff --git a/src/MaxKB-1.7.2/ui/src/api/type/dataset.ts b/src/MaxKB-1.7.2/ui/src/api/type/dataset.ts new file mode 100644 index 0000000..a30c5c9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/type/dataset.ts @@ -0,0 +1,9 @@ +interface datasetData { + name: String + desc: String + documents?: Array + type?: String + embedding_mode_id?: String +} + +export type { datasetData } diff --git a/src/MaxKB-1.7.2/ui/src/api/type/function-lib.ts b/src/MaxKB-1.7.2/ui/src/api/type/function-lib.ts new file mode 100644 index 0000000..2c5efe2 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/type/function-lib.ts @@ -0,0 +1,11 @@ +interface functionLibData { + id?: String + name?: String + desc?: String + code?: String + permission_type?: 'PRIVATE' | 'PUBLIC' + input_field_list?: Array + is_active?: Boolean +} + +export type { functionLibData } diff --git a/src/MaxKB-1.7.2/ui/src/api/type/model.ts b/src/MaxKB-1.7.2/ui/src/api/type/model.ts new file mode 100644 index 0000000..6672920 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/type/model.ts @@ -0,0 +1,148 @@ +import { store } from '@/stores' +import type { Dict } from './common' +interface modelRequest { + name: string + model_type: string + model_name: string +} + +interface Provider { + /** + * 供应商代号 + */ + provider: string + /** + * 供应商名称 + */ + name: string + /** + * 供应商icon + */ + icon: string +} + +interface ListModelRequest { + /** + * 模型名称 + */ + name?: string + /** + * 模型类型 + */ + model_type?: string + /** + * 基础模型名称 + */ + model_name?: string + /** + * 供应商 + */ + provider?: string +} + +interface Model { + /** + * 主键id + */ + id: string + /** + * 模型名 + */ + name: string + /** + * 模型类型 + */ + model_type: string + user_id: string + username: string + permission_type: 'PUBLIC' | 'PRIVATE' + /** + * 基础模型 + */ + model_name: string + /** + * 认证信息 + */ + credential: any + /** + * 供应商 + */ + provider: string + /** + * 状态 + */ + status: 'SUCCESS' | 'DOWNLOAD' | 'ERROR' | 'PAUSE_DOWNLOAD' + /** + * 元数据 + */ + meta: Dict + /** + * 模型参数配置 + */ + model_params_form: Dict[] +} +interface CreateModelRequest { + /** + * 模型名 + */ + name: string + /** + * 模型类型 + */ + model_type: string + /** + * 基础模型 + */ + model_name: string + /** + * 认证信息 + */ + credential: any + /** + * 供应商 + */ + provider: string +} + +interface EditModelRequest { + /** + * 模型名 + */ + name: string + /** + * 模型类型 + */ + model_type: string + /** + * 基础模型 + */ + model_name: string + /** + * 认证信息 + */ + credential: any +} + +interface BaseModel { + /** + * 基础模型名称 + */ + name: string + /** + * 基础模型描述 + */ + desc: string + /** + * 基础模型类型 + */ + model_type: string +} +export type { + modelRequest, + Provider, + ListModelRequest, + Model, + BaseModel, + CreateModelRequest, + EditModelRequest +} diff --git a/src/MaxKB-1.7.2/ui/src/api/type/team.ts b/src/MaxKB-1.7.2/ui/src/api/type/team.ts new file mode 100644 index 0000000..9e17a84 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/type/team.ts @@ -0,0 +1,13 @@ +interface TeamMember { + id: string + username: string + email: string + team_id: string + /** + * 类型:type:manage 所有者; + */ + type: string + user_id: string +} + +export type { TeamMember } diff --git a/src/MaxKB-1.7.2/ui/src/api/type/user.ts b/src/MaxKB-1.7.2/ui/src/api/type/user.ts new file mode 100644 index 0000000..a91a40e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/type/user.ts @@ -0,0 +1,120 @@ +interface User { + /** + * 用户id + */ + id: string + /** + * 用户名 + */ + username: string + /** + * 邮箱 + */ + email: string + /** + * 用户角色 + */ + role: string + /** + * 用户权限 + */ + permissions: Array + /** + * 是否需要修改密码 + */ + is_edit_password?: boolean + IS_XPACK?: boolean + XPACK_LICENSE_IS_VALID?: boolean +} + +interface LoginRequest { + /** + * 用户名 + */ + username: string + /** + * 密码 + */ + password: string +} + +interface RegisterRequest { + /** + * 用户名 + */ + username: string + /** + * 密码 + */ + password: string + /** + * 确定密码 + */ + re_password: string + /** + * 邮箱 + */ + email: string + /** + * 验证码 + */ + code: string +} + +interface CheckCodeRequest { + /** + * 邮箱 + */ + email: string + /** + *验证码 + */ + code: string + /** + * 类型 + */ + type: 'register' | 'reset_password' +} + +interface ResetCurrentUserPasswordRequest { + /** + * 验证码 + */ + code: string + /** + *密码 + */ + password: string + /** + * 确认密码 + */ + re_password: string +} + +interface ResetPasswordRequest { + /** + * 邮箱 + */ + email?: string + /** + * 验证码 + */ + code?: string + /** + * 密码 + */ + password: string + /** + * 确认密码 + */ + re_password: string +} + +export type { + LoginRequest, + RegisterRequest, + CheckCodeRequest, + ResetPasswordRequest, + User, + ResetCurrentUserPasswordRequest +} diff --git a/src/MaxKB-1.7.2/ui/src/api/user-manage.ts b/src/MaxKB-1.7.2/ui/src/api/user-manage.ts new file mode 100644 index 0000000..ceed082 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/user-manage.ts @@ -0,0 +1,77 @@ +import { Result } from '@/request/Result' +import { get, post, del, put } from '@/request/index' +import type { pageRequest } from '@/api/type/common' +import { type Ref } from 'vue' + +const prefix = '/user_manage' +/** + * 用户分页列表 + * @param 参数 + * page { + "current_page": "string", + "page_size": "string", + } + * @query 参数 + email_or_username: string + */ +const getUserManage: ( + page: pageRequest, + email_or_username: string, + loading?: Ref +) => Promise> = (page, email_or_username, loading) => { + return get( + `${prefix}/${page.current_page}/${page.page_size}`, + email_or_username ? { email_or_username } : undefined, + loading + ) +} + +/** + * 删除用户 + * @param 参数 user_id, + */ +const delUserManage: (user_id: string, loading?: Ref) => Promise> = ( + user_id, + loading +) => { + return del(`${prefix}/${user_id}`, undefined, {}, loading) +} + +/** + * 创建用户 + */ +const postUserManage: (data: any, loading?: Ref) => Promise> = ( + data, + loading +) => { + return post(`${prefix}`, data, undefined, loading) +} + +/** + * 编辑用户 + */ +const putUserManage: ( + user_id: string, + data: any, + loading?: Ref +) => Promise> = (user_id, data, loading) => { + return put(`${prefix}/${user_id}`, data, undefined, loading) +} +/** + * 修改用户密码 + */ +const putUserManagePassword: ( + user_id: string, + data: any, + loading?: Ref +) => Promise> = (user_id, data, loading) => { + return put(`${prefix}/${user_id}/re_password`, data, undefined, loading) +} + +export default { + getUserManage, + delUserManage, + postUserManage, + putUserManage, + putUserManagePassword +} diff --git a/src/MaxKB-1.7.2/ui/src/api/user.ts b/src/MaxKB-1.7.2/ui/src/api/user.ts new file mode 100644 index 0000000..b0aca1e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/api/user.ts @@ -0,0 +1,196 @@ +import { Result } from '@/request/Result' +import { get, post } from '@/request/index' +import type { + LoginRequest, + RegisterRequest, + CheckCodeRequest, + ResetPasswordRequest, + User, + ResetCurrentUserPasswordRequest +} from '@/api/type/user' +import type { Ref } from 'vue' + +/** + * 登录 + * @param auth_type + * @param request 登录接口请求表单 + * @param loading 接口加载器 + * @returns 认证数据 + */ +const login: ( + auth_type: string, + request: LoginRequest, + loading?: Ref +) => Promise> = (auth_type, request, loading) => { + if (auth_type !== '') { + return post(`/${auth_type}/login`, request, undefined, loading) + } + return post('/user/login', request, undefined, loading) +} +/** + * 登出 + * @param loading 接口加载器 + * @returns + */ +const logout: (loading?: Ref) => Promise> = (loading) => { + return post('/user/logout', undefined, undefined, loading) +} + +/** + * 注册用户 + * @param request 注册请求对象 + * @param loading 接口加载器 + * @returns + */ +const register: (request: RegisterRequest, loading?: Ref) => Promise> = ( + request, + loading +) => { + return post('/user/register', request, undefined, loading) +} + +/** + * 校验验证码 + * @param request 请求对象 + * @param loading 接口加载器 + * @returns + */ +const checkCode: (request: CheckCodeRequest, loading?: Ref) => Promise> = ( + request, + loading +) => { + return post('/user/check_code', request, undefined, loading) +} + +/** + * 发送邮件 + * @param email 邮件地址 + * @param loading 接口加载器 + * @returns + */ +const sendEmit: ( + email: string, + type: 'register' | 'reset_password', + loading?: Ref +) => Promise> = (email, type, loading) => { + return post('/user/send_email', { email, type }, undefined, loading) +} +/** + * 发送邮件到当前用户 + * @param loading 发送验证码到当前用户 + * @returns + */ +const sendEmailToCurrent: (loading?: Ref) => Promise> = (loading) => { + return post('/user/current/send_email', undefined, undefined, loading) +} +/** + * 修改当前用户密码 + * @param request 请求对象 + * @param loading 加载器 + * @returns + */ +const resetCurrentUserPassword: ( + request: ResetCurrentUserPasswordRequest, + loading?: Ref +) => Promise> = (request, loading) => { + return post('/user/current/reset_password', request, undefined, loading) +} +/** + * 获取用户基本信息 + * @param loading 接口加载器 + * @returns 用户基本信息 + */ +const profile: (loading?: Ref) => Promise> = (loading) => { + return get('/user', undefined, loading) +} + +/** + * 重置密码 + * @param request 重置密码请求参数 + * @param loading 接口加载器 + * @returns + */ +const resetPassword: ( + request: ResetPasswordRequest, + loading?: Ref +) => Promise> = (request, loading) => { + return post('/user/re_password', request, undefined, loading) +} + +/** + * 添加团队需要查询用户列表 + * @param loading 接口加载器 + * email_or_username + */ +const getUserList: (email_or_username: string, loading?: Ref) => Promise> = ( + email_or_username, + loading +) => { + return get('/user/list', { email_or_username }, loading) +} + +/** + * 获取profile + */ +const getProfile: (loading?: Ref) => Promise> = (loading) => { + return get('/profile', undefined, loading) +} + +/** + * 获取校验 + * @param valid_type 校验类型: application|dataset|user + * @param valid_count 校验数量: 5 | 50 | 2 + */ +const getValid: ( + valid_type: string, + valid_count: number, + loading?: Ref +) => Promise> = (valid_type, valid_count, loading) => { + return get(`/valid/${valid_type}/${valid_count}`, undefined, loading) +} +/** + * 获取登录方式 + */ +const getAuthType: (loading?: Ref) => Promise> = (loading) => { + return get('auth/types', undefined, loading) +} + +/** + * 获取二维码类型 + */ +const getQrType: (loading?: Ref) => Promise> = (loading) => { + return get('qr_type', undefined, loading) +} + +const getDingCallback: (code: string, loading?: Ref) => Promise> = ( + code, + loading +) => { + return get('dingtalk', { code }, loading) +} + +const getWecomCallback: (code: string, loading?: Ref) => Promise> = ( + code, + loading +) => { + return get('wecom', { code }, loading) +} + +export default { + login, + register, + sendEmit, + checkCode, + profile, + resetPassword, + sendEmailToCurrent, + resetCurrentUserPassword, + logout, + getUserList, + getProfile, + getValid, + getAuthType, + getDingCallback, + getQrType, + getWecomCallback +} diff --git a/src/MaxKB-1.7.2/ui/src/assets/404.png b/src/MaxKB-1.7.2/ui/src/assets/404.png new file mode 100644 index 0000000..e6ed7d3 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/404.png differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/csv-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/csv-icon.svg new file mode 100644 index 0000000..85147cc --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/csv-icon.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/display-bg1.png b/src/MaxKB-1.7.2/ui/src/assets/display-bg1.png new file mode 100644 index 0000000..dbf63be Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/display-bg1.png differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/display-bg2.png b/src/MaxKB-1.7.2/ui/src/assets/display-bg2.png new file mode 100644 index 0000000..606a5d9 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/display-bg2.png differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/display-bg3.png b/src/MaxKB-1.7.2/ui/src/assets/display-bg3.png new file mode 100644 index 0000000..52d0f92 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/display-bg3.png differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/doc-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/doc-icon.svg new file mode 100644 index 0000000..899a008 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/doc-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/docx-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/docx-icon.svg new file mode 100644 index 0000000..899a008 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/docx-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/hit-test-empty.png b/src/MaxKB-1.7.2/ui/src/assets/hit-test-empty.png new file mode 100644 index 0000000..83a2c9a Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/hit-test-empty.png differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/html-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/html-icon.svg new file mode 100644 index 0000000..b59a488 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/html-icon.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_condition.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_condition.svg new file mode 100644 index 0000000..2bc80a2 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_condition.svg @@ -0,0 +1,3 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_document.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_document.svg new file mode 100644 index 0000000..34fe1b3 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_document.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_file-folder_colorful.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_file-folder_colorful.svg new file mode 100644 index 0000000..7aa4703 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_file-folder_colorful.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_function_outlined.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_function_outlined.svg new file mode 100644 index 0000000..dbdef4c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_function_outlined.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_globe_color.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_globe_color.svg new file mode 100644 index 0000000..7ede591 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_globe_color.svg @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_hi.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_hi.svg new file mode 100644 index 0000000..84bb36a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_hi.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_qr_outlined.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_qr_outlined.svg new file mode 100644 index 0000000..1d3cf43 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_qr_outlined.svg @@ -0,0 +1,3 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_reply.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_reply.svg new file mode 100644 index 0000000..430fc7f --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_reply.svg @@ -0,0 +1,3 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_reranker.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_reranker.svg new file mode 100644 index 0000000..e561122 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_reranker.svg @@ -0,0 +1,21 @@ + + + + + + + + + \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_robot.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_robot.svg new file mode 100644 index 0000000..cca9ee6 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_robot.svg @@ -0,0 +1 @@ +MaxKB \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_send.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_send.svg new file mode 100644 index 0000000..79ff642 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_send.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_send_colorful.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_send_colorful.svg new file mode 100644 index 0000000..b6a1dac --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_send_colorful.svg @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_setting.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_setting.svg new file mode 100644 index 0000000..afa9736 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_setting.svg @@ -0,0 +1,3 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_start.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_start.svg new file mode 100644 index 0000000..0b8d730 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_start.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/icon_web.svg b/src/MaxKB-1.7.2/ui/src/assets/icon_web.svg new file mode 100644 index 0000000..f958baf --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/icon_web.svg @@ -0,0 +1,4 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/load_error.png b/src/MaxKB-1.7.2/ui/src/assets/load_error.png new file mode 100644 index 0000000..695abfd Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/load_error.png differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/logo/MaxKB-logo-currentColor.svg b/src/MaxKB-1.7.2/ui/src/assets/logo/MaxKB-logo-currentColor.svg new file mode 100644 index 0000000..9428164 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/logo/MaxKB-logo-currentColor.svg @@ -0,0 +1,20 @@ + + + + + + + + + + + + + + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/logo/MaxKB-logo.svg b/src/MaxKB-1.7.2/ui/src/assets/logo/MaxKB-logo.svg new file mode 100644 index 0000000..beb86aa --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/logo/MaxKB-logo.svg @@ -0,0 +1,64 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/logo/logo-currentColor.svg b/src/MaxKB-1.7.2/ui/src/assets/logo/logo-currentColor.svg new file mode 100644 index 0000000..5f50e4c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/logo/logo-currentColor.svg @@ -0,0 +1 @@ +MaxKB \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/assets/logo/logo.svg b/src/MaxKB-1.7.2/ui/src/assets/logo/logo.svg new file mode 100644 index 0000000..2e601bb --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/logo/logo.svg @@ -0,0 +1 @@ +MaxKB \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/assets/logo_dingtalk.svg b/src/MaxKB-1.7.2/ui/src/assets/logo_dingtalk.svg new file mode 100644 index 0000000..64d957d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/logo_dingtalk.svg @@ -0,0 +1,3 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/logo_lark.svg b/src/MaxKB-1.7.2/ui/src/assets/logo_lark.svg new file mode 100644 index 0000000..938c505 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/logo_lark.svg @@ -0,0 +1,12 @@ + + + + + + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/logo_wechat-work.svg b/src/MaxKB-1.7.2/ui/src/assets/logo_wechat-work.svg new file mode 100644 index 0000000..ea86012 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/logo_wechat-work.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/logo_wechat.svg b/src/MaxKB-1.7.2/ui/src/assets/logo_wechat.svg new file mode 100644 index 0000000..6c0e78d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/logo_wechat.svg @@ -0,0 +1,3 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/md-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/md-icon.svg new file mode 100644 index 0000000..7b35a92 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/md-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/pdf-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/pdf-icon.svg new file mode 100644 index 0000000..17a4be0 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/pdf-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/theme/default.jpg b/src/MaxKB-1.7.2/ui/src/assets/theme/default.jpg new file mode 100644 index 0000000..162ebe9 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/theme/default.jpg differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/theme/green.jpg b/src/MaxKB-1.7.2/ui/src/assets/theme/green.jpg new file mode 100644 index 0000000..937e8a0 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/theme/green.jpg differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/theme/orange.jpg b/src/MaxKB-1.7.2/ui/src/assets/theme/orange.jpg new file mode 100644 index 0000000..64b4c6a Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/theme/orange.jpg differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/theme/purple.jpg b/src/MaxKB-1.7.2/ui/src/assets/theme/purple.jpg new file mode 100644 index 0000000..843a425 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/theme/purple.jpg differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/theme/red.jpg b/src/MaxKB-1.7.2/ui/src/assets/theme/red.jpg new file mode 100644 index 0000000..cabf84f Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/theme/red.jpg differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/tipIMG.jpg b/src/MaxKB-1.7.2/ui/src/assets/tipIMG.jpg new file mode 100644 index 0000000..9f6955d Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/tipIMG.jpg differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/txt-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/txt-icon.svg new file mode 100644 index 0000000..051ea2b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/txt-icon.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/unknow-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/unknow-icon.svg new file mode 100644 index 0000000..20270ac --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/unknow-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/upload-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/upload-icon.svg new file mode 100644 index 0000000..3a2466c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/upload-icon.svg @@ -0,0 +1,6 @@ + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/user-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/user-icon.svg new file mode 100644 index 0000000..5dd0f63 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/user-icon.svg @@ -0,0 +1,14 @@ + + + + + + + + + + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/window1.png b/src/MaxKB-1.7.2/ui/src/assets/window1.png new file mode 100644 index 0000000..a96b907 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/window1.png differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/window2.png b/src/MaxKB-1.7.2/ui/src/assets/window2.png new file mode 100644 index 0000000..02d1ec9 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/assets/window2.png differ diff --git a/src/MaxKB-1.7.2/ui/src/assets/xls-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/xls-icon.svg new file mode 100644 index 0000000..22cb869 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/xls-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/assets/xlsx-icon.svg b/src/MaxKB-1.7.2/ui/src/assets/xlsx-icon.svg new file mode 100644 index 0000000..22cb869 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/assets/xlsx-icon.svg @@ -0,0 +1,5 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/bus/index.ts b/src/MaxKB-1.7.2/ui/src/bus/index.ts new file mode 100644 index 0000000..c1ab013 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/bus/index.ts @@ -0,0 +1,8 @@ +import mitt from "mitt"; +const bus: any = {}; +const emitter = mitt(); +bus.on = emitter.on; +bus.off = emitter.off; +bus.emit = emitter.emit; + +export default bus; diff --git a/src/MaxKB-1.7.2/ui/src/components/ai-chat/ExecutionDetailDialog.vue b/src/MaxKB-1.7.2/ui/src/components/ai-chat/ExecutionDetailDialog.vue new file mode 100644 index 0000000..ebee3cc --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/ai-chat/ExecutionDetailDialog.vue @@ -0,0 +1,305 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/ai-chat/KnowledgeSource.vue b/src/MaxKB-1.7.2/ui/src/components/ai-chat/KnowledgeSource.vue new file mode 100644 index 0000000..1fe0ef5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/ai-chat/KnowledgeSource.vue @@ -0,0 +1,80 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/ai-chat/LogOperationButton.vue b/src/MaxKB-1.7.2/ui/src/components/ai-chat/LogOperationButton.vue new file mode 100644 index 0000000..10dd471 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/ai-chat/LogOperationButton.vue @@ -0,0 +1,218 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/ai-chat/OperationButton.vue b/src/MaxKB-1.7.2/ui/src/components/ai-chat/OperationButton.vue new file mode 100644 index 0000000..6a12542 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/ai-chat/OperationButton.vue @@ -0,0 +1,237 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/ai-chat/ParagraphSourceDialog.vue b/src/MaxKB-1.7.2/ui/src/components/ai-chat/ParagraphSourceDialog.vue new file mode 100644 index 0000000..0f6bcf2 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/ai-chat/ParagraphSourceDialog.vue @@ -0,0 +1,84 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/ai-chat/component/ParagraphCard.vue b/src/MaxKB-1.7.2/ui/src/components/ai-chat/component/ParagraphCard.vue new file mode 100644 index 0000000..d15fee8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/ai-chat/component/ParagraphCard.vue @@ -0,0 +1,58 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/ai-chat/index.vue b/src/MaxKB-1.7.2/ui/src/components/ai-chat/index.vue new file mode 100644 index 0000000..e792d6b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/ai-chat/index.vue @@ -0,0 +1,1213 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/app-avatar/index.vue b/src/MaxKB-1.7.2/ui/src/components/app-avatar/index.vue new file mode 100644 index 0000000..a35217c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/app-avatar/index.vue @@ -0,0 +1,58 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/app-charts/components/LineCharts.vue b/src/MaxKB-1.7.2/ui/src/components/app-charts/components/LineCharts.vue new file mode 100644 index 0000000..b4215c9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/app-charts/components/LineCharts.vue @@ -0,0 +1,135 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/app-charts/index.vue b/src/MaxKB-1.7.2/ui/src/components/app-charts/index.vue new file mode 100644 index 0000000..e3b042b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/app-charts/index.vue @@ -0,0 +1,30 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/components/app-table/index.vue b/src/MaxKB-1.7.2/ui/src/components/app-table/index.vue new file mode 100644 index 0000000..1b77308 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/app-table/index.vue @@ -0,0 +1,158 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/auto-tooltip/index.vue b/src/MaxKB-1.7.2/ui/src/components/auto-tooltip/index.vue new file mode 100644 index 0000000..cc29d99 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/auto-tooltip/index.vue @@ -0,0 +1,39 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/back-button/index.vue b/src/MaxKB-1.7.2/ui/src/components/back-button/index.vue new file mode 100644 index 0000000..0a0ea69 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/back-button/index.vue @@ -0,0 +1,31 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/card-add/index.vue b/src/MaxKB-1.7.2/ui/src/components/card-add/index.vue new file mode 100644 index 0000000..3e993b5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/card-add/index.vue @@ -0,0 +1,47 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/card-box/index.vue b/src/MaxKB-1.7.2/ui/src/components/card-box/index.vue new file mode 100644 index 0000000..ea58579 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/card-box/index.vue @@ -0,0 +1,108 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/card-checkbox/index.vue b/src/MaxKB-1.7.2/ui/src/components/card-checkbox/index.vue new file mode 100644 index 0000000..e5d7282 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/card-checkbox/index.vue @@ -0,0 +1,100 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/codemirror-editor/index.vue b/src/MaxKB-1.7.2/ui/src/components/codemirror-editor/index.vue new file mode 100644 index 0000000..5e63c7c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/codemirror-editor/index.vue @@ -0,0 +1,56 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/common-list/index.vue b/src/MaxKB-1.7.2/ui/src/components/common-list/index.vue new file mode 100644 index 0000000..3a8194d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/common-list/index.vue @@ -0,0 +1,91 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/Demo.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/Demo.vue new file mode 100644 index 0000000..dce7f8f --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/Demo.vue @@ -0,0 +1,299 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/DemoConstructor.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/DemoConstructor.vue new file mode 100644 index 0000000..b8e0b60 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/DemoConstructor.vue @@ -0,0 +1,55 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/FormItem.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/FormItem.vue new file mode 100644 index 0000000..fbf26b1 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/FormItem.vue @@ -0,0 +1,182 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/FormItemLabel.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/FormItemLabel.vue new file mode 100644 index 0000000..b84dc1e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/FormItemLabel.vue @@ -0,0 +1,11 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/data.ts b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/data.ts new file mode 100644 index 0000000..ff100fa --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/data.ts @@ -0,0 +1,27 @@ +const input_type_list = [ + { + label: '文本框', + value: 'TextInput' + }, + { + label: '滑块', + value: 'Slider' + }, + { + label: '开关', + value: 'SwitchInput' + }, + { + label: '单选框', + value: 'SingleSelect' + }, + { + label: '日期', + value: 'DatePicker' + }, + { + label: 'JSON文本框', + value: 'JsonInput' + } +] +export { input_type_list } diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/index.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/index.vue new file mode 100644 index 0000000..8c3297b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/index.vue @@ -0,0 +1,139 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/DatePickerConstructor.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/DatePickerConstructor.vue new file mode 100644 index 0000000..e370dd6 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/DatePickerConstructor.vue @@ -0,0 +1,109 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/JsonInputConstructor.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/JsonInputConstructor.vue new file mode 100644 index 0000000..10eb956 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/JsonInputConstructor.vue @@ -0,0 +1,65 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/SingleSelectConstructor.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/SingleSelectConstructor.vue new file mode 100644 index 0000000..4e4f2a0 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/SingleSelectConstructor.vue @@ -0,0 +1,101 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/SliderConstructor.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/SliderConstructor.vue new file mode 100644 index 0000000..12c9830 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/SliderConstructor.vue @@ -0,0 +1,143 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/SwitchInputConstructor.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/SwitchInputConstructor.vue new file mode 100644 index 0000000..92a52a0 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/SwitchInputConstructor.vue @@ -0,0 +1,43 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/TextInputConstructor.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/TextInputConstructor.vue new file mode 100644 index 0000000..553cb44 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/constructor/items/TextInputConstructor.vue @@ -0,0 +1,158 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/index.ts b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/index.ts new file mode 100644 index 0000000..f4d69a0 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/index.ts @@ -0,0 +1,25 @@ +import type { App } from 'vue' +import type { Dict } from '@/api/type/common' +import DynamicsForm from '@/components/dynamics-form/index.vue' +let components: Dict = import.meta.glob('@/components/dynamics-form/**/**.vue', { + eager: true +}) +components = { + ...components, + ...import.meta.glob('@/components/dynamics-form/**/**/**.vue', { + eager: true + }) +} + +const install = (app: App) => { + Object.keys(components).forEach((key: string) => { + const commentName: string = key + .substring(key.lastIndexOf('/') + 1, key.length) + .replace('.vue', '') + if (key !== '/src/components/dynamics-form/constructor/index.vue') { + app.component(commentName, components[key].default) + } + }) + app.component('DynamicsForm', DynamicsForm) +} +export default { install } diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/index.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/index.vue new file mode 100644 index 0000000..e199980 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/index.vue @@ -0,0 +1,206 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/DatePicker.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/DatePicker.vue new file mode 100644 index 0000000..291978c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/DatePicker.vue @@ -0,0 +1,5 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/JsonInput.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/JsonInput.vue new file mode 100644 index 0000000..4a2804b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/JsonInput.vue @@ -0,0 +1,133 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/PasswordInput.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/PasswordInput.vue new file mode 100644 index 0000000..2111d24 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/PasswordInput.vue @@ -0,0 +1,5 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/TextInput.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/TextInput.vue new file mode 100644 index 0000000..46ca9b4 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/TextInput.vue @@ -0,0 +1,5 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/complex/ArrayObjectCard.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/complex/ArrayObjectCard.vue new file mode 100644 index 0000000..6b7ab08 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/complex/ArrayObjectCard.vue @@ -0,0 +1,156 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/complex/ObjectCard.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/complex/ObjectCard.vue new file mode 100644 index 0000000..a007b00 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/complex/ObjectCard.vue @@ -0,0 +1,75 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/complex/TabCard.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/complex/TabCard.vue new file mode 100644 index 0000000..ff5ba79 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/complex/TabCard.vue @@ -0,0 +1,123 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/label/TooltipLabel.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/label/TooltipLabel.vue new file mode 100644 index 0000000..ae8f477 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/label/TooltipLabel.vue @@ -0,0 +1,42 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/radio/Radio.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/radio/Radio.vue new file mode 100644 index 0000000..9c94a3f --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/radio/Radio.vue @@ -0,0 +1,38 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/radio/RadioButton.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/radio/RadioButton.vue new file mode 100644 index 0000000..874d61d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/radio/RadioButton.vue @@ -0,0 +1,38 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/radio/RadioCard.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/radio/RadioCard.vue new file mode 100644 index 0000000..492cff4 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/radio/RadioCard.vue @@ -0,0 +1,92 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/select/MultiSelect.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/select/MultiSelect.vue new file mode 100644 index 0000000..f78f898 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/select/MultiSelect.vue @@ -0,0 +1,65 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/select/SingleSelect.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/select/SingleSelect.vue new file mode 100644 index 0000000..a4d7656 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/select/SingleSelect.vue @@ -0,0 +1,77 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/slider/Slider.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/slider/Slider.vue new file mode 100644 index 0000000..3892f15 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/slider/Slider.vue @@ -0,0 +1,11 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/switch/SwitchInput.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/switch/SwitchInput.vue new file mode 100644 index 0000000..c787945 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/switch/SwitchInput.vue @@ -0,0 +1,7 @@ + + + + \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/ProgressTableItem.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/ProgressTableItem.vue new file mode 100644 index 0000000..baf9e3e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/ProgressTableItem.vue @@ -0,0 +1,70 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/TableCheckbox.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/TableCheckbox.vue new file mode 100644 index 0000000..12db31a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/TableCheckbox.vue @@ -0,0 +1,214 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/TableColumn.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/TableColumn.vue new file mode 100644 index 0000000..9b6989e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/TableColumn.vue @@ -0,0 +1,22 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/TableRadio.vue b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/TableRadio.vue new file mode 100644 index 0000000..61c0078 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/items/table/TableRadio.vue @@ -0,0 +1,202 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/dynamics-form/type.ts b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/type.ts new file mode 100644 index 0000000..20bb29d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/dynamics-form/type.ts @@ -0,0 +1,172 @@ +import type { Dict } from '@/api/type/common' + +interface ViewCardItem { + /** + * 类型 + */ + type: 'eval' | 'default' + /** + * 标题 + */ + title: string + /** + * 值 根据类型不一样 取值也不一样 default= row[value_field] eval `${parseFloat(row.number).toLocaleString("zh-CN",{style: "decimal",maximumFractionDigits:1})}%   ` + */ + value_field: string +} + +interface TableColumn { + /** + * 字段|组件名称|可计算的模板字符串 + */ + property: string + /** + *表头 + */ + label: string + /** + * 表数据字段 + */ + value_field?: string + + attrs?: Attrs + /** + * 类型 + */ + type: 'eval' | 'component' | 'default' + + props_info?: PropsInfo +} +interface ColorItem { + /** + * 颜色#f56c6c + */ + color: string + /** + * 进度 + */ + percentage: number +} +interface Attrs { + /** + * 提示语 + */ + placeholder?: string + /** + * 标签的长度,例如 '50px'。 作为 Form 直接子元素的 form-item 会继承该值。 可以使用 auto。 + */ + labelWidth?: string + /** + * 表单域标签的后缀 + */ + labelSuffix?: string + /** + * 星号的位置。 + */ + requireAsteriskPosition?: 'left' | 'right' + + color?: Array + + [propName: string]: any +} +interface PropsInfo { + /** + * 表格选择的card + */ + view_card?: Array + /** + * 表格选择 + */ + table_columns?: Array + /** + * 选中 message + */ + active_msg?: string + + /** + * 组件样式 + */ + style?: Dict + + /** + * el-form-item 样式 + */ + item_style?: Dict + /** + * 表单校验 这个和element校验一样 + */ + rules?: Dict + /** + * 默认 不为空校验提示 + */ + err_msg?: string + /** + *tabs的时候使用 + */ + tabs_label?: string + + [propName: string]: any +} + +interface FormField { + field: string + /** + * 输入框类型 + */ + input_type: string + /** + * 提示 + */ + label?: string | any + /** + * 是否 必填 + */ + required?: boolean + /** + * 默认值 + */ + default_value?: any + /** + * {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才显示 + */ + relation_show_field_dict?: Dict> + /** + * {field:field_value_list} 表示在 field有值 ,并且值在field_value_list中才 执行函数获取 数据 + */ + relation_trigger_field_dict?: Dict> + /** + * 执行器类型 OPTION_LIST请求Option_list数据 CHILD_FORMS请求子表单 + */ + trigger_type?: 'OPTION_LIST' | 'CHILD_FORMS' + /** + * 前端attr数据 + */ + attrs?: Attrs + /** + * 其他额外信息 + */ + props_info?: PropsInfo + /** + * 下拉选字段field + */ + text_field?: string + /** + * 下拉选 value + */ + value_field?: string + /** + * 下拉选数据 + */ + option_list?: Array + /** + * 供应商 + */ + provider?: string + /** + * 执行函数 + */ + method?: string + + children?: Array +} +export type { FormField } diff --git a/src/MaxKB-1.7.2/ui/src/components/icons/AppIcon.vue b/src/MaxKB-1.7.2/ui/src/components/icons/AppIcon.vue new file mode 100644 index 0000000..346405b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/icons/AppIcon.vue @@ -0,0 +1,33 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/icons/index.ts b/src/MaxKB-1.7.2/ui/src/components/icons/index.ts new file mode 100644 index 0000000..35e94fa --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/icons/index.ts @@ -0,0 +1,1252 @@ +import { h } from 'vue' +export const iconMap: any = { + 'app-404': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 1024 1024', + version: '1.1', + style: 'height:14px;width:14px', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M260.266667 789.333333c-21.333333 0-38.4-17.066667-38.4-38.4v-59.733333H38.4c-12.8 0-29.866667-8.533333-34.133333-21.333333-4.266667-17.066667-4.266667-29.866667 4.266666-42.666667l221.866667-294.4c8.533333-12.8 25.6-17.066667 42.666667-12.8 17.066667 4.266667 25.6 21.333333 25.6 38.4v256h34.133333c21.333333 0 38.4 17.066667 38.4 38.4s-17.066667 38.4-38.4 38.4H298.666667v59.733333c0 21.333333-17.066667 38.4-38.4 38.4z m-145.066667-179.2h106.666667V469.333333l-106.666667 140.8zM913.066667 742.4c-21.333333 0-38.4-17.066667-38.4-38.4v-59.733333h-183.466667c-12.8 0-29.866667-8.533333-34.133333-21.333334-8.533333-12.8-4.266667-29.866667 4.266666-38.4l221.866667-294.4c8.533333-12.8 25.6-17.066667 42.666667-12.8 17.066667 4.266667 25.6 21.333333 25.6 38.4v256h34.133333c21.333333 0 38.4 17.066667 38.4 38.4s-17.066667 38.4-38.4 38.4h-34.133333v59.733334c0 17.066667-17.066667 34.133333-38.4 34.133333zM768 567.466667h106.666667V426.666667L768 567.466667zM533.333333 597.333333c-46.933333 0-85.333333-25.6-119.466666-68.266666-29.866667-38.4-42.666667-93.866667-42.666667-145.066667 0-55.466667 17.066667-106.666667 42.666667-145.066667 29.866667-42.666667 72.533333-68.266667 119.466666-68.266666 46.933333 0 85.333333 25.6 119.466667 68.266666 29.866667 38.4 42.666667 93.866667 42.666667 145.066667 0 55.466667-17.066667 106.666667-42.666667 145.066667-34.133333 46.933333-76.8 68.266667-119.466667 68.266666z m0-362.666666c-55.466667 0-98.133333 68.266667-98.133333 149.333333s46.933333 149.333333 98.133333 149.333333c55.466667 0 98.133333-68.266667 98.133334-149.333333s-46.933333-149.333333-98.133334-149.333333z', + fill: '#978CFF' + }), + h('path', { + d: 'M354.133333 691.2a162.133333 21.333333 0 1 0 324.266667 0 162.133333 21.333333 0 1 0-324.266667 0Z', + fill: '#E3E5FC' + }), + h('path', { + d: 'M8.533333 832a162.133333 21.333333 0 1 0 324.266667 0 162.133333 21.333333 0 1 0-324.266667 0Z', + fill: '#E3E5FC' + }), + h('path', { + d: 'M661.333333 797.866667a162.133333 21.333333 0 1 0 324.266667 0 162.133333 21.333333 0 1 0-324.266667 0Z', + fill: '#E3E5FC' + }) + ] + ) + ]) + } + }, + + 'app-add-users': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M6.24984 5.41667C6.24984 6.7975 7.37067 7.91667 8.74984 7.91667C10.129 7.91667 11.2498 6.7975 11.2498 5.41667C11.2498 4.03583 10.129 2.91667 8.74984 2.91667C7.37067 2.91667 6.24984 4.03583 6.24984 5.41667ZM8.74984 1.25C11.0498 1.25 12.9165 3.11542 12.9165 5.41667C12.9165 7.71792 11.0498 9.58333 8.74984 9.58333C6.44984 9.58333 4.58317 7.71792 4.58317 5.41667C4.58317 3.11542 6.44984 1.25 8.74984 1.25ZM3.43734 15C3.37067 15.2663 3.33317 15.5454 3.33317 15.8333V16.6667H10.854C11.0841 16.6667 11.2706 16.8532 11.2706 17.0833V17.9167C11.2706 18.1468 11.0841 18.3333 10.854 18.3333H2.49984C2.0415 18.3333 1.6665 17.9604 1.6665 17.5V15.8333C1.6665 13.0721 3.904 10.8333 6.6665 10.8333H10.854C11.0841 10.8333 11.2706 11.0199 11.2706 11.25V12.0833C11.2706 12.3135 11.0841 12.5 10.854 12.5H6.6665C5.11234 12.5 3.80817 13.5625 3.43734 15ZM15.4165 11.6667C15.6466 11.6667 15.8332 11.8532 15.8332 12.0833V14.1667H17.9165C18.1466 14.1667 18.3332 14.3532 18.3332 14.5833V15.4167C18.3332 15.6468 18.1466 15.8333 17.9165 15.8333H15.8332V17.9167C15.8332 18.1468 15.6466 18.3333 15.4165 18.3333H14.5832C14.3531 18.3333 14.1665 18.1468 14.1665 17.9167V15.8333H12.0832C11.8531 15.8333 11.6665 15.6468 11.6665 15.4167V14.5833C11.6665 14.3532 11.8531 14.1667 12.0832 14.1667H14.1665V12.0833C14.1665 11.8532 14.3531 11.6667 14.5832 11.6667H15.4165Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + + 'app-exit': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M874.666667 855.744a19.093333 19.093333 0 0 1-19.136 18.922667H168.469333A19.2 19.2 0 0 1 149.333333 855.530667V168.469333A19.2 19.2 0 0 1 168.469333 149.333333h687.061334c10.581333 0 19.136 8.533333 19.136 18.922667V320h42.666666V168.256A61.717333 61.717333 0 0 0 855.530667 106.666667H168.469333A61.866667 61.866667 0 0 0 106.666667 168.469333v687.061334A61.866667 61.866667 0 0 0 168.469333 917.333333h687.061334A61.76 61.76 0 0 0 917.333333 855.744V704h-42.666666v151.744zM851.84 533.333333l-131.797333 131.754667a21.141333 21.141333 0 0 0 0.213333 29.973333 21.141333 21.141333 0 0 0 29.973333 0.192l165.589334-165.589333a20.821333 20.821333 0 0 0 6.122666-14.976 21.44 21.44 0 0 0-6.314666-14.997333l-168.533334-168.533334a21.141333 21.141333 0 0 0-29.952-0.213333 21.141333 21.141333 0 0 0 0.213334 29.973333L847.296 490.666667H469.333333v42.666666h382.506667z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + + 'app-team': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M7.08317 4.16667C5.93317 4.16667 4.99984 5.09958 4.99984 6.25C4.99984 7.40042 5.93317 8.33333 7.08317 8.33333C8.23317 8.33333 9.1665 7.40042 9.1665 6.25C9.1665 5.09958 8.23317 4.16667 7.08317 4.16667ZM3.33317 6.25C3.33317 4.17875 5.01234 2.5 7.08317 2.5C9.154 2.5 10.8332 4.17875 10.8332 6.25C10.8332 8.32125 9.154 10 7.08317 10C5.01234 10 3.33317 8.32125 3.33317 6.25ZM4.86234 12.5C3.34567 12.5 2.08317 13.7488 2.08317 15.3333V16.6667H12.0832V15.3333C12.0832 13.7488 10.8207 12.5 9.304 12.5H4.86234ZM0.416504 15.3333C0.416504 12.8479 2.40817 10.8333 4.86234 10.8333H9.304C11.7582 10.8333 13.7498 12.8479 13.7498 15.3333V17.5833C13.7498 17.9975 13.4165 18.3333 13.0082 18.3333H1.15817C0.749837 18.3333 0.416504 17.9975 0.416504 17.5833V15.3333ZM19.029 17.5H15.304V17.1592V15.8333H17.9165V15.3333C17.9165 14.4983 17.2123 13.75 16.2498 13.75H15.1582C14.9998 13.1342 14.7165 12.5692 14.3373 12.0833H16.2498C18.0915 12.0833 19.5832 13.5383 19.5832 15.3333V16.9583C19.5832 17.2575 19.3332 17.5 19.029 17.5ZM13.7498 8.33333C13.7498 7.87292 14.1248 7.5 14.5832 7.5C15.0415 7.5 15.4165 7.87292 15.4165 8.33333C15.4165 8.79375 15.0415 9.16667 14.5832 9.16667C14.1248 9.16667 13.7498 8.79375 13.7498 8.33333ZM14.5832 5.83333C13.204 5.83333 12.0832 6.9525 12.0832 8.33333C12.0832 9.71417 13.204 10.8333 14.5832 10.8333C15.9623 10.8333 17.0832 9.71417 17.0832 8.33333C17.0832 6.9525 15.9623 5.83333 14.5832 5.83333Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-team-active': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M7.08317 10C9.15424 10 10.8332 8.32107 10.8332 6.25C10.8332 4.17893 9.15424 2.5 7.08317 2.5C5.0121 2.5 3.33317 4.17893 3.33317 6.25C3.33317 8.32107 5.0121 10 7.08317 10Z', + fill: 'currentColor' + }), + h('path', { + d: 'M1.24984 18.3333C0.7896 18.3333 0.416504 17.9602 0.416504 17.5V15.8889C0.416504 13.0968 2.76035 10.8333 5.47333 10.8333H8.70065C11.4136 10.8333 13.7498 13.0968 13.7498 15.8889V17.5C13.7498 17.9602 13.3767 18.3333 12.9165 18.3333H1.24984Z', + fill: 'currentColor' + }), + h('path', { + d: 'M15.4165 17.5V17.2535C15.4165 15.3267 15.4165 13.3333 13.7498 12.0833C13.8196 12.0773 13.9366 12.0794 14.0491 12.0814C14.1036 12.0824 14.157 12.0833 14.2034 12.0833H15.8332C17.8679 12.0833 19.5832 13.3643 19.5832 15.4583V16.875C19.5832 17.2202 19.3033 17.5 18.9582 17.5H15.4165Z', + fill: 'currentColor' + }), + h('path', { + d: 'M14.5832 10.8333C15.9639 10.8333 17.0832 9.71405 17.0832 8.33333C17.0832 6.95262 15.9639 5.83333 14.5832 5.83333C13.2025 5.83333 12.0832 6.95262 12.0832 8.33333C12.0832 9.71405 13.2025 10.8333 14.5832 10.8333Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-template': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M0.833496 6.36987C0.833496 6.04308 1.0245 5.74646 1.32199 5.61123L9.65533 1.82335C9.87443 1.72376 10.1259 1.72376 10.345 1.82335L18.6783 5.61123C18.9758 5.74646 19.1668 6.04308 19.1668 6.36987V14.4849C19.1668 14.8006 18.9885 15.0891 18.7062 15.2303L10.3728 19.3969C10.1382 19.5142 9.86209 19.5142 9.62748 19.3969L1.29415 15.2303C1.01183 15.0891 0.833496 14.8006 0.833496 14.4849V6.36987ZM16.015 6.2314L10.0002 3.49737L3.95283 6.24616L9.9668 8.83681L16.015 6.2314ZM10.8335 10.2782V17.3032L17.5002 13.9699V7.40638L10.8335 10.2782ZM2.50016 7.43512V13.9699L9.16683 17.3032V10.3069L2.50016 7.43512Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-template-active': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M0.833496 7.40037C0.833496 7.04018 1.20284 6.79815 1.53309 6.94194L8.77383 10.0945C8.95625 10.1739 9.07423 10.3539 9.07423 10.5529V18.339C9.07423 18.7151 8.67454 18.9566 8.34159 18.7816L1.1439 14.9982C1.05114 14.9474 0.973682 14.8737 0.919239 14.7847C0.864795 14.6956 0.835262 14.5943 0.833578 14.4907L0.833496 14.4814V7.40037ZM18.4646 6.9322C18.7952 6.78604 19.1668 7.02807 19.1668 7.38949V14.4814C19.1668 14.5866 19.1381 14.6899 19.0835 14.7807C19.029 14.8715 18.9506 14.9466 18.8564 14.9982L11.6587 18.7816C11.3258 18.9566 10.9261 18.7151 10.9261 18.339L10.9261 10.5912C10.9261 10.3932 11.0429 10.2139 11.2239 10.1339L18.4646 6.9322ZM9.70006 1.74337C9.79165 1.69312 9.89502 1.66672 10.0002 1.66672C10.1053 1.66672 10.2087 1.69312 10.3003 1.74337L17.1982 4.80724C17.5964 4.9841 17.5936 5.55021 17.1937 5.72313L10.1986 8.74754C10.072 8.80229 9.92836 8.80229 9.80173 8.74754L2.80663 5.72313C2.4067 5.55021 2.40389 4.9841 2.80209 4.80724L9.70006 1.74337Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-setting': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M3.60734 16.4448L3.34807 16.1624C2.44036 15.1737 1.75935 13.9944 1.36011 12.7066L1.24756 12.3435L2.95427 10.0001L1.24756 7.65668L1.36011 7.29361C1.75935 6.00574 2.44036 4.82649 3.34807 3.83779L3.60734 3.55539L6.47552 3.86889L7.64049 1.21319L8.01405 1.12909C8.66134 0.983366 9.32633 0.90918 10.0004 0.90918C10.6744 0.90918 11.3394 0.983366 11.9867 1.12909L12.3603 1.21319L13.5252 3.86889L16.3934 3.55539L16.6527 3.83779C17.5604 4.82649 18.2414 6.00574 18.6406 7.29361L18.7532 7.65668L17.0465 10.0001L18.7532 12.3435L18.6406 12.7066C18.2414 13.9944 17.5604 15.1737 16.6527 16.1624L16.3934 16.4448L13.5252 16.1313L12.3603 18.787L11.9867 18.8711C11.3394 19.0168 10.6744 19.091 10.0004 19.091C9.32633 19.091 8.66134 19.0168 8.01405 18.8711L7.64049 18.787L6.47552 16.1313L3.60734 16.4448ZM6.51159 14.6031C7.05002 14.5443 7.56436 14.8417 7.78194 15.3377L8.71565 17.4662C9.13677 17.5389 9.56603 17.5758 10.0004 17.5758C10.4347 17.5758 10.864 17.5389 11.2851 17.4662L12.2188 15.3377C12.4364 14.8417 12.9507 14.5443 13.4892 14.6031L15.7844 14.854C16.3387 14.1868 16.7757 13.4286 17.0741 12.6116L15.7038 10.7301C15.3869 10.295 15.3869 9.70511 15.7038 9.26999L17.0741 7.38847C16.7757 6.57146 16.3387 5.81331 15.7844 5.14609L13.4892 5.39696C12.9507 5.45581 12.4364 5.1584 12.2188 4.66238L11.2851 2.53389C10.864 2.46117 10.4347 2.42429 10.0004 2.42429C9.56603 2.42429 9.13677 2.46117 8.71565 2.53389L7.78194 4.66238C7.56436 5.1584 7.05002 5.45581 6.51159 5.39696L4.21641 5.14609C3.66208 5.81331 3.22502 6.57146 2.92666 7.38847L4.29697 9.26999C4.61387 9.70511 4.61387 10.295 4.29697 10.7301L2.92666 12.6116C3.22502 13.4286 3.66208 14.1868 4.21641 14.854L6.51159 14.6031ZM10.0004 13.788C7.91555 13.788 6.22693 12.0913 6.22693 10.0001C6.22693 7.9089 7.91555 6.2122 10.0004 6.2122C12.0852 6.2122 13.7738 7.9089 13.7738 10.0001C13.7738 12.0913 12.0852 13.788 10.0004 13.788ZM10.0004 12.2729C11.2468 12.2729 12.2587 11.2561 12.2587 10.0001C12.2587 8.74413 11.2468 7.72741 10.0004 7.72741C8.75397 7.72741 7.74208 8.74413 7.74208 10.0001C7.74208 11.2561 8.75397 12.2729 10.0004 12.2729Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-setting-active': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M3.26425 16.2151C2.35478 15.2292 1.65887 14.0432 1.25 12.7305L2.70785 10.7384C3.02952 10.2988 3.02952 9.70154 2.70785 9.26197L1.25 7.26979C1.65887 5.95714 2.35478 4.77112 3.26425 3.78522L5.71416 4.05172C6.25589 4.11065 6.77338 3.81185 6.99316 3.31321L7.98848 1.05505C8.63579 0.910018 9.30896 0.833496 10 0.833496C10.691 0.833496 11.3642 0.910018 12.0115 1.05505L13.0068 3.31321C13.2266 3.81185 13.7441 4.11065 14.2858 4.05172L16.7357 3.78522C17.6452 4.77112 18.3411 5.95714 18.75 7.26979L17.2921 9.26197C16.9705 9.70154 16.9705 10.2988 17.2921 10.7384L18.75 12.7305C18.3411 14.0432 17.6452 15.2292 16.7357 16.2151L14.2858 15.9486C13.7441 15.8897 13.2266 16.1885 13.0068 16.6871L12.0115 18.9453C11.3642 19.0903 10.691 19.1668 10 19.1668C9.30896 19.1668 8.63579 19.0903 7.98848 18.9453L6.99316 16.6871C6.77338 16.1885 6.25589 15.8897 5.71416 15.9486L3.26425 16.2151ZM10 13.3335C11.8409 13.3335 13.3333 11.8411 13.3333 10.0002C13.3333 8.15921 11.8409 6.66683 10 6.66683C8.15905 6.66683 6.66667 8.15921 6.66667 10.0002C6.66667 11.8411 8.15905 13.3335 10 13.3335Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-document': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M13.3333 2.50016H4.16667V17.5002H15.8333V5.01641H13.75C13.6395 5.01641 13.5335 4.97251 13.4554 4.89437C13.3772 4.81623 13.3333 4.71025 13.3333 4.59975V2.50016ZM3.33333 0.833496H14.2379C14.3474 0.833465 14.4558 0.855013 14.557 0.896908C14.6582 0.938804 14.7501 1.00023 14.8275 1.07766L17.2563 3.50725C17.4124 3.66356 17.5001 3.87548 17.5 4.09641V18.3335C17.5 18.5545 17.4122 18.7665 17.2559 18.9228C17.0996 19.079 16.8877 19.1668 16.6667 19.1668H3.33333C3.11232 19.1668 2.90036 19.079 2.74408 18.9228C2.5878 18.7665 2.5 18.5545 2.5 18.3335V1.66683C2.5 1.44582 2.5878 1.23385 2.74408 1.07757C2.90036 0.921293 3.11232 0.833496 3.33333 0.833496ZM6.66667 8.3335H13.3333C13.4438 8.3335 13.5498 8.3774 13.628 8.45554C13.7061 8.53368 13.75 8.63966 13.75 8.75016V9.5835C13.75 9.694 13.7061 9.79998 13.628 9.87812C13.5498 9.95626 13.4438 10.0002 13.3333 10.0002H6.66667C6.55616 10.0002 6.45018 9.95626 6.37204 9.87812C6.2939 9.79998 6.25 9.694 6.25 9.5835V8.75016C6.25 8.63966 6.2939 8.53368 6.37204 8.45554C6.45018 8.3774 6.55616 8.3335 6.66667 8.3335ZM6.66667 12.5002H10.4167C10.4714 12.5002 10.5256 12.5109 10.5761 12.5319C10.6267 12.5528 10.6726 12.5835 10.7113 12.6222C10.75 12.6609 10.7807 12.7068 10.8016 12.7574C10.8226 12.8079 10.8333 12.8621 10.8333 12.9168V13.7502C10.8333 13.8049 10.8226 13.8591 10.8016 13.9096C10.7807 13.9602 10.75 14.0061 10.7113 14.0448C10.6726 14.0835 10.6267 14.1142 10.5761 14.1351C10.5256 14.1561 10.4714 14.1668 10.4167 14.1668H6.66667C6.55616 14.1668 6.45018 14.1229 6.37204 14.0448C6.2939 13.9667 6.25 13.8607 6.25 13.7502V12.9168C6.25 12.8063 6.2939 12.7003 6.37204 12.6222C6.45018 12.5441 6.55616 12.5002 6.66667 12.5002Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-document-active': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M3.3335 2.08333C3.3335 1.6231 3.70659 1.25 4.16683 1.25H12.3842C12.4959 1.25 12.603 1.29489 12.6813 1.37459L16.5473 5.30784C16.6239 5.38576 16.6668 5.49065 16.6668 5.59992V17.9167C16.6668 18.3769 16.2937 18.75 15.8335 18.75H4.16683C3.70659 18.75 3.3335 18.3769 3.3335 17.9167V2.08333Z', + fill: 'currentColor' + }), + h('path', { + d: 'M12.5 1.2666C12.568 1.28633 12.6306 1.32327 12.6812 1.37472L16.5472 5.30797C16.5788 5.34017 16.6047 5.37698 16.6242 5.4168H13.4459C12.9235 5.4168 12.5 4.99328 12.5 4.47085V1.2666Z', + fill: '#2B5FD9' + }), + h('path', { + d: 'M6.71305 7.72705C6.48293 7.72705 6.29639 7.9136 6.29639 8.14372V8.82554C6.29639 9.05565 6.48294 9.2422 6.71305 9.2422H13.2871C13.5172 9.2422 13.7038 9.05565 13.7038 8.82554V8.14372C13.7038 7.9136 13.5172 7.72705 13.2871 7.72705H6.71305Z', + fill: 'white' + }), + h('path', { + d: 'M6.71305 11.5149C6.48293 11.5149 6.29639 11.7015 6.29639 11.9316V12.6134C6.29639 12.8435 6.48294 13.0301 6.71305 13.0301H9.58342C9.81354 13.0301 10.0001 12.8435 10.0001 12.6134V11.9316C10.0001 11.7015 9.81354 11.5149 9.58342 11.5149H6.71305Z', + fill: 'white' + }) + ] + ) + ]) + } + }, + 'app-view': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + viewBox: '0 0 16 12', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M6.9649 8.5176L10.8075 6.59629C10.9365 6.53178 11.0412 6.42717 11.1057 6.29815C11.2703 5.96883 11.1368 5.56838 10.8075 5.40372L6.9649 3.48241C6.87233 3.43612 6.77025 3.41203 6.66675 3.41203C6.29856 3.41203 6.00009 3.71051 6.00009 4.07869V7.92132C6.00009 8.02481 6.02418 8.12689 6.07047 8.21946C6.23513 8.54878 6.63558 8.68226 6.9649 8.5176Z', + fill: 'currentColor' + }), + h('path', { + d: 'M15.3334 0.75C15.3334 0.335786 15.0349 0 14.6667 0H1.33341C0.965225 0 0.666748 0.335786 0.666748 0.75V11.25C0.666748 11.6642 0.965225 12 1.33341 12H14.6667C15.0349 12 15.3334 11.6642 15.3334 11.25V0.75ZM2.00008 1.5H14.0001V10.5H2.00008V1.5Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-all-menu': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M2.91683 2.0835H8.3335C8.79373 2.0835 9.16683 2.45659 9.16683 2.91683V8.3335C9.16683 8.79373 8.79373 9.16683 8.3335 9.16683H2.91683C2.45659 9.16683 2.0835 8.79373 2.0835 8.3335V2.91683C2.0835 2.45659 2.45659 2.0835 2.91683 2.0835ZM3.75016 3.75016V7.50016H7.50016V3.75016H3.75016Z', + fill: 'currentColor' + }), + h('path', { + d: 'M2.91683 10.8335H8.3335C8.79373 10.8335 9.16683 11.2066 9.16683 11.6668V17.0835C9.16683 17.5437 8.79373 17.9168 8.3335 17.9168H2.91683C2.45659 17.9168 2.0835 17.5437 2.0835 17.0835V11.6668C2.0835 11.2066 2.45659 10.8335 2.91683 10.8335ZM3.75016 16.2502H7.50016V12.5002H3.75016V16.2502Z', + fill: 'currentColor' + }), + h('path', { + d: 'M11.6668 2.0835H17.0835C17.5437 2.0835 17.9168 2.45659 17.9168 2.91683V8.3335C17.9168 8.79373 17.5437 9.16683 17.0835 9.16683H11.6668C11.2066 9.16683 10.8335 8.79373 10.8335 8.3335V2.91683C10.8335 2.45659 11.2066 2.0835 11.6668 2.0835ZM12.5002 7.50016H16.2502V3.75016H12.5002V7.50016Z', + fill: 'currentColor' + }), + h('path', { + d: 'M11.6668 10.8335H17.0835C17.5437 10.8335 17.9168 11.2066 17.9168 11.6668V17.0835C17.9168 17.5437 17.5437 17.9168 17.0835 17.9168H11.6668C11.2066 17.9168 10.8335 17.5437 10.8335 17.0835V11.6668C10.8335 11.2066 11.2066 10.8335 11.6668 10.8335ZM12.5002 12.5002V16.2502H16.2502V12.5002H12.5002Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-all-menu-active': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M8.33317 1.6665H2.49984C2.0396 1.6665 1.6665 2.0396 1.6665 2.49984V8.33317C1.6665 8.79341 2.0396 9.1665 2.49984 9.1665H8.33317C8.79341 9.1665 9.1665 8.79341 9.1665 8.33317V2.49984C9.1665 2.0396 8.79341 1.6665 8.33317 1.6665Z', + fill: 'currentColor' + }), + h('path', { + d: 'M8.33317 10.8332H2.49984C2.0396 10.8332 1.6665 11.2063 1.6665 11.6665V17.4998C1.6665 17.9601 2.0396 18.3332 2.49984 18.3332H8.33317C8.79341 18.3332 9.1665 17.9601 9.1665 17.4998V11.6665C9.1665 11.2063 8.79341 10.8332 8.33317 10.8332Z', + fill: 'currentColor' + }), + h('path', { + d: 'M17.4998 1.6665H11.6665C11.2063 1.6665 10.8332 2.0396 10.8332 2.49984V8.33317C10.8332 8.79341 11.2063 9.1665 11.6665 9.1665H17.4998C17.9601 9.1665 18.3332 8.79341 18.3332 8.33317V2.49984C18.3332 2.0396 17.9601 1.6665 17.4998 1.6665Z', + fill: 'currentColor' + }), + h('path', { + d: 'M17.4508 10.8332H11.7155C11.2282 10.8332 10.8332 11.2282 10.8332 11.7155V17.4508C10.8332 17.9381 11.2282 18.3332 11.7155 18.3332H17.4508C17.9381 18.3332 18.3332 17.9381 18.3332 17.4508V11.7155C18.3332 11.2282 17.9381 10.8332 17.4508 10.8332Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-restore': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 16 16', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M3.33333 5.3335V13.3335H10V5.3335H3.33333ZM11.3333 4.66683V14.0742C11.3333 14.4015 11.0548 14.6668 10.7111 14.6668H2.62222C2.27858 14.6668 2 14.4015 2 14.0742V4.59276C2 4.26548 2.27858 4.00016 2.62222 4.00016H10.6667C11.0349 4.00016 11.3333 4.29864 11.3333 4.66683ZM13.8047 1.52876C13.9254 1.6494 14 1.81607 14 2.00016V10.3335C14 10.5176 13.8508 10.6668 13.6667 10.6668H13C12.8159 10.6668 12.6667 10.5176 12.6667 10.3335V2.66683H6.33333C6.14924 2.66683 6 2.51759 6 2.3335V1.66683C6 1.48273 6.14924 1.3335 6.33333 1.3335H13.3333C13.5174 1.3335 13.6841 1.40812 13.8047 1.52876Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-copy': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M213.333333 341.333333v512h426.666667V341.333333H213.333333z m512-42.666666v602.069333c0 20.949333-17.834667 37.930667-39.808 37.930667H167.808C145.834667 938.666667 128 921.685333 128 900.736V293.973333C128 272.981333 145.834667 256 167.808 256H682.666667a42.666667 42.666667 0 0 1 42.666666 42.666667z m158.165334-200.832A42.538667 42.538667 0 0 1 896 128v533.333333a21.333333 21.333333 0 0 1-21.333333 21.333334h-42.666667a21.333333 21.333333 0 0 1-21.333333-21.333334V170.666667H405.333333a21.333333 21.333333 0 0 1-21.333333-21.333334v-42.666666a21.333333 21.333333 0 0 1 21.333333-21.333334H853.333333c11.776 0 22.442667 4.778667 30.165334 12.501334z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-like': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 16 16', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M2.00518 14.6608H0.666612C0.666097 14.6874 0.666707 5.33317 0.666612 5.29087H2.00518C2.00004 5.33317 1.98014 14.6874 2.00518 14.6608ZM9.70096 5.28984H12.5717C14.5687 5.28984 15.0274 7.05264 14.5687 8.37353L12.5717 13.6308C12.4029 14.2423 11.8409 14.6665 11.1995 14.6665H3.33882C3.154 14.6665 3.00418 14.5167 3.00418 14.3319V5.62448C3.00418 5.43966 3.154 5.28984 3.33882 5.28984H4.02656C4.24449 5.28984 4.44877 5.18374 4.5741 5.00545L7.35254 1.05296C7.5406 0.753754 8.04824 0.52438 8.5893 0.770777C9.40089 1.14037 10.3724 1.94718 10.3724 3.28394C10.3724 3.78809 10.1486 4.45673 9.70096 5.28984ZM12.5717 6.62841H7.46215L8.52183 4.65626C8.87422 4.00045 9.03388 3.52351 9.03388 3.28394C9.03388 2.89556 8.9524 2.45627 8.25544 2.09612L5.26934 6.34402C5.14401 6.5223 4.93973 6.62841 4.72181 6.62841H4.34275V13.3279H11.1995C11.2411 13.3279 11.2734 13.3035 11.2813 13.2747L11.298 13.2142L13.3098 7.91815C13.5743 7.13902 13.3105 6.62841 12.5717 6.62841Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-like-color': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 16 16', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M2.00497 14.6608H2.00518C2.00511 14.6609 2.00504 14.6609 2.00497 14.6608H0.666612C0.666097 14.6874 0.666707 5.33317 0.666612 5.29087H2.00518C2.00006 5.33305 1.98026 14.6344 2.00497 14.6608Z', + fill: '#FFC60A' + }), + h('path', { + d: 'M12.5717 5.28984H9.70096C10.1486 4.45673 10.3724 3.78809 10.3724 3.28394C10.3724 1.94718 9.40089 1.14037 8.5893 0.770777C8.04824 0.52438 7.5406 0.753754 7.35254 1.05296L4.5741 5.00545C4.44877 5.18374 4.24449 5.28984 4.02656 5.28984H3.33882C3.154 5.28984 3.00418 5.43966 3.00418 5.62448V14.3319C3.00418 14.5167 3.154 14.6665 3.33882 14.6665H11.1995C11.8409 14.6665 12.4029 14.2423 12.5717 13.6308L14.5687 8.37353C15.0274 7.05264 14.5687 5.28984 12.5717 5.28984Z', + fill: '#FFC60A' + }) + ] + ) + ]) + } + }, + 'app-oppose': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 16 16', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M2.00518 1.28008H0.666616C0.666616 1.33341 0.666504 10.6667 0.666616 10.65H2.00518C1.99984 10.6667 1.99984 1.33341 2.00518 1.28008ZM9.70097 10.6511H12.5717C14.5687 10.6511 15.0274 8.88828 14.5687 7.56739L12.5717 2.3101C12.4029 1.69862 11.8409 1.27441 11.1996 1.27441H3.33883C3.15401 1.27441 3.00418 1.42424 3.00418 1.60906V10.3164C3.00418 10.5013 3.15401 10.6511 3.33883 10.6511H4.02656C4.24449 10.6511 4.44877 10.7572 4.5741 10.9355L7.35254 14.888C7.5406 15.1872 8.04825 15.4165 8.58931 15.1701C9.40089 14.8005 10.3724 13.9937 10.3724 12.657C10.3724 12.1528 10.1486 11.4842 9.70097 10.6511ZM12.5717 9.31251H7.46216L8.52184 11.2847C8.87422 11.9405 9.03388 12.4174 9.03388 12.657C9.03388 13.0454 8.95241 13.4846 8.25545 13.8448L5.26935 9.5969C5.14402 9.41861 4.93974 9.31251 4.72181 9.31251H4.34275V2.61298H11.1996C11.2411 2.61298 11.2734 2.63737 11.2813 2.6662L11.298 2.72673L13.3098 8.02277C13.5743 8.8019 13.3105 9.31251 12.5717 9.31251Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-oppose-color': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 16 16', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M9.70106 10.7102H12.5718C14.5688 10.7102 15.0275 8.94736 14.5688 7.62647L12.5718 2.36918C12.403 1.7577 11.841 1.3335 11.1996 1.3335H3.33891C3.1541 1.3335 3.00427 1.48332 3.00427 1.66814V10.3755C3.00427 10.5603 3.1541 10.7102 3.33891 10.7102H4.02665C4.24458 10.7102 4.44886 10.8163 4.57419 10.9945L7.35263 14.947C7.54069 15.2462 8.04834 15.4756 8.58939 15.2292C9.40098 14.8596 10.3725 14.0528 10.3725 12.7161C10.3725 12.2119 10.1487 11.5433 9.70106 10.7102Z', + fill: '#F54A45' + }), + h('path', { + d: 'M2.00004 1.3335H0.661473C0.661473 1.3335 0.660982 10.7764 0.661473 10.7035H2.00001C1.99469 10.6868 1.9947 1.38674 2.00004 1.3335Z', + fill: '#F54A45' + }) + ] + ) + ]) + } + }, + 'app-hit-test': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + + [ + h('path', { + d: 'M1.6665 9.99986C1.6665 5.3975 5.39748 1.66653 9.99984 1.66653H10.8332V3.3332H9.99984C6.31795 3.3332 3.33317 6.31797 3.33317 9.99986C3.33317 13.6818 6.31795 16.6665 9.99984 16.6665C13.6817 16.6665 16.6665 13.6818 16.6665 9.99986V9.16653H18.3332V9.99986C18.3332 14.6022 14.6022 18.3332 9.99984 18.3332C5.39748 18.3332 1.6665 14.6022 1.6665 9.99986Z', + fill: 'currentColor', + fillRule: 'evenodd', + clipRule: 'evenodd' + }), + h('path', { + d: 'M5.4165 9.99986C5.4165 7.46854 7.46852 5.41653 9.99984 5.41653H10.8332V7.0832H9.99984C8.38899 7.0832 7.08317 8.38902 7.08317 9.99986C7.08317 11.6107 8.38899 12.9165 9.99984 12.9165C11.6107 12.9165 12.9165 11.6107 12.9165 9.99986V9.16653H14.5832V9.99986C14.5832 12.5312 12.5312 14.5832 9.99984 14.5832C7.46852 14.5832 5.4165 12.5312 5.4165 9.99986Z', + fill: 'currentColor', + fillRule: 'evenodd', + clipRule: 'evenodd' + }), + h('path', { + d: 'M13.2138 6.78296C13.5394 7.10825 13.5397 7.63588 13.2144 7.96147L10.5894 10.5889C10.2641 10.9145 9.73644 10.9147 9.41085 10.5894C9.08527 10.2641 9.08502 9.73651 9.41031 9.41092L12.0353 6.7835C12.3606 6.45792 12.8882 6.45767 13.2138 6.78296Z', + fill: 'currentColor', + fillRule: 'evenodd', + clipRule: 'evenodd' + }), + h('path', { + d: 'M15.1942 1.72962C15.506 1.8584 15.7095 2.16249 15.7095 2.49986V4.29161H17.4998C17.8365 4.29161 18.1401 4.49423 18.2693 4.80516C18.3985 5.11608 18.3279 5.47421 18.0904 5.71284L15.8508 7.96276C15.6944 8.11987 15.4819 8.2082 15.2602 8.2082H12.6248C12.1645 8.2082 11.7914 7.8351 11.7914 7.37486V4.76086C11.7914 4.54046 11.8787 4.32904 12.0342 4.17287L14.2856 1.91186C14.5237 1.6728 14.8824 1.60085 15.1942 1.72962ZM13.4581 5.105V6.54153H14.9139L15.4945 5.95828H14.8761C14.4159 5.95828 14.0428 5.58518 14.0428 5.12495V4.51779L13.4581 5.105Z', + fill: 'currentColor', + fillRule: 'evenodd', + clipRule: 'evenodd' + }) + ] + ) + ]) + } + }, + 'app-warning': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M512 234.666667A53.333333 53.333333 0 1 1 512 341.333333a53.333333 53.333333 0 0 1 0-106.666666zM522.666667 384h-64a21.333333 21.333333 0 0 0-21.333334 21.333333v42.666667a21.333333 21.333333 0 0 0 21.333334 21.333333h21.333333v213.333334H426.666667a21.333333 21.333333 0 0 0-21.333334 21.333333v42.666667a21.333333 21.333333 0 0 0 21.333334 21.333333h192a21.333333 21.333333 0 0 0 21.333333-21.333333v-42.666667a21.333333 21.333333 0 0 0-21.333333-21.333333h-53.333334v-256a42.666667 42.666667 0 0 0-42.666666-42.666667z', + fill: 'currentColor' + }), + h('path', { + d: 'M512 981.333333C252.8 981.333333 42.666667 771.2 42.666667 512S252.8 42.666667 512 42.666667s469.333333 210.133333 469.333333 469.333333-210.133333 469.333333-469.333333 469.333333z m0-85.333333a384 384 0 1 0 0-768 384 384 0 0 0 0 768z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-warning-colorful': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M42.666667 512c0 259.2 210.133333 469.333333 469.333333 469.333333s469.333333-210.133333 469.333333-469.333333S771.2 42.666667 512 42.666667 42.666667 252.8 42.666667 512z m469.333333-277.333333A53.333333 53.333333 0 1 1 512 341.333333a53.333333 53.333333 0 0 1 0-106.666666zM458.666667 384h64a42.666667 42.666667 0 0 1 42.666666 42.666667v256h53.333334a21.333333 21.333333 0 0 1 21.333333 21.333333v42.666667a21.333333 21.333333 0 0 1-21.333333 21.333333H426.666667a21.333333 21.333333 0 0 1-21.333334-21.333333v-42.666667a21.333333 21.333333 0 0 1 21.333334-21.333333h53.333333v-213.333334h-21.333333a21.333333 21.333333 0 0 1-21.333334-21.333333v-42.666667a21.333333 21.333333 0 0 1 21.333334-21.333333z', + fill: '#3370FF' + }) + ] + ) + ]) + } + }, + 'app-operation': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 16 16', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M11.0002 11.3333H14.2395C14.3881 11.3333 14.442 11.3487 14.4963 11.3778C14.5506 11.4068 14.5933 11.4495 14.6223 11.5038C14.6514 11.5581 14.6668 11.612 14.6668 11.7606V12.2393C14.6668 12.3879 14.6514 12.4417 14.6223 12.4961C14.5933 12.5504 14.5506 12.593 14.4963 12.6221C14.442 12.6511 14.3881 12.6666 14.2395 12.6666H11.0002V14.2393C11.0002 14.3879 10.9847 14.4417 10.9556 14.4961C10.9266 14.5504 10.884 14.593 10.8296 14.6221C10.7753 14.6511 10.7214 14.6666 10.5728 14.6666H10.0941C9.94556 14.6666 9.89168 14.6511 9.83736 14.6221C9.78304 14.593 9.7404 14.5504 9.71135 14.4961C9.6823 14.4417 9.66683 14.3879 9.66683 14.2393V12.6666H1.76081C1.61222 12.6666 1.55834 12.6511 1.50402 12.6221C1.4497 12.593 1.40707 12.5504 1.37802 12.4961C1.34897 12.4417 1.3335 12.3879 1.3335 12.2393V11.7606C1.3335 11.612 1.34897 11.5581 1.37802 11.5038C1.40707 11.4495 1.4497 11.4068 1.50402 11.3778C1.55834 11.3487 1.61222 11.3333 1.76081 11.3333H9.66683V9.76057C9.66683 9.61198 9.6823 9.5581 9.71135 9.50378C9.7404 9.44946 9.78304 9.40683 9.83736 9.37778C9.89168 9.34872 9.94556 9.33325 10.0941 9.33325H10.5728C10.7214 9.33325 10.7753 9.34872 10.8296 9.37778C10.884 9.40683 10.9266 9.44946 10.9556 9.50378C10.9847 9.5581 11.0002 9.61198 11.0002 9.76057V11.3333ZM5.00016 3.33325V1.76057C5.00016 1.61198 5.01563 1.5581 5.04469 1.50378C5.07374 1.44946 5.11637 1.40683 5.17069 1.37777C5.22501 1.34872 5.27889 1.33325 5.42748 1.33325H5.90618C6.05477 1.33325 6.10865 1.34872 6.16297 1.37777C6.21729 1.40683 6.25992 1.44946 6.28897 1.50378C6.31803 1.5581 6.3335 1.61198 6.3335 1.76057V3.33325H14.2395C14.3881 3.33325 14.442 3.34872 14.4963 3.37777C14.5506 3.40683 14.5933 3.44946 14.6223 3.50378C14.6514 3.5581 14.6668 3.61198 14.6668 3.76057V4.23927C14.6668 4.38786 14.6514 4.44174 14.6223 4.49606C14.5933 4.55038 14.5506 4.59301 14.4963 4.62206C14.442 4.65111 14.3881 4.66659 14.2395 4.66659H6.3335V6.23927C6.3335 6.38786 6.31803 6.44174 6.28897 6.49606C6.25992 6.55038 6.21729 6.59301 6.16297 6.62206C6.10865 6.65111 6.05477 6.66659 5.90618 6.66659H5.42748C5.27889 6.66659 5.22501 6.65111 5.17069 6.62206C5.11637 6.59301 5.07374 6.55038 5.04469 6.49606C5.01563 6.44174 5.00016 6.38786 5.00016 6.23927V4.66659H1.76081C1.61222 4.66659 1.55834 4.65111 1.50402 4.62206C1.4497 4.59301 1.40707 4.55038 1.37802 4.49606C1.34897 4.44174 1.3335 4.38786 1.3335 4.23927V3.76057C1.3335 3.61198 1.34897 3.5581 1.37802 3.50378C1.40707 3.44946 1.4497 3.40683 1.50402 3.37777C1.55834 3.34872 1.61222 3.33325 1.76081 3.33325H5.00016Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-reading': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M768 128H256a85.333333 85.333333 0 0 0-85.333333 85.333333v426.666667h512V64h85.333333v640a21.333333 21.333333 0 0 1-21.333333 21.333333H256a85.333333 85.333333 0 0 0-0.128 170.666667H832a21.333333 21.333333 0 0 0 21.333333-21.333333V341.333333h85.333334v597.333334a42.666667 42.666667 0 0 1-42.666667 42.666666H256c-94.293333 0-170.666667-76.16-170.666667-170.410666V213.248C85.333333 119.04 161.706667 42.666667 256 42.666667h469.333333a42.666667 42.666667 0 0 1 42.666667 42.666666v42.666667z', + fill: 'currentColor' + }), + h('path', { + d: 'M277.333333 768a21.333333 21.333333 0 0 0-21.333333 21.333333v42.666667a21.333333 21.333333 0 0 0 21.333333 21.333333h469.333334a21.333333 21.333333 0 0 0 21.333333-21.333333v-42.666667a21.333333 21.333333 0 0 0-21.333333-21.333333h-469.333334z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-github': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M511.6 76.3C264.3 76.2 64 276.4 64 523.5 64 718.9 189.3 885 363.8 946c23.5 5.9 19.9-10.8 19.9-22.2v-77.5c-135.7 15.9-141.2-73.9-150.3-88.9C215 726 171.5 718 184.5 703c30.9-15.9 62.4 4 98.9 57.9 26.4 39.1 77.9 32.5 104 26 5.7-23.5 17.9-44.5 34.7-60.8-140.6-25.2-199.2-111-199.2-213 0-49.5 16.3-95 48.3-131.7-20.4-60.5 1.9-112.3 4.9-120 58.1-5.2 118.5 41.6 123.2 45.3 33-8.9 70.7-13.6 112.9-13.6 42.4 0 80.2 4.9 113.5 13.9 11.3-8.6 67.3-48.8 121.3-43.9 2.9 7.7 24.7 58.3 5.5 118 32.4 36.8 48.9 82.7 48.9 132.3 0 102.2-59 188.1-200 212.9 23.5 23.2 38.1 55.4 38.1 91v112.5c0.8 9 0 17.9 15 17.9 177.1-59.7 304.6-227 304.6-424.1 0-247.2-200.4-447.3-447.5-447.3z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-help': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M512 896a384 384 0 1 0 0-768 384 384 0 0 0 0 768z m0 85.333333C252.8 981.333333 42.666667 771.2 42.666667 512S252.8 42.666667 512 42.666667s469.333333 210.133333 469.333333 469.333333-210.133333 469.333333-469.333333 469.333333z m-21.333333-298.666666h42.666666a21.333333 21.333333 0 0 1 21.333334 21.333333v42.666667a21.333333 21.333333 0 0 1-21.333334 21.333333h-42.666666a21.333333 21.333333 0 0 1-21.333334-21.333333v-42.666667a21.333333 21.333333 0 0 1 21.333334-21.333333zM343.466667 396.032c0.554667-4.778667 1.109333-8.746667 1.664-11.946667 8.32-46.293333 29.397333-80.341333 63.189333-102.144 26.453333-17.28 59.008-25.941333 97.621333-25.941333 50.730667 0 92.842667 12.288 126.378667 36.864 33.578667 24.533333 50.346667 60.928 50.346667 109.141333 0 29.568-7.253333 54.485333-21.888 74.752-8.533333 12.245333-24.917333 27.946667-49.152 47.061334l-23.893334 18.773333c-13.013333 10.24-21.632 22.186667-25.898666 35.84-1.152 3.712-2.176 10.624-3.072 20.736a21.333333 21.333333 0 0 1-21.248 19.498667h-47.786667a21.333333 21.333333 0 0 1-21.248-23.296c2.773333-29.696 5.717333-48.469333 8.832-56.362667 5.845333-14.677333 20.906667-31.573333 45.141333-50.688l24.533334-19.413333c8.106667-6.144 49.749333-35.456 49.749333-61.44 0-25.941333-4.522667-35.498667-17.578667-49.749334-13.013333-14.208-42.368-18.773333-68.864-18.773333-26.026667 0-48.256 6.869333-59.136 24.405333-5.034667 8.106667-9.173333 16.768-12.117333 25.6a89.472 89.472 0 0 0-3.114667 13.098667 21.333333 21.333333 0 0 1-21.034666 17.706667H364.672a21.333333 21.333333 0 0 1-21.205333-23.722667z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-pricing': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M723.584 128c24.618667 0 48.213333 10.112 64.768 28.074667l170.965333 185.472c28.288 30.72 29.226667 76.373333 4.48 106.837333l-403.797333 457.685333a64 64 0 0 1-96 0l-397.824-450.986666a82.090667 82.090667 0 0 1-1.493333-113.493334l170.965333-185.514666C252.16 138.154667 275.754667 128 300.373333 128h423.168z m0 85.333333H300.373333c-1.024 0-1.834667 0.341333-2.048 0.597334L129.152 397.482667 512 831.488l382.848-433.92-169.216-183.637333a2.304 2.304 0 0 0-1.109333-0.512L723.584 213.333333z m-12.586667 202.794667a42.666667 42.666667 0 0 1 3.541334 60.202667l-170.666667 192a42.666667 42.666667 0 0 1-63.744 0l-170.666667-192a42.666667 42.666667 0 1 1 63.744-56.661334L512 575.744l138.794667-156.074667a42.666667 42.666667 0 0 1 60.202666-3.541333z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-translate': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + xmlns: 'http://www.w3.org/2000/svg', + viewBox: '0 0 20 20', + fill: 'currentColor' + }, + [ + h('path', { + d: 'M7.75 2.75a.75.75 0 0 0-1.5 0v1.258a32.987 32.987 0 0 0-3.599.278.75.75 0 1 0 .198 1.487A31.545 31.545 0 0 1 8.7 5.545 19.381 19.381 0 0 1 7 9.56a19.418 19.418 0 0 1-1.002-2.05.75.75 0 0 0-1.384.577 20.935 20.935 0 0 0 1.492 2.91 19.613 19.613 0 0 1-3.828 4.154.75.75 0 1 0 .945 1.164A21.116 21.116 0 0 0 7 12.331c.095.132.192.262.29.391a.75.75 0 0 0 1.194-.91c-.204-.266-.4-.538-.59-.815a20.888 20.888 0 0 0 2.333-5.332c.31.031.618.068.924.108a.75.75 0 0 0 .198-1.487 32.832 32.832 0 0 0-3.599-.278V2.75Z' + }), + h('path', { + 'fill-rule': 'evenodd', + d: 'M13 8a.75.75 0 0 1 .671.415l4.25 8.5a.75.75 0 1 1-1.342.67L15.787 16h-5.573l-.793 1.585a.75.75 0 1 1-1.342-.67l4.25-8.5A.75.75 0 0 1 13 8Zm2.037 6.5L13 10.427 10.964 14.5h4.073Z', + 'clip-rule': 'evenodd' + }) + ] + ) + ]) + } + }, + 'app-user': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 24 24', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M15 13H9C6.23858 13 3 14.9314 3 18.4V21.1C3 21.597 3.44772 22 4 22H20C20.5523 22 21 21.597 21 21.1V18.4C21 14.9285 17.7614 13 15 13Z', + fill: 'currentColor' + }), + h('path', { + d: 'M7 6.99997C7 9.76139 9.23858 12 12 12C14.7614 12 17 9.76139 17 6.99997C17 4.23855 14.7614 1.99997 12 1.99997C9.23858 1.99997 7 4.23855 7 6.99997Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-question': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 24 24', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M12.7071 22.2009L17 18.5111H21.5C22.0523 18.5111 22.5 18.0539 22.5 17.4899V2.52112C22.5 1.95715 22.0523 1.49997 21.5 1.49997H2C1.44772 1.49997 1 1.95715 1 2.52112V17.4899C1 18.0539 1.44772 18.5111 2 18.5111H7L11.2929 22.2009C11.6834 22.5997 12.3166 22.5997 12.7071 22.2009ZM6.5 8.49997H7.5C8.05228 8.49997 8.5 8.94768 8.5 9.49997V10.5C8.5 11.0523 8.05228 11.5 7.5 11.5H6.5C5.94772 11.5 5.5 11.0523 5.5 10.5V9.49997C5.5 8.94768 5.94772 8.49997 6.5 8.49997ZM10.5 9.49997C10.5 8.94768 10.9477 8.49997 11.5 8.49997H12.5C13.0523 8.49997 13.5 8.94768 13.5 9.49997V10.5C13.5 11.0523 13.0523 11.5 12.5 11.5H11.5C10.9477 11.5 10.5 11.0523 10.5 10.5V9.49997ZM16.5 8.49997H17.5C18.0523 8.49997 18.5 8.94768 18.5 9.49997V10.5C18.5 11.0523 18.0523 11.5 17.5 11.5H16.5C15.9477 11.5 15.5 11.0523 15.5 10.5V9.49997C15.5 8.94768 15.9477 8.49997 16.5 8.49997Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-tokens': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 24 24', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M15.6 2.39996C12.288 2.39996 9.60002 5.08796 9.60002 8.39996C9.60002 9.11996 9.74402 9.79196 9.97202 10.428L2.47325 17.9267C2.42636 17.9736 2.40002 18.0372 2.40002 18.1035V21.1C2.40002 21.3761 2.62388 21.6 2.90002 21.6H4.30002C4.57617 21.6 4.80002 21.3761 4.80002 21.1V20.4H6.70003C6.97617 20.4 7.20002 20.1761 7.20002 19.9V18H8.40002L10.8 15.6H12L13.572 14.028C14.208 14.256 14.88 14.4 15.6 14.4C18.912 14.4 21.6 11.712 21.6 8.39996C21.6 5.08796 18.912 2.39996 15.6 2.39996ZM17.4 8.39996C16.404 8.39996 15.6 7.59596 15.6 6.59996C15.6 5.60396 16.404 4.79996 17.4 4.79996C18.396 4.79996 19.2 5.60396 19.2 6.59996C19.2 7.59596 18.396 8.39996 17.4 8.39996Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-user-stars': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 24 24', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M12 23C18.0751 23 23 18.0751 23 12C23 5.92484 18.0751 0.999969 12 0.999969C5.92487 0.999969 1 5.92484 1 12C1 18.0751 5.92487 23 12 23ZM8.5 10.5C7.67157 10.5 7 9.8284 7 8.99997C7 8.17154 7.67157 7.49997 8.5 7.49997C9.32843 7.49997 10 8.17154 10 8.99997C10 9.8284 9.32843 10.5 8.5 10.5ZM17 8.99997C17 9.8284 16.3284 10.5 15.5 10.5C14.6716 10.5 14 9.8284 14 8.99997C14 8.17154 14.6716 7.49997 15.5 7.49997C16.3284 7.49997 17 8.17154 17 8.99997ZM16.9779 13.4994C16.7521 16.0264 14.8169 18 12 18C9.18312 18 7.24789 16.0264 7.02213 13.4994C6.99756 13.2244 7.22386 13 7.5 13H16.5C16.7761 13 17.0024 13.2244 16.9779 13.4994Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-problems': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M512 896a384 384 0 1 0 0-768 384 384 0 0 0 0 768z m0 85.333333C252.8 981.333333 42.666667 771.2 42.666667 512S252.8 42.666667 512 42.666667s469.333333 210.133333 469.333333 469.333333-210.133333 469.333333-469.333333 469.333333z m-21.333333-298.666666h42.666666a21.333333 21.333333 0 0 1 21.333334 21.333333v42.666667a21.333333 21.333333 0 0 1-21.333334 21.333333h-42.666666a21.333333 21.333333 0 0 1-21.333334-21.333333v-42.666667a21.333333 21.333333 0 0 1 21.333334-21.333333zM343.466667 396.032c0.554667-4.778667 1.109333-8.746667 1.664-11.946667 8.32-46.293333 29.397333-80.341333 63.189333-102.144 26.453333-17.28 59.008-25.941333 97.621333-25.941333 50.730667 0 92.842667 12.288 126.378667 36.864 33.578667 24.533333 50.346667 60.928 50.346667 109.141333 0 29.568-7.253333 54.485333-21.888 74.752-8.533333 12.245333-24.917333 27.946667-49.152 47.061334l-23.893334 18.773333c-13.013333 10.24-21.632 22.186667-25.898666 35.84-1.152 3.712-2.176 10.624-3.072 20.736a21.333333 21.333333 0 0 1-21.248 19.498667h-47.786667a21.333333 21.333333 0 0 1-21.248-23.296c2.773333-29.696 5.717333-48.469333 8.832-56.362667 5.845333-14.677333 20.906667-31.573333 45.141333-50.688l24.533334-19.413333c8.106667-6.144 49.749333-35.456 49.749333-61.44 0-25.941333-4.522667-35.498667-17.578667-49.749334-13.013333-14.208-42.368-18.773333-68.864-18.773333-26.026667 0-48.256 6.869333-59.136 24.405333-5.034667 8.106667-9.173333 16.768-12.117333 25.6a89.472 89.472 0 0 0-3.114667 13.098667 21.333333 21.333333 0 0 1-21.034666 17.706667H364.672a21.333333 21.333333 0 0 1-21.205333-23.722667z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-quxiaoguanlian': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M544 298.688a32 32 0 0 1 32-32h320c41.216 0 74.688 33.408 74.688 74.624V640c0 41.216-33.472 74.688-74.688 74.688h-85.312a32 32 0 1 1 0-64H896a10.688 10.688 0 0 0 10.688-10.688V341.312A10.688 10.688 0 0 0 896 330.688H576a32 32 0 0 1-32-32zM53.312 341.312c0-41.216 33.472-74.624 74.688-74.624h106.688a32 32 0 1 1 0 64H128a10.688 10.688 0 0 0-10.688 10.624V640c0 5.888 4.8 10.688 10.688 10.688h320a32 32 0 1 1 0 64H128A74.688 74.688 0 0 1 53.312 640V341.312zM282.432 100.416a32 32 0 0 1 43.84 11.392l426.624 725.312a32 32 0 0 1-55.168 32.448L271.104 144.256a32 32 0 0 1 11.328-43.84zM650.688 490.688a32 32 0 0 1 32-32H768a32 32 0 1 1 0 64h-85.312a32 32 0 0 1-32-32zM224 490.688a32 32 0 0 1 32-32h85.312a32 32 0 1 1 0 64H256a32 32 0 0 1-32-32z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-migrate': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M896.128 113.792a42.666667 42.666667 0 0 1 42.24 36.864l0.426667 5.802667v711.509333a42.666667 42.666667 0 0 1-36.906667 42.24l-5.76 0.426667h-263.082667a21.333333 21.333333 0 0 1-20.906666-17.066667l-0.426667-4.266667v-42.666666a21.333333 21.333333 0 0 1 17.066667-20.906667l4.266666-0.426667h220.416V199.125333H281.941333l0.042667 192.170667a21.333333 21.333333 0 0 1-21.333333 21.333333h-42.666667a21.333333 21.333333 0 0 1-21.333333-21.333333V135.125333a21.333333 21.333333 0 0 1 17.066666-20.906666l4.266667-0.426667h678.144zM424.96 485.973333c6.272 0 12.373333 2.218667 17.152 6.272l178.858667 151.338667a26.538667 26.538667 0 0 1 0 40.533333l-178.858667 151.381334a26.538667 26.538667 0 0 1-43.690667-20.266667v-103.765333H135.168a21.333333 21.333333 0 0 1-21.333333-21.333334v-42.666666a21.333333 21.333333 0 0 1 21.333333-21.333334H398.506667l-0.042667-113.621333c0-14.677333 11.904-26.538667 26.538667-26.538667z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-chat-record': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 16 16', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M11.3333 7.33334C11.3333 6.96515 11.6318 6.66667 12 6.66667H14.6667C15.0349 6.66667 15.3333 6.96515 15.3333 7.33334V12.6667C15.3333 13.0349 15.0349 13.3333 14.6667 13.3333H13.2761L12.4714 14.1381C12.2111 14.3984 11.7889 14.3984 11.5286 14.1381L10.7239 13.3333H7.33334C6.96515 13.3333 6.66667 13.0349 6.66667 12.6667V10C6.66667 9.63182 6.96515 9.33334 7.33334 9.33334H11.3333V7.33334ZM12.6667 8.00001V10C12.6667 10.3682 12.3682 10.6667 12 10.6667H8.00001V12H11C11.1768 12 11.3464 12.0702 11.4714 12.1953L12 12.7239L12.5286 12.1953C12.6536 12.0702 12.8232 12 13 12H14V8.00001H12.6667Z', + fill: 'currentColor' + }), + h('path', { + d: 'M1.33334 1.33333C0.965149 1.33333 0.666672 1.63181 0.666672 1.99999V10C0.666672 10.3682 0.965149 10.6667 1.33334 10.6667H2.72386L3.86193 11.8047C4.12228 12.0651 4.54439 12.0651 4.80474 11.8047L5.94281 10.6667H12C12.3682 10.6667 12.6667 10.3682 12.6667 10V1.99999C12.6667 1.63181 12.3682 1.33333 12 1.33333H1.33334ZM4.66667 5.99999C4.66667 6.36818 4.36819 6.66666 4.00001 6.66666C3.63182 6.66666 3.33334 6.36818 3.33334 5.99999C3.33334 5.6318 3.63182 5.33333 4.00001 5.33333C4.36819 5.33333 4.66667 5.6318 4.66667 5.99999ZM7.33334 5.99999C7.33334 6.36818 7.03486 6.66666 6.66667 6.66666C6.29848 6.66666 6 6.36818 6 5.99999C6 5.6318 6.29848 5.33333 6.66667 5.33333C7.03486 5.33333 7.33334 5.6318 7.33334 5.99999ZM10 5.99999C10 6.36818 9.70153 6.66666 9.33334 6.66666C8.96515 6.66666 8.66667 6.36818 8.66667 5.99999C8.66667 5.6318 8.96515 5.33333 9.33334 5.33333C9.70153 5.33333 10 5.6318 10 5.99999Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-history-outlined': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 20 20', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M18.6667 10.0001C18.6667 14.6025 14.9358 18.3334 10.3334 18.3334C7.68359 18.3334 5.32266 17.0967 3.79633 15.1689L5.12054 14.1563C6.3421 15.6864 8.22325 16.6667 10.3334 16.6667C14.0153 16.6667 17 13.682 17 10.0001C17 6.31818 14.0153 3.33341 10.3334 3.33341C7.03005 3.33341 4.28786 5.73596 3.75889 8.88897H4.3469C4.70187 8.88897 4.9136 9.28459 4.7167 9.57995L3.32493 11.6676C3.14901 11.9315 2.76125 11.9315 2.58533 11.6676L1.19356 9.57995C0.996651 9.28459 1.20838 8.88897 1.56336 8.88897H2.07347C2.61669 4.8119 6.10774 1.66675 10.3334 1.66675C14.9358 1.66675 18.6667 5.39771 18.6667 10.0001Z', + fill: 'currentColor' + }), + h('path', { + d: 'M10.8334 9.7223V7.11119C10.8334 6.86573 10.6344 6.66675 10.3889 6.66675H9.61115C9.36569 6.66675 9.16671 6.86573 9.16671 7.11119V10.9445C9.16671 11.19 9.36569 11.389 9.61115 11.389H13.1667C13.4122 11.389 13.6112 11.19 13.6112 10.9445V10.1667C13.6112 9.92129 13.4122 9.7223 13.1667 9.7223H10.8334Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-document-refresh': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M512 170.666667a85.333333 85.333333 0 0 1 85.333333-85.333334h256a85.333333 85.333333 0 0 1 85.333334 85.333334v256a85.333333 85.333333 0 0 1-85.333334 85.333333h-256a85.333333 85.333333 0 0 1-85.333333-85.333333V170.666667z m85.333333 0v256h256V170.666667h-256zM85.333333 597.333333a85.333333 85.333333 0 0 1 85.333334-85.333333h256a85.333333 85.333333 0 0 1 85.333333 85.333333v256a85.333333 85.333333 0 0 1-85.333333 85.333334H170.666667a85.333333 85.333333 0 0 1-85.333334-85.333334v-256z m85.333334 0v256h256v-256H170.666667zM128 298.666667a213.333333 213.333333 0 0 1 213.333333-213.333334h85.333334v85.333334H341.333333a128 128 0 0 0-128 128h57.514667a12.8 12.8 0 0 1 9.728 21.12l-100.181333 116.906666a12.8 12.8 0 0 1-19.456 0l-100.181334-116.906666A12.8 12.8 0 0 1 70.485333 298.666667H128zM896 725.333333a213.333333 213.333333 0 0 1-213.333333 213.333334h-85.333334v-85.333334h85.333334a128 128 0 0 0 128-128v-21.333333h-57.514667a12.8 12.8 0 0 1-9.728-21.12l100.181333-116.906667a12.8 12.8 0 0 1 19.456 0l100.181334 116.906667a12.8 12.8 0 0 1-9.728 21.12H896v21.333333z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-export': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M791.04 554.24l-386.432-1.728a21.248 21.248 0 0 1-21.12-21.248L383.36 490.88c-0.064-11.776 9.408-21.376 21.12-21.44h0.192l394.112 1.728-97.664-98.112a21.44 21.44 0 0 1 0-30.208l30.08-30.144a21.12 21.12 0 0 1 29.952 0l165.12 165.952a42.88 42.88 0 0 1 0 60.288l-165.12 165.952a21.12 21.12 0 0 1-30.016 0l-30.016-30.144a21.44 21.44 0 0 1 0-30.208L791.04 554.24z m-132.672-383.552H170.24v682.624h488.128c11.712 0 21.184 9.6 21.184 21.376v42.624a21.248 21.248 0 0 1-21.248 21.376h-530.56A42.56 42.56 0 0 1 85.376 896V128c0-23.552 19.008-42.688 42.496-42.688h530.56c11.712 0 21.184 9.6 21.184 21.376v42.624a21.248 21.248 0 0 1-21.248 21.376z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-fitview': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M128 85.333333h192a21.333333 21.333333 0 0 1 21.333333 21.333334v42.666666a21.333333 21.333333 0 0 1-21.333333 21.333334H170.666667v149.333333a21.333333 21.333333 0 0 1-21.333334 21.333333h-42.666666a21.333333 21.333333 0 0 1-21.333334-21.333333V128a42.666667 42.666667 0 0 1 42.666667-42.666667z m768 853.333334h-192a21.333333 21.333333 0 0 1-21.333333-21.333334v-42.666666a21.333333 21.333333 0 0 1 21.333333-21.333334H853.333333v-149.333333a21.333333 21.333333 0 0 1 21.333334-21.333333h42.666666a21.333333 21.333333 0 0 1 21.333334 21.333333V896a42.666667 42.666667 0 0 1-42.666667 42.666667zM85.333333 896v-192a21.333333 21.333333 0 0 1 21.333334-21.333333h42.666666a21.333333 21.333333 0 0 1 21.333334 21.333333V853.333333h149.333333a21.333333 21.333333 0 0 1 21.333333 21.333334v42.666666a21.333333 21.333333 0 0 1-21.333333 21.333334H128a42.666667 42.666667 0 0 1-42.666667-42.666667zM938.666667 128v192a21.333333 21.333333 0 0 1-21.333334 21.333333h-42.666666a21.333333 21.333333 0 0 1-21.333334-21.333333V170.666667h-149.333333a21.333333 21.333333 0 0 1-21.333333-21.333334v-42.666666a21.333333 21.333333 0 0 1 21.333333-21.333334H896a42.666667 42.666667 0 0 1 42.666667 42.666667z', + fill: 'currentColor' + }), + h('path', { + d: 'M512 512m-170.666667 0a170.666667 170.666667 0 1 0 341.333334 0 170.666667 170.666667 0 1 0-341.333334 0Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-magnify': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M366.165333 593.749333a21.333333 21.333333 0 0 1 30.208 0l30.165334 30.165334a21.333333 21.333333 0 0 1 0 30.208l-170.752 170.666666H377.173333a21.333333 21.333333 0 0 1 21.333334 21.333334v42.666666a21.333333 21.333333 0 0 1-21.333334 21.333334H156.458667a42.538667 42.538667 0 0 1-42.666667-42.666667v-220.16a21.333333 21.333333 0 0 1 21.333333-21.333333h42.666667a21.333333 21.333333 0 0 1 21.333333 21.333333v113.493333l167.04-167.04z m500.992-480a42.538667 42.538667 0 0 1 42.666667 42.666667v220.16a21.333333 21.333333 0 0 1-21.333333 21.333333h-42.666667a21.333333 21.333333 0 0 1-21.333333-21.333333v-113.493333l-167.04 167.04a21.333333 21.333333 0 0 1-30.165334 0l-30.165333-30.165334a21.333333 21.333333 0 0 1 0-30.165333l170.709333-170.666667h-121.344a21.333333 21.333333 0 0 1-21.333333-21.333333v-42.666667a21.333333 21.333333 0 0 1 21.333333-21.333333h220.672z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-minify': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M384.341333 597.205333a42.538667 42.538667 0 0 1 42.666667 42.666667v220.16a21.333333 21.333333 0 0 1-21.333333 21.333333h-42.666667a21.333333 21.333333 0 0 1-21.333333-21.333333v-113.493333l-167.04 167.04a21.333333 21.333333 0 0 1-30.165334 0l-30.165333-30.208a21.333333 21.333333 0 0 1 0-30.165334l170.709333-170.666666H163.669333a21.333333 21.333333 0 0 1-21.333333-21.333334v-42.666666a21.333333 21.333333 0 0 1 21.333333-21.333334h220.672zM849.92 110.506667a21.333333 21.333333 0 0 1 30.165333 0l30.165334 30.165333a21.333333 21.333333 0 0 1 0 30.165333l-170.709334 170.666667h121.344a21.333333 21.333333 0 0 1 21.333334 21.333333v42.666667a21.333333 21.333333 0 0 1-21.333334 21.333333h-220.672a42.538667 42.538667 0 0 1-42.666666-42.666666v-220.16a21.333333 21.333333 0 0 1 21.333333-21.333334h42.666667a21.333333 21.333333 0 0 1 21.333333 21.333334v113.493333l167.04-166.997333z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-play-outlined': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 14 14', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M2.63333 1.82346C2.81847 1.72056 3.04484 1.72611 3.22472 1.83795L10.8081 6.55299C10.9793 6.65945 11.0834 6.84677 11.0834 7.04838C11.0834 7.24999 10.9793 7.43731 10.8081 7.54376L3.22472 12.2588C3.04484 12.3707 2.81847 12.3762 2.63333 12.2733C2.44819 12.1704 2.33337 11.9752 2.33337 11.7634V2.33333C2.33337 2.12152 2.44819 1.92635 2.63333 1.82346ZM3.50004 3.38293V10.7138L9.39529 7.04838L3.50004 3.38293Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-save-outlined': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 14 14', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M1.16666 2.53734C1.16666 1.78025 1.7804 1.1665 2.53749 1.1665H11.4625C12.2196 1.1665 12.8333 1.78025 12.8333 2.53734V11.4623C12.8333 12.2194 12.2196 12.8332 11.4625 12.8332H2.53749C1.7804 12.8332 1.16666 12.2194 1.16666 11.4623V2.53734ZM2.53749 2.33317C2.42473 2.33317 2.33332 2.42458 2.33332 2.53734V11.4623C2.33332 11.5751 2.42473 11.6665 2.53749 11.6665H11.4625C11.5753 11.6665 11.6667 11.5751 11.6667 11.4623V2.53734C11.6667 2.42457 11.5753 2.33317 11.4625 2.33317H2.53749Z', + fill: 'currentColor' + }), + h('path', { + d: 'M3.79166 1.74984C3.79166 1.42767 4.05282 1.1665 4.37499 1.1665H9.33332C9.65549 1.1665 9.91666 1.42767 9.91666 1.74984V6.99984C9.91666 7.322 9.65549 7.58317 9.33332 7.58317H4.37499C4.05282 7.58317 3.79166 7.322 3.79166 6.99984V1.74984ZM4.95832 2.33317V6.4165H8.74999V2.33317H4.95832Z', + fill: 'currentColor' + }), + h('path', { + d: 'M7.58333 3.2085C7.9055 3.2085 8.16667 3.46966 8.16667 3.79183V4.9585C8.16667 5.28066 7.9055 5.54183 7.58333 5.54183C7.26117 5.54183 7 5.28066 7 4.9585V3.79183C7 3.46966 7.26117 3.2085 7.58333 3.2085Z', + fill: 'currentColor' + }), + h('path', { + d: 'M2.62415 1.74984C2.62415 1.42767 2.88531 1.1665 3.20748 1.1665H10.4996C10.8217 1.1665 11.0829 1.42767 11.0829 1.74984C11.0829 2.072 10.8217 2.33317 10.4996 2.33317H3.20748C2.88531 2.33317 2.62415 2.072 2.62415 1.74984Z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-reference-outlined': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M121.216 714.368c-7.082667-17.493333-7.466667-83.413333-7.424-104.32 0.341333-142.72 34.048-256.426667 88.32-330.112C262.4 198.229333 351.701333 161.024 460.8 172.8c7.893333 0.853333 11.946667 7.338667 10.581333 16.981333l-7.381333 51.285334c-1.749333 12.202667-9.813333 12.885333-17.621333 12.202666-138.709333-11.946667-232.576 84.053333-245.76 296.704a165.632 165.632 0 0 1 83.754666-22.528c91.050667 0 164.906667 72.96 164.906667 162.944C449.28 780.373333 375.466667 853.333333 284.373333 853.333333c-82.858667 0-151.424-60.330667-163.157333-138.965333z m438.570667 0c-7.082667-17.493333-7.509333-83.413333-7.466667-104.32 0.426667-142.72 34.090667-256.426667 88.405333-330.112 60.202667-81.706667 149.504-118.912 258.645334-107.136 7.893333 0.853333 11.946667 7.338667 10.581333 16.981333l-7.381333 51.285334c-1.749333 12.202667-9.813333 12.885333-17.621334 12.202666-138.752-11.946667-232.576 84.053333-245.76 296.704a165.632 165.632 0 0 1 83.712-22.528c91.093333 0 164.906667 72.96 164.906667 162.944 0 90.026667-73.813333 162.944-164.906667 162.944-82.773333 0-151.381333-60.330667-163.114666-138.965333z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-access': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M490.368 48.554667a42.666667 42.666667 0 0 1 43.264 0l362.666667 213.333333A42.666667 42.666667 0 0 1 917.333333 298.666667v426.666666a42.666667 42.666667 0 0 1-21.034666 36.778667l-362.666667 213.333333a42.666667 42.666667 0 0 1-43.264 0l-362.666667-213.333333A42.666667 42.666667 0 0 1 106.666667 725.333333V298.666667a42.666667 42.666667 0 0 1 21.034666-36.778667l362.666667-213.333333zM192 323.072v377.856L512 889.173333l320-188.245333V323.072L512 134.826667 192 323.072z', + fill: 'currentColor' + }), + h('path', { + d: 'M705.194667 441.472a42.666667 42.666667 0 1 0-45.226667-72.362667l-148.096 92.586667L363.946667 369.066667a42.666667 42.666667 0 1 0-45.312 72.362666L469.333333 535.722667V704a42.666667 42.666667 0 1 0 85.333334 0v-168.448l150.528-94.08z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-access-active': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M533.632 48.554667a42.666667 42.666667 0 0 0-43.264 0l-362.666667 213.333333A42.666667 42.666667 0 0 0 106.666667 298.666667v426.666666a42.666667 42.666667 0 0 0 21.034666 36.778667l362.666667 213.333333a42.666667 42.666667 0 0 0 43.264 0l362.666667-213.333333A42.666667 42.666667 0 0 0 917.333333 725.333333V298.666667a42.666667 42.666667 0 0 0-21.034666-36.778667l-362.666667-213.333333z m185.130667 334.08a42.666667 42.666667 0 0 1-13.568 58.837333L554.666667 535.552V704a42.666667 42.666667 0 1 1-85.333334 0v-168.277333l-150.613333-94.293334a42.666667 42.666667 0 0 1 45.226667-72.32l147.925333 92.586667 148.053333-92.586667a42.666667 42.666667 0 0 1 58.837334 13.568z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-video-play': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M512 896a384 384 0 1 0 0-768 384 384 0 0 0 0 768z m469.333333-384c0 259.2-210.133333 469.333333-469.333333 469.333333S42.666667 771.2 42.666667 512 252.8 42.666667 512 42.666667s469.333333 210.133333 469.333333 469.333333z', + fill: 'currentColor' + }), + h('path', { + d: 'M686.890667 539.776l-253.141334 159.274667a32.298667 32.298667 0 0 1-44.8-10.453334 32.896 32.896 0 0 1-4.949333-17.322666V352.768a32.64 32.64 0 0 1 32.512-32.768c6.101333 0 12.074667 1.706667 17.28 4.992l253.098667 159.232a32.853333 32.853333 0 0 1 0 55.552z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-video-stop': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M981.333333 512c0 259.2-210.133333 469.333333-469.333333 469.333333S42.666667 771.2 42.666667 512 252.8 42.666667 512 42.666667s469.333333 210.133333 469.333333 469.333333z m-85.333333 0a384 384 0 1 0-768 0 384 384 0 0 0 768 0zM384 341.333333h256c23.466667 0 42.666667 19.072 42.666667 42.666667v256c0 23.552-19.2 42.666667-42.666667 42.666667H384c-23.466667 0-42.666667-19.114667-42.666667-42.666667V384c0-23.594667 19.2-42.666667 42.666667-42.666667z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-video-pause': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M405.333333 341.333333a21.333333 21.333333 0 0 0-21.333333 21.333334v298.666666a21.333333 21.333333 0 0 0 21.333333 21.333334h42.666667a21.333333 21.333333 0 0 0 21.333333-21.333334v-298.666666a21.333333 21.333333 0 0 0-21.333333-21.333334h-42.666667zM576 341.333333a21.333333 21.333333 0 0 0-21.333333 21.333334v298.666666a21.333333 21.333333 0 0 0 21.333333 21.333334h42.666667a21.333333 21.333333 0 0 0 21.333333-21.333334v-298.666666a21.333333 21.333333 0 0 0-21.333333-21.333334h-42.666667z', + fill: 'currentColor' + }), + h('path', { + d: 'M512 42.666667C252.8 42.666667 42.666667 252.8 42.666667 512s210.133333 469.333333 469.333333 469.333333 469.333333-210.133333 469.333333-469.333333S771.2 42.666667 512 42.666667zM128 512a384 384 0 1 1 768 0 384 384 0 0 1-768 0z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + 'app-invisible': { + iconReader: () => { + return h('i', [ + h( + 'svg', + { + style: { height: '100%', width: '100%' }, + viewBox: '0 0 1024 1024', + version: '1.1', + xmlns: 'http://www.w3.org/2000/svg' + }, + [ + h('path', { + d: 'M512 640c-28.032 0-55.466667-2.218667-82.090667-6.4l-21.248 79.274667a21.333333 21.333333 0 0 1-26.154666 15.061333L341.333333 716.885333a21.333333 21.333333 0 0 1-15.061333-26.112l20.821333-77.653333a473.770667 473.770667 0 0 1-97.152-45.653333l-67.84 67.84a21.333333 21.333333 0 0 1-30.122666 0l-30.165334-30.208a21.333333 21.333333 0 0 1 0-30.165334l59.733334-59.733333A386.389333 386.389333 0 0 1 104.789333 416.426667a37.76 37.76 0 0 1 7.594667-45.397334c10.496-9.514667 17.877333-16 24.32-22.442666a170.24 170.24 0 0 0 1.834667-1.92c9.301333-9.6 25.173333-6.016 30.634666 6.186666C222.336 471.936 349.568 554.666667 512 554.666667c155.648 0 285.866667-80.512 338.090667-190.976 1.365333-2.858667 2.901333-6.485333 4.437333-10.325334a18.346667 18.346667 0 0 1 29.866667-6.613333l27.392 27.434667a36.565333 36.565333 0 0 1 6.997333 42.666666c-1.792 3.456-3.541333 6.698667-5.034667 9.301334a390.4 390.4 0 0 1-76.928 94.293333l54.442667 54.485333a21.333333 21.333333 0 0 1 0 30.165334l-30.165333 30.165333a21.333333 21.333333 0 0 1-30.165334 0l-63.658666-63.658667a475.306667 475.306667 0 0 1-90.282667 41.514667l20.778667 77.653333a21.333333 21.333333 0 0 1-15.061334 26.112l-41.216 11.093334a21.333333 21.333333 0 0 1-26.154666-15.104l-21.248-79.317334c-26.581333 4.266667-54.058667 6.442667-82.090667 6.442667z', + fill: 'currentColor' + }) + ] + ) + ]) + } + }, + // 'app-history-outlined': { + // iconReader: () => { + // return h('i', [ + // h( + // 'svg', + // { + // style: { height: '100%', width: '100%' }, + // viewBox: '0 0 1024 1024', + // version: '1.1', + // xmlns: 'http://www.w3.org/2000/svg' + // }, + // [ + // h('path', { + // d: 'M955.733333 512c0 235.648-191.018667 426.666667-426.666666 426.666667a425.898667 425.898667 0 0 1-334.677334-162.005334l67.797334-51.84a341.333333 341.333333 0 1 0-69.717334-269.653333h30.08c18.176-0.042667 29.013333 20.181333 18.944 35.328L170.24 597.333333a22.741333 22.741333 0 0 1-37.888 0l-71.253333-106.88a22.741333 22.741333 0 0 1 18.944-35.413333h26.112C133.973333 246.4 312.746667 85.333333 529.066667 85.333333c235.648 0 426.666667 191.018667 426.666666 426.666667z" p-id="16742">
+ +
+
+ + 加载中... + + + 到底啦! + +
+ + + diff --git a/src/MaxKB-1.7.2/ui/src/components/layout-container/index.vue b/src/MaxKB-1.7.2/ui/src/components/layout-container/index.vue new file mode 100644 index 0000000..ec5d8c3 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/layout-container/index.vue @@ -0,0 +1,48 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/loading/DownloadLoading.vue b/src/MaxKB-1.7.2/ui/src/components/loading/DownloadLoading.vue new file mode 100644 index 0000000..83332c8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/loading/DownloadLoading.vue @@ -0,0 +1,93 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/login-container/index.vue b/src/MaxKB-1.7.2/ui/src/components/login-container/index.vue new file mode 100644 index 0000000..0714533 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/login-container/index.vue @@ -0,0 +1,38 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/login-layout/index.vue b/src/MaxKB-1.7.2/ui/src/components/login-layout/index.vue new file mode 100644 index 0000000..9d9b87d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/login-layout/index.vue @@ -0,0 +1,56 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/logo/LogoFull.vue b/src/MaxKB-1.7.2/ui/src/components/logo/LogoFull.vue new file mode 100644 index 0000000..0c67339 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/logo/LogoFull.vue @@ -0,0 +1,95 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/logo/LogoIcon.vue b/src/MaxKB-1.7.2/ui/src/components/logo/LogoIcon.vue new file mode 100644 index 0000000..51d47db --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/logo/LogoIcon.vue @@ -0,0 +1,59 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/logo/SendIcon.vue b/src/MaxKB-1.7.2/ui/src/components/logo/SendIcon.vue new file mode 100644 index 0000000..933976c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/logo/SendIcon.vue @@ -0,0 +1,44 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/markdown/EchartsRander.vue b/src/MaxKB-1.7.2/ui/src/components/markdown/EchartsRander.vue new file mode 100644 index 0000000..6e3d2e8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/markdown/EchartsRander.vue @@ -0,0 +1,124 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/markdown/HtmlRander.vue b/src/MaxKB-1.7.2/ui/src/components/markdown/HtmlRander.vue new file mode 100644 index 0000000..a8be059 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/markdown/HtmlRander.vue @@ -0,0 +1,36 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/markdown/MdEditor.vue b/src/MaxKB-1.7.2/ui/src/components/markdown/MdEditor.vue new file mode 100644 index 0000000..ec1c3ae --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/markdown/MdEditor.vue @@ -0,0 +1,14 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/markdown/MdEditorMagnify.vue b/src/MaxKB-1.7.2/ui/src/components/markdown/MdEditorMagnify.vue new file mode 100644 index 0000000..cf12d42 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/markdown/MdEditorMagnify.vue @@ -0,0 +1,63 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/markdown/MdPreview.vue b/src/MaxKB-1.7.2/ui/src/components/markdown/MdPreview.vue new file mode 100644 index 0000000..7632a5a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/markdown/MdPreview.vue @@ -0,0 +1,8 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/markdown/MdRenderer.vue b/src/MaxKB-1.7.2/ui/src/components/markdown/MdRenderer.vue new file mode 100644 index 0000000..517a3bc --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/markdown/MdRenderer.vue @@ -0,0 +1,197 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/markdown/assets/markdown-iconfont.js b/src/MaxKB-1.7.2/ui/src/components/markdown/assets/markdown-iconfont.js new file mode 100644 index 0000000..6b8505f --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/markdown/assets/markdown-iconfont.js @@ -0,0 +1 @@ +window._iconfont_svg_string_2605852='',function(l){var a=(a=document.getElementsByTagName("script"))[a.length-1],c=a.getAttribute("data-injectcss"),a=a.getAttribute("data-disable-injectsvg");if(!a){var o,t,i,e,h,d=function(a,c){c.parentNode.insertBefore(a,c)};if(c&&!l.__iconfont__svg__cssinject__){l.__iconfont__svg__cssinject__=!0;try{document.write("")}catch(a){console&&console.log(a)}}o=function(){var a,c=document.createElement("div");c.innerHTML=l._iconfont_svg_string_2605852,(c=c.getElementsByTagName("svg")[0])&&(c.setAttribute("aria-hidden","true"),c.style.position="absolute",c.style.width=0,c.style.height=0,c.style.overflow="hidden",c=c,(a=document.body).firstChild?d(c,a.firstChild):a.appendChild(c))},document.addEventListener?~["complete","loaded","interactive"].indexOf(document.readyState)?setTimeout(o,0):(t=function(){document.removeEventListener("DOMContentLoaded",t,!1),o()},document.addEventListener("DOMContentLoaded",t,!1)):document.attachEvent&&(i=o,e=l.document,h=!1,v(),e.onreadystatechange=function(){"complete"==e.readyState&&(e.onreadystatechange=null,m())})}function m(){h||(h=!0,i())}function v(){try{e.documentElement.doScroll("left")}catch(a){return void setTimeout(v,50)}m()}}(window); \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/components/read-write/index.vue b/src/MaxKB-1.7.2/ui/src/components/read-write/index.vue new file mode 100644 index 0000000..8a13186 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/read-write/index.vue @@ -0,0 +1,126 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/tag-ellipsis/index.vue b/src/MaxKB-1.7.2/ui/src/components/tag-ellipsis/index.vue new file mode 100644 index 0000000..aa08d86 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/tag-ellipsis/index.vue @@ -0,0 +1,36 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/components/tags-input/index.vue b/src/MaxKB-1.7.2/ui/src/components/tags-input/index.vue new file mode 100644 index 0000000..156827c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/components/tags-input/index.vue @@ -0,0 +1,116 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/directives/clickoutside.ts b/src/MaxKB-1.7.2/ui/src/directives/clickoutside.ts new file mode 100644 index 0000000..0e93cf7 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/directives/clickoutside.ts @@ -0,0 +1,7 @@ +import type { App } from 'vue' +import { ClickOutside as vClickOutside } from 'element-plus' +export default { + install: (app: App) => { + app.directive('click-outside', vClickOutside) + } +} diff --git a/src/MaxKB-1.7.2/ui/src/directives/hasPermission.ts b/src/MaxKB-1.7.2/ui/src/directives/hasPermission.ts new file mode 100644 index 0000000..69c9a47 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/directives/hasPermission.ts @@ -0,0 +1,27 @@ +import type { App } from 'vue' +import { hasPermission } from '@/utils/permission' + +const display = async (el: any, binding: any) => { + const has = hasPermission( + binding.value?.permission || binding.value, + binding.value?.compare || 'OR' + ) + if (!has) { + el.style.display = 'none' + } else { + delete el.style.display + } +} + +export default { + install: (app: App) => { + app.directive('hasPermission', { + async created(el: any, binding: any) { + display(el, binding) + }, + async beforeUpdate(el: any, binding: any) { + display(el, binding) + } + }) + } +} diff --git a/src/MaxKB-1.7.2/ui/src/directives/index.ts b/src/MaxKB-1.7.2/ui/src/directives/index.ts new file mode 100644 index 0000000..de8ee80 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/directives/index.ts @@ -0,0 +1,14 @@ +import type { App } from 'vue' + +const directives = import.meta.glob('./*.ts', { eager: true }) +const install = (app: App) => { + Object.keys(directives) + .filter((key: string) => { + return !key.endsWith('index.ts') + }) + .forEach((key: string) => { + const directive: any = directives[key] + app.use(directive.default) + }) +} +export default { install } diff --git a/src/MaxKB-1.7.2/ui/src/directives/infiniteScrollUp.ts b/src/MaxKB-1.7.2/ui/src/directives/infiniteScrollUp.ts new file mode 100644 index 0000000..c5e2410 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/directives/infiniteScrollUp.ts @@ -0,0 +1,151 @@ +import { nextTick } from 'vue' + +import { throttle } from 'lodash-unified' +import { getScrollContainer } from 'element-plus/es/utils/index' +import type { App } from 'vue' +export const SCOPE = 'InfiniteScrollUP' +export const CHECK_INTERVAL = 50 +export const DEFAULT_DELAY = 200 +export const DEFAULT_DISTANCE = 0 + +const attributes = { + delay: { + type: Number, + default: DEFAULT_DELAY + }, + distance: { + type: Number, + default: DEFAULT_DISTANCE + }, + disabled: { + type: Boolean, + default: false + }, + immediate: { + type: Boolean, + default: true + } +} + +type Attrs = typeof attributes +type ScrollOptions = { [K in keyof Attrs]: Attrs[K]['default'] } +type InfiniteScrollCallback = () => void +type InfiniteScrollEl = HTMLElement & { + [SCOPE]: { + container: HTMLElement | Window + containerEl: HTMLElement + instance: any + delay: number // export for test + lastScrollTop: number + cb: InfiniteScrollCallback + onScroll: () => void + observer?: MutationObserver + } +} + +const getScrollOptions = (el: HTMLElement, instance: any): ScrollOptions => { + return Object.entries(attributes).reduce((acm: any, [name, option]) => { + const { type, default: defaultValue } = option + const attrVal: any = el.getAttribute(`infinite-scroll-up-${name}`) + let value = instance[attrVal] ?? attrVal ?? defaultValue + value = value === 'false' ? false : value + value = type(value) + acm[name] = Number.isNaN(value) ? defaultValue : value + return acm + }, {} as ScrollOptions) +} + +const destroyObserver = (el: InfiniteScrollEl) => { + const { observer } = el[SCOPE] + + if (observer) { + observer.disconnect() + delete el[SCOPE].observer + } +} + +const handleScroll = (el: InfiniteScrollEl, cb: InfiniteScrollCallback) => { + const { container, containerEl, instance, observer, lastScrollTop } = el[SCOPE] + const { disabled } = getScrollOptions(el, instance) + const { scrollTop } = containerEl + + el[SCOPE].lastScrollTop = scrollTop + + // trigger only if full check has done and not disabled and scroll down + + if (observer || disabled || scrollTop > 0) return + + if (scrollTop == 0) { + cb.call(instance) + } +} + +function checkFull(el: InfiniteScrollEl, cb: InfiniteScrollCallback) { + const { containerEl, instance } = el[SCOPE] + const { disabled } = getScrollOptions(el, instance) + + if (disabled || containerEl.clientHeight == 0) return + + if (containerEl.scrollTop <= 0) { + cb.call(instance) + } else { + destroyObserver(el) + } +} + +const InfiniteScroll = { + async mounted(el: any, binding: any) { + const { instance, value: cb } = binding + + // ensure parentNode mounted + await nextTick() + + const { delay, immediate } = getScrollOptions(el, instance) + const container = getScrollContainer(el, true) + const containerEl = container === window ? document.documentElement : (container as HTMLElement) + const onScroll = throttle(handleScroll.bind(null, el, cb), delay) + + if (!container) return + + el[SCOPE] = { + instance, + container, + containerEl, + delay, + cb, + onScroll, + lastScrollTop: containerEl.scrollTop + } + + if (immediate) { + const observer = new MutationObserver(throttle(checkFull.bind(null, el, cb), CHECK_INTERVAL)) + el[SCOPE].observer = observer + observer.observe(el, { childList: true, subtree: true }) + checkFull(el, cb) + } + + container.addEventListener('scroll', onScroll) + }, + unmounted(el: any) { + if (!el[SCOPE]) return + const { container, onScroll } = el[SCOPE] + + container?.removeEventListener('scroll', onScroll) + destroyObserver(el) + }, + async updated(el: any) { + if (!el[SCOPE]) { + await nextTick() + } else { + const { containerEl, cb, observer } = el[SCOPE] + if (containerEl.clientHeight && observer) { + checkFull(el, cb) + } + } + } +} +export default { + install: (app: App) => { + app.directive('infinite-scroll-up', InfiniteScroll) + } +} diff --git a/src/MaxKB-1.7.2/ui/src/directives/resize.ts b/src/MaxKB-1.7.2/ui/src/directives/resize.ts new file mode 100644 index 0000000..255b94b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/directives/resize.ts @@ -0,0 +1,31 @@ +import type { App } from 'vue' +export default { + install: (app: App) => { + app.directive('resize', { + created(el: any, binding: any) { + // 记录长宽 + let width = '' + let height = '' + function getSize() { + const style = (document.defaultView as any).getComputedStyle(el) + // 如果当前长宽和历史长宽不同 + if (width !== style.width || height !== style.height) { + // binding.value在这里就是下面的resizeChart函数 + + binding.value({ + width: parseFloat(style.width), + height: parseFloat(style.height) + }) + } + width = style.width + height = style.height + } + + ;(el as any).__vueDomResize__ = setInterval(getSize, 500) + }, + unmounted(el: any, binding: any) { + clearInterval((el as any).__vueDomResize__) + } + }) + } +} diff --git a/src/MaxKB-1.7.2/ui/src/enums/application.ts b/src/MaxKB-1.7.2/ui/src/enums/application.ts new file mode 100644 index 0000000..ee0f6d4 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/enums/application.ts @@ -0,0 +1,5 @@ +export enum SearchMode { + embedding = '向量检索', + keywords = '全文检索', + blend = '混合检索' +} diff --git a/src/MaxKB-1.7.2/ui/src/enums/common.ts b/src/MaxKB-1.7.2/ui/src/enums/common.ts new file mode 100644 index 0000000..3afd8fc --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/enums/common.ts @@ -0,0 +1,16 @@ +export enum DeviceType { + Mobile = 'Mobile', + Desktop = 'Desktop' +} + +export enum ValidType { + Application = 'application', + Dataset = 'dataset', + User = 'user' +} + +export enum ValidCount { + Application = 5, + Dataset = 50, + User = 2 +} diff --git a/src/MaxKB-1.7.2/ui/src/enums/document.ts b/src/MaxKB-1.7.2/ui/src/enums/document.ts new file mode 100644 index 0000000..f3a7d24 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/enums/document.ts @@ -0,0 +1,16 @@ +export enum hitHandlingMethod { + optimization = '模型优化', + directly_return = '直接回答' +} + +export enum hitStatus { + waiting = '等待中', + processing = '处理中', + completed = '已完成', + failed = '失败' +} + +export enum isActivated { + true = '启用', + false = '禁用' +} diff --git a/src/MaxKB-1.7.2/ui/src/enums/model.ts b/src/MaxKB-1.7.2/ui/src/enums/model.ts new file mode 100644 index 0000000..9219f45 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/enums/model.ts @@ -0,0 +1,16 @@ +export enum PermissionType { + PRIVATE = '私有', + PUBLIC = '公用' +} +export enum PermissionDesc { + PRIVATE = '仅当前用户使用', + PUBLIC = '所有用户都可使用,不能编辑' +} + +export enum modelType { + EMBEDDING = '向量模型', + LLM = '大语言模型', + STT = '语音识别', + TTS = '语音合成', + RERANKER = '重排模型' +} diff --git a/src/MaxKB-1.7.2/ui/src/enums/team.ts b/src/MaxKB-1.7.2/ui/src/enums/team.ts new file mode 100644 index 0000000..4e72f36 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/enums/team.ts @@ -0,0 +1,6 @@ +export enum TeamEnum { + MANAGE = 'MANAGE', + USE = 'USE', + DATASET = 'DATASET', + APPLICATION = 'APPLICATION' +} diff --git a/src/MaxKB-1.7.2/ui/src/enums/workflow.ts b/src/MaxKB-1.7.2/ui/src/enums/workflow.ts new file mode 100644 index 0000000..be2571a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/enums/workflow.ts @@ -0,0 +1,12 @@ +export enum WorkflowType { + Base = 'base-node', + Start = 'start-node', + AiChat = 'ai-chat-node', + SearchDataset = 'search-dataset-node', + Question = 'question-node', + Condition = 'condition-node', + Reply = 'reply-node', + FunctionLib = 'function-lib-node', + FunctionLibCustom = 'function-node', + RrerankerNode = 'reranker-node' +} diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/app-header/index.vue b/src/MaxKB-1.7.2/ui/src/layout/components/app-header/index.vue new file mode 100644 index 0000000..dae3957 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/app-header/index.vue @@ -0,0 +1,34 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/app-main/index.vue b/src/MaxKB-1.7.2/ui/src/layout/components/app-main/index.vue new file mode 100644 index 0000000..1598772 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/app-main/index.vue @@ -0,0 +1,24 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/breadcrumb/index.vue b/src/MaxKB-1.7.2/ui/src/layout/components/breadcrumb/index.vue new file mode 100644 index 0000000..c380d37 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/breadcrumb/index.vue @@ -0,0 +1,247 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/index.ts b/src/MaxKB-1.7.2/ui/src/layout/components/index.ts new file mode 100644 index 0000000..c6d38f7 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/index.ts @@ -0,0 +1,4 @@ +export { default as Sidebar } from './sidebar/index.vue' +export { default as AppMain } from './app-main/index.vue' +export { default as TopBar } from './top-bar/index.vue' +export { default as AppHeader } from './app-header/index.vue' diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/sidebar/SidebarItem.vue b/src/MaxKB-1.7.2/ui/src/layout/components/sidebar/SidebarItem.vue new file mode 100644 index 0000000..8601d63 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/sidebar/SidebarItem.vue @@ -0,0 +1,109 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/sidebar/index.vue b/src/MaxKB-1.7.2/ui/src/layout/components/sidebar/index.vue new file mode 100644 index 0000000..4fc4523 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/sidebar/index.vue @@ -0,0 +1,56 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/APIKeyDialog.vue b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/APIKeyDialog.vue new file mode 100644 index 0000000..39d6797 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/APIKeyDialog.vue @@ -0,0 +1,185 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/AboutDialog.vue b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/AboutDialog.vue new file mode 100644 index 0000000..31cfa27 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/AboutDialog.vue @@ -0,0 +1,149 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/ResetPassword.vue b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/ResetPassword.vue new file mode 100644 index 0000000..02e332b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/ResetPassword.vue @@ -0,0 +1,218 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/SettingAPIKeyDialog.vue b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/SettingAPIKeyDialog.vue new file mode 100644 index 0000000..e7d95a8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/SettingAPIKeyDialog.vue @@ -0,0 +1,106 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/index.vue b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/index.vue new file mode 100644 index 0000000..dd7b65e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/avatar/index.vue @@ -0,0 +1,99 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/index.vue b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/index.vue new file mode 100644 index 0000000..d18d1cc --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/index.vue @@ -0,0 +1,118 @@ +· + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/top-menu/MenuItem.vue b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/top-menu/MenuItem.vue new file mode 100644 index 0000000..8e63664 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/top-menu/MenuItem.vue @@ -0,0 +1,63 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/top-menu/index.vue b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/top-menu/index.vue new file mode 100644 index 0000000..b8b7390 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/components/top-bar/top-menu/index.vue @@ -0,0 +1,23 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/hooks/useResize.ts b/src/MaxKB-1.7.2/ui/src/layout/hooks/useResize.ts new file mode 100644 index 0000000..d1c0a34 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/hooks/useResize.ts @@ -0,0 +1,38 @@ +import { nextTick, onBeforeMount, onMounted, onBeforeUnmount } from 'vue' +import { useRoute } from 'vue-router' +import useStore from '@/stores' +import { DeviceType } from '@/enums/common' +/** 参考 Bootstrap 的响应式设计 WIDTH = 600 */ +const WIDTH = 600 + +/** 根据大小变化重新布局 */ +export default () => { + const { common } = useStore() + const _isMobile = () => { + const rect = document.body?.getBoundingClientRect() + return rect.width - 1 < WIDTH + } + + const _resizeHandler = () => { + if (!document.hidden) { + const isMobile = _isMobile() + common.toggleDevice(isMobile ? DeviceType.Mobile : DeviceType.Desktop) + } + } + + onBeforeMount(() => { + window.addEventListener('resize', _resizeHandler) + }) + + onMounted(() => { + nextTick(() => { + if (_isMobile()) { + common.toggleDevice(DeviceType.Mobile) + } + }) + }) + + onBeforeUnmount(() => { + window.removeEventListener('resize', _resizeHandler) + }) +} diff --git a/src/MaxKB-1.7.2/ui/src/layout/layout-template/AppLayout.vue b/src/MaxKB-1.7.2/ui/src/layout/layout-template/AppLayout.vue new file mode 100644 index 0000000..71bca0d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/layout-template/AppLayout.vue @@ -0,0 +1,17 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/layout-template/DetailLayout.vue b/src/MaxKB-1.7.2/ui/src/layout/layout-template/DetailLayout.vue new file mode 100644 index 0000000..9e1a6c2 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/layout-template/DetailLayout.vue @@ -0,0 +1,18 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/layout-template/SystemLayout.vue b/src/MaxKB-1.7.2/ui/src/layout/layout-template/SystemLayout.vue new file mode 100644 index 0000000..46221f3 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/layout-template/SystemLayout.vue @@ -0,0 +1,24 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/layout/layout-template/index.scss b/src/MaxKB-1.7.2/ui/src/layout/layout-template/index.scss new file mode 100644 index 0000000..8dcba63 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/layout/layout-template/index.scss @@ -0,0 +1,26 @@ +.app-layout { + background-color: var(--app-layout-bg-color); + height: 100%; +} + +.app-main { + position: relative; + height: 100%; + padding: var(--app-header-height) 0 0 !important; + box-sizing: border-box; + overflow: auto; + &.isExpire { + padding-top: calc(var(--app-header-height) + 40px) !important; + } +} + +.sidebar-container { + box-sizing: border-box; + transition: width 0.28s; + width: var(--sidebar-width); + min-width: var(--sidebar-width); + background-color: var(--sidebar-bg-color); +} +.view-container { + width: calc(100% - var(--sidebar-width)); +} diff --git a/src/MaxKB-1.7.2/ui/src/locales/index.ts b/src/MaxKB-1.7.2/ui/src/locales/index.ts new file mode 100644 index 0000000..51f94c1 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/index.ts @@ -0,0 +1,66 @@ +import { useLocalStorage, usePreferredLanguages } from '@vueuse/core'; +import { computed } from 'vue'; +import { createI18n } from 'vue-i18n'; + +// 导入语言文件 +const langModules = import.meta.glob('./lang/*/index.ts', { eager: true }) as Record Promise<{ default: Object }>>; + +const langModuleMap = new Map(); + +export const langCode: Array = []; + +export const localeConfigKey = 'MaxKB-locale'; + +// 获取浏览器默认语言环境 +const languages = usePreferredLanguages(); + +// 生成语言模块列表 +const generateLangModuleMap = () => { + const fullPaths = Object.keys(langModules); + fullPaths.forEach((fullPath) => { + const k = fullPath.replace('./lang', ''); + const startIndex = 1; + const lastIndex = k.lastIndexOf('/'); + const code = k.substring(startIndex, lastIndex); + langCode.push(code); + langModuleMap.set(code, langModules[fullPath]); + }); +}; + +// 导出 Message +const importMessages = computed(() => { + generateLangModuleMap(); + + const message: Recordable = {}; + langModuleMap.forEach((value: any, key) => { + message[key] = value.default; + }); + return message; +}); + +export const i18n = createI18n({ + legacy: false, + locale: useLocalStorage(localeConfigKey, 'zh_CN').value || languages.value[0] || 'zh_CN', + fallbackLocale: 'zh_CN', + messages: importMessages.value, + globalInjection: true, +}); + +export const langList = computed(() => { + if (langModuleMap.size === 0) generateLangModuleMap(); + + const list:any=[] + langModuleMap.forEach((value: any, key) => { + list.push({ + label: value.default.lang, + value: key, + }); + }); + + return list; +}); + +// @ts-ignore +export const { t } = i18n.global; + +export default i18n; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/components/index.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/components/index.ts new file mode 100644 index 0000000..bd77588 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/components/index.ts @@ -0,0 +1,4 @@ + +export default { + +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/index.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/index.ts new file mode 100644 index 0000000..3e80c0a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/index.ts @@ -0,0 +1,66 @@ +import en from 'element-plus/es/locale/lang/en'; +import components from './components'; +import layout from './layout'; +import views from './views'; + +export default { + lang: 'English', + layout, + views, + components, + en, + login: { + authentication: 'Login Authentication', + ldap: { + title: 'LDAP Settings', + address: 'LDAP Address', + serverPlaceholder: 'Please enter LDAP address', + bindDN: 'Bind DN', + bindDNPlaceholder: 'Please enter Bind DN', + password: 'Password', + passwordPlaceholder: 'Please enter password', + ou: 'User OU', + ouPlaceholder: 'Please enter User OU', + ldap_filter: 'User Filter', + ldap_filterPlaceholder: 'Please enter User Filter', + ldap_mapping: 'LDAP Attribute Mapping', + ldap_mappingPlaceholder: 'Please enter LDAP Attribute Mapping', + test: 'Test Connection', + enableAuthentication: 'Enable LDAP Authentication', + save: 'Save', + testConnectionSuccess: 'Test Connection Success', + testConnectionFailed: 'Test Connection Failed', + saveSuccess: 'Save Success', + }, + cas: { + title: 'CAS Settings', + ldpUri: 'ldpUri', + ldpUriPlaceholder: 'Please enter ldpUri', + redirectUrl: 'Callback Address', + redirectUrlPlaceholder: 'Please enter Callback Address', + enableAuthentication: 'Enable CAS Authentication', + saveSuccess: 'Save Success', + save: 'Save', + }, + oidc: { + title: 'OIDC Settings', + authEndpoint: 'Auth Endpoint', + authEndpointPlaceholder: 'Please enter Auth Endpoint', + tokenEndpoint: 'Token Endpoint', + tokenEndpointPlaceholder: 'Please enter Token Endpoint', + userInfoEndpoint: 'User Info Endpoint', + userInfoEndpointPlaceholder: 'Please enter User Info Endpoint', + clientId: 'Client ID', + clientIdPlaceholder: 'Please enter Client ID', + clientSecret: 'Client Secret', + clientSecretPlaceholder: 'Please enter Client Secret', + logoutEndpoint: 'Logout Endpoint', + logoutEndpointPlaceholder: 'Please enter Logout Endpoint', + redirectUrl: 'Redirect URL', + redirectUrlPlaceholder: 'Please enter Redirect URL', + enableAuthentication: 'Enable OIDC Authentication', + }, + jump_tip: 'Jumping to the authentication source page for authentication', + jump: 'Jump', + }, +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/layout.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/layout.ts new file mode 100644 index 0000000..bdd3d77 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/layout.ts @@ -0,0 +1,37 @@ +export default { + breadcrumb: {}, + sidebar: {}, + topbar: { + github: "Project address", + wiki: "User manual", + forum: "Forum for help", + MenuItem: { + application: "Application", + dataset: "Knowledge base", + setting: "System settings" + }, + avatar: { + resetPassword: "Change password", + about: "About", + logout: "Logout", + version: "Version", + apiKey: "API Key", + apiServiceAddress: "API Service Address", + dialog: { + newPassword: "New password", + enterPassword: "Please enter new password", + confirmPassword: "Confirm password", + passwordLength: "Password length should be between 6 and 20 characters", + passwordMismatch: "Passwords do not match", + useEmail: "Use email", + enterEmail: "Please enter email", + enterVerificationCode: "Please enter verification code", + getVerificationCode: "Get verification code", + verificationCodeSentSuccess: "Verification code sent successfully", + resend: "Resend", + cancel: "Cancel", + save: "Save", + } + } + }, +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/404.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/404.ts new file mode 100644 index 0000000..c637111 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/404.ts @@ -0,0 +1,5 @@ +export default { + title: "404", + message: "Unable to Access Application", + operate: "Back to Home", +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/application-overview.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/application-overview.ts new file mode 100644 index 0000000..2875f33 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/application-overview.ts @@ -0,0 +1,110 @@ +export default { + title: 'Overview', + appInfo: { + header: 'Application Info', + publicAccessLink: 'Public Access Link', + openText: 'On', + closeText: 'Off', + copyLinkText: 'Copy Link', + refreshLinkText: 'Refresh Link', + demo: 'Demo', + embedThirdParty: 'Embed Third Party', + accessRestrictions: 'Access Restrictions', + displaySetting: 'Display Setting', + apiAccessCredentials: 'API Access Credentials', + apiKey: 'API Key', + refreshToken: { + msgConfirm1: 'Do you want to regenerate the public access link?', + msgConfirm2: + 'Regenerating the public access link will affect third-party embedded scripts changes and will require re-embedding the new script into third-party sites. Please proceed with caution!', + confirm: 'Confirm', + cancel: 'Cancel', + refreshSuccess: 'Refresh Successful' + }, + changeState: { + enableSuccess: 'Enable Successful', + disableSuccess: 'Disable Successful' + }, + APIKeyDialog: { + creatApiKey: 'Create', + status: 'Status', + creationDate: 'Creation Date', + operations: 'Operations', + settings: 'Settings', + delete: 'Delete', + saveSettings: 'Save Settings', + msgConfirm1: 'Are you sure you want to delete the API Key?', + msgConfirm2: + 'Deleting the API Key cannot be undone. Please confirm if you want to delete it!', + confirmDelete: 'Delete', + deleteSuccess: 'Delete Successful', + cancel: 'Cancel', + enabledSuccess: 'Enabled', + disabledSuccess: 'Disabled' + }, + EditAvatarDialog: { + title: 'Edit Logo', + customizeUpload: 'Custom Upload', + upload: 'Upload', + default: 'Default Logo', + custom: 'Custom', + sizeTip: 'Suggested size 32*32, supports jpg, png, gif, size no more than 10 MB', + cancel: 'Cancel', + save: 'Save', + fileSizeExceeded: 'File size exceeds 10 MB', + setSuccess: 'Setting Successful', + uploadImagePrompt: 'Please upload an image' + }, + EmbedDialog: { + embedDialogTitle: 'Embed Third Party', + fullscreenModeTitle: 'Fullscreen Mode', + copyInstructions: 'Copy the following code to embed', + floatingModeTitle: 'Floating Mode' + }, + LimitDialog: { + dialogTitle: 'Access Restrictions', + showSourceLabel: 'Show Source', + clientQueryLimitLabel: 'Each Client Query Limit', + authentication: 'Authentication', + authenticationValue: 'Authentication Password', + timesDays: 'Times/Day', + whitelistLabel: 'Whitelist', + whitelistPlaceholder: + 'Please enter allowed third-party source addresses, one per line, such as:\nhttp://127.0.0.1:5678\nhttps://dataease.io', + cancelButtonText: 'Cancel', + saveButtonText: 'Save', + settingSuccessMessage: 'Setting Successful' + }, + SettingAPIKeyDialog: { + dialogTitle: 'Settings', + allowCrossDomainLabel: 'Allow Cross-Domain Address', + crossDomainPlaceholder: + 'Please enter allowed cross-domain addresses, if open without inputting addresses, there are no restrictions.\nCross-domain addresses one per line, such as:\nhttp://127.0.0.1:5678\nhttps://dataease.io', + cancelButtonText: 'Cancel', + saveButtonText: 'Save', + successMessage: 'Setting Successful' + } + }, + monitor: { + monitoringStatistics: 'Monitoring Statistics', + customRange: 'Custom Range', + startDatePlaceholder: 'Start Date', + endDatePlaceholder: 'End Date', + pastDayOptions: { + past7Days: 'Past 7 Days', + past30Days: 'Past 30 Days', + past90Days: 'Past 90 Days', + past183Days: 'Past Half Year', + other: 'Custom' + }, + charts: { + customerTotal: 'Total Customers', + customerNew: 'New Customers', + queryCount: 'Query Count', + tokensTotal: 'Total Tokens', + userSatisfaction: 'User Satisfaction', + approval: 'Approval', + disapproval: 'Disapproval' + } + } +} diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/application.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/application.ts new file mode 100644 index 0000000..0c153f5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/application.ts @@ -0,0 +1,116 @@ +export default { + applicationList: { + title: 'Applications', + searchBar: { + placeholder: 'Search by name' + }, + card: { + createApplication: 'Create Application', + overview: 'Overview', + demo: 'Demo', + setting: 'Settings', + delete: { + tooltip: 'Delete', + confirmTitle: 'Are you sure you want to delete this application?', + confirmMessage: + 'Deleting this application will no longer provide its services. Please proceed with caution.', + confirmButton: 'Delete', + cancelButton: 'Cancel', + successMessage: 'Successfully deleted' + } + }, + tooltips: { + demo: 'Demo', + setting: 'Settings', + delete: 'Delete' + } + }, + applicationForm: { + title: { + create: 'Create Application', + edit: 'Edit Settings', + info: 'Application Information', + copy: 'Copy Application' + }, + form: { + appName: { + label: 'Application Name', + placeholder: 'Please enter the application name', + requiredMessage: 'Application name is required' + }, + appDescription: { + label: 'Application Description', + placeholder: + 'Describe the application scenario and use, e.g.: MaxKB assistant answering user questions about MaxKB product usage' + }, + aiModel: { + label: 'AI Model', + placeholder: 'Please select an AI model', + unavailable: '(Unavailable)' + }, + prompt: { + label: 'Prompt', + placeholder: 'Please enter prompt', + tooltip: + 'By adjusting the content of the prompt, you can guide the direction of the large model conversation. This prompt will be fixed at the beginning of the context. Variables used: {data} carries known information from the knowledge base; {question} is the question posed by the user.' + }, + multipleRoundsDialogue: 'Multiple Rounds Dialogue', + relatedKnowledgeBase: 'Related Knowledge Base', + relatedKnowledgeBaseWhere: 'Associated knowledge bases are displayed here', + prologue: 'Prologue', + problemOptimization: { + label: 'Problem Optimization', + tooltip: + 'Optimize the current question based on historical chat to better match knowledge points.' + }, + addModel: 'Add Model', + paramSetting: 'Parameter Settings', + add: 'Add', + apptest: 'Debug Preview' + }, + buttons: { + confirm: 'Confirm', + cancel: 'Cancel', + create: 'Create', + createSuccess: 'Create Success', + save: 'Save', + saveSuccess: 'Save Success', + copy: 'Copy', + copySuccess: 'Copy Success' + }, + dialogues: { + addDataset: 'Add Related Knowledge Base', + removeDataset: 'Remove Knowledge Base', + paramSettings: 'Parameter Settings', + refresh: 'Refresh', + selectSearchMode: 'Search Mode', + vectorSearch: 'Vector Search', + vectorSearchTooltip: + 'Vector search is a retrieval method based on vector distance calculations, suitable for large data volumes in the knowledge base.', + fullTextSearch: 'Full-text Search', + fullTextSearchTooltip: + 'Full-text search is a retrieval method based on text similarity, suitable for small data volumes in the knowledge base.', + hybridSearch: 'Hybrid Search', + hybridSearchTooltip: + 'Hybrid search is a retrieval method based on both vector and text similarity, suitable for medium data volumes in the knowledge base.', + similarityThreshold: 'Similarity Threshold', + topReferences: 'Top N References', + maxCharacters: 'Maximum Characters per Reference', + noReferencesAction: 'When there are no knowledge base references', + continueQuestioning: 'Continue Questioning AI Model', + provideAnswer: 'Provide a Specific Answer', + prompt: 'Prompt', + promptPlaceholder: 'Please enter a prompt', + concent: 'Content', + concentPlaceholder: 'Please enter content', + designated_answer: + 'Hello, I am MaxKB Assistant. My knowledge base only contains information related to MaxKB products. Please rephrase your question.' + } + }, + prompt: { + defaultPrompt: + 'Known information:\n{data}\nResponse requirements:\n- Please use concise and professional language to answer the user\'s question.\n- If you do not know the answer, reply, "No relevant information was found in the knowledge base; it is recommended to consult technical support or refer to the official documentation for operations."\n- Avoid mentioning that your knowledge is obtained from known information.\n- Ensure the answer is consistent with the information described in the known data.\n- Please use Markdown syntax to optimize the format of the answer.\n- Directly return any images, link addresses, and script languages found in the known information.\n- Please respond in the same language as the question.\nQuestion:\n{question}', + defaultPrologue: + 'Hello, I am MaxKB Assistant. You can ask me questions about using MaxKB.\n- What are the main features of MaxKB?\n- Which large language models does MaxKB support?\n- What document types does MaxKB support?' + } +} diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/index.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/index.ts new file mode 100644 index 0000000..076cd0a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/en_US/views/index.ts @@ -0,0 +1,8 @@ +import notFound from './404'; +import application from './application'; +import applicationOverview from './application-overview'; +export default { + notFound, + application, + applicationOverview +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/components/index.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/components/index.ts new file mode 100644 index 0000000..bd77588 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/components/index.ts @@ -0,0 +1,4 @@ + +export default { + +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/index.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/index.ts new file mode 100644 index 0000000..26a9a0c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/index.ts @@ -0,0 +1,66 @@ +import zhCn from 'element-plus/es/locale/lang/zh-cn'; +import components from './components'; +import layout from './layout'; +import views from './views'; + +export default { + lang: '简体中文', + layout, + views, + components, + zhCn, + login: { + authentication: '登录认证', + ldap: { + title: 'LDAP 设置', + address: 'LDAP 地址', + serverPlaceholder: '请输入LDAP 地址', + bindDN: '绑定DN', + bindDNPlaceholder: '请输入绑定 DN', + password: '密码', + passwordPlaceholder: '请输入密码', + ou: '用户OU', + ouPlaceholder: '请输入用户 OU', + ldap_filter: '用户过滤器', + ldap_filterPlaceholder: '请输入用户过滤器', + ldap_mapping: 'LDAP 属性映射', + ldap_mappingPlaceholder: '请输入 LDAP 属性映射', + test: '测试连接', + enableAuthentication: '启用 LDAP 认证', + save: '保存', + testConnectionSuccess: '测试连接成功', + testConnectionFailed: '测试连接失败', + saveSuccess: '保存成功', + }, + cas: { + title: 'CAS 设置', + ldpUri: 'ldpUri', + ldpUriPlaceholder: '请输入ldpUri', + redirectUrl: '回调地址', + redirectUrlPlaceholder: '请输入回调地址', + enableAuthentication: '启用CAS认证', + saveSuccess: '保存成功', + save: '保存', + }, + oidc: { + title: 'OIDC 设置', + authEndpoint: '授权端地址', + authEndpointPlaceholder: '请输入授权端地址', + tokenEndpoint: 'Token端地址', + tokenEndpointPlaceholder: '请输入Token端地址', + userInfoEndpoint: '用户信息端地址', + userInfoEndpointPlaceholder: '请输入用户信息端地址', + clientId: '客户端ID', + clientIdPlaceholder: '请输入客户端ID', + clientSecret: '客户端密钥', + clientSecretPlaceholder: '请输入客户端密钥', + logoutEndpoint: '注销端地址', + logoutEndpointPlaceholder: '请输入注销端地址', + redirectUrl: '回调地址', + redirectUrlPlaceholder: '请输入回调地址', + enableAuthentication: '启用OIDC认证', + }, + jump_tip: '即将跳转至认证源页面进行认证', + jump: '跳转', + }, +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/layout.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/layout.ts new file mode 100644 index 0000000..7dd28ac --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/layout.ts @@ -0,0 +1,41 @@ +export default { + breadcrumb: { + + }, + sidebar: { + + }, + topbar: { + github: "项目地址", + wiki: "用户手册", + forum: "论坛求助", + MenuItem: { + application: "应用", + dataset: "知识库", + setting: "系统管理" + }, + avatar: { + resetPassword: "修改密码", + about: "关于", + logout: "退出", + version:"版本号", + apiKey: "API Key 管理", + apiServiceAddress: "API 服务地址", + dialog:{ + newPassword:"新密码", + enterPassword: "请输入修改密码", + confirmPassword: "确认密码", + passwordLength:"密码长度在 6 到 20 个字符", + passwordMismatch:"两次密码输入不一致", + useEmail:"使用邮箱", + enterEmail: "请输入邮箱", + enterVerificationCode: "请输入验证码", + getVerificationCode: "获取验证码", + verificationCodeSentSuccess:"验证码发送成功", + resend:"重新发送", + cancel:"取消", + save:"保存", + } + } + }, +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/404.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/404.ts new file mode 100644 index 0000000..a65dcbb --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/404.ts @@ -0,0 +1,5 @@ +export default { + title: "404", + message: "无法访问应用", + operate: "返回首页", +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/application-overview.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/application-overview.ts new file mode 100644 index 0000000..88e71fc --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/application-overview.ts @@ -0,0 +1,109 @@ +export default { + title: '概览', + appInfo: { + header: '应用信息', + publicAccessLink: '公开访问链接', + openText: '开', + closeText: '关', + copyLinkText: '复制链接', + refreshLinkText: '刷新链接', + demo: '演示', + embedThirdParty: '嵌入第三方', + accessRestrictions: '访问限制', + displaySetting: '显示设置', + apiAccessCredentials: 'API 访问凭据', + apiKey: 'API Key', + refreshToken: { + msgConfirm1: '是否重新生成公开访问链接?', + msgConfirm2: + '重新生成公开访问链接会影响嵌入第三方脚本变更,需要将新脚本重新嵌入第三方,请谨慎操作!', + confirm: '确认', + cancel: '取消', + refreshSuccess: '刷新成功' + }, + changeState: { + enableSuccess: '启用成功', + disableSuccess: '禁用成功' + }, + APIKeyDialog: { + creatApiKey: '创建', + status: '状态', + creationDate: '创建日期', + operations: '操作', + settings: '设置', + delete: '删除', + saveSettings: '保存设置', + msgConfirm1: '是否删除API Key', + msgConfirm2: '删除API Key后将无法恢复,请确认是否删除?', + confirmDelete: '删除', + deleteSuccess: '删除成功', + cancel: '取消', + enabledSuccess: '已启用', + disabledSuccess: '已禁用' + }, + EditAvatarDialog: { + title: '应用头像', + customizeUpload: '自定义上传', + upload: '上传', + default: '默认logo', + custom: '自定义', + sizeTip: '建议尺寸 32*32,支持 JPG、PNG、GIF,大小不超过 10 MB', + cancel: '取消', + save: '保存', + fileSizeExceeded: '文件大小超过 10 MB', + setSuccess: '设置成功', + uploadImagePrompt: '请上传一张图片' + }, + EmbedDialog: { + embedDialogTitle: '嵌入第三方', + fullscreenModeTitle: '全屏模式', + copyInstructions: '复制以下代码进行嵌入', + floatingModeTitle: '浮窗模式' + }, + LimitDialog: { + dialogTitle: '访问限制', + showSourceLabel: '显示知识来源', + clientQueryLimitLabel: '每个客户端提问限制', + timesDays: '次/天', + authentication: '身份验证', + authenticationValue: '验证密码', + whitelistLabel: '白名单', + whitelistPlaceholder: + '请输入允许嵌入第三方的源地址,一行一个,如:\nhttp://127.0.0.1:5678\nhttps://dataease.io', + cancelButtonText: '取消', + saveButtonText: '保存', + settingSuccessMessage: '设置成功' + }, + SettingAPIKeyDialog: { + dialogTitle: '设置', + allowCrossDomainLabel: '允许跨域地址', + crossDomainPlaceholder: + '请输入允许的跨域地址,开启后不输入跨域地址则不限制。\n跨域地址一行一个,如:\nhttp://127.0.0.1:5678 \nhttps://dataease.io', + cancelButtonText: '取消', + saveButtonText: '保存', + successMessage: '设置成功' + } + }, + monitor: { + monitoringStatistics: '监控统计', + customRange: '自定义范围', + startDatePlaceholder: '开始时间', + endDatePlaceholder: '结束时间', + pastDayOptions: { + past7Days: '过去7天', + past30Days: '过去30天', + past90Days: '过去90天', + past183Days: '过去半年', + other: '自定义' + }, + charts: { + customerTotal: '用户总数', + customerNew: '用户新增数', + queryCount: '提问次数', + tokensTotal: 'Tokens 总数', + userSatisfaction: '用户满意度', + approval: '赞同', + disapproval: '反对' + } + } +} diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/application.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/application.ts new file mode 100644 index 0000000..81c4588 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/application.ts @@ -0,0 +1,114 @@ +export default { + applicationList: { + title: '应用', + searchBar: { + placeholder: '按名称搜索' + }, + card: { + createApplication: '创建应用', + overview: '概览', + demo: '演示', + setting: '设置', + delete: { + tooltip: '删除', + confirmTitle: '是否删除应用:', + confirmMessage: '删除后该应用将不再提供服务,请谨慎操作。', + confirmButton: '删除', + cancelButton: '取消', + successMessage: '删除成功' + } + }, + tooltips: { + demo: '演示', + setting: '设置', + delete: '删除' + } + }, + applicationForm: { + title: { + create: '创建应用', + edit: '设置', + info: '应用信息', + copy: '复制应用' + }, + form: { + appName: { + label: '应用名称', + placeholder: '请输入应用名称', + requiredMessage: '请输入应用名称' + }, + appDescription: { + label: '应用描述', + placeholder: '描述该应用的应用场景及用途,如:XXX 小助手回答用户提出的 XXX 产品使用问题' + }, + aiModel: { + label: 'AI 模型', + placeholder: '请选择 AI 模型', + unavailable: '(不可用)' + }, + prompt: { + label: '提示词', + placeholder: '请输入提示词', + tooltip: + '通过调整提示词内容,可以引导大模型聊天方向,该提示词会被固定在上下文的开头。可以使用变量:{data} 是携带知识库中已知信息;{question} 是用户提出的问题。' + }, + multipleRoundsDialogue: '多轮对话', + relatedKnowledgeBase: '关联知识库', + relatedKnowledgeBaseWhere: '关联知识库展示在这里', + prologue: '开场白', + problemOptimization: { + label: '问题优化', + tooltip: '根据历史聊天优化完善当前问题,更利于匹配知识点。' + }, + addModel: '添加模型', + paramSetting: '参数设置', + add: '添加', + apptest: '调试预览' + }, + buttons: { + confirm: '确认', + cancel: '取消', + create: '创建', + createSuccess: '创建成功', + save: '保存', + saveSuccess: '保存成功', + copy: '复制', + copySuccess: '复制成功' + }, + dialogues: { + addDataset: '添加关联知识库', + removeDataset: '移除知识库', + paramSettings: '参数设置', + refresh: '刷新', + selectSearchMode: '检索模式', + vectorSearch: '向量检索', + vectorSearchTooltip: '向量检索是一种基于向量相似度的检索方式,适用于知识库中的大数据量场景。', + fullTextSearch: '全文检索', + fullTextSearchTooltip: + '全文检索是一种基于文本相似度的检索方式,适用于知识库中的小数据量场景。', + hybridSearch: '混合检索', + hybridSearchTooltip: + '混合检索是一种基于向量和文本相似度的检索方式,适用于知识库中的中等数据量场景。', + similarityThreshold: '相似度高于', + topReferences: '引用分段数 TOP', + maxCharacters: '最多引用字符数', + noReferencesAction: '无引用知识库分段时', + continueQuestioning: '继续向 AI 模型提问', + provideAnswer: '指定回答内容', + prompt: '提示词', + promptPlaceholder: '请输入提示词', + concent: '内容', + concentPlaceholder: '请输入内容', + designated_answer: + '你好,我是 XXX 小助手,我的知识库只包含了 XXX 产品相关知识,请重新描述您的问题。' + } + }, + prompt: { + defaultPrompt: `已知信息:{data} +用户问题:{question} +回答要求: + - 请使用中文回答用户问题`, + defaultPrologue: + '您好,我是 XXX 小助手,您可以向我提出 XXX 使用问题。\n- XXX 主要功能有什么?\n- XXX 如何收费?\n- 需要转人工服务' + } +} diff --git a/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/index.ts b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/index.ts new file mode 100644 index 0000000..076cd0a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/lang/zh_CN/views/index.ts @@ -0,0 +1,8 @@ +import notFound from './404'; +import application from './application'; +import applicationOverview from './application-overview'; +export default { + notFound, + application, + applicationOverview +}; diff --git a/src/MaxKB-1.7.2/ui/src/locales/useLocale.ts b/src/MaxKB-1.7.2/ui/src/locales/useLocale.ts new file mode 100644 index 0000000..c802162 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/locales/useLocale.ts @@ -0,0 +1,28 @@ +import { useLocalStorage } from '@vueuse/core'; +import { computed } from 'vue'; +import { useI18n } from 'vue-i18n'; + +import { i18n, langCode, localeConfigKey } from '@/locales/index'; + +export function useLocale() { + const { locale } = useI18n({ useScope: 'global' }); + function changeLocale(lang: string) { + // 如果切换的语言不在对应语言文件里则默认为简体中文 + if (!langCode.includes(lang)) { + lang = 'zh_CN'; + } + + locale.value = lang; + useLocalStorage(localeConfigKey, 'zh_CN').value = lang; + } + + const getComponentsLocale = computed(() => { + return i18n.global.getLocaleMessage(locale.value).componentsLocale; + }); + + return { + changeLocale, + getComponentsLocale, + locale, + }; +} diff --git a/src/MaxKB-1.7.2/ui/src/main.ts b/src/MaxKB-1.7.2/ui/src/main.ts new file mode 100644 index 0000000..a93fe9e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/main.ts @@ -0,0 +1,62 @@ +import '@/styles/index.scss' +import ElementPlus from 'element-plus' +import * as ElementPlusIcons from '@element-plus/icons-vue' +import zhCn from 'element-plus/dist/locale/zh-cn.mjs' +import { createApp } from 'vue' +import { store } from '@/stores' +import directives from '@/directives' +import App from './App.vue' +import router from '@/router' +import Components from '@/components' +import i18n from './locales' +import { config } from 'md-editor-v3' + +import screenfull from 'screenfull' + +import katex from 'katex' +import 'katex/dist/katex.min.css' + +import Cropper from 'cropperjs' +import 'cropperjs/dist/cropper.css' + +import mermaid from 'mermaid' + +import highlight from 'highlight.js' +import 'highlight.js/styles/atom-one-dark.css' + +config({ + editorExtensions: { + highlight: { + instance: highlight + }, + screenfull: { + instance: screenfull + }, + katex: { + instance: katex + }, + cropper: { + instance: Cropper + }, + mermaid: { + instance: mermaid + } + } +}) + +const app = createApp(App) +app.use(store) +app.use(directives) + +for (const [key, component] of Object.entries(ElementPlusIcons)) { + app.component(key, component) +} +app.use(ElementPlus, { + locale: zhCn +}) + +app.use(router) +app.use(i18n) +app.use(Components) +app.mount('#app') +export { app } diff --git a/src/MaxKB-1.7.2/ui/src/request/Result.ts b/src/MaxKB-1.7.2/ui/src/request/Result.ts new file mode 100644 index 0000000..c1b04e5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/request/Result.ts @@ -0,0 +1,42 @@ +export class Result { + message: string; + code: number; + data: T; + constructor(message: string, code: number, data: T) { + this.message = message; + this.code = code; + this.data = data; + } + + static success(data: any) { + return new Result("请求成功", 200, data); + } + static error(message: string, code: number) { + return new Result(message, code, null); + } +} + +interface Page { + /** + *分页数据 + */ + records: Array; + /** + *当前页 + */ + current: number; + /** + * 每页展示size + */ + size: number; + /** + *总数 + */ + total: number; + /** + *是否有下一页 + */ + hasNext: boolean; +} +export type { Page }; +export default Result; diff --git a/src/MaxKB-1.7.2/ui/src/request/index.ts b/src/MaxKB-1.7.2/ui/src/request/index.ts new file mode 100644 index 0000000..204fa29 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/request/index.ts @@ -0,0 +1,318 @@ +import axios, { type AxiosRequestConfig } from 'axios' +import { MsgError } from '@/utils/message' +import type { NProgress } from 'nprogress' +import type { Ref } from 'vue' +import type { Result } from '@/request/Result' +import useStore from '@/stores' +import router from '@/router' + +import { ref, type WritableComputedRef } from 'vue' + +const axiosConfig = { + baseURL: '/api', + withCredentials: false, + timeout: 600000, + headers: {} +} + +const instance = axios.create(axiosConfig) + +/* 设置请求拦截器 */ +instance.interceptors.request.use( + (config: AxiosRequestConfig) => { + if (config.headers === undefined) { + config.headers = {} + } + const { user } = useStore() + const token = user.getToken() + if (token) { + config.headers['AUTHORIZATION'] = `${token}` + } + return config + }, + (err: any) => { + return Promise.reject(err) + } +) + +//设置响应拦截器 +instance.interceptors.response.use( + (response: any) => { + if (response.data) { + if (response.data.code !== 200 && !(response.data instanceof Blob)) { + if (response.config.url.includes('/application/authentication')) { + return Promise.reject(response.data) + } + if ( + !response.config.url.includes('/valid') && + !response.config.url.includes('/function_lib/debug') + ) { + MsgError(response.data.message) + return Promise.reject(response.data) + } + } + } + return response + }, + (err: any) => { + if (err.code === 'ECONNABORTED') { + MsgError(err.message) + console.error(err) + } + if (err.response?.status === 404) { + if (!err.response.config.url.includes('/application/authentication')) { + router.push('/404 ') + } + } + if (err.response?.status === 401) { + if ( + !err.response.config.url.includes('chat/open') && + !err.response.config.url.includes('application/profile') + ) { + router.push({ name: 'login' }) + } + } + + if (err.response?.status === 403 && !err.response.config.url.includes('chat/open')) { + MsgError( + err.response.data && err.response.data.message ? err.response.data.message : '没有权限访问' + ) + } + return Promise.reject(err) + } +) + +export const request = instance + +/* 简化请求方法,统一处理返回结果,并增加loading处理,这里以{success,data,message}格式的返回值为例,具体项目根据实际需求修改 */ +const promise: ( + request: Promise, + loading?: NProgress | Ref | WritableComputedRef +) => Promise> = (request, loading = ref(false)) => { + return new Promise((resolve, reject) => { + if ((loading as NProgress).start) { + ;(loading as NProgress).start() + } else { + ;(loading as Ref).value = true + } + request + .then((response) => { + // blob类型的返回状态是response.status + if (response.status === 200) { + resolve(response?.data || response) + } else { + reject(response?.data || response) + } + }) + .catch((error) => { + reject(error) + }) + .finally(() => { + if ((loading as NProgress).start) { + ;(loading as NProgress).done() + } else { + ;(loading as Ref).value = false + } + }) + }) +} + +/** + * 发送get请求 一般用来请求资源 + * @param url 资源url + * @param params 参数 + * @param loading loading + * @returns 异步promise对象 + */ +export const get: ( + url: string, + params?: unknown, + loading?: NProgress | Ref, + timeout?: number +) => Promise> = ( + url: string, + params: unknown, + loading?: NProgress | Ref, + timeout?: number +) => { + return promise(request({ url: url, method: 'get', params, timeout: timeout }), loading) +} + +/** + * faso post请求 一般用来添加资源 + * @param url 资源url + * @param params 参数 + * @param data 添加数据 + * @param loading loading + * @returns 异步promise对象 + */ +export const post: ( + url: string, + data?: unknown, + params?: unknown, + loading?: NProgress | Ref, + timeout?: number +) => Promise | any> = (url, data, params, loading, timeout) => { + return promise(request({ url: url, method: 'post', data, params, timeout }), loading) +} + +/**| + * 发送put请求 用于修改服务器资源 + * @param url 资源地址 + * @param params params参数地址 + * @param data 需要修改的数据 + * @param loading 进度条 + * @returns + */ +export const put: ( + url: string, + data?: unknown, + params?: unknown, + loading?: NProgress | Ref, + timeout?: number +) => Promise> = (url, data, params, loading, timeout) => { + return promise(request({ url: url, method: 'put', data, params, timeout }), loading) +} + +/** + * 删除 + * @param url 删除url + * @param params params参数 + * @param loading 进度条 + * @returns + */ +export const del: ( + url: string, + params?: unknown, + data?: unknown, + loading?: NProgress | Ref, + timeout?: number +) => Promise> = (url, params, data, loading, timeout) => { + return promise(request({ url: url, method: 'delete', params, data, timeout }), loading) +} + +/** + * 流处理 + * @param url url地址 + * @param data 请求body + * @returns + */ +export const postStream: (url: string, data?: unknown) => Promise | any> = ( + url, + data +) => { + const { user } = useStore() + const token = user.getToken() + const headers: HeadersInit = { 'Content-Type': 'application/json' } + if (token) { + headers['AUTHORIZATION'] = `${token}` + } + return fetch(url, { + method: 'POST', + body: data ? JSON.stringify(data) : undefined, + headers: headers + }) +} + +export const exportExcel: ( + fileName: string, + url: string, + params: any, + loading?: NProgress | Ref +) => Promise = ( + fileName: string, + url: string, + params: any, + loading?: NProgress | Ref +) => { + return promise(request({ url: url, method: 'get', params, responseType: 'blob' }), loading) + .then((res: any) => { + if (res) { + const blob = new Blob([res], { + type: 'application/vnd.ms-excel' + }) + const link = document.createElement('a') + link.href = window.URL.createObjectURL(blob) + link.download = fileName + link.click() + //释放内存 + window.URL.revokeObjectURL(link.href) + } + return true + }) + .catch((e) => {}) +} + +export const exportExcelPost: ( + fileName: string, + url: string, + params: any, + data: any, + loading?: NProgress | Ref +) => Promise = ( + fileName: string, + url: string, + params: any, + data: any, + loading?: NProgress | Ref +) => { + return promise( + request({ + url: url, + method: 'post', + params, // 查询字符串参数 + data, // 请求体数据 + responseType: 'blob' + }), + loading + ) + .then((res: any) => { + if (res) { + const blob = new Blob([res], { + type: 'application/vnd.ms-excel' + }) + const link = document.createElement('a') + link.href = window.URL.createObjectURL(blob) + link.download = fileName + link.click() + // 释放内存 + window.URL.revokeObjectURL(link.href) + } + return true + }) + .catch((e) => {}) +} + +export const download: ( + url: string, + method: string, + data?: any, + params?: any, + loading?: NProgress | Ref +) => Promise = ( + url: string, + method: string, + data?: any, + params?: any, + loading?: NProgress | Ref +) => { + return promise(request({ url: url, method: method, data, params, responseType: 'blob' }), loading) +} + +/** + * 与服务器建立ws链接 + * @param url websocket路径 + * @returns 返回一个websocket实例 + */ +export const socket = (url: string) => { + let protocol = 'ws://' + if (window.location.protocol === 'https:') { + protocol = 'wss://' + } + let uri = protocol + window.location.host + url + if (!import.meta.env.DEV) { + uri = protocol + window.location.host + import.meta.env.VITE_BASE_PATH + url + } + return new WebSocket(uri) +} +export default instance diff --git a/src/MaxKB-1.7.2/ui/src/router/index.ts b/src/MaxKB-1.7.2/ui/src/router/index.ts new file mode 100644 index 0000000..b760c85 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/router/index.ts @@ -0,0 +1,76 @@ +import { hasPermission } from '@/utils/permission/index' +import { + createRouter, + createWebHistory, + type NavigationGuardNext, + type RouteLocationNormalized, + type RouteRecordRaw, + type RouteRecordName +} from 'vue-router' +import useStore from '@/stores' +import { routes } from '@/router/routes' +const router = createRouter({ + history: createWebHistory(import.meta.env.BASE_URL), + routes: routes +}) + +// 路由前置拦截器 +router.beforeEach( + async (to: RouteLocationNormalized, from: RouteLocationNormalized, next: NavigationGuardNext) => { + if (to.name === '404') { + next() + return + } + const { user } = useStore() + const notAuthRouteNameList = ['register', 'login', 'forgot_password', 'reset_password', 'Chat'] + + if (!notAuthRouteNameList.includes(to.name ? to.name.toString() : '')) { + if (to.query && to.query.token) { + localStorage.setItem('token', to.query.token.toString()) + } + const token = user.getToken() + if (!token) { + next({ + path: '/login' + }) + return + } + if (!user.userInfo) { + await user.profile() + } + } + // 判断是否有菜单权限 + if (to.meta.permission ? hasPermission(to.meta.permission as any, 'OR') : true) { + next() + } else { + // 如果没有权限则直接取404页面 + next('404') + } + } +) + +export const getChildRouteListByPathAndName = (path: any, name?: RouteRecordName | any) => { + return getChildRouteList(routes, path, name) +} + +export const getChildRouteList: ( + routeList: Array, + path: string, + name?: RouteRecordName | null | undefined +) => Array = (routeList, path, name) => { + for (let index = 0; index < routeList.length; index++) { + const route = routeList[index] + if (name === route.name && path === route.path) { + return route.children || [] + } + if (route.children && route.children.length > 0) { + const result = getChildRouteList(route.children, path, name) + if (result && result?.length > 0) { + return result + } + } + } + return [] +} + +export default router diff --git a/src/MaxKB-1.7.2/ui/src/router/modules/application.ts b/src/MaxKB-1.7.2/ui/src/router/modules/application.ts new file mode 100644 index 0000000..07fc42e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/router/modules/application.ts @@ -0,0 +1,92 @@ +import Layout from '@/layout/layout-template/DetailLayout.vue' +import { ComplexPermission } from '@/utils/permission/type' +const applicationRouter = { + path: '/application', + name: 'application', + meta: { title: '应用', permission: 'APPLICATION:READ' }, + redirect: '/application', + component: () => import('@/layout/layout-template/AppLayout.vue'), + children: [ + { + path: '/application', + name: 'application', + component: () => import('@/views/application/index.vue') + }, + { + path: '/application/:id/:type', + name: 'ApplicationDetail', + meta: { title: '应用详情', activeMenu: '/application' }, + component: Layout, + hidden: true, + children: [ + { + path: 'overview', + name: 'AppOverview', + meta: { + icon: 'app-all-menu', + iconActive: 'app-all-menu-active', + title: '概览', + active: 'overview', + parentPath: '/application/:id/:type', + parentName: 'ApplicationDetail' + }, + component: () => import('@/views/application-overview/index.vue') + }, + { + path: 'setting', + name: 'AppSetting', + meta: { + icon: 'app-setting', + iconActive: 'app-setting-active', + title: '设置', + active: 'setting', + parentPath: '/application/:id/:type', + parentName: 'ApplicationDetail' + }, + component: () => import('@/views/application/ApplicationSetting.vue') + }, + { + path: 'access', + name: 'AppAccess', + meta: { + icon: 'app-access', + iconActive: 'app-access-active', + title: '应用接入', + active: 'access', + parentPath: '/application/:id/:type', + parentName: 'ApplicationDetail', + permission: new ComplexPermission([], ['x-pack'], 'OR') + }, + component: () => import('@/views/application/ApplicationAccess.vue') + }, + { + path: 'hit-test', + name: 'AppHitTest', + meta: { + icon: 'app-hit-test', + title: '命中测试', + active: 'hit-test', + parentPath: '/application/:id/:type', + parentName: 'ApplicationDetail' + }, + component: () => import('@/views/hit-test/index.vue') + }, + { + path: 'log', + name: 'Log', + meta: { + icon: 'app-document', + iconActive: 'app-document-active', + title: '对话日志', + active: 'log', + parentPath: '/application/:id/:type', + parentName: 'ApplicationDetail' + }, + component: () => import('@/views/log/index.vue') + } + ] + } + ] +} + +export default applicationRouter diff --git a/src/MaxKB-1.7.2/ui/src/router/modules/dataset.ts b/src/MaxKB-1.7.2/ui/src/router/modules/dataset.ts new file mode 100644 index 0000000..b6d4c2b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/router/modules/dataset.ts @@ -0,0 +1,91 @@ +import Layout from '@/layout/layout-template/DetailLayout.vue' +const datasetRouter = { + path: '/dataset', + name: 'dataset', + meta: { title: '知识库', permission: 'DATASET:READ' }, + component: () => import('@/layout/layout-template/AppLayout.vue'), + redirect: '/dataset', + children: [ + { + path: '/dataset', + name: 'dataset', + component: () => import('@/views/dataset/index.vue') + }, + { + path: '/dataset/:type', // upload + name: 'UploadDocumentDataset', + meta: { activeMenu: '/dataset' }, + component: () => import('@/views/dataset/UploadDocumentDataset.vue'), + hidden: true + }, + { + path: '/dataset/:id', + name: 'DatasetDetail', + meta: { title: '文档', activeMenu: '/dataset' }, + component: Layout, + hidden: true, + children: [ + { + path: 'document', + name: 'Document', + meta: { + icon: 'app-document', + iconActive: 'app-document-active', + title: '文档', + active: 'document', + parentPath: '/dataset/:id', + parentName: 'DatasetDetail' + }, + component: () => import('@/views/document/index.vue') + }, + { + path: 'problem', + name: 'Problem', + meta: { + icon: 'app-problems', + iconActive: 'QuestionFilled', + title: '问题', + active: 'problem', + parentPath: '/dataset/:id', + parentName: 'DatasetDetail' + }, + component: () => import('@/views/problem/index.vue') + }, + { + path: 'hit-test', + name: 'DatasetHitTest', + meta: { + icon: 'app-hit-test', + title: '命中测试', + active: 'hit-test', + parentPath: '/dataset/:id', + parentName: 'DatasetDetail' + }, + component: () => import('@/views/hit-test/index.vue') + }, + { + path: 'setting', + name: 'DatasetSetting', + meta: { + icon: 'app-setting', + iconActive: 'app-setting-active', + title: '设置', + active: 'setting', + parentPath: '/dataset/:id', + parentName: 'DatasetDetail' + }, + component: () => import('@/views/dataset/DatasetSetting.vue') + } + ] + }, + { + path: '/dataset/:id/:documentId', // 分段详情 + name: 'Paragraph', + meta: { activeMenu: '/dataset' }, + component: () => import('@/views/paragraph/index.vue'), + hidden: true + } + ] +} + +export default datasetRouter diff --git a/src/MaxKB-1.7.2/ui/src/router/modules/function-lib.ts b/src/MaxKB-1.7.2/ui/src/router/modules/function-lib.ts new file mode 100644 index 0000000..7671236 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/router/modules/function-lib.ts @@ -0,0 +1,17 @@ +import Layout from '@/layout/layout-template/DetailLayout.vue' +const functionLibRouter = { + path: '/function-lib', + name: 'function-lib', + meta: { title: '函数库', permission: 'APPLICATION:READ' }, + redirect: '/function-lib', + component: () => import('@/layout/layout-template/AppLayout.vue'), + children: [ + { + path: '/function-lib', + name: 'function-lib', + component: () => import('@/views/function-lib/index.vue') + } + ] +} + +export default functionLibRouter diff --git a/src/MaxKB-1.7.2/ui/src/router/modules/setting.ts b/src/MaxKB-1.7.2/ui/src/router/modules/setting.ts new file mode 100644 index 0000000..719aef7 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/router/modules/setting.ts @@ -0,0 +1,110 @@ +import { hasPermission } from '@/utils/permission/index' +import Layout from '@/layout/layout-template/SystemLayout.vue' +import { Role, ComplexPermission } from '@/utils/permission/type' +const settingRouter = { + path: '/setting', + name: 'setting', + meta: { icon: 'Setting', title: '系统设置', permission: 'SETTING:READ' }, + redirect: () => { + if (hasPermission(new Role('ADMIN'), 'AND')) { + return '/user' + } + return '/team' + }, + component: Layout, + children: [ + { + path: '/user', + name: 'user', + meta: { + icon: 'User', + iconActive: 'UserFilled', + title: '用户管理', + activeMenu: '/setting', + parentPath: '/setting', + parentName: 'setting', + permission: new Role('ADMIN') + }, + component: () => import('@/views/user-manage/index.vue') + }, + { + path: '/team', + name: 'team', + meta: { + icon: 'app-team', + iconActive: 'app-team-active', + title: '团队成员', + activeMenu: '/setting', + parentPath: '/setting', + parentName: 'setting' + }, + component: () => import('@/views/team/index.vue') + }, + { + path: '/template', + name: 'template', + meta: { + icon: 'app-template', + iconActive: 'app-template-active', + title: '模型设置', + activeMenu: '/setting', + parentPath: '/setting', + parentName: 'setting' + }, + component: () => import('@/views/template/index.vue') + }, + { + path: '/system', + name: 'system', + meta: { + icon: 'app-setting', + iconActive: 'app-setting-active', + title: '系统设置', + activeMenu: '/setting', + parentPath: '/setting', + parentName: 'setting', + permission: new Role('ADMIN') + }, + children: [ + { + path: '/system/theme', + name: 'theme', + meta: { + title: '外观设置', + activeMenu: '/setting', + parentPath: '/setting', + parentName: 'setting', + permission: new ComplexPermission(['ADMIN'], ['x-pack'], 'AND') + }, + component: () => import('@/views/theme/index.vue') + }, + { + path: '/system/authentication', + name: 'authentication', + meta: { + title: '登录认证', + activeMenu: '/setting', + parentPath: '/setting', + parentName: 'setting', + permission: new ComplexPermission(['ADMIN'], ['x-pack'], 'AND') + }, + component: () => import('@/views/authentication/index.vue') + }, + { + path: '/system/email', + name: 'email', + meta: { + title: '邮箱配置', + activeMenu: '/setting', + parentPath: '/setting', + parentName: 'setting', + permission: new Role('ADMIN') + }, + component: () => import('@/views/email/index.vue') + } + ] + } + ] +} + +export default settingRouter diff --git a/src/MaxKB-1.7.2/ui/src/router/routes.ts b/src/MaxKB-1.7.2/ui/src/router/routes.ts new file mode 100644 index 0000000..82ddb05 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/router/routes.ts @@ -0,0 +1,54 @@ +import type { RouteRecordRaw } from 'vue-router' +import { Role } from '@/utils/permission/type' + +const modules: any = import.meta.glob('./modules/*.ts', { eager: true }) +const rolesRoutes: RouteRecordRaw[] = [...Object.keys(modules).map((key) => modules[key].default)] + +export const routes: Array = [ + { + path: '/', + name: 'home', + redirect: '/application', + children: [...rolesRoutes] + }, + + // 高级编排 + { + path: '/application/:id/workflow', + name: 'ApplicationWorkflow', + meta: { activeMenu: '/application' }, + component: () => import('@/views/application-workflow/index.vue') + }, + + { + path: '/chat/:accessToken', + name: 'Chat', + component: () => import('@/views/chat/index.vue') + }, + + { + path: '/login', + name: 'login', + component: () => import('@/views/login/index.vue') + }, + { + path: '/register', + name: 'register', + component: () => import('@/views/login/register/index.vue') + }, + { + path: '/forgot_password', + name: 'forgot_password', + component: () => import('@/views/login/forgot-password/index.vue') + }, + { + path: '/reset_password/:code/:email', + name: 'reset_password', + component: () => import('@/views/login/reset-password/index.vue') + }, + { + path: '/:pathMatch(.*)', + name: '404', + component: () => import('@/views/404/index.vue') + } +] diff --git a/src/MaxKB-1.7.2/ui/src/stores/index.ts b/src/MaxKB-1.7.2/ui/src/stores/index.ts new file mode 100644 index 0000000..7329855 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/index.ts @@ -0,0 +1,28 @@ +import { createPinia } from 'pinia' +const store = createPinia() +export { store } +import useCommonStore from './modules/common' +import useUserStore from './modules/user' +import useDatasetStore from './modules/dataset' +import useParagraphStore from './modules/paragraph' +import useModelStore from './modules/model' +import useApplicationStore from './modules/application' +import useDocumentStore from './modules/document' +import useProblemStore from './modules/problem' +import useLogStore from './modules/log' +import usePromptStore from './modules/prompt' + +const useStore = () => ({ + common: useCommonStore(), + user: useUserStore(), + dataset: useDatasetStore(), + paragraph: useParagraphStore(), + model: useModelStore(), + application: useApplicationStore(), + document: useDocumentStore(), + problem: useProblemStore(), + log: useLogStore(), + prompt: usePromptStore(), +}) + +export default useStore diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/application.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/application.ts new file mode 100644 index 0000000..40588fb --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/application.ts @@ -0,0 +1,141 @@ +import { defineStore } from 'pinia' +import applicationApi from '@/api/application' +import applicationXpackApi from '@/api/application-xpack' +import { type Ref } from 'vue' + +import useUserStore from './user' + +const useApplicationStore = defineStore({ + id: 'application', + state: () => ({ + location: `${window.location.origin}/ui/chat/` + }), + actions: { + async asyncGetAllApplication() { + return new Promise((resolve, reject) => { + applicationApi + .getAllAppilcation() + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + + async asyncGetApplicationDetail(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .getApplicationDetail(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + + async asyncGetApplicationDataset(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .getApplicationDataset(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + + async asyncGetAccessToken(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + const user = useUserStore() + if (user.isEnterprise()) { + applicationXpackApi + .getAccessToken(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + } else { + applicationApi + .getAccessToken(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + } + }) + }, + + async asyncGetAppProfile(loading?: Ref) { + return new Promise((resolve, reject) => { + const user = useUserStore() + applicationApi + .getAppProfile(loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + + async asyncAppAuthentication( + token: string, + loading?: Ref, + authentication_value?: any + ) { + return new Promise((resolve, reject) => { + applicationApi + .postAppAuthentication(token, loading, authentication_value) + .then((res) => { + localStorage.setItem('accessToken', res.data) + sessionStorage.setItem('accessToken', res.data) + resolve(res) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async refreshAccessToken(token: string) { + this.asyncAppAuthentication(token) + }, + // 修改应用 + async asyncPutApplication(id: string, data: any, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .putApplication(id, data, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async validatePassword(id: string, password: string, loading?: Ref) { + return new Promise((resolve, reject) => { + applicationApi + .validatePassword(id, password, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + } + } +}) + +export default useApplicationStore diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/common.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/common.ts new file mode 100644 index 0000000..636a36d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/common.ts @@ -0,0 +1,53 @@ +import { defineStore } from 'pinia' +import { DeviceType, ValidType } from '@/enums/common' +import type { Ref } from 'vue' +import userApi from '@/api/user' + +export interface commonTypes { + breadcrumb: any + paginationConfig: any | null + search: any + device: string +} + +const useCommonStore = defineStore({ + id: 'common', + state: (): commonTypes => ({ + breadcrumb: null, + // 搜索和分页缓存 + paginationConfig: {}, + search: {}, + device: DeviceType.Desktop + }), + actions: { + saveBreadcrumb(data: any) { + this.breadcrumb = data + }, + savePage(val: string, data: any) { + this.paginationConfig[val] = data + }, + saveCondition(val: string, data: any) { + this.search[val] = data + }, + toggleDevice(value: DeviceType) { + this.device = value + }, + isMobile() { + return this.device === DeviceType.Mobile + }, + async asyncGetValid(valid_type: ValidType, valid_count: number, loading?: Ref) { + return new Promise((resolve, reject) => { + userApi + .getValid(valid_type, valid_count, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + } + } +}) + +export default useCommonStore diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/dataset.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/dataset.ts new file mode 100644 index 0000000..b185d79 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/dataset.ts @@ -0,0 +1,74 @@ +import { defineStore } from 'pinia' +import type { datasetData } from '@/api/type/dataset' +import type { UploadUserFile } from 'element-plus' +import datasetApi from '@/api/dataset' +import { type Ref } from 'vue' + +export interface datasetStateTypes { + baseInfo: datasetData | null + webInfo: any + documentsType: string + documentsFiles: UploadUserFile[] +} + +const useDatasetStore = defineStore({ + id: 'dataset', + state: (): datasetStateTypes => ({ + baseInfo: null, + webInfo: null, + documentsType: '', + documentsFiles: [] + }), + actions: { + saveBaseInfo(info: datasetData | null) { + this.baseInfo = info + }, + saveWebInfo(info: any) { + this.webInfo = info + }, + saveDocumentsType(val: string) { + this.documentsType = val + }, + saveDocumentsFile(file: UploadUserFile[]) { + this.documentsFiles = file + }, + async asyncGetAllDataset(loading?: Ref) { + return new Promise((resolve, reject) => { + datasetApi + .getAllDataset(loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncGetDatasetDetail(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + datasetApi + .getDatasetDetail(id, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncSyncDataset(id: string, sync_type: string, loading?: Ref) { + return new Promise((resolve, reject) => { + datasetApi + .putSyncWebDataset(id, sync_type, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + } + } +}) + +export default useDatasetStore diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/document.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/document.ts new file mode 100644 index 0000000..0037f03 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/document.ts @@ -0,0 +1,36 @@ +import { defineStore } from 'pinia' +import documentApi from '@/api/document' +import { type Ref } from 'vue' + +const useDocumentStore = defineStore({ + id: 'document', + state: () => ({}), + actions: { + async asyncGetAllDocument(id: string, loading?: Ref) { + return new Promise((resolve, reject) => { + documentApi + .getAllDocument(id, loading) + .then((res) => { + resolve(res) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncPostDocument(datasetId: string, data: any, loading?: Ref) { + return new Promise((resolve, reject) => { + documentApi + .postDocument(datasetId, data, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + } + } +}) + +export default useDocumentStore diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/log.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/log.ts new file mode 100644 index 0000000..d7a8d64 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/log.ts @@ -0,0 +1,67 @@ +import { defineStore } from 'pinia' +import logApi from '@/api/log' +import { type Ref } from 'vue' +import type { pageRequest } from '@/api/type/common' + +const useLogStore = defineStore({ + id: 'log', + state: () => ({}), + actions: { + async asyncGetChatLog(id: string, page: pageRequest, param: any, loading?: Ref) { + return new Promise((resolve, reject) => { + logApi + .getChatLog(id, page, param, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncChatRecordLog( + id: string, + chatId: string, + page: pageRequest, + loading?: Ref, + order_asc?: boolean + ) { + return new Promise((resolve, reject) => { + logApi + .getChatRecordLog(id, chatId, page, loading, order_asc) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncGetChatLogClient(id: string, page: pageRequest, loading?: Ref) { + return new Promise((resolve, reject) => { + logApi + .getChatLogClient(id, page, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncDelChatClientLog(id: string, chatId: string, loading?: Ref) { + return new Promise((resolve, reject) => { + logApi + .delChatClientLog(id, chatId, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + } + } +}) + +export default useLogStore diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/model.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/model.ts new file mode 100644 index 0000000..0875e20 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/model.ts @@ -0,0 +1,35 @@ +import { defineStore } from 'pinia' +import modelApi from '@/api/model' +import type { ListModelRequest, Provider } from '@/api/type/model' +const useModelStore = defineStore({ + id: 'model', + state: () => ({}), + actions: { + async asyncGetModel(data?: ListModelRequest) { + return new Promise((resolve, reject) => { + modelApi + .getModel(data) + .then((res) => { + resolve(res) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncGetProvider() { + return new Promise((resolve, reject) => { + modelApi + .getProvider() + .then((res) => { + resolve(res) + }) + .catch((error) => { + reject(error) + }) + }) + } + } +}) + +export default useModelStore diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/paragraph.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/paragraph.ts new file mode 100644 index 0000000..6c7c519 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/paragraph.ts @@ -0,0 +1,48 @@ +import { defineStore } from 'pinia' +import paragraphApi from '@/api/paragraph' +import type { Ref } from 'vue' + +const useParagraphStore = defineStore({ + id: 'paragraph', + state: () => ({}), + actions: { + async asyncPutParagraph( + datasetId: string, + documentId: string, + paragraphId: string, + data: any, + loading?: Ref + ) { + return new Promise((resolve, reject) => { + paragraphApi + .putParagraph(datasetId, documentId, paragraphId, data, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + + async asyncDelParagraph( + datasetId: string, + documentId: string, + paragraphId: string, + loading?: Ref + ) { + return new Promise((resolve, reject) => { + paragraphApi + .delParagraph(datasetId, documentId, paragraphId, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + } + } +}) + +export default useParagraphStore diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/problem.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/problem.ts new file mode 100644 index 0000000..f8dd7b1 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/problem.ts @@ -0,0 +1,79 @@ +import { defineStore } from 'pinia' +import { type Ref } from 'vue' +import problemApi from '@/api/problem' +import paragraphApi from '@/api/paragraph' +import type { pageRequest } from '@/api/type/common' + +const useProblemStore = defineStore({ + id: 'problem', + state: () => ({}), + actions: { + async asyncPostProblem(datasetId: string, data: any, loading?: Ref) { + return new Promise((resolve, reject) => { + problemApi + .postProblems(datasetId, data, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncGetProblem( + datasetId: string, + page: pageRequest, + param: any, + loading?: Ref + ) { + return new Promise((resolve, reject) => { + problemApi + .getProblems(datasetId, page, param, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncDisassociationProblem( + datasetId: string, + documentId: string, + paragraphId: string, + problemId: string, + loading?: Ref + ) { + return new Promise((resolve, reject) => { + paragraphApi + .disassociationProblem(datasetId, documentId, paragraphId, problemId, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + }, + async asyncAssociationProblem( + datasetId: string, + documentId: string, + paragraphId: string, + problemId: string, + loading?: Ref + ) { + return new Promise((resolve, reject) => { + paragraphApi + .associationProblem(datasetId, documentId, paragraphId, problemId, loading) + .then((data) => { + resolve(data) + }) + .catch((error) => { + reject(error) + }) + }) + } + } +}) + +export default useProblemStore diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/prompt.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/prompt.ts new file mode 100644 index 0000000..ec4be97 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/prompt.ts @@ -0,0 +1,40 @@ +import { defineStore } from 'pinia' + +export interface promptTypes { + user: string + formValue: { model_id: string, prompt: string } +} + +const usePromptStore = defineStore({ + id: 'prompt', + state: (): promptTypes[] => (JSON.parse(localStorage.getItem('PROMPT_CACHE') || '[]')), + actions: { + save(user: string, formValue: any) { + this.$state.forEach((item: any, index: number) => { + if (item.user === user) { + this.$state.splice(index, 1) + } + }) + this.$state.push({ user, formValue }) + localStorage.setItem('PROMPT_CACHE', JSON.stringify(this.$state)) + }, + get(user: string) { + for (let i = 0; i < this.$state.length; i++) { + if (this.$state[i].user === user) { + return this.$state[i].formValue + } + } + return { + model_id: '', + prompt: '内容:{data}\n' + + '\n' + + '请总结上面的内容,并根据内容总结生成 5 个问题。\n' + + '回答要求:\n' + + '- 请只输出问题;\n' + + '- 请将每个问题放置标签中。' + } + } + } +}) + +export default usePromptStore \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/stores/modules/user.ts b/src/MaxKB-1.7.2/ui/src/stores/modules/user.ts new file mode 100644 index 0000000..af4f558 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/stores/modules/user.ts @@ -0,0 +1,166 @@ +import { defineStore } from 'pinia' +import { type Ref } from 'vue' +import type { User } from '@/api/type/user' +import { cloneDeep } from 'lodash' +import UserApi from '@/api/user' +import ThemeApi from '@/api/theme' +import { useElementPlusTheme } from 'use-element-plus-theme' +import { defaultPlatformSetting } from '@/utils/theme' + +export interface userStateTypes { + userType: number // 1 系统操作者 2 对话用户 + userInfo: User | null + token: any + version?: string + accessToken?: string + XPACK_LICENSE_IS_VALID: false + isXPack: false + themeInfo: any +} + +const useUserStore = defineStore({ + id: 'user', + state: (): userStateTypes => ({ + userType: 1, + userInfo: null, + token: '', + version: '', + XPACK_LICENSE_IS_VALID: false, + isXPack: false, + themeInfo: null + }), + actions: { + showXpack() { + return this.isXPack + }, + isDefaultTheme() { + return !this.themeInfo?.theme || this.themeInfo?.theme === '#3370FF' + }, + setTheme(data: any) { + const { changeTheme } = useElementPlusTheme(this.themeInfo?.theme) + changeTheme(data?.['theme']) + this.themeInfo = cloneDeep(data) + }, + isExpire() { + return this.isXPack && !this.XPACK_LICENSE_IS_VALID + }, + isEnterprise() { + return this.isXPack && this.XPACK_LICENSE_IS_VALID + }, + getToken(): String | null { + if (this.token) { + return this.token + } + return this.userType === 1 ? localStorage.getItem('token') : this.getAccessToken() + }, + getAccessToken() { + const accessToken = sessionStorage.getItem('accessToken') + if (accessToken) { + return accessToken + } + return localStorage.getItem('accessToken') + }, + + getPermissions() { + if (this.userInfo) { + return this.isXPack && this.XPACK_LICENSE_IS_VALID + ? [...this.userInfo?.permissions, 'x-pack'] + : this.userInfo?.permissions + } else { + return [] + } + }, + getRole() { + if (this.userInfo) { + return this.userInfo?.role + } else { + return '' + } + }, + changeUserType(num: number) { + this.userType = num + }, + + async asyncGetProfile() { + return new Promise((resolve, reject) => { + UserApi.getProfile() + .then(async (ok) => { + this.version = ok.data?.version || '-' + this.isXPack = ok.data?.IS_XPACK + this.XPACK_LICENSE_IS_VALID = ok.data?.XPACK_LICENSE_IS_VALID + + if (this.isEnterprise()) { + await this.theme() + } else { + this.themeInfo = { + ...defaultPlatformSetting + } + } + resolve(ok) + }) + .catch((error) => { + reject(error) + }) + }) + }, + + async theme(loading?: Ref) { + return await ThemeApi.getThemeInfo(loading).then((ok) => { + this.setTheme(ok.data) + // window.document.title = this.themeInfo['title'] || 'MaxKB' + // const link = document.querySelector('link[rel="icon"]') as any + // if (link) { + // link['href'] = this.themeInfo['icon'] || '/favicon.ico' + // } + }) + }, + + async profile() { + return UserApi.profile().then(async (ok) => { + this.userInfo = ok.data + return this.asyncGetProfile() + }) + }, + + async login(auth_type: string, username: string, password: string) { + return UserApi.login(auth_type, { username, password }).then((ok) => { + this.token = ok.data + localStorage.setItem('token', ok.data) + return this.profile() + }) + }, + async dingCallback(code: string) { + return UserApi.getDingCallback(code).then((ok) => { + this.token = ok.data + localStorage.setItem('token', ok.data) + return this.profile() + }) + }, + async wecomCallback(code: string) { + return UserApi.getWecomCallback(code).then((ok) => { + this.token = ok.data + localStorage.setItem('token', ok.data) + return this.profile() + }) + }, + + async logout() { + return UserApi.logout().then(() => { + localStorage.removeItem('token') + return true + }) + }, + async getAuthType() { + return UserApi.getAuthType().then((ok) => { + return ok.data + }) + }, + async getQrType() { + return UserApi.getQrType().then((ok) => { + return ok.data + }) + } + } +}) + +export default useUserStore diff --git a/src/MaxKB-1.7.2/ui/src/styles/app.scss b/src/MaxKB-1.7.2/ui/src/styles/app.scss new file mode 100644 index 0000000..f868dec --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/styles/app.scss @@ -0,0 +1,749 @@ +@font-face { + font-family: AlibabaPuHuiTi; + src: + url('./font/AlibabaPuHuiTi-3-55-Regular.woff') format('woff'), + url('./font/AlibabaPuHuiTi-3-55-Regular.ttf') format('truetype'), + url('./font/AlibabaPuHuiTi-3-55-Regular.eot') format('eot'), + url('./font/AlibabaPuHuiTi-3-55-Regular.otf') format('opentype'), + url('./font/AlibabaPuHuiTi-3-55-Regular.woff2') format('woff2'); +} +* { + margin: 0; + padding: 0; +} + +html { + height: 100%; + box-sizing: border-box; +} + +body { + -moz-osx-font-smoothing: grayscale; + -webkit-font-smoothing: antialiased; + font-family: 'PingFang SC', AlibabaPuHuiTi !important; + font-size: 14px; + font-style: normal; + font-weight: 500; + height: 100%; + margin: 0; + padding: 0; + color: var(--app-text-color); +} + +#app { + height: 100%; +} + +:focus { + outline: none; +} + +a:active { + outline: none; +} + +a, +a:focus, +a:hover { + cursor: pointer; + color: inherit; + text-decoration: none; +} + +div:focus { + outline: none; +} + +ul { + list-style: none; + margin: 0; + padding: 0; +} + +/* 滚动条整体部分 */ +::-webkit-scrollbar { + width: 6px; // 纵向滚动条宽度 + height: 6px; // 横向滚动条高度 +} + +/* 滑块 */ +::-webkit-scrollbar-thumb { + border-radius: 5px; +} + +/* 轨道 */ +::-webkit-scrollbar-track { + border-radius: 5px; + background-color: transparent; +} + +.clearfix:after { + content: ''; + display: block; + clear: both; +} + +h1 { + font-size: 24px; +} + +h2 { + font-size: 20px; + font-weight: 500; +} + +h3 { + font-size: 18px; +} + +h4 { + font-size: 16px; +} + +h5 { + font-size: 14px; + font-weight: 500; +} + +.bold { + font-weight: 600; +} +.lighter { + font-weight: 400; +} + +.w-full { + width: 100%; +} +.h-full { + height: 100%; +} +.w-120 { + width: 120px; +} +.w-240 { + width: 240px; +} +.w-280 { + width: 280px; +} +.w-500 { + width: 500px; +} +.max-w-200 { + max-width: 200px; +} + +.mt-4 { + margin-top: calc(var(--app-base-px) - 4px); +} + +.mt-8 { + margin-top: var(--app-base-px); +} +.mt-12 { + margin-top: calc(var(--app-base-px) + 4px); +} +.mt-16 { + margin-top: calc(var(--app-base-px) * 2); +} +.mt-20 { + margin-top: calc(var(--app-base-px) * 2 + 4px); +} +.mt-24 { + margin-top: calc(var(--app-base-px) * 3); +} + +.mb-4 { + margin-bottom: calc(var(--app-base-px) - 4px); +} +.mb-8 { + margin-bottom: var(--app-base-px); +} +.mb-12 { + margin-bottom: calc(var(--app-base-px) + 4px); +} +.mb-16 { + margin-bottom: calc(var(--app-base-px) * 2); +} +.mb-24 { + margin-bottom: calc(var(--app-base-px) * 3); +} +.ml-4 { + margin-left: calc(var(--app-base-px) - 4px); +} +.ml-8 { + margin-left: var(--app-base-px); +} +.ml-12 { + margin-left: calc(var(--app-base-px) + 4px); +} +.ml-16 { + margin-left: calc(var(--app-base-px) * 2); +} +.ml-24 { + margin-left: calc(var(--app-base-px) * 3); +} +.mr-4 { + margin-right: calc(var(--app-base-px) - 4px); +} +.mr-8 { + margin-right: var(--app-base-px); +} +.mr-12 { + margin-right: calc(var(--app-base-px) + 4px); +} +.mr-16 { + margin-right: calc(var(--app-base-px) * 2); +} +.mr-24 { + margin-right: calc(var(--app-base-px) * 3); +} +.p-8 { + padding: var(--app-base-px); +} +.p-16 { + padding: calc(var(--app-base-px) * 2); +} +.p-24 { + padding: calc(var(--app-base-px) * 3); +} +.p-8-12 { + padding: calc(var(--app-base-px)) calc(var(--app-base-px) + 4px); +} +.p-12-24 { + padding: calc(var(--app-base-px) + 4px) calc(var(--app-base-px) * 3); +} +.p-16-24 { + padding: calc(var(--app-base-px) * 2) calc(var(--app-base-px) * 3); +} + +.pt-0 { + padding-top: 0; +} +.pb-0 { + padding-bottom: 0; +} + +.float-right { + float: right; +} + +.flex { + display: flex; +} + +.flex-center { + display: flex; + align-items: center; + justify-content: center; +} + +.flex-between { + display: flex; + justify-content: space-between; + align-items: center; +} + +.flex-wrap { + display: flex; + flex-wrap: wrap; + align-content: space-between; +} + +.align-center { + align-items: center; +} + +.align-baseline { + align-items: baseline; +} + +.justify-center { + justify-content: center; +} + +.text-left { + text-align: left; +} +.text-center { + text-align: center; +} +.text-right { + text-align: right; +} + +.vertical-middle { + vertical-align: middle; +} + +.border { + border: 1px solid var(--el-border-color); +} + +.border-l { + border-left: 1px solid var(--el-border-color); +} + +.border-b { + border-bottom: 1px solid var(--el-border-color); +} +.border-r { + border-right: 1px solid var(--el-border-color); +} +.border-t { + border-top: 1px solid var(--el-border-color); +} + +.border-b-light { + border-bottom: 1px solid var(--el-border-color-lighter); +} +.border-r-4 { + border-radius: 4px; +} + +.border-t-dashed { + border-top: 1px dashed var(--el-border-color); +} +.border-primary { + border: 1px solid var(--el-color-primary); + color: var(--el-color-primary); +} + +.border-none { + border: none; +} + +.cursor { + cursor: pointer; +} +.notAllowed { + cursor: not-allowed; +} + +/* + 超出省略号 +*/ + +.ellipsis { + display: inline-block; + max-width: 130px; + text-overflow: ellipsis; + white-space: nowrap; + overflow: hidden; +} + +/* + 双行超出省略号,其他行数自定义 -webkit-line-clamp +*/ +.ellipsis-2 { + display: -webkit-box; + -webkit-box-orient: vertical; + -webkit-line-clamp: 2; + overflow: hidden; +} + +.ellipsis-1 { + display: -webkit-box; + -webkit-box-orient: vertical; + -webkit-line-clamp: 1; + overflow: hidden; +} + +.break-all { + word-break: break-all; +} + +.pre-wrap { + white-space: pre-wrap; +} + +/* + 内容部分 自适应高度 +*/ +.main-calc-height { + height: var(--app-main-height); + box-sizing: border-box; +} + +/* + 标题前带竖线样式 +*/ +.title-decoration-1 { + position: relative; + padding-left: 12px; + &:before { + position: absolute; + left: 2px; + top: 50%; + transform: translate(-50%, -50%); + width: 2px; + height: 80%; + content: ''; + background: var(--el-color-primary); + } +} + +/* tag */ +.default-tag { + background: var(--el-color-primary-light-8); + color: var(--el-color-primary); + border: none; +} +.danger-tag { + background: var(--tag-danger-bg); + color: #d03f3b; + border: none; +} +.success-tag { + background: var(--tag-success-bg); + color: var(--el-color-success); + border: none; +} +.warning-tag { + background: var(--tag-warning-bg); + color: var(--el-color-warning); + border: none; +} + +.info-tag { + background: var(--app-text-color-light-1); + color: var(--app-text-color-secondary); + border: none; +} + +.purple-tag { + background: #f2ebfe; + color: #7f3bf5; + border-color: #e0d7f0; +} + +.blue-tag { + background: #ebf1ff; + color: #3370ff; + border-color: #d6e2ff; +} + +/* + card 无边框无阴影 灰色背景 +*/ +.card-never { + background: var(--app-layout-bg-color); + border: none; +} + +/* + 图标旋转90度 +*/ +.rotate-90 { + transform: rotateZ(90deg); +} +.rotate-180 { + transform: rotateZ(180deg); +} +/* + 表格第一行插入自定义行 +*/ +.table-quick-append { + background: #ffffff; + .el-table__append-wrapper { + position: absolute; + top: 0; + border-bottom: var(--el-table-border); + width: 100%; + height: 49px; + box-sizing: border-box; + align-items: center; + display: flex; + padding: 0 12px; + background: #ffffff; + cursor: pointer; + z-index: 2; + &:hover { + background: var(--el-color-primary-light-9); + z-index: 1; + } + } + .el-table__body { + margin-top: 49px; + } +} + +// checkbox-group 文字在左 input在右 +.app-custom-checkbox-group { + line-height: normal; + .el-checkbox__label { + display: none; + } +} + +/* + 头像渐变背景 +*/ +.avatar-gradient { + background: var(--app-avatar-gradient-color); +} + +.avatar-light { + background: var(--el-color-primary-light-4); +} + +.avatar-purple { + background: #7f3bf5; +} +.avatar-blue { + background: #3370ff; +} + +.avatar-green { + background: #34c724; +} +.avatar-grey { + background: #bbbfc4; +} + +.success { + color: var(--el-color-success); +} +.danger { + color: var(--el-color-danger); +} +.warning { + color: var(--el-color-warning); +} +.primary { + color: var(--el-color-primary); +} +.info { + color: var(--el-color-info); +} + +.color-secondary { + color: var(--app-text-color-secondary); +} + +.layout-bg { + background: var(--app-layout-bg-color); +} + +.white-bg { + background: #ffffff; +} + +.app-warning-icon { + font-size: 16px; + color: var(--app-text-color-secondary); +} + +.dotting { + display: inline-block; + width: 10px; + min-height: 2px; + padding-right: 2px; + margin-left: 2px; + padding-left: 2px; + border-left: 2px solid currentColor; + border-right: 2px solid currentColor; + background-color: currentColor; + background-clip: content-box; + box-sizing: border-box; + -webkit-animation: dot 0.8s infinite step-start both; + animation: dot 0.8s infinite step-start both; + &:before { + content: '...'; + } + &::before { + content: ''; + } +} + +@-webkit-keyframes dot { + 25% { + border-color: transparent; + background-color: transparent; + } + 50% { + border-right-color: transparent; + background-color: transparent; + } + 75% { + border-right-color: transparent; + } +} +@keyframes dot { + 25% { + border-color: transparent; + background-color: transparent; + } + 50% { + border-right-color: transparent; + background-color: transparent; + } + 75% { + border-right-color: transparent; + } +} + +.file-List-card { + border-radius: 4px; + .el-card__body { + padding: 8px 16px 8px 12px; + } +} + +.card__radio { + width: 100%; + display: block; + + .el-radio { + white-space: break-spaces; + width: 100%; + height: 100%; + line-height: 22px; + color: var(--app-text-color); + } + + :deep(.el-radio__label) { + padding-left: 30px; + width: 100%; + } + :deep(.el-radio__input) { + position: absolute; + top: 16px; + } + .active { + border: 1px solid var(--el-color-primary); + } + .el-card__body { + padding: calc(var(--app-base-px) + 4px) calc(var(--app-base-px) * 2); + } +} + +// AI模型选择:添加模型hover样式 +.select-model { + .el-select-dropdown__footer { + &:hover { + background-color: var(--el-fill-color-light); + } + } + .model-icon { + width: 20px; + } + .check-icon { + position: absolute; + right: 10px; + } +} + +// 段落card +.paragraph-source-card { + height: 210px; + width: 100%; + .active-button { + position: absolute; + right: 16px; + top: 16px; + } +} + +// 分段 dialog +.paragraph-dialog { + padding: 0 !important; + .el-scrollbar { + height: auto !important; + } + .el-dialog__header { + padding: 16px 24px; + } + .el-dialog__body { + border-top: 1px solid var(--el-border-color); + } + .el-dialog__footer { + padding: 16px 24px; + border-top: 1px solid var(--el-border-color); + } + + .title { + color: var(--app-text-color); + } +} + +// card 选中样式 +.selected { + border: 1px solid var(--el-color-primary) !important; + &:before { + content: ''; + position: absolute; + right: 0; + top: 0; + border: 14px solid var(--el-color-primary); + border-bottom-color: transparent; + border-left-color: transparent; + } + + &:after { + content: ''; + width: 3px; + height: 6px; + position: absolute; + right: 5px; + top: 2px; + border: 2px solid #fff; + border-top-color: transparent; + border-left-color: transparent; + transform: rotate(35deg); + } + &:hover { + border: 1px solid var(--el-color-primary); + } +} + +.app-card { + background: #fff; + border-radius: 8px; + box-shadow: 0px 2px 4px 0px rgba(31, 35, 41, 0.12); +} + +.app-radio-button-group { + border: 1px solid var(--app-border-color-dark); + border-radius: var(--el-border-radius-base); + .el-radio-button { + padding: 3px; + } + .el-radio-button__inner { + border: none !important; + border-radius: var(--el-border-radius-base) !important; + padding: 5px 8px; + font-weight: 400; + } + .el-radio-button__original-radio:checked + .el-radio-button__inner { + color: var(--el-color-primary) !important; + background: var(--el-color-primary-light-9) !important; + border: none !important; + box-shadow: none !important; + font-weight: 500; + } +} + +// 自定义主题 +.custom-header { + background: var(--el-color-primary-light-9) !important; +} + +.edit-avatar { + position: relative; + .edit-mask { + position: absolute; + left: 0; + background: rgba(0, 0, 0, 0.4); + } +} + +.record-tip-confirm { + max-width: 800px !important; +} + +//企业微信 +.wwLogin_qrcode_head { + padding:20px 0 !important; +} \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/styles/element-plus.scss b/src/MaxKB-1.7.2/ui/src/styles/element-plus.scss new file mode 100644 index 0000000..afd607a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/styles/element-plus.scss @@ -0,0 +1,397 @@ +:root { + --el-color-primary: #3370ff; + --el-menu-item-height: 45px; + --el-box-shadow-light: 0px 2px 4px 0px rgba(31, 35, 41, 0.12); + --el-border-color: #dee0e3; + --el-text-color-regular: #1f2329; + --el-color-info: #8f959e !important; +} + +.el-button { + --el-button-font-weight: 400; + padding: 5px 12px; + &.is-text { + padding: 4px !important; + font-size: 16px; + max-height: 24px; + &:not(.is-disabled):hover { + background: var(--app-text-color-light-1); + } + } + &:focus { + color: var(--el-button-text-color); + background-color: var(--el-button-bg-color); + border-color: var(--el-button-border-color); + } + &.is-link:focus { + background: none; + border: none; + } +} +.el-button--large { + font-size: 16px; +} +.el-avatar { + --el-avatar-bg-color: var(--el-color-primary); + --el-avatar-size-small: 33px; + --el-avatar-border-radius: 6px; + cursor: pointer; +} + +.el-form { + --el-form-inline-content-width: 100%; +} + +.el-form-item { + margin-bottom: 16px; + .el-form-item { + margin-bottom: 16px; + &:last-child { + margin-bottom: 0px; + } + } +} +.el-form-item__label { + font-weight: 400; + width: 100%; +} + +.el-form-item__error { + position: unset; + font-size: 14px; +} + +.el-form--label-top .el-form-item .el-form-item__label { + padding-right: 0; +} + +.el-dialog { + --el-dialog-title-font-size: 16px; + --el-dialog-padding-primary: 24px; + --el-dialog-content-font-size: 14px; + .dialog-sub-title { + color: var(--app-text-color-secondary); + margin: 5px 0; + font-weight: 400; + } + .el-dialog__body { + color: var(--app-text-color); + } +} +.el-dialog__headerbtn { + top: 6px; +} +.el-dialog__header { + padding-bottom: 24px; + font-weight: 500; +} +.el-dialog__footer { + padding-top: 0; +} + +.el-message { + --el-message-close-icon-color: var(--app-text-color-secondary); +} +.el-message-box { + --el-messagebox-font-size: 16px; + padding: 24px; + .el-message-box__header { + padding: 0; + } + .el-message-box__title { + word-break: break-all; + width: 95%; + } +} + +.el-message-box__content { + padding: 24px 0; + color: var(--app-text-color); + font-weight: 400; +} +.el-message-box__btns { + padding: 0; + button { + min-width: 80px; + &:nth-child(2) { + margin-left: 12px; + } + } + button.danger { + background: var(--el-color-danger); + border: var(--el-color-danger); + color: #ffffff; + } +} +.el-message-box__headerbtn { + right: 10px; + top: 15px; + .el-message-box__close { + font-size: 20px; + } +} + +.el-card { + --el-card-padding: calc(var(--app-base-px) * 2); +} +.el-dropdown { + color: var(--app-text-color); +} +.el-dropdown-menu__item { + color: var(--app-text-color); + font-weight: 400; + padding: 5px 11px; + &:not(.is-disabled):focus, + &:not(.is-active):focus { + background-color: var(--app-text-color-light-1); + color: var(--app-text-color); + } + &.is-active, + &.is-active:hover { + color: var(--el-menu-active-color); + background: var(--el-color-primary-light-9); + } +} + +.el-tag { + --el-tag-border-radius: 2px; + height: 24px; + padding: 0 6px; +} +.el-table { + --el-table-header-bg-color: var(--app-layout-bg-color); + --el-table-text-color: var(--app-text-color); + font-weight: 400; + thead { + color: var(--app-text-color-secondary); + th { + font-weight: 500; + } + } + + th.el-table__cell { + border-top: var(--el-table-border); + } + .el-table__cell { + padding: 12px 0; + } + .el-checkbox { + height: 23px; + } + tr.highlight { + background: var(--el-table-current-row-bg-color); + } +} + +.el-pagination .el-select .el-input { + width: 100px; +} + +/* el-steps */ +.el-step__icon { + background: none; +} +.el-step__head.is-process { + .el-step__icon { + &.is-text { + color: #ffffff; + border-color: var(--el-color-primary) !important; + background: var(--el-color-primary) !important; + } + } +} + +.el-text { + font-weight: 400; +} + +.el-switch { + height: auto; +} + +.el-slider { + --el-slider-button-size: 14px; + --el-slider-height: 4px; +} +.el-slider__button { + border: solid 1px var(--app-border-color-dark); + &.hover { + border: solid 2px var(--el-slider-main-bg-color); + } +} +.el-slider__runway.show-input { + margin-right: calc(var(--app-base-px) + 4px); +} +.el-slider__input { + width: 72px; +} + +.input-with-select { + .el-input-group__prepend { + background-color: var(--el-fill-color-blank); + } +} +.el-textarea { + --el-input-placeholder-color: var(--app-input-color-placeholder); +} +.el-textarea__inner { + font-size: 14px; +} +.el-input { + --el-input-icon-color: var(--app-text-color-secondary); + --el-input-placeholder-color: var(--app-input-color-placeholder); +} + +.el-input__inner { + font-size: 14px; +} + +.el-input__wrapper { + padding: 1px 12px !important; +} + +.el-input--large { + .el-input__inner { + font-size: 16px; + } +} + +.el-select__placeholder.is-transparent { + color: var(--app-input-color-placeholder); + font-weight: 400; +} + +.el-select-group .el-select-dropdown__item { + padding-left: 11px; +} + +.el-select__caret { + color: var(--app-text-color-secondary); +} +.el-tabs__header { + margin: 0 0 12px; +} +.el-tabs__item { + padding: 0 14px; +} + +.el-tabs__nav-wrap:after { + height: 1px; +} +.el-tabs__active-bar { + height: 3px; +} +.el-drawer { + .el-drawer__header { + padding: 16px 24px; + margin: 0; + border-bottom: 1px solid var(--el-border-color); + color: var(--app-text-color); + } + .el-drawer__footer { + border-top: 1px solid var(--el-border-color); + padding: 16px 24px; + } +} + +.el-cascader-node { + padding-left: 2px; +} +.el-cascader-node__prefix { + right: 10px; + left: auto; +} + +// 自动补全增加暂无数据 +.platform-auto-complete { + .el-autocomplete-suggestion__wrap { + padding: 5px 0; + ul li { + pointer-events: none; // 阻止可点击事件 + .default { + text-align: center; + color: #999; + } + &:hover { + background-color: #fff; + } + } + } +} + +.el-popover { + --el-popover-padding: 16px; +} + +.el-radio { + --el-radio-font-weight: 400; +} +.el-radio__input.is-checked + .el-radio__label { + color: var(--app-text-color); +} + +.el-input-number.is-controls-right .el-input__wrapper { + padding-left: 15px !important; + padding-right: 42px !important; +} + +.el-input-number.is-without-controls .el-input__wrapper { + padding-left: 12px !important; + padding-right: 12px !important; +} + +// select下拉框 +.select-popper { + max-width: 300px; + .el-select-dropdown__wrap { + max-width: 300px; + } +} + +.auto-tooltip-popper { + max-width: 500px; +} + +// radio 一行一个样式 +.radio-block { + width: 100%; + display: block; + .el-radio { + align-items: flex-start; + height: 100%; + width: 100%; + } + .el-radio__label { + width: 100%; + margin-top: -8px; + line-height: 30px; + } +} + +// 提示横幅 +.el-alert__title { + color: var(--el-text-color-regular) !important; + font-weight: 400; +} +.el-alert--warning.is-light { + background-color: #ffe7cc; + .el-alert__icon { + color: #ff8800; + } +} +.el-alert--success.is-light { + background-color: #d6f4d3; + .el-alert__icon { + color: #34c724; + } +} +.el-alert--danger.is-light { + background-color: #fddbda; + .el-alert__icon { + color: #f54a45; + } +} + +.el-checkbox__input.is-checked + .el-checkbox__label { + color: var(--el-checkbox-text-color); +} diff --git a/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.eot b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.eot new file mode 100644 index 0000000..82f27fd Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.eot differ diff --git a/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.otf b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.otf new file mode 100644 index 0000000..541e3c1 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.otf differ diff --git a/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.ttf b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.ttf new file mode 100644 index 0000000..a6eaf36 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.ttf differ diff --git a/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.woff b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.woff new file mode 100644 index 0000000..f576376 Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.woff differ diff --git a/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.woff2 b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.woff2 new file mode 100644 index 0000000..4b574fb Binary files /dev/null and b/src/MaxKB-1.7.2/ui/src/styles/font/AlibabaPuHuiTi-3-55-Regular.woff2 differ diff --git a/src/MaxKB-1.7.2/ui/src/styles/index.scss b/src/MaxKB-1.7.2/ui/src/styles/index.scss new file mode 100644 index 0000000..36251d3 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/styles/index.scss @@ -0,0 +1,7 @@ +@import 'element-plus/dist/index.css'; +@import './variables.scss'; +@import './app.scss'; +@import './element-plus.scss'; +@import 'nprogress/nprogress.css'; +@import 'md-editor-v3/lib/style.css'; +@import './md-editor.scss'; diff --git a/src/MaxKB-1.7.2/ui/src/styles/md-editor.scss b/src/MaxKB-1.7.2/ui/src/styles/md-editor.scss new file mode 100644 index 0000000..ce87022 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/styles/md-editor.scss @@ -0,0 +1,29 @@ +.md-editor-preview { + padding: 0; + margin: 0; + font-size: inherit; + p { + padding: 0 !important; + } + .md-editor-admonition { + margin: 0; + padding: 0; + } + img { + border: 0 !important; + } +} + +.md-editor-preview-wrapper { + padding: 0; +} + +.md-editor-footer { + height: auto !important; +} + +.ͼ1 .cm-placeholder { + color: var(--app-input-color-placeholder); + font-size: 14px; + font-weight: 400; +} diff --git a/src/MaxKB-1.7.2/ui/src/styles/variables.scss b/src/MaxKB-1.7.2/ui/src/styles/variables.scss new file mode 100644 index 0000000..4f7a918 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/styles/variables.scss @@ -0,0 +1,53 @@ +:root { + --app-base-px: 8px; + --app-layout-bg-color: #f5f6f7; + --app-text-color: #1f2329; + --app-text-color-light-1: rgba(31, 35, 41, 0.1); + --app-text-color-secondary: #646a73; + --app-text-color-disable: #bbbfc4; + --app-input-color-placeholder: #8f959e; + --app-view-padding: 24px; + --app-view-bg-color: #ffffff; + --app-border-color-dark: #bbbfc4; + --md-bk-hover-color:var(--el-border-color-hover); + /** header 组件 */ + --app-header-height: 56px; + --app-header-padding: 0 20px; + --app-header-bg-color: linear-gradient(90deg, #ebf1ff 24.34%, #e5fbf8 56.18%, #f2ebfe 90.18%); + --app-logo-color: linear-gradient(180deg, #3370FF 0%, #7f3bf5 100%); + --app-avatar-gradient-color: linear-gradient(270deg, #9258f7 0%, #3370FF 100%); + + /* 计算高度 */ + --app-main-height: calc(100vh - var(--app-header-height) - var(--app-view-padding) * 2 - 40px); + + /** sidebar 组件 */ + --sidebar-bg-color: #ffffff; + --sidebar-width: 240px; + /** tag */ + --tag-default-bg: rgba(51, 112, 255, 0.2); + --tag-default-color: #2b5fd9; + --tag-success-bg: rgba(52, 199, 36, 0.2); + --tag-success-color: #2ca91f; + --tag-warning-bg: rgba(255, 136, 0, 0.2); + --tag-warning-color: #d97400; + --tag-danger-bg: rgba(245, 74, 69, 0.2); + + /** card */ + --card-width: 330px; + --card-min-height: 160px; + --card-min-width: 220px; + + /** setting */ + --setting-left-width: 280px; + + /** dataset */ + --create-dataset-height: calc(var(--app-main-height) - 70px); + + /** ai-chat */ + --dialog-bg-gradient-color: linear-gradient( + 188deg, + rgba(235, 241, 255, 0.2) 39.6%, + rgba(231, 249, 255, 0.2) 94.3% + ), + #eff0f1; +} diff --git a/src/MaxKB-1.7.2/ui/src/utils/application.ts b/src/MaxKB-1.7.2/ui/src/utils/application.ts new file mode 100644 index 0000000..4666da7 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/application.ts @@ -0,0 +1,10 @@ +export const defaultIcon = '/ui/favicon.ico' + +// 是否显示字母 / icon +export function isAppIcon(url: string | undefined) { + return url === defaultIcon ? '' : url +} + +export function isWorkFlow(type: string | undefined) { + return type === 'WORK_FLOW' +} diff --git a/src/MaxKB-1.7.2/ui/src/utils/clipboard.ts b/src/MaxKB-1.7.2/ui/src/utils/clipboard.ts new file mode 100644 index 0000000..bde5e36 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/clipboard.ts @@ -0,0 +1,15 @@ +import Clipboard from 'vue-clipboard3' +import { MsgSuccess, MsgError } from '@/utils/message' +/* + 复制粘贴 +*/ +export async function copyClick(info: string) { + const { toClipboard } = Clipboard() + try { + await toClipboard(info) + MsgSuccess('复制成功') + } catch (e) { + console.error(e) + MsgError('复制失败') + } +} diff --git a/src/MaxKB-1.7.2/ui/src/utils/common.ts b/src/MaxKB-1.7.2/ui/src/utils/common.ts new file mode 100644 index 0000000..770adde --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/common.ts @@ -0,0 +1,24 @@ +/** + * 拆分数组 每n个拆分为一个数组 + * @param sourceDataList 资源数据 + * @param splitNum 每多少个拆分为一个数组 + * @returns 拆分后数组 + */ +export function splitArray(sourceDataList: Array, splitNum: number) { + const count = + sourceDataList.length % splitNum == 0 + ? sourceDataList.length / splitNum + : sourceDataList.length / splitNum + 1 + const arrayList: Array> = [] + for (let i = 0; i < count; i++) { + let index = i * splitNum + const list: Array = [] + let j = 0 + while (j < splitNum && index < sourceDataList.length) { + list.push(sourceDataList[index++]) + j++ + } + arrayList.push(list) + } + return arrayList +} diff --git a/src/MaxKB-1.7.2/ui/src/utils/decimalFormat.ts b/src/MaxKB-1.7.2/ui/src/utils/decimalFormat.ts new file mode 100644 index 0000000..b9e84f9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/decimalFormat.ts @@ -0,0 +1,16 @@ +function format(decimal?: number, digits?: number): string | undefined { + if (digits == undefined) { + digits = 0; + } + return decimal?.toLocaleString("zh-CN", { + style: "decimal", + minimumFractionDigits: digits, + maximumFractionDigits: digits, + }); +} + +const util = { + format, +}; + +export default util; diff --git a/src/MaxKB-1.7.2/ui/src/utils/message.ts b/src/MaxKB-1.7.2/ui/src/utils/message.ts new file mode 100644 index 0000000..ea9df38 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/message.ts @@ -0,0 +1,60 @@ +import { ElMessageBox, ElMessage } from 'element-plus' + +export const MsgSuccess = (message: string) => { + ElMessage.success({ + message: message, + type: 'success', + showClose: true, + duration: 3000 + }) +} + +export const MsgInfo = (message: string) => { + ElMessage.info({ + message: message, + type: 'info', + showClose: true, + duration: 3000 + }) +} + +export const MsgWarning = (message: string) => { + ElMessage.warning({ + message: message, + type: 'warning', + showClose: true, + duration: 3000 + }) +} + +export const MsgError = (message: string) => { + ElMessage.error({ + message: message, + type: 'error', + showClose: true, + duration: 3000 + }) +} + +export const MsgAlert = (title: string, description: string, options?: any) => { + const defaultOptions: Object = { + confirmButtonText: '确定', + ...options + } + return ElMessageBox.alert(description, title, defaultOptions) +} + +/** + * 删除知识库 + * @param 参数 message: {title, description,type} + */ + +export const MsgConfirm = (title: string, description: string, options?: any) => { + const defaultOptions: Object = { + showCancelButton: true, + confirmButtonText: '确定', + cancelButtonText: '取消', + ...options + } + return ElMessageBox.confirm(description, title, defaultOptions) +} diff --git a/src/MaxKB-1.7.2/ui/src/utils/permission/index.ts b/src/MaxKB-1.7.2/ui/src/utils/permission/index.ts new file mode 100644 index 0000000..37313f9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/permission/index.ts @@ -0,0 +1,55 @@ +import useStore from '@/stores'; +import { Role, Permission, ComplexPermission } from '@/utils/permission/type' +/** + * 是否包含当前权限 + * @param permission 当前权限 + * @returns True 包含 false 不包含 + */ +const hasPermissionChild = (permission: Role | string | Permission | ComplexPermission) => { + const { user } = useStore(); + const permissions = user.getPermissions() + const role = user.getRole() + if (!permission) { + return true + } + if (permission instanceof Role) { + return role === permission.role + } + if (permission instanceof Permission) { + return permissions.includes(permission.permission) + } + if (permission instanceof ComplexPermission) { + const permissionOk = permission.permissionList.some((p) => permissions.includes(p)) + const roleOk = permission.roleList.includes(role) + return permission.compare === 'AND' ? permissionOk && roleOk : permissionOk || roleOk + } + if (typeof permission === 'string') { + return permissions.includes(permission) + } + + return false +} +/** + * 判断是否有角色和权限 + * @param role 角色 + * @param permissions 权限 + * @param requiredPermissions 权限 + * @returns + */ +export const hasPermission = ( + permission: + | Array + | Role + | string + | Permission + | ComplexPermission, + compare: 'OR' | 'AND' +): boolean => { + if (permission instanceof Array) { + return compare === 'OR' + ? permission.some((p) => hasPermissionChild(p)) + : permission.every((p) => hasPermissionChild(p)) + } else { + return hasPermissionChild(permission) + } +} diff --git a/src/MaxKB-1.7.2/ui/src/utils/permission/type.ts b/src/MaxKB-1.7.2/ui/src/utils/permission/type.ts new file mode 100644 index 0000000..874fc83 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/permission/type.ts @@ -0,0 +1,36 @@ +/** + * 角色对象 + */ +export class Role { + role: string + + constructor(role: string) { + this.role = role + } +} +/** + * 权限对象 + */ +export class Permission { + permission: string + + constructor(permission: string) { + this.permission = permission + } +} +/** + * 复杂权限对象 + */ +export class ComplexPermission { + roleList: Array + + permissionList: Array + + compare: 'OR' | 'AND' + + constructor(roleList: Array, permissionList: Array, compare: 'OR' | 'AND') { + this.roleList = roleList + this.permissionList = permissionList + this.compare = compare + } +} diff --git a/src/MaxKB-1.7.2/ui/src/utils/theme.ts b/src/MaxKB-1.7.2/ui/src/utils/theme.ts new file mode 100644 index 0000000..35dbbbb --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/theme.ts @@ -0,0 +1,62 @@ +export const themeList = [ + { + label: '默认', + value: '#3370FF', + loginBackground: 'default' + }, + { + label: '活力橙', + value: '#FF8800', + loginBackground: 'orange' + }, + { + label: '松石绿', + value: '#00B69D', + loginBackground: 'green' + }, + { + label: '神秘紫', + value: '#7F3BF5', + loginBackground: 'purple' + }, + { + label: '胭脂红', + value: '#F01D94', + loginBackground: 'red' + } +] + +export function getThemeImg(val: string) { + return themeList.filter((v) => v.value === val)?.[0]?.loginBackground || 'default' +} + +export const defaultSetting = { + icon: '', + loginLogo: '', + loginImage: '', + title: 'MaxKB', + slogan: '欢迎使用 MaxKB 智能知识库问答系统' +} + +export const defaultPlatformSetting = { + showUserManual: true, + userManualUrl: 'https://maxkb.cn/docs/', + showForum: true, + forumUrl: 'https://bbs.fit2cloud.com/c/mk/11', + showProject: true, + projectUrl: 'https://github.com/1Panel-dev/MaxKB' +} + +export function hexToRgba(hex?: string, alpha?: number) { + // 将16进制颜色值的两个字符一起转换成十进制 + if (!hex) { + return '' + } else { + const r = parseInt(hex.slice(1, 3), 16) + const g = parseInt(hex.slice(3, 5), 16) + const b = parseInt(hex.slice(5, 7), 16) + + // 返回RGBA格式的字符串 + return `rgba(${r}, ${g}, ${b}, ${alpha})` + } +} diff --git a/src/MaxKB-1.7.2/ui/src/utils/time.ts b/src/MaxKB-1.7.2/ui/src/utils/time.ts new file mode 100644 index 0000000..dcbe565 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/time.ts @@ -0,0 +1,80 @@ +import moment from 'moment' +import 'moment/dist/locale/zh-cn' +moment.locale('zh-cn') + +// 当天日期 YYYY-MM-DD +export const nowDate = moment().format('YYYY-MM-DD') + +// 当前时间的前n天 +export function beforeDay(n: number | string) { + return moment().subtract(n, 'days').format('YYYY-MM-DD') +} + +const getCheckDate = (timestamp: any) => { + if (!timestamp) return false + const dt = new Date(timestamp) + if (isNaN(dt.getTime())) return false + return dt +} +export const datetimeFormat = (timestamp: any) => { + const dt = getCheckDate(timestamp) + if (!dt) return timestamp + + const y = dt.getFullYear() + const m = (dt.getMonth() + 1 + '').padStart(2, '0') + const d = (dt.getDate() + '').padStart(2, '0') + const hh = (dt.getHours() + '').padStart(2, '0') + const mm = (dt.getMinutes() + '').padStart(2, '0') + const ss = (dt.getSeconds() + '').padStart(2, '0') + + return `${y}-${m}-${d} ${hh}:${mm}:${ss}` +} + +export const dateFormat = (timestamp: any) => { + const dt = getCheckDate(timestamp) + if (!dt) return timestamp + + const y = dt.getFullYear() + const m = (dt.getMonth() + 1 + '').padStart(2, '0') + const d = (dt.getDate() + '').padStart(2, '0') + + return `${y}-${m}-${d}` +} + +export function fromNowDate(time: any) { + // 拿到当前时间戳和发布时的时间戳,然后得出时间戳差 + const curTime = new Date() + const futureTime = new Date(time) + const timeDiff = futureTime.getTime() - curTime.getTime() + + // 单位换算 + const min = 60 * 1000 + const hour = min * 60 + const day = hour * 24 + const week = day * 7 + + // 计算发布时间距离当前时间的周、天、时、分 + const exceedWeek = Math.floor(timeDiff / week) + const exceedDay = Math.floor(timeDiff / day) + const exceedHour = Math.floor(timeDiff / hour) + const exceedMin = Math.floor(timeDiff / min) + + // 最后判断时间差到底是属于哪个区间,然后return + if (exceedWeek > 0) { + return '' + } else { + if (exceedDay < 7 && exceedDay > 0) { + return exceedDay + '天后' + } else { + if (exceedHour < 24 && exceedHour > 0) { + return exceedHour + '小时后' + } else { + if (exceedMin < 0) { + return '已过期' + } else { + return '即将到期' + } + } + } + } +} diff --git a/src/MaxKB-1.7.2/ui/src/utils/utils.ts b/src/MaxKB-1.7.2/ui/src/utils/utils.ts new file mode 100644 index 0000000..b4a87f3 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/utils/utils.ts @@ -0,0 +1,90 @@ +export function toThousands(num: any) { + return num?.toString().replace(/\d+/, function (n: any) { + return n.replace(/(\d)(?=(?:\d{3})+$)/g, '$1,') + }) +} +export function numberFormat(num: number) { + return num < 1000 ? toThousands(num) : toThousands((num / 1000).toFixed(1)) + 'k' +} + +export function filesize(size: number) { + if (!size) return '' + /* byte */ + const num = 1024.0 + + if (size < num) return size + 'B' + if (size < Math.pow(num, 2)) return (size / num).toFixed(2) + 'K' //kb + if (size < Math.pow(num, 3)) return (size / Math.pow(num, 2)).toFixed(2) + 'M' //M + if (size < Math.pow(num, 4)) return (size / Math.pow(num, 3)).toFixed(2) + 'G' //G + return (size / Math.pow(num, 4)).toFixed(2) + 'T' //T +} + +/* + 随机id +*/ +export const randomId = function () { + return Math.floor(Math.random() * 10000) + '' +} + +/* + 获取文件后缀 +*/ +export function fileType(name: string) { + const suffix = name.split('.') + return suffix[suffix.length - 1] +} + +/* + 获得文件对应图片 +*/ +const typeList: any = { + txt: ['txt', 'pdf', 'docx', 'csv', 'md', 'html'], + table: ['xlsx', 'xls', 'csv'], + QA: ['xlsx', 'csv', 'xls'] +} + +export function getImgUrl(name: string) { + const list = Object.values(typeList).flat() + + const type = list.includes(fileType(name).toLowerCase()) ? fileType(name).toLowerCase() : 'unknow' + return new URL(`../assets/${type}-icon.svg`, import.meta.url).href +} +// 是否是白名单后缀 +export function isRightType(name: string, type: string) { + return typeList[type].includes(fileType(name).toLowerCase()) +} + +/* + 从指定数组中过滤出对应的对象 +*/ +export function relatedObject(list: any, val: any, attr: string) { + const filterData: any = list.filter((item: any) => item[attr] === val)?.[0] + return filterData || null +} + +// 排序 +export function arraySort(list: Array, property: any, desc?: boolean) { + return list.sort((a: any, b: any) => { + return desc ? b[property] - a[property] : a[property] - b[property] + }) +} + +// 判断对象里所有属性全部为空 +export function isAllPropertiesEmpty(obj: object) { + return Object.values(obj).every( + (value) => + value === null || typeof value === 'undefined' || (typeof value === 'string' && !value) + ) +} + +// 数组对象中某一属性值的集合 +export function getAttrsArray(array: Array, attr: string) { + return array.map((item) => { + return item[attr] + }) +} + +// 求和 +export function getSum(array: Array) { + return array.reduce((total, item) => total + item, 0) +} diff --git a/src/MaxKB-1.7.2/ui/src/views/404/index.vue b/src/MaxKB-1.7.2/ui/src/views/404/index.vue new file mode 100644 index 0000000..bb00f10 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/404/index.vue @@ -0,0 +1,51 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-overview/component/APIKeyDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/APIKeyDialog.vue new file mode 100644 index 0000000..3aa8748 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/APIKeyDialog.vue @@ -0,0 +1,161 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-overview/component/DisplaySettingDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/DisplaySettingDialog.vue new file mode 100644 index 0000000..a98d569 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/DisplaySettingDialog.vue @@ -0,0 +1,248 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-overview/component/EditAvatarDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/EditAvatarDialog.vue new file mode 100644 index 0000000..b4d15e8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/EditAvatarDialog.vue @@ -0,0 +1,142 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-overview/component/EmbedDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/EmbedDialog.vue new file mode 100644 index 0000000..be37645 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/EmbedDialog.vue @@ -0,0 +1,128 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-overview/component/LimitDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/LimitDialog.vue new file mode 100644 index 0000000..5de87f4 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/LimitDialog.vue @@ -0,0 +1,198 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-overview/component/SettingAPIKeyDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/SettingAPIKeyDialog.vue new file mode 100644 index 0000000..dc580f9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/SettingAPIKeyDialog.vue @@ -0,0 +1,105 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-overview/component/StatisticsCharts.vue b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/StatisticsCharts.vue new file mode 100644 index 0000000..dbcc2e8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/StatisticsCharts.vue @@ -0,0 +1,163 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-overview/component/XPackDisplaySettingDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/XPackDisplaySettingDialog.vue new file mode 100644 index 0000000..83e2910 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-overview/component/XPackDisplaySettingDialog.vue @@ -0,0 +1,512 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-overview/index.vue b/src/MaxKB-1.7.2/ui/src/views/application-overview/index.vue new file mode 100644 index 0000000..6c380c9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-overview/index.vue @@ -0,0 +1,414 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-workflow/component/DropdownMenu.vue b/src/MaxKB-1.7.2/ui/src/views/application-workflow/component/DropdownMenu.vue new file mode 100644 index 0000000..e9e03e4 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-workflow/component/DropdownMenu.vue @@ -0,0 +1,168 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-workflow/component/PublishHistory.vue b/src/MaxKB-1.7.2/ui/src/views/application-workflow/component/PublishHistory.vue new file mode 100644 index 0000000..d66d7fb --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-workflow/component/PublishHistory.vue @@ -0,0 +1,144 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application-workflow/index.vue b/src/MaxKB-1.7.2/ui/src/views/application-workflow/index.vue new file mode 100644 index 0000000..5df4f22 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application-workflow/index.vue @@ -0,0 +1,442 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/ApplicationAccess.vue b/src/MaxKB-1.7.2/ui/src/views/application/ApplicationAccess.vue new file mode 100644 index 0000000..0461d6b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/ApplicationAccess.vue @@ -0,0 +1,174 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/ApplicationSetting.vue b/src/MaxKB-1.7.2/ui/src/views/application/ApplicationSetting.vue new file mode 100644 index 0000000..53f0f0c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/ApplicationSetting.vue @@ -0,0 +1,910 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/component/AIModeParamSettingDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application/component/AIModeParamSettingDialog.vue new file mode 100644 index 0000000..02f65fb --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/component/AIModeParamSettingDialog.vue @@ -0,0 +1,114 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/component/AccessSettingDrawer.vue b/src/MaxKB-1.7.2/ui/src/views/application/component/AccessSettingDrawer.vue new file mode 100644 index 0000000..b5d90b7 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/component/AccessSettingDrawer.vue @@ -0,0 +1,259 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/component/AddDatasetDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application/component/AddDatasetDialog.vue new file mode 100644 index 0000000..2c7ecaa --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/component/AddDatasetDialog.vue @@ -0,0 +1,168 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/component/CopyApplicationDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application/component/CopyApplicationDialog.vue new file mode 100644 index 0000000..e49f3e1 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/component/CopyApplicationDialog.vue @@ -0,0 +1,178 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/component/CreateApplicationDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application/component/CreateApplicationDialog.vue new file mode 100644 index 0000000..1ddb32b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/component/CreateApplicationDialog.vue @@ -0,0 +1,219 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/component/ParamSettingDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application/component/ParamSettingDialog.vue new file mode 100644 index 0000000..123ddc5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/component/ParamSettingDialog.vue @@ -0,0 +1,348 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/component/TTSModeParamSettingDialog.vue b/src/MaxKB-1.7.2/ui/src/views/application/component/TTSModeParamSettingDialog.vue new file mode 100644 index 0000000..0b1b406 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/component/TTSModeParamSettingDialog.vue @@ -0,0 +1,173 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/application/index.vue b/src/MaxKB-1.7.2/ui/src/views/application/index.vue new file mode 100644 index 0000000..b6addf2 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/application/index.vue @@ -0,0 +1,262 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/authentication/component/CAS.vue b/src/MaxKB-1.7.2/ui/src/views/authentication/component/CAS.vue new file mode 100644 index 0000000..4f790db --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/authentication/component/CAS.vue @@ -0,0 +1,97 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/authentication/component/EditModal.vue b/src/MaxKB-1.7.2/ui/src/views/authentication/component/EditModal.vue new file mode 100644 index 0000000..8ea3afa --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/authentication/component/EditModal.vue @@ -0,0 +1,182 @@ +template + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/authentication/component/LDAP.vue b/src/MaxKB-1.7.2/ui/src/views/authentication/component/LDAP.vue new file mode 100644 index 0000000..f4c2692 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/authentication/component/LDAP.vue @@ -0,0 +1,143 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/authentication/component/OIDC.vue b/src/MaxKB-1.7.2/ui/src/views/authentication/component/OIDC.vue new file mode 100644 index 0000000..079660c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/authentication/component/OIDC.vue @@ -0,0 +1,144 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/authentication/component/SCAN.vue b/src/MaxKB-1.7.2/ui/src/views/authentication/component/SCAN.vue new file mode 100644 index 0000000..d0ccead --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/authentication/component/SCAN.vue @@ -0,0 +1,254 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/authentication/index.vue b/src/MaxKB-1.7.2/ui/src/views/authentication/index.vue new file mode 100644 index 0000000..484f915 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/authentication/index.vue @@ -0,0 +1,70 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/chat/auth/component/password.vue b/src/MaxKB-1.7.2/ui/src/views/chat/auth/component/password.vue new file mode 100644 index 0000000..4139258 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/chat/auth/component/password.vue @@ -0,0 +1,83 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/chat/auth/index.vue b/src/MaxKB-1.7.2/ui/src/views/chat/auth/index.vue new file mode 100644 index 0000000..fd4e936 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/chat/auth/index.vue @@ -0,0 +1,70 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/chat/base/index.vue b/src/MaxKB-1.7.2/ui/src/views/chat/base/index.vue new file mode 100644 index 0000000..49871a7 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/chat/base/index.vue @@ -0,0 +1,110 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/chat/embed/index.vue b/src/MaxKB-1.7.2/ui/src/views/chat/embed/index.vue new file mode 100644 index 0000000..b935ee4 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/chat/embed/index.vue @@ -0,0 +1,356 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/chat/index.vue b/src/MaxKB-1.7.2/ui/src/views/chat/index.vue new file mode 100644 index 0000000..36b8e44 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/chat/index.vue @@ -0,0 +1,99 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/chat/pc/index.vue b/src/MaxKB-1.7.2/ui/src/views/chat/pc/index.vue new file mode 100644 index 0000000..5ec014e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/chat/pc/index.vue @@ -0,0 +1,473 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/DatasetSetting.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/DatasetSetting.vue new file mode 100644 index 0000000..c1caa4a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/DatasetSetting.vue @@ -0,0 +1,201 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/UploadDocumentDataset.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/UploadDocumentDataset.vue new file mode 100644 index 0000000..6c13325 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/UploadDocumentDataset.vue @@ -0,0 +1,194 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/component/BaseForm.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/component/BaseForm.vue new file mode 100644 index 0000000..43a00b3 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/component/BaseForm.vue @@ -0,0 +1,191 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/component/CreateDatasetDialog.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/component/CreateDatasetDialog.vue new file mode 100644 index 0000000..96b49cd --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/component/CreateDatasetDialog.vue @@ -0,0 +1,171 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/component/EditParagraphDialog.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/component/EditParagraphDialog.vue new file mode 100644 index 0000000..f62212e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/component/EditParagraphDialog.vue @@ -0,0 +1,127 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/component/ParagraphList.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/component/ParagraphList.vue new file mode 100644 index 0000000..3741f93 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/component/ParagraphList.vue @@ -0,0 +1,93 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/component/ParagraphPreview.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/component/ParagraphPreview.vue new file mode 100644 index 0000000..3ad15ce --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/component/ParagraphPreview.vue @@ -0,0 +1,68 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/component/ResultSuccess.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/component/ResultSuccess.vue new file mode 100644 index 0000000..d8ed40c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/component/ResultSuccess.vue @@ -0,0 +1,83 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/component/SetRules.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/component/SetRules.vue new file mode 100644 index 0000000..593ac67 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/component/SetRules.vue @@ -0,0 +1,256 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/component/SyncWebDialog.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/component/SyncWebDialog.vue new file mode 100644 index 0000000..60dda86 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/component/SyncWebDialog.vue @@ -0,0 +1,85 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/component/UploadComponent.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/component/UploadComponent.vue new file mode 100644 index 0000000..53f2d81 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/component/UploadComponent.vue @@ -0,0 +1,296 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/dataset/index.vue b/src/MaxKB-1.7.2/ui/src/views/dataset/index.vue new file mode 100644 index 0000000..3771fbe --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/dataset/index.vue @@ -0,0 +1,236 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/document/component/GenerateRelatedDialog.vue b/src/MaxKB-1.7.2/ui/src/views/document/component/GenerateRelatedDialog.vue new file mode 100644 index 0000000..a5908a5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/document/component/GenerateRelatedDialog.vue @@ -0,0 +1,236 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/document/component/ImportDocumentDialog.vue b/src/MaxKB-1.7.2/ui/src/views/document/component/ImportDocumentDialog.vue new file mode 100644 index 0000000..92d20e1 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/document/component/ImportDocumentDialog.vue @@ -0,0 +1,211 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/document/component/SelectDatasetDialog.vue b/src/MaxKB-1.7.2/ui/src/views/document/component/SelectDatasetDialog.vue new file mode 100644 index 0000000..774cce7 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/document/component/SelectDatasetDialog.vue @@ -0,0 +1,137 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/document/index.vue b/src/MaxKB-1.7.2/ui/src/views/document/index.vue new file mode 100644 index 0000000..cd30160 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/document/index.vue @@ -0,0 +1,715 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/email/index.vue b/src/MaxKB-1.7.2/ui/src/views/email/index.vue new file mode 100644 index 0000000..2b92280 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/email/index.vue @@ -0,0 +1,122 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/first/index.vue b/src/MaxKB-1.7.2/ui/src/views/first/index.vue new file mode 100644 index 0000000..83075ab --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/first/index.vue @@ -0,0 +1,19 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/function-lib/component/FieldFormDialog.vue b/src/MaxKB-1.7.2/ui/src/views/function-lib/component/FieldFormDialog.vue new file mode 100644 index 0000000..c16c627 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/function-lib/component/FieldFormDialog.vue @@ -0,0 +1,111 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/function-lib/component/FunctionDebugDrawer.vue b/src/MaxKB-1.7.2/ui/src/views/function-lib/component/FunctionDebugDrawer.vue new file mode 100644 index 0000000..1f80347 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/function-lib/component/FunctionDebugDrawer.vue @@ -0,0 +1,139 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/function-lib/component/FunctionFormDrawer.vue b/src/MaxKB-1.7.2/ui/src/views/function-lib/component/FunctionFormDrawer.vue new file mode 100644 index 0000000..6d9928b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/function-lib/component/FunctionFormDrawer.vue @@ -0,0 +1,317 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/function-lib/index.vue b/src/MaxKB-1.7.2/ui/src/views/function-lib/index.vue new file mode 100644 index 0000000..e543112 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/function-lib/index.vue @@ -0,0 +1,232 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/hit-test/index.vue b/src/MaxKB-1.7.2/ui/src/views/hit-test/index.vue new file mode 100644 index 0000000..2a1845e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/hit-test/index.vue @@ -0,0 +1,391 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/log/component/ChatRecordDrawer.vue b/src/MaxKB-1.7.2/ui/src/views/log/component/ChatRecordDrawer.vue new file mode 100644 index 0000000..b3709b9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/log/component/ChatRecordDrawer.vue @@ -0,0 +1,137 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/log/component/EditContentDialog.vue b/src/MaxKB-1.7.2/ui/src/views/log/component/EditContentDialog.vue new file mode 100644 index 0000000..17731a0 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/log/component/EditContentDialog.vue @@ -0,0 +1,273 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/log/component/EditMarkDialog.vue b/src/MaxKB-1.7.2/ui/src/views/log/component/EditMarkDialog.vue new file mode 100644 index 0000000..06fa99e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/log/component/EditMarkDialog.vue @@ -0,0 +1,160 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/log/index.vue b/src/MaxKB-1.7.2/ui/src/views/log/index.vue new file mode 100644 index 0000000..b7b21b9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/log/index.vue @@ -0,0 +1,455 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/login/components/QrCodeTab.vue b/src/MaxKB-1.7.2/ui/src/views/login/components/QrCodeTab.vue new file mode 100644 index 0000000..896c919 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/login/components/QrCodeTab.vue @@ -0,0 +1,76 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/login/components/dingtalkQrCode.vue b/src/MaxKB-1.7.2/ui/src/views/login/components/dingtalkQrCode.vue new file mode 100644 index 0000000..9878410 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/login/components/dingtalkQrCode.vue @@ -0,0 +1,143 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/login/components/larkQrCode.vue b/src/MaxKB-1.7.2/ui/src/views/login/components/larkQrCode.vue new file mode 100644 index 0000000..72e6d9b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/login/components/larkQrCode.vue @@ -0,0 +1,59 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/login/components/wecomQrCode.vue b/src/MaxKB-1.7.2/ui/src/views/login/components/wecomQrCode.vue new file mode 100644 index 0000000..5487ef9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/login/components/wecomQrCode.vue @@ -0,0 +1,72 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/login/forgot-password/index.vue b/src/MaxKB-1.7.2/ui/src/views/login/forgot-password/index.vue new file mode 100644 index 0000000..681e1b8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/login/forgot-password/index.vue @@ -0,0 +1,130 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/login/index.vue b/src/MaxKB-1.7.2/ui/src/views/login/index.vue new file mode 100644 index 0000000..dbdba29 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/login/index.vue @@ -0,0 +1,275 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/login/register/index.vue b/src/MaxKB-1.7.2/ui/src/views/login/register/index.vue new file mode 100644 index 0000000..e5acf2a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/login/register/index.vue @@ -0,0 +1,216 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/login/reset-password/index.vue b/src/MaxKB-1.7.2/ui/src/views/login/reset-password/index.vue new file mode 100644 index 0000000..ac2111c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/login/reset-password/index.vue @@ -0,0 +1,134 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/paragraph/component/GenerateRelatedDialog.vue b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/GenerateRelatedDialog.vue new file mode 100644 index 0000000..c76144d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/GenerateRelatedDialog.vue @@ -0,0 +1,236 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/paragraph/component/ParagraphDialog.vue b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/ParagraphDialog.vue new file mode 100644 index 0000000..8874956 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/ParagraphDialog.vue @@ -0,0 +1,153 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/paragraph/component/ParagraphForm.vue b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/ParagraphForm.vue new file mode 100644 index 0000000..43739d2 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/ParagraphForm.vue @@ -0,0 +1,173 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/paragraph/component/ProblemComponent.vue b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/ProblemComponent.vue new file mode 100644 index 0000000..b940385 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/ProblemComponent.vue @@ -0,0 +1,202 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/paragraph/component/SelectDocumentDialog.vue b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/SelectDocumentDialog.vue new file mode 100644 index 0000000..61b2b49 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/paragraph/component/SelectDocumentDialog.vue @@ -0,0 +1,168 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/paragraph/index.vue b/src/MaxKB-1.7.2/ui/src/views/paragraph/index.vue new file mode 100644 index 0000000..cb8f68a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/paragraph/index.vue @@ -0,0 +1,411 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/problem/component/CreateProblemDialog.vue b/src/MaxKB-1.7.2/ui/src/views/problem/component/CreateProblemDialog.vue new file mode 100644 index 0000000..2546f8a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/problem/component/CreateProblemDialog.vue @@ -0,0 +1,92 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/problem/component/DetailProblemDrawer.vue b/src/MaxKB-1.7.2/ui/src/views/problem/component/DetailProblemDrawer.vue new file mode 100644 index 0000000..b3aa32d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/problem/component/DetailProblemDrawer.vue @@ -0,0 +1,192 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/problem/component/RelateProblemDialog.vue b/src/MaxKB-1.7.2/ui/src/views/problem/component/RelateProblemDialog.vue new file mode 100644 index 0000000..c91103d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/problem/component/RelateProblemDialog.vue @@ -0,0 +1,294 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/problem/index.vue b/src/MaxKB-1.7.2/ui/src/views/problem/index.vue new file mode 100644 index 0000000..2b6c59d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/problem/index.vue @@ -0,0 +1,362 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/team/component/CreateMemberDialog.vue b/src/MaxKB-1.7.2/ui/src/views/team/component/CreateMemberDialog.vue new file mode 100644 index 0000000..7aaaefd --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/team/component/CreateMemberDialog.vue @@ -0,0 +1,118 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/team/component/PermissionSetting.vue b/src/MaxKB-1.7.2/ui/src/views/team/component/PermissionSetting.vue new file mode 100644 index 0000000..e5ed837 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/team/component/PermissionSetting.vue @@ -0,0 +1,151 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/team/index.vue b/src/MaxKB-1.7.2/ui/src/views/team/index.vue new file mode 100644 index 0000000..4a35c59 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/team/index.vue @@ -0,0 +1,279 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/template/component/AddParamDrawer.vue b/src/MaxKB-1.7.2/ui/src/views/template/component/AddParamDrawer.vue new file mode 100644 index 0000000..f1f6bd8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/template/component/AddParamDrawer.vue @@ -0,0 +1,77 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/template/component/CreateModelDialog.vue b/src/MaxKB-1.7.2/ui/src/views/template/component/CreateModelDialog.vue new file mode 100644 index 0000000..8f4bf72 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/template/component/CreateModelDialog.vue @@ -0,0 +1,307 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/template/component/EditModel.vue b/src/MaxKB-1.7.2/ui/src/views/template/component/EditModel.vue new file mode 100644 index 0000000..333f801 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/template/component/EditModel.vue @@ -0,0 +1,289 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/template/component/ModelCard.vue b/src/MaxKB-1.7.2/ui/src/views/template/component/ModelCard.vue new file mode 100644 index 0000000..1ab1118 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/template/component/ModelCard.vue @@ -0,0 +1,269 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/template/component/ParamSettingDialog.vue b/src/MaxKB-1.7.2/ui/src/views/template/component/ParamSettingDialog.vue new file mode 100644 index 0000000..5adfd90 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/template/component/ParamSettingDialog.vue @@ -0,0 +1,148 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/template/component/SelectProviderDialog.vue b/src/MaxKB-1.7.2/ui/src/views/template/component/SelectProviderDialog.vue new file mode 100644 index 0000000..ce102dd --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/template/component/SelectProviderDialog.vue @@ -0,0 +1,94 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/template/index.vue b/src/MaxKB-1.7.2/ui/src/views/template/index.vue new file mode 100644 index 0000000..0a33c8e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/template/index.vue @@ -0,0 +1,342 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/theme/LoginPreview.vue b/src/MaxKB-1.7.2/ui/src/views/theme/LoginPreview.vue new file mode 100644 index 0000000..ed2aa00 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/theme/LoginPreview.vue @@ -0,0 +1,104 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/theme/index.vue b/src/MaxKB-1.7.2/ui/src/views/theme/index.vue new file mode 100644 index 0000000..df25ff4 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/theme/index.vue @@ -0,0 +1,375 @@ + + + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/user-manage/component/UserDialog.vue b/src/MaxKB-1.7.2/ui/src/views/user-manage/component/UserDialog.vue new file mode 100644 index 0000000..3624ba6 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/user-manage/component/UserDialog.vue @@ -0,0 +1,161 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/user-manage/component/UserPwdDialog.vue b/src/MaxKB-1.7.2/ui/src/views/user-manage/component/UserPwdDialog.vue new file mode 100644 index 0000000..2b419ea --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/user-manage/component/UserPwdDialog.vue @@ -0,0 +1,131 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/views/user-manage/index.vue b/src/MaxKB-1.7.2/ui/src/views/user-manage/index.vue new file mode 100644 index 0000000..9a81222 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/views/user-manage/index.vue @@ -0,0 +1,214 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/common/CustomLine.vue b/src/MaxKB-1.7.2/ui/src/workflow/common/CustomLine.vue new file mode 100644 index 0000000..48146e5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/common/CustomLine.vue @@ -0,0 +1,37 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/common/NodeCascader.vue b/src/MaxKB-1.7.2/ui/src/workflow/common/NodeCascader.vue new file mode 100644 index 0000000..272bc11 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/common/NodeCascader.vue @@ -0,0 +1,120 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/common/NodeContainer.vue b/src/MaxKB-1.7.2/ui/src/workflow/common/NodeContainer.vue new file mode 100644 index 0000000..5973b24 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/common/NodeContainer.vue @@ -0,0 +1,266 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/common/NodeControl.vue b/src/MaxKB-1.7.2/ui/src/workflow/common/NodeControl.vue new file mode 100644 index 0000000..ac14269 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/common/NodeControl.vue @@ -0,0 +1,34 @@ + + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/common/app-node.ts b/src/MaxKB-1.7.2/ui/src/workflow/common/app-node.ts new file mode 100644 index 0000000..f720ad4 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/common/app-node.ts @@ -0,0 +1,261 @@ +import Components from '@/components' +import ElementPlus from 'element-plus' +import * as ElementPlusIcons from '@element-plus/icons-vue' +import zhCn from 'element-plus/dist/locale/zh-cn.mjs' +import { HtmlResize } from '@logicflow/extension' + +import { h as lh } from '@logicflow/core' +import { createApp, h } from 'vue' +import directives from '@/directives' +import i18n from '@/locales' +import { WorkflowType } from '@/enums/workflow' +import { nodeDict } from '@/workflow/common/data' +class AppNode extends HtmlResize.view { + isMounted + r + app + + constructor(props: any, VueNode: any) { + super(props) + this.isMounted = false + this.r = h(VueNode, { + properties: props.model.properties, + nodeModel: props.model + }) + + this.app = createApp({ + render: () => this.r + }) + this.app.use(ElementPlus, { + locale: zhCn + }) + this.app.use(Components) + this.app.use(directives) + this.app.use(i18n) + for (const [key, component] of Object.entries(ElementPlusIcons)) { + this.app.component(key, component) + } + + if (props.model.properties.noRender) { + delete props.model.properties.noRender + } else { + const filterNodes = props.graphModel.nodes.filter((v: any) => v.type === props.model.type) + if (filterNodes.length - 1 > 0) { + props.model.properties.stepName = props.model.properties.stepName + (filterNodes.length - 1) + } + } + props.model.properties.config = nodeDict[props.model.type].properties.config + if (props.model.properties.height) { + props.model.height = props.model.properties.height + } + } + + getAnchorShape(anchorData: any) { + const { x, y, type } = anchorData + let isConnect = false + + if (type == 'left') { + isConnect = this.props.graphModel.edges.some((edge) => edge.targetAnchorId == anchorData.id) + } else { + isConnect = this.props.graphModel.edges.some((edge) => edge.sourceAnchorId == anchorData.id) + } + + return lh( + 'foreignObject', + { + ...anchorData, + x: x - 10, + y: y - 12, + width: 30, + height: 30 + }, + [ + lh('div', { + style: { zindex: 0 }, + onClick: () => { + if (type == 'right') { + this.props.model.openNodeMenu(anchorData) + } + }, + dangerouslySetInnerHTML: { + __html: isConnect + ? ` + + + + + + + + + + + + + + + + + + ` + : ` + + + + + + + + + + + + + + + + + ` + } + }) + ] + ) + } + + setHtml(rootEl: HTMLElement) { + if (!this.isMounted) { + this.isMounted = true + const node = document.createElement('div') + rootEl.appendChild(node) + this.app?.mount(node) + } else { + if (this.r && this.r.component) { + this.r.component.props.properties = this.props.model.getProperties() + } + } + } +} + +class AppNodeModel extends HtmlResize.model { + getResizeOutlineStyle() { + const style = super.getResizeOutlineStyle() + style.stroke = 'none' + return style + } + getControlPointStyle() { + const style = super.getControlPointStyle() + style.stroke = 'none' + style.fill = 'none' + return style + } + getNodeStyle() { + return { + overflow: 'visible' + } + } + getOutlineStyle() { + const style = super.getOutlineStyle() + style.stroke = 'none' + if (style.hover) { + style.hover.stroke = 'none' + } + return style + } + // 如果不用修改锚地形状,可以重写颜色相关样式 + getAnchorStyle(anchorInfo: any) { + const style = super.getAnchorStyle(anchorInfo) + if (anchorInfo.type === 'left') { + style.fill = 'red' + style.hover.fill = 'transparent' + style.hover.stroke = 'transpanrent' + style.className = 'lf-hide-default' + } else { + style.fill = 'green' + } + return style + } + + setHeight(height: number) { + const sourceHeight = this.height + const targetHeight = height + 100 + this.height = targetHeight + this.properties['height'] = targetHeight + this.move(0, (targetHeight - sourceHeight) / 2) + this.outgoing.edges.forEach((edge: any) => { + // 调用自定义的更新方案 + edge.updatePathByAnchor() + }) + this.incoming.edges.forEach((edge: any) => { + // 调用自定义的更新方案 + edge.updatePathByAnchor() + }) + } + get_width() { + return this.properties?.width || 340 + } + + setAttributes() { + this.width = this.get_width() + const isLoop = (node_id: string, target_node_id: string) => { + const up_node_list = this.graphModel.getNodeIncomingNode(node_id) + for (const index in up_node_list) { + const item = up_node_list[index] + if (item.id === target_node_id) { + return true + } else { + const result = isLoop(item.id, target_node_id) + if (result) { + return true + } + } + } + return false + } + const circleOnlyAsTarget = { + message: '只允许从右边的锚点连出', + validate: (sourceNode: any, targetNode: any, sourceAnchor: any) => { + return sourceAnchor.type === 'right' + } + } + this.sourceRules.push({ + message: '不可循环连线', + validate: (sourceNode: any, targetNode: any, sourceAnchor: any, targetAnchor: any) => { + return !isLoop(sourceNode.id, targetNode.id) + } + }) + + this.sourceRules.push(circleOnlyAsTarget) + this.targetRules.push({ + message: '只允许连接左边的锚点', + validate: (sourceNode: any, targetNode: any, sourceAnchor: any, targetAnchor: any) => { + return targetAnchor.type === 'left' + } + }) + } + getDefaultAnchor() { + const { id, x, y, width } = this + const showNode = this.properties.showNode === undefined ? true : this.properties.showNode + const anchors: any = [] + + if (this.type !== WorkflowType.Base) { + if (this.type !== WorkflowType.Start) { + anchors.push({ + x: x - width / 2 + 10, + y: showNode ? y : y - 15, + id: `${id}_left`, + edgeAddable: false, + type: 'left' + }) + } + anchors.push({ + x: x + width / 2 - 10, + y: showNode ? y : y - 15, + id: `${id}_right`, + type: 'right' + }) + } + + return anchors + } +} +export { AppNodeModel, AppNode } diff --git a/src/MaxKB-1.7.2/ui/src/workflow/common/data.ts b/src/MaxKB-1.7.2/ui/src/workflow/common/data.ts new file mode 100644 index 0000000..8cca9bc --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/common/data.ts @@ -0,0 +1,259 @@ +import { WorkflowType } from '@/enums/workflow' +import { t } from '@/locales' + +export const startNode = { + id: WorkflowType.Start, + type: WorkflowType.Start, + x: 180, + y: 720, + properties: { + height: 200, + stepName: '开始', + config: { + fields: [ + { + label: '用户问题', + value: 'question' + } + ], + globalFields: [ + { + value: 'time', + label: '当前时间' + } + ] + } + } +} +export const baseNode = { + id: WorkflowType.Base, + type: WorkflowType.Base, + x: 200, + y: 270, + text: '', + properties: { + width: 420, + height: 200, + stepName: '基本信息', + input_field_list: [], + node_data: { + name: '', + desc: '', + prologue: t('views.application.prompt.defaultPrologue') + }, + config: {} + } +} +/** + * 说明 + * type 与 nodes 文件对应 + */ +export const baseNodes = [baseNode, startNode] +/** + * ai对话节点配置数据 + */ +export const aiChatNode = { + type: WorkflowType.AiChat, + text: '与 AI 大模型进行对话', + label: 'AI 对话', + height: 340, + properties: { + stepName: 'AI 对话', + config: { + fields: [ + { + label: 'AI 回答内容', + value: 'answer' + } + ] + } + } +} +/** + * 知识库检索配置数据 + */ +export const searchDatasetNode = { + type: WorkflowType.SearchDataset, + text: '关联知识库,查找与问题相关的分段', + label: '知识库检索', + height: 355, + properties: { + stepName: '知识库检索', + config: { + fields: [ + { label: '检索结果的分段列表', value: 'paragraph_list' }, + { label: '满足直接回答的分段列表', value: 'is_hit_handling_method_list' }, + { + label: '检索结果', + value: 'data' + }, + { + label: '满足直接回答的分段内容', + value: 'directly_return' + } + ] + } + } +} +export const questionNode = { + type: WorkflowType.Question, + text: '根据历史聊天记录优化完善当前问题,更利于匹配知识库分段', + label: '问题优化', + height: 345, + properties: { + stepName: '问题优化', + config: { + fields: [ + { + label: '问题优化结果', + value: 'answer' + } + ] + } + } +} +export const conditionNode = { + type: WorkflowType.Condition, + text: '根据不同条件执行不同的节点', + label: '判断器', + height: 175, + properties: { + width: 600, + stepName: '判断器', + config: { + fields: [ + { + label: '分支名称', + value: 'branch_name' + } + ] + } + } +} +export const replyNode = { + type: WorkflowType.Reply, + text: '指定回复内容,引用变量会转换为字符串进行输出', + label: '指定回复', + height: 210, + properties: { + stepName: '指定回复', + config: { + fields: [ + { + label: '内容', + value: 'answer' + } + ] + } + } +} +export const rerankerNode = { + type: WorkflowType.RrerankerNode, + text: '使用重排模型对多个知识库的检索结果进行二次召回', + label: '多路召回', + height: 252, + properties: { + stepName: '多路召回', + config: { + fields: [ + { + label: '重排结果列表', + value: 'result_list' + }, + { + label: '重排结果', + value: 'result' + } + ] + } + } +} +export const menuNodes = [ + aiChatNode, + searchDatasetNode, + questionNode, + conditionNode, + replyNode, + rerankerNode +] + +/** + * 自定义函数配置数据 + */ +export const functionNode = { + type: WorkflowType.FunctionLibCustom, + text: '通过执行自定义脚本,实现数据处理', + label: '自定义函数', + height: 260, + properties: { + stepName: '自定义函数', + config: { + fields: [ + { + label: '结果', + value: 'result' + } + ] + } + } +} +export const functionLibNode = { + type: WorkflowType.FunctionLib, + text: '通过执行自定义脚本,实现数据处理', + label: '自定义函数', + height: 170, + properties: { + stepName: '自定义函数', + config: { + fields: [ + { + label: '结果', + value: 'result' + } + ] + } + } +} + +export const compareList = [ + { value: 'is_null', label: '为空' }, + { value: 'is_not_null', label: '不为空' }, + { value: 'contain', label: '包含' }, + { value: 'not_contain', label: '不包含' }, + { value: 'eq', label: '等于' }, + { value: 'ge', label: '大于等于' }, + { value: 'gt', label: '大于' }, + { value: 'le', label: '小于等于' }, + { value: 'lt', label: '小于' }, + { value: 'len_eq', label: '长度等于' }, + { value: 'len_ge', label: '长度大于等于' }, + { value: 'len_gt', label: '长度大于' }, + { value: 'len_le', label: '长度小于等于' }, + { value: 'len_lt', label: '长度小于' } +] + +export const nodeDict: any = { + [WorkflowType.AiChat]: aiChatNode, + [WorkflowType.SearchDataset]: searchDatasetNode, + [WorkflowType.Question]: questionNode, + [WorkflowType.Condition]: conditionNode, + [WorkflowType.Base]: baseNode, + [WorkflowType.Start]: startNode, + [WorkflowType.Reply]: replyNode, + [WorkflowType.FunctionLib]: functionLibNode, + [WorkflowType.FunctionLibCustom]: functionNode, + [WorkflowType.RrerankerNode]: rerankerNode +} +export function isWorkFlow(type: string | undefined) { + return type === 'WORK_FLOW' +} + +export function isLastNode(nodeModel: any) { + const incoming = nodeModel.graphModel.getNodeIncomingNode(nodeModel.id) + const outcomming = nodeModel.graphModel.getNodeOutgoingNode(nodeModel.id) + if (incoming.length > 0 && outcomming.length === 0) { + return true + } else { + return false + } +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/common/edge.ts b/src/MaxKB-1.7.2/ui/src/workflow/common/edge.ts new file mode 100644 index 0000000..cf4fdc1 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/common/edge.ts @@ -0,0 +1,184 @@ +import { BezierEdge, BezierEdgeModel, h } from '@logicflow/core' +import { createApp, h as vh } from 'vue' + +import CustomLine from './CustomLine.vue' +function isMouseInElement(element: any, e: any) { + const rect = element.getBoundingClientRect() + return ( + e.clientX >= rect.left && + e.clientX <= rect.right && + e.clientY >= rect.top && + e.clientY <= rect.bottom + ) +} +const DEFAULT_WIDTH = 32 +const DEFAULT_HEIGHT = 32 +class CustomEdge2 extends BezierEdge { + isMounted + + constructor() { + super() + this.isMounted = false + this.handleMouseUp = (e: any) => { + this.props.graphModel.clearSelectElements() + this.props.model.isSelected = true + const element = e.target.parentNode.parentNode.querySelector('.lf-custom-edge-wrapper') + if (isMouseInElement(element, e)) { + this.props.model.graphModel.deleteEdgeById(this.props.model.id) + } + } + } + + getEdge() { + const { model } = this.props + const id = model.id + const { customWidth = DEFAULT_WIDTH, customHeight = DEFAULT_HEIGHT } = model.getProperties() + const { startPoint, endPoint, path, isAnimation, arrowConfig } = model + const animationStyle = model.getEdgeAnimationStyle() + const { + strokeDasharray, + stroke, + strokeDashoffset, + animationName, + animationDuration, + animationIterationCount, + animationTimingFunction, + animationDirection + } = animationStyle + const positionData = { + x: (startPoint.x + endPoint.x - customWidth) / 2, + y: (startPoint.y + endPoint.y - customHeight) / 2, + width: customWidth, + height: customHeight + } + const style = model.getEdgeStyle() + const wrapperStyle = { + width: customWidth, + height: customHeight + } + + const app = createApp({ + render: () => vh(CustomLine, { model: this.props.model }) + }) + setTimeout(() => { + const s = document.getElementById(id) + if (s && !this.isMounted) { + app.mount(s) + this.isMounted = true + } + }, 0) + + delete style.stroke + + return h('g', {}, [ + h('style', { type: 'text/css' }, '.lf-edge{stroke:#afafaf}.lf-edge:hover{stroke: #3370FF;}'), + h('path', { + d: path, + ...style, + ...arrowConfig, + ...(isAnimation + ? { + strokeDasharray, + stroke, + style: { + strokeDashoffset, + animationName, + animationDuration, + animationIterationCount, + animationTimingFunction, + animationDirection + } + } + : {}) + }), + h( + 'foreignObject', + { + ...positionData, + y: positionData.y + 5, + x: positionData.x + 5, + style: {} + }, + [ + h('div', { + id, + style: { ...wrapperStyle }, + className: 'lf-custom-edge-wrapper' + }) + ] + ) + ]) + } +} + +class CustomEdgeModel2 extends BezierEdgeModel { + getArrowStyle() { + const arrowStyle = super.getArrowStyle() + arrowStyle.offset = 1 + arrowStyle.verticalLength = 0 + return arrowStyle + } + + getEdgeStyle() { + const style = super.getEdgeStyle() + // svg属性 + style.strokeWidth = 2 + style.stroke = '#BBBFC4' + style.offset = 0 + return style + } + /** + * 重写此方法,使保存数据是能带上锚点数据。 + */ + getData() { + const data: any = super.getData() + if (data) { + data.sourceAnchorId = this.sourceAnchorId + data.targetAnchorId = this.targetAnchorId + } + return data + } + /** + * 给边自定义方案,使其支持基于锚点的位置更新边的路径 + */ + updatePathByAnchor() { + // TODO + const sourceNodeModel = this.graphModel.getNodeModelById(this.sourceNodeId) + const sourceAnchor = sourceNodeModel + .getDefaultAnchor() + .find((anchor: any) => anchor.id === this.sourceAnchorId) + + const targetNodeModel = this.graphModel.getNodeModelById(this.targetNodeId) + const targetAnchor = targetNodeModel + .getDefaultAnchor() + .find((anchor: any) => anchor.id === this.targetAnchorId) + if (sourceAnchor && targetAnchor) { + const startPoint = { + x: sourceAnchor.x, + y: sourceAnchor.y + } + this.updateStartPoint(startPoint) + const endPoint = { + x: targetAnchor.x, + y: targetAnchor.y + } + + this.updateEndPoint(endPoint) + } + + // 这里需要将原有的pointsList设置为空,才能触发bezier的自动计算control点。 + this.pointsList = [] + this.initPoints() + } + setAttributes(): void { + super.setAttributes() + this.isHitable = true + this.zIndex = 0 + } +} + +export default { + type: 'app-edge', + view: CustomEdge2, + model: CustomEdgeModel2 +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/common/shortcut.ts b/src/MaxKB-1.7.2/ui/src/workflow/common/shortcut.ts new file mode 100644 index 0000000..b2fc43f --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/common/shortcut.ts @@ -0,0 +1,135 @@ +import type LogicFlow from '@logicflow/core' +import { type GraphModel } from '@logicflow/core' +import { MsgSuccess, MsgError, MsgConfirm } from '@/utils/message' +import { WorkflowType } from '@/enums/workflow' +let selected: any | null = null + +function translationNodeData(nodeData: any, distance: any) { + nodeData.x += distance + nodeData.y += distance + if (nodeData.text) { + nodeData.text.x += distance + nodeData.text.y += distance + } + return nodeData +} + +function translationEdgeData(edgeData: any, distance: any) { + if (edgeData.startPoint) { + edgeData.startPoint.x += distance + edgeData.startPoint.y += distance + } + if (edgeData.endPoint) { + edgeData.endPoint.x += distance + edgeData.endPoint.y += distance + } + if (edgeData.pointsList && edgeData.pointsList.length > 0) { + edgeData.pointsList.forEach((point: any) => { + point.x += distance + point.y += distance + }) + } + if (edgeData.text) { + edgeData.text.x += distance + edgeData.text.y += distance + } + return edgeData +} + +const TRANSLATION_DISTANCE = 40 +let CHILDREN_TRANSLATION_DISTANCE = 40 + +export function initDefaultShortcut(lf: LogicFlow, graph: GraphModel) { + const { keyboard } = lf + const { + options: { keyboard: keyboardOptions } + } = keyboard + const copy_node = () => { + CHILDREN_TRANSLATION_DISTANCE = TRANSLATION_DISTANCE + if (!keyboardOptions?.enabled) return true + if (graph.textEditElement) return true + const { guards } = lf.options + const elements = graph.getSelectElements(false) + const enabledClone = guards && guards.beforeClone ? guards.beforeClone(elements) : true + if (!enabledClone || (elements.nodes.length === 0 && elements.edges.length === 0)) { + selected = null + return true + } + const base_nodes = elements.nodes.filter( + (node: any) => node.type === WorkflowType.Start || node.type === WorkflowType.Base + ) + if (base_nodes.length > 0) { + MsgError(base_nodes[0]?.properties?.stepName + '不能被复制') + return + } + selected = elements + selected.nodes.forEach((node: any) => translationNodeData(node, TRANSLATION_DISTANCE)) + selected.edges.forEach((edge: any) => translationEdgeData(edge, TRANSLATION_DISTANCE)) + MsgSuccess('已复制节点') + return false + } + const paste_node = () => { + if (!keyboardOptions?.enabled) return true + if (graph.textEditElement) return true + if (selected && (selected.nodes || selected.edges)) { + lf.clearSelectElements() + const addElements = lf.addElements(selected, CHILDREN_TRANSLATION_DISTANCE) + if (!addElements) return true + addElements.nodes.forEach((node) => lf.selectElementById(node.id, true)) + addElements.edges.forEach((edge) => lf.selectElementById(edge.id, true)) + selected.nodes.forEach((node: any) => translationNodeData(node, TRANSLATION_DISTANCE)) + selected.edges.forEach((edge: any) => translationEdgeData(edge, TRANSLATION_DISTANCE)) + CHILDREN_TRANSLATION_DISTANCE = CHILDREN_TRANSLATION_DISTANCE + TRANSLATION_DISTANCE + } + return false + } + const delete_node = () => { + const elements = graph.getSelectElements(true) + lf.clearSelectElements() + if (elements.nodes.length == 0 && elements.edges.length == 0) { + return + } + if (elements.edges.length > 0 && elements.nodes.length == 0) { + elements.edges.forEach((edge: any) => lf.deleteEdge(edge.id)) + return + } + const nodes = elements.nodes.filter((node) => ['start-node', 'base-node'].includes(node.type)) + if (nodes.length > 0) { + MsgError(`${nodes[0].properties?.stepName}节点不允许删除`) + return + } + MsgConfirm(`提示`, `确定删除该节点?`, { + confirmButtonText: '删除', + confirmButtonClass: 'danger' + }).then(() => { + if (!keyboardOptions?.enabled) return true + if (graph.textEditElement) return true + + elements.edges.forEach((edge: any) => lf.deleteEdge(edge.id)) + elements.nodes.forEach((node: any) => lf.deleteNode(node.id)) + }) + + return false + } + graph.eventCenter.on('copy_node', copy_node) + // 复制 + keyboard.on(['cmd + c', 'ctrl + c'], copy_node) + // 粘贴 + keyboard.on(['cmd + v', 'ctrl + v'], paste_node) + // undo + keyboard.on(['cmd + z', 'ctrl + z'], () => { + // if (!keyboardOptions?.enabled) return true + // if (graph.textEditElement) return true + // lf.undo() + // return false + }) + // redo + keyboard.on(['cmd + y', 'ctrl + y'], () => { + if (!keyboardOptions?.enabled) return true + if (graph.textEditElement) return true + lf.redo() + return false + }) + // delete + keyboard.on(['backspace'], delete_node) +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/common/validate.ts b/src/MaxKB-1.7.2/ui/src/workflow/common/validate.ts new file mode 100644 index 0000000..000d13c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/common/validate.ts @@ -0,0 +1,144 @@ +import { WorkflowType } from '@/enums/workflow' + +const end_nodes: Array = [ + WorkflowType.AiChat, + WorkflowType.Reply, + WorkflowType.FunctionLib, + WorkflowType.FunctionLibCustom +] +export class WorkFlowInstance { + nodes + edges + workFlowNodes: Array + constructor(workflow: { nodes: Array; edges: Array }) { + this.nodes = workflow.nodes + this.edges = workflow.edges + this.workFlowNodes = [] + } + /** + * 校验开始节点 + */ + private is_valid_start_node() { + const start_node_list = this.nodes.filter((item) => item.id === WorkflowType.Start) + if (start_node_list.length == 0) { + throw '开始节点必填' + } else if (start_node_list.length > 1) { + throw '开始节点只能有一个' + } + } + /** + * 校验基本信息节点 + */ + private is_valid_base_node() { + const start_node_list = this.nodes.filter((item) => item.id === WorkflowType.Base) + if (start_node_list.length == 0) { + throw '基本信息节点必填' + } else if (start_node_list.length > 1) { + throw '基本信息节点只能有一个' + } + } + /** + * 校验节点 + */ + is_valid() { + this.is_valid_start_node() + this.is_valid_base_node() + this.is_valid_work_flow() + this.is_valid_nodes() + } + + /** + * 获取开始节点 + * @returns + */ + get_start_node() { + const start_node_list = this.nodes.filter((item) => item.id === WorkflowType.Start) + return start_node_list[0] + } + /** + * 获取基本节点 + * @returns 基本节点 + */ + get_base_node() { + const base_node_list = this.nodes.filter((item) => item.id === WorkflowType.Base) + return base_node_list[0] + } + + /** + * 校验工作流 + * @param up_node 上一个节点 + */ + private _is_valid_work_flow(up_node?: any) { + if (!up_node) { + up_node = this.get_start_node() + } + this.workFlowNodes.push(up_node) + this.is_valid_node(up_node) + const next_nodes = this.get_next_nodes(up_node) + for (const next_node of next_nodes) { + this._is_valid_work_flow(next_node) + } + } + private is_valid_work_flow() { + this.workFlowNodes = [] + this._is_valid_work_flow() + const notInWorkFlowNodes = this.nodes + .filter((node: any) => node.id !== WorkflowType.Start && node.id !== WorkflowType.Base) + .filter((node) => !this.workFlowNodes.includes(node)) + if (notInWorkFlowNodes.length > 0) { + throw `未在流程中的节点:${notInWorkFlowNodes.map((node) => node.properties.stepName).join(',')}` + } + this.workFlowNodes = [] + } + /** + * 获取流程下一个节点列表 + * @param node 节点 + * @returns 节点列表 + */ + private get_next_nodes(node: any) { + const edge_list = this.edges.filter((edge) => edge.sourceNodeId == node.id) + const node_list = edge_list + .map((edge) => this.nodes.filter((node) => node.id == edge.targetNodeId)) + .reduce((x, y) => [...x, ...y], []) + if (node_list.length == 0 && !end_nodes.includes(node.type)) { + throw '不存在的下一个节点' + } + return node_list + } + private is_valid_nodes() { + for (const node of this.nodes) { + if (node.type !== WorkflowType.Base && node.type !== WorkflowType.Start) { + if (!this.edges.some((edge) => edge.targetNodeId === node.id)) { + throw `未在流程中的节点:${node.properties.stepName}` + } + } + } + } + /** + * 校验节点 + * @param node 节点 + */ + private is_valid_node(node: any) { + if (node.properties.status && node.properties.status === 500) { + throw `${node.properties.stepName} 节点不可用` + } + if (node.type === WorkflowType.Condition) { + const branch_list = node.properties.node_data.branch + for (const branch of branch_list) { + const source_anchor_id = `${node.id}_${branch.id}_right` + const edge_list = this.edges.filter((edge) => edge.sourceAnchorId == source_anchor_id) + if (edge_list.length == 0) { + throw `${node.properties.stepName} 节点的${branch.type}分支需要连接` + } + } + } else { + const edge_list = this.edges.filter((edge) => edge.sourceNodeId == node.id) + if (edge_list.length == 0 && !end_nodes.includes(node.type)) { + throw `${node.properties.stepName} 节点不能当做结束节点` + } + } + if (node.properties.status && node.properties.status !== 200) { + throw `${node.properties.stepName} 节点不可用` + } + } +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/ai-chat-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/ai-chat-node-icon.vue new file mode 100644 index 0000000..24e6d46 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/ai-chat-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/base-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/base-node-icon.vue new file mode 100644 index 0000000..a0b5d27 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/base-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/condition-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/condition-node-icon.vue new file mode 100644 index 0000000..6deed31 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/condition-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/function-lib-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/function-lib-node-icon.vue new file mode 100644 index 0000000..e6e84a8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/function-lib-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/function-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/function-node-icon.vue new file mode 100644 index 0000000..e6e84a8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/function-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/global-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/global-icon.vue new file mode 100644 index 0000000..5d476dc --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/global-icon.vue @@ -0,0 +1,4 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/question-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/question-node-icon.vue new file mode 100644 index 0000000..74ab30d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/question-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/reply-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/reply-node-icon.vue new file mode 100644 index 0000000..07b2ed5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/reply-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/reranker-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/reranker-node-icon.vue new file mode 100644 index 0000000..70c8f48 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/reranker-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/search-dataset-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/search-dataset-node-icon.vue new file mode 100644 index 0000000..d2b2302 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/search-dataset-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/start-node-icon.vue b/src/MaxKB-1.7.2/ui/src/workflow/icons/start-node-icon.vue new file mode 100644 index 0000000..9a01ac9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/start-node-icon.vue @@ -0,0 +1,6 @@ + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/icons/utils.ts b/src/MaxKB-1.7.2/ui/src/workflow/icons/utils.ts new file mode 100644 index 0000000..b667cdc --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/icons/utils.ts @@ -0,0 +1,5 @@ +const icons: any = import.meta.glob('./**.vue', { eager: true }) +export function iconComponent(name: string) { + const url = `./${name}.vue` + return icons[url]?.default || null +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/index.vue new file mode 100644 index 0000000..3d9daaa --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/index.vue @@ -0,0 +1,169 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/ai-chat-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/ai-chat-node/index.ts new file mode 100644 index 0000000..b226719 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/ai-chat-node/index.ts @@ -0,0 +1,12 @@ +import ChatNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class ChatNode extends AppNode { + constructor(props: any) { + super(props, ChatNodeVue) + } +} +export default { + type: 'ai-chat-node', + model: AppNodeModel, + view: ChatNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/ai-chat-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/ai-chat-node/index.vue new file mode 100644 index 0000000..77d0a73 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/ai-chat-node/index.vue @@ -0,0 +1,326 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/ApiFieldFormDialog.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/ApiFieldFormDialog.vue new file mode 100644 index 0000000..474cae9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/ApiFieldFormDialog.vue @@ -0,0 +1,128 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue new file mode 100644 index 0000000..f67642f --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/ApiInputFieldTable.vue @@ -0,0 +1,122 @@ + + + + + + \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/UserFieldFormDialog.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/UserFieldFormDialog.vue new file mode 100644 index 0000000..b64aef0 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/UserFieldFormDialog.vue @@ -0,0 +1,157 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/UserInputFieldTable.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/UserInputFieldTable.vue new file mode 100644 index 0000000..e0c87a8 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/component/UserInputFieldTable.vue @@ -0,0 +1,149 @@ + + + + + + \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/index.ts new file mode 100644 index 0000000..69ede8e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/index.ts @@ -0,0 +1,22 @@ +import BaseNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' + +class BaseNode extends AppNode { + constructor(props: any) { + super(props, BaseNodeVue) + } +} + +class BaseModel extends AppNodeModel { + constructor(data: any, graphModel: any) { + super(data, graphModel) + } + get_width() { + return 600 + } +} +export default { + type: 'base-node', + model: BaseModel, + view: BaseNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/index.vue new file mode 100644 index 0000000..eb5df11 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/base-node/index.vue @@ -0,0 +1,367 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/condition-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/condition-node/index.ts new file mode 100644 index 0000000..275c63a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/condition-node/index.ts @@ -0,0 +1,68 @@ +import ConditioNodeVue from './index.vue' +import { cloneDeep, set } from 'lodash' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class ConditioNode extends AppNode { + constructor(props: any) { + super(props, ConditioNodeVue) + } +} +const get_up_index_height = (condition_list: Array, index: number) => { + return condition_list + .filter((item, i) => i < index) + .map((item) => item.height + 8) + .reduce((x, y) => x + y, 0) +} +class ConditionModel extends AppNodeModel { + refreshBranch() { + // 更新节点连接边的path + this.incoming.edges.forEach((edge: any) => { + // 调用自定义的更新方案 + edge.updatePathByAnchor() + }) + this.outgoing.edges.forEach((edge: any) => { + edge.updatePathByAnchor() + }) + } + getDefaultAnchor() { + const { + id, + x, + y, + width, + height, + properties: { branch_condition_list } + } = this + if (this.height === undefined) { + this.height = 200 + } + const showNode = this.properties.showNode === undefined ? true : this.properties.showNode + const anchors: any = [] + anchors.push({ + x: x - width / 2 + 10, + y: showNode ? y : y - 15, + id: `${id}_left`, + edgeAddable: false, + type: 'left' + }) + + if (branch_condition_list) { + for (let index = 0; index < branch_condition_list.length; index++) { + const element = branch_condition_list[index] + const h = get_up_index_height(branch_condition_list, index) + anchors.push({ + x: x + width / 2 - 10, + y: showNode ? y - height / 2 + 75 + h + element.height / 2 : y - 15, + id: `${id}_${element.id}_right`, + type: 'right' + }) + } + } + + return anchors + } +} +export default { + type: 'condition-node', + model: ConditionModel, + view: ConditioNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/condition-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/condition-node/index.vue new file mode 100644 index 0000000..2a2346a --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/condition-node/index.vue @@ -0,0 +1,296 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-lib-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-lib-node/index.ts new file mode 100644 index 0000000..475818c --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-lib-node/index.ts @@ -0,0 +1,12 @@ +import FunctionLibNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class FunctionLibNode extends AppNode { + constructor(props: any) { + super(props, FunctionLibNodeVue) + } +} +export default { + type: 'function-lib-node', + model: AppNodeModel, + view: FunctionLibNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-lib-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-lib-node/index.vue new file mode 100644 index 0000000..53ac5a4 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-lib-node/index.vue @@ -0,0 +1,144 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-node/index.ts new file mode 100644 index 0000000..ab3f36e --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-node/index.ts @@ -0,0 +1,12 @@ +import FunctionNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class FunctionLibCustomNode extends AppNode { + constructor(props: any) { + super(props, FunctionNodeVue) + } +} +export default { + type: 'function-node', + model: AppNodeModel, + view: FunctionLibCustomNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-node/index.vue new file mode 100644 index 0000000..db08024 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/function-node/index.vue @@ -0,0 +1,233 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/question-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/question-node/index.ts new file mode 100644 index 0000000..324c246 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/question-node/index.ts @@ -0,0 +1,12 @@ +import QuestionNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class QuestionNode extends AppNode { + constructor(props: any) { + super(props, QuestionNodeVue) + } +} +export default { + type: 'question-node', + model: AppNodeModel, + view: QuestionNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/question-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/question-node/index.vue new file mode 100644 index 0000000..c6a9918 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/question-node/index.vue @@ -0,0 +1,317 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/reply-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reply-node/index.ts new file mode 100644 index 0000000..e3bd9d9 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reply-node/index.ts @@ -0,0 +1,12 @@ +import ReplyNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class ReplyNode extends AppNode { + constructor(props: any) { + super(props, ReplyNodeVue) + } +} +export default { + type: 'reply-node', + model: AppNodeModel, + view: ReplyNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/reply-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reply-node/index.vue new file mode 100644 index 0000000..77ce289 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reply-node/index.vue @@ -0,0 +1,130 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/reranker-node/ParamSettingDialog.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reranker-node/ParamSettingDialog.vue new file mode 100644 index 0000000..f968ffd --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reranker-node/ParamSettingDialog.vue @@ -0,0 +1,142 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/reranker-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reranker-node/index.ts new file mode 100644 index 0000000..9b3afc5 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reranker-node/index.ts @@ -0,0 +1,12 @@ +import RerankerNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class RerankerNode extends AppNode { + constructor(props: any) { + super(props, RerankerNodeVue) + } +} +export default { + type: 'reranker-node', + model: AppNodeModel, + view: RerankerNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/reranker-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reranker-node/index.vue new file mode 100644 index 0000000..731c6d7 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/reranker-node/index.vue @@ -0,0 +1,319 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/search-dataset-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/search-dataset-node/index.ts new file mode 100644 index 0000000..316854d --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/search-dataset-node/index.ts @@ -0,0 +1,12 @@ +import SearchDatasetVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class SearchDatasetNode extends AppNode { + constructor(props: any) { + super(props, SearchDatasetVue) + } +} +export default { + type: 'search-dataset-node', + model: AppNodeModel, + view: SearchDatasetNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/search-dataset-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/search-dataset-node/index.vue new file mode 100644 index 0000000..6401367 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/search-dataset-node/index.vue @@ -0,0 +1,214 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/start-node/index.ts b/src/MaxKB-1.7.2/ui/src/workflow/nodes/start-node/index.ts new file mode 100644 index 0000000..5adfb77 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/start-node/index.ts @@ -0,0 +1,12 @@ +import StartNodeVue from './index.vue' +import { AppNode, AppNodeModel } from '@/workflow/common/app-node' +class StartNode extends AppNode { + constructor(props: any) { + super(props, StartNodeVue) + } +} +export default { + type: 'start-node', + model: AppNodeModel, + view: StartNode +} diff --git a/src/MaxKB-1.7.2/ui/src/workflow/nodes/start-node/index.vue b/src/MaxKB-1.7.2/ui/src/workflow/nodes/start-node/index.vue new file mode 100644 index 0000000..883495b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/src/workflow/nodes/start-node/index.vue @@ -0,0 +1,68 @@ + + + diff --git a/src/MaxKB-1.7.2/ui/tsconfig.app.json b/src/MaxKB-1.7.2/ui/tsconfig.app.json new file mode 100644 index 0000000..8601394 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/tsconfig.app.json @@ -0,0 +1,15 @@ +{ + "extends": "@vue/tsconfig/tsconfig.dom.json", + "include": ["env.d.ts", "src/**/*", "src/**/*.vue"], + "exclude": ["src/**/__tests__/*"], + "compilerOptions": { + "composite": true, + "moduleResolution": "node", + "baseUrl": ".", + "target": "esnext", // 使用ES最新语法 + "module": "esnext", // 使用ES模块语法 + "paths": { + "@/*": ["./src/*"] + } + } +} diff --git a/src/MaxKB-1.7.2/ui/tsconfig.json b/src/MaxKB-1.7.2/ui/tsconfig.json new file mode 100644 index 0000000..248dd5b --- /dev/null +++ b/src/MaxKB-1.7.2/ui/tsconfig.json @@ -0,0 +1,16 @@ +{ + + "files": [], + + "references": [ + { + "path": "./tsconfig.node.json" + }, + { + "path": "./tsconfig.app.json" + }, + { + "path": "./tsconfig.vitest.json" + } + ] +} diff --git a/src/MaxKB-1.7.2/ui/tsconfig.node.json b/src/MaxKB-1.7.2/ui/tsconfig.node.json new file mode 100644 index 0000000..b4a8795 --- /dev/null +++ b/src/MaxKB-1.7.2/ui/tsconfig.node.json @@ -0,0 +1,19 @@ +{ + "extends": "@tsconfig/node18/tsconfig.json", + "include": [ + "vite.config.*", + "vitest.config.*", + "cypress.config.*", + "nightwatch.conf.*", + "playwright.config.*" + ], + "compilerOptions": { + "composite": true, + "module": "ESNext", + "moduleResolution": "node", + "skipLibCheck": true, // 跳过node依赖包语法检查 + "types": [ + "node" + ] + } +} \ No newline at end of file diff --git a/src/MaxKB-1.7.2/ui/tsconfig.vitest.json b/src/MaxKB-1.7.2/ui/tsconfig.vitest.json new file mode 100644 index 0000000..940b2df --- /dev/null +++ b/src/MaxKB-1.7.2/ui/tsconfig.vitest.json @@ -0,0 +1,11 @@ +{ + "extends": "./tsconfig.app.json", + "exclude": [], + "compilerOptions": { + "composite": true, + "target": "esnext", // 使用ES最新语法 + "module": "esnext", // 使用ES模块语法 + "lib": [], + "types": ["node", "jsdom"] + } +} diff --git a/src/MaxKB-1.7.2/ui/vite.config.ts b/src/MaxKB-1.7.2/ui/vite.config.ts new file mode 100644 index 0000000..6c442ab --- /dev/null +++ b/src/MaxKB-1.7.2/ui/vite.config.ts @@ -0,0 +1,41 @@ +import { fileURLToPath, URL } from 'node:url' +import type { ProxyOptions } from 'vite' +import { defineConfig, loadEnv } from 'vite' + +import vue from '@vitejs/plugin-vue' +import DefineOptions from 'unplugin-vue-define-options/vite' + +const envDir = './env' +// https://vitejs.dev/config/ +export default defineConfig(({ mode }) => { + const ENV = loadEnv(mode, envDir) + const prefix = process.env.VITE_DYNAMIC_PREFIX || ENV.VITE_BASE_PATH; + const proxyConf: Record = {} + proxyConf['/api'] = { + target: 'http://127.0.0.1:8080', + changeOrigin: true, + rewrite: (path) => path.replace(ENV.VITE_BASE_PATH, '/') + } + return { + preflight: false, + lintOnSave: false, + base: prefix, + envDir: envDir, + plugins: [vue(), DefineOptions()], + server: { + cors: true, + host: '0.0.0.0', + port: Number(ENV.VITE_APP_PORT), + strictPort: true, + proxy: proxyConf + }, + build: { + outDir: 'dist/ui' + }, + resolve: { + alias: { + '@': fileURLToPath(new URL('./src', import.meta.url)) + } + } + } +}) diff --git a/src/MaxKB-1.7.2/ui/vitest.config.ts b/src/MaxKB-1.7.2/ui/vitest.config.ts new file mode 100644 index 0000000..7c37fae --- /dev/null +++ b/src/MaxKB-1.7.2/ui/vitest.config.ts @@ -0,0 +1,14 @@ +import { fileURLToPath } from 'node:url' +import { mergeConfig, defineConfig, configDefaults } from 'vitest/config' +import viteConfig from './vite.config' + +export default mergeConfig( + viteConfig as never, + defineConfig({ + test: { + environment: 'jsdom', + exclude: [...configDefaults.exclude, 'e2e/*'], + root: fileURLToPath(new URL('./', import.meta.url)) + } + }) +)