diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index fc4233504b2..29fbc588679 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -158,7 +158,7 @@ jobs: - name: Run unit test run: | - uv sync --python 3.12 --group test --frozen + uv sync --python 3.13 --group test --frozen source .venv/bin/activate which pytest || echo "pytest not in PATH" echo "Start to run unit test" @@ -222,7 +222,7 @@ jobs: # Patch entrypoint.sh for coverage sed -i '/"\$PY" api\/ragflow_server.py \${INIT_SUPERUSER_ARGS} &/c\ echo "Ensuring coverage is installed..."\n "$PY" -m pip install coverage -i https://mirrors.aliyun.com/pypi/simple\n export COVERAGE_FILE=/ragflow/logs/.coverage\n echo "Starting ragflow_server with coverage..."\n "$PY" -m coverage run --source=./api/apps --omit="*/tests/*,*/migrations/*" -a api/ragflow_server.py ${INIT_SUPERUSER_ARGS} &' ./entrypoint.sh cd .. - uv sync --python 3.12 --group test --frozen && uv pip install -e sdk/python + uv sync --python 3.13 --group test --frozen && uv pip install -e sdk/python - name: Start ragflow:nightly for Infinity @@ -240,23 +240,14 @@ jobs: echo "Start to run test sdk on Infinity" source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-infinity-sdk.xml test/testcases/test_sdk_api 2>&1 | tee infinity_sdk_test.log - - name: Run web api tests against Infinity + - name: Run New RESTFUL api tests against Infinity run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do echo "Waiting for service to be available... (last exit code: $?)" sleep 5 done - source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api/test_chunk_feedback 2>&1 | tee infinity_web_api_test.log - - - name: Run http api tests against Infinity - run: | - export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do - echo "Waiting for service to be available... (last exit code: $?)" - sleep 5 - done - source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee infinity_http_api_test.log + source .venv/bin/activate && set -o pipefail; DOC_ENGINE=infinity pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/restful_api 2>&1 | tee infinity_restful_api_test.log - name: RAGFlow CLI retrieval test Infinity env: @@ -432,24 +423,15 @@ jobs: done echo "Start to run test sdk on Elasticsearch" source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} --junitxml=pytest-infinity-sdk.xml --cov=sdk/python/ragflow_sdk --cov-branch --cov-report=xml:coverage-es-sdk.xml test/testcases/test_sdk_api 2>&1 | tee es_sdk_test.log - - - name: Run web api tests against Elasticsearch - run: | - export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" - until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do - echo "Waiting for service to be available... (last exit code: $?)" - sleep 5 - done - source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_web_api 2>&1 | tee es_web_api_test.log - - - name: Run http api tests against Elasticsearch + + - name: Run New RESTFUL api tests against Elasticsearch run: | export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY="" until sudo docker exec ${RAGFLOW_CONTAINER} curl -s --connect-timeout 5 ${HOST_ADDRESS}/v1/system/ping > /dev/null 2>&1; do echo "Waiting for service to be available... (last exit code: $?)" sleep 5 done - source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/test_http_api 2>&1 | tee es_http_api_test.log + source .venv/bin/activate && set -o pipefail; pytest -s --tb=short --level=${HTTP_API_TEST_LEVEL} test/testcases/restful_api 2>&1 | tee es_restful_api_test.log - name: RAGFlow CLI retrieval test Elasticsearch env: diff --git a/.gitignore b/.gitignore index f65d204fb24..097a885152b 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ Cargo.lock .idea/ .vscode/ +.cursor/settings.json # Exclude Mac generated files .DS_Store diff --git a/CLAUDE.md b/CLAUDE.md index 81888ba3d71..7cb61ad1266 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -51,7 +51,7 @@ RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on d ```bash # Install Python dependencies -uv sync --python 3.12 --all-extras +uv sync --python 3.13 --all-extras uv run python3 download_deps.py pre-commit install @@ -118,7 +118,7 @@ RAGFlow supports switching between Elasticsearch (default) and Infinity: ## Development Environment Requirements -- Python 3.10-3.12 +- Python 3.10-3.13 - Node.js >=18.20.4 - Docker & Docker Compose - uv package manager diff --git a/Dockerfile b/Dockerfile index fdc5f4c4bba..dd7fcfa8730 100644 --- a/Dockerfile +++ b/Dockerfile @@ -78,7 +78,7 @@ RUN --mount=type=bind,from=infiniflow/ragflow_deps:latest,source=/,target=/deps tar xzf "/deps/uv-${uv_arch}-unknown-linux-gnu.tar.gz" \ && cp "uv-${uv_arch}-unknown-linux-gnu/"* /usr/local/bin/ \ && rm -rf "uv-${uv_arch}-unknown-linux-gnu" \ - && uv python install 3.12 + && uv python install 3.13 ENV PYTHONDONTWRITEBYTECODE=1 DOTNET_SYSTEM_GLOBALIZATION_INVARIANT=1 \ UV_HTTP_TIMEOUT=200 \ @@ -147,7 +147,7 @@ RUN --mount=type=cache,id=ragflow_uv,target=/root/.cache/uv,sharing=locked \ else \ sed -i 's|mirrors.aliyun.com/pypi|pypi.org|g' uv.lock; \ fi; \ - uv sync --python 3.12 --frozen && \ + uv sync --python 3.13 --frozen && \ # Ensure pip is available in the venv for runtime package installation (fixes #12651) .venv/bin/python3 -m ensurepip --upgrade diff --git a/README.md b/README.md index 5f8bed3db16..2a86b2490fb 100644 --- a/README.md +++ b/README.md @@ -25,7 +25,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Latest Release @@ -192,12 +192,12 @@ releases! 🌟 > All Docker images are built for x86 platforms. We don't currently offer Docker images for ARM64. > If you are on an ARM64 platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a Docker image compatible with your system. -> The command below downloads the `v0.25.2` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.25.2`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. +> The command below downloads the `v0.25.5` edition of the RAGFlow Docker image. See the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.25.5`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) # This step ensures the **entrypoint.sh** file in the code matches the Docker image version. diff --git a/README_ar.md b/README_ar.md index a02003d8342..1f7393945e7 100644 --- a/README_ar.md +++ b/README_ar.md @@ -25,7 +25,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Latest Release @@ -192,12 +192,12 @@ > جميع الصور Docker مصممة لمنصات x86. لا نعرض حاليًا صور Docker لـ ARM64. > إذا كنت تستخدم نظامًا أساسيًا ARM64، فاتبع [هذا الدليل](https://ragflow.io/docs/dev/build_docker_image) لإنشاء صورة Docker متوافقة مع نظامك. -> يقوم الأمر أدناه بتنزيل إصدار `v0.25.2` من الصورة RAGFlow Docker. راجع الجدول التالي للحصول على أوصاف لإصدارات RAGFlow المختلفة. لتنزيل إصدار RAGFlow مختلف عن `v0.25.2`، قم بتحديث المتغير `RAGFLOW_IMAGE` وفقًا لذلك في **docker/.env** قبل استخدام `docker compose` لبدء تشغيل الخادم. +> يقوم الأمر أدناه بتنزيل إصدار `v0.25.5` من الصورة RAGFlow Docker. راجع الجدول التالي للحصول على أوصاف لإصدارات RAGFlow المختلفة. لتنزيل إصدار RAGFlow مختلف عن `v0.25.5`، قم بتحديث المتغير `RAGFLOW_IMAGE` وفقًا لذلك في **docker/.env** قبل استخدام `docker compose` لبدء تشغيل الخادم. ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) # This step ensures the **entrypoint.sh** file in the code matches the Docker image version. diff --git a/README_fr.md b/README_fr.md index 37253de7e60..b21d380ec34 100644 --- a/README_fr.md +++ b/README_fr.md @@ -25,7 +25,7 @@ Badge statique - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Dernière version @@ -189,12 +189,12 @@ Essayez notre service cloud sur [https://cloud.ragflow.io](https://cloud.ragflow > Toutes les images Docker sont construites pour les plateformes x86. Nous ne proposons pas actuellement d'images Docker pour ARM64. > Si vous êtes sur une plateforme ARM64, suivez [ce guide](https://ragflow.io/docs/dev/build_docker_image) pour construire une image Docker compatible avec votre système. -> La commande ci-dessous télécharge l'édition `v0.25.2` de l'image Docker RAGFlow. Consultez le tableau suivant pour les descriptions des différentes éditions de RAGFlow. Pour télécharger une édition de RAGFlow différente de `v0.25.2`, mettez à jour la variable `RAGFLOW_IMAGE` dans **docker/.env** avant d'utiliser `docker compose` pour démarrer le serveur. +> La commande ci-dessous télécharge l'édition `v0.25.5` de l'image Docker RAGFlow. Consultez le tableau suivant pour les descriptions des différentes éditions de RAGFlow. Pour télécharger une édition de RAGFlow différente de `v0.25.5`, mettez à jour la variable `RAGFLOW_IMAGE` dans **docker/.env** avant d'utiliser `docker compose` pour démarrer le serveur. ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # Optionnel : utiliser un tag stable (voir les versions : https://github.com/infiniflow/ragflow/releases) # Cette étape garantit que le fichier **entrypoint.sh** dans le code correspond à la version de l'image Docker. diff --git a/README_id.md b/README_id.md index d2cecfcfc5a..a9d45317feb 100644 --- a/README_id.md +++ b/README_id.md @@ -25,7 +25,7 @@ Lencana Daring - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Rilis Terbaru @@ -192,12 +192,12 @@ Coba layanan cloud kami di [https://cloud.ragflow.io](https://cloud.ragflow.io). > Semua gambar Docker dibangun untuk platform x86. Saat ini, kami tidak menawarkan gambar Docker untuk ARM64. > Jika Anda menggunakan platform ARM64, [silakan gunakan panduan ini untuk membangun gambar Docker yang kompatibel dengan sistem Anda](https://ragflow.io/docs/dev/build_docker_image). -> Perintah di bawah ini mengunduh edisi v0.25.2 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.25.2, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. +> Perintah di bawah ini mengunduh edisi v0.25.5 dari gambar Docker RAGFlow. Silakan merujuk ke tabel berikut untuk deskripsi berbagai edisi RAGFlow. Untuk mengunduh edisi RAGFlow yang berbeda dari v0.25.5, perbarui variabel RAGFLOW_IMAGE di docker/.env sebelum menggunakan docker compose untuk memulai server. ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # Opsional: gunakan tag stabil (lihat releases: https://github.com/infiniflow/ragflow/releases) # This steps ensures the **entrypoint.sh** file in the code matches the Docker image version. diff --git a/README_ja.md b/README_ja.md index 1d4100d2eda..185d6e9c360 100644 --- a/README_ja.md +++ b/README_ja.md @@ -25,7 +25,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Latest Release @@ -172,12 +172,12 @@ > 現在、公式に提供されているすべての Docker イメージは x86 アーキテクチャ向けにビルドされており、ARM64 用の Docker イメージは提供されていません。 > ARM64 アーキテクチャのオペレーティングシステムを使用している場合は、[このドキュメント](https://ragflow.io/docs/dev/build_docker_image)を参照して Docker イメージを自分でビルドしてください。 -> 以下のコマンドは、RAGFlow Docker イメージの v0.25.2 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.25.2 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。 +> 以下のコマンドは、RAGFlow Docker イメージの v0.25.5 エディションをダウンロードします。異なる RAGFlow エディションの説明については、以下の表を参照してください。v0.25.5 とは異なるエディションをダウンロードするには、docker/.env ファイルの RAGFLOW_IMAGE 変数を適宜更新し、docker compose を使用してサーバーを起動してください。 ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # 任意: 安定版タグを利用 (一覧: https://github.com/infiniflow/ragflow/releases) # この手順は、コード内の entrypoint.sh ファイルが Docker イメージのバージョンと一致していることを確認します。 diff --git a/README_ko.md b/README_ko.md index 2d293a44f72..ee243344eab 100644 --- a/README_ko.md +++ b/README_ko.md @@ -25,7 +25,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Latest Release @@ -174,12 +174,12 @@ > 모든 Docker 이미지는 x86 플랫폼을 위해 빌드되었습니다. 우리는 현재 ARM64 플랫폼을 위한 Docker 이미지를 제공하지 않습니다. > ARM64 플랫폼을 사용 중이라면, [시스템과 호환되는 Docker 이미지를 빌드하려면 이 가이드를 사용해 주세요](https://ragflow.io/docs/dev/build_docker_image). - > 아래 명령어는 RAGFlow Docker 이미지의 v0.25.2 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.25.2와 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. + > 아래 명령어는 RAGFlow Docker 이미지의 v0.25.5 버전을 다운로드합니다. 다양한 RAGFlow 버전에 대한 설명은 다음 표를 참조하십시오. v0.25.5와 다른 RAGFlow 버전을 다운로드하려면, docker/.env 파일에서 RAGFLOW_IMAGE 변수를 적절히 업데이트한 후 docker compose를 사용하여 서버를 시작하십시오. ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # Optional: use a stable tag (see releases: https://github.com/infiniflow/ragflow/releases) # 이 단계는 코드의 entrypoint.sh 파일이 Docker 이미지 버전과 일치하도록 보장합니다. diff --git a/README_pt_br.md b/README_pt_br.md index c830f1facd8..2bebd2c2611 100644 --- a/README_pt_br.md +++ b/README_pt_br.md @@ -25,7 +25,7 @@ Badge Estático - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Última Versão @@ -192,12 +192,12 @@ Experimente o nosso serviço na nuvem em [https://cloud.ragflow.io](https://clou > Todas as imagens Docker são construídas para plataformas x86. Atualmente, não oferecemos imagens Docker para ARM64. > Se você estiver usando uma plataforma ARM64, por favor, utilize [este guia](https://ragflow.io/docs/dev/build_docker_image) para construir uma imagem Docker compatível com o seu sistema. - > O comando abaixo baixa a edição`v0.25.2` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.25.2`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. + > O comando abaixo baixa a edição`v0.25.5` da imagem Docker do RAGFlow. Consulte a tabela a seguir para descrições de diferentes edições do RAGFlow. Para baixar uma edição do RAGFlow diferente da `v0.25.5`, atualize a variável `RAGFLOW_IMAGE` conforme necessário no **docker/.env** antes de usar `docker compose` para iniciar o servidor. ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # Opcional: use uma tag estável (veja releases: https://github.com/infiniflow/ragflow/releases) # Esta etapa garante que o arquivo entrypoint.sh no código corresponda à versão da imagem do Docker. diff --git a/README_tr.md b/README_tr.md index c022dcbf7a1..778d4122fbc 100644 --- a/README_tr.md +++ b/README_tr.md @@ -25,7 +25,7 @@ Çevrimiçi Demo - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Son Sürüm @@ -190,12 +190,12 @@ Bulut hizmetimizi [https://cloud.ragflow.io](https://cloud.ragflow.io) adresinde > Tüm Docker imajları x86 platformları için oluşturulmuştur. Şu anda ARM64 için Docker imajı sunmuyoruz. > ARM64 platformundaysanız, sisteminizle uyumlu bir Docker imajı oluşturmak için [bu kılavuzu](https://ragflow.io/docs/dev/build_docker_image) takip edin. -> Aşağıdaki komut RAGFlow Docker imajının `v0.25.2` sürümünü indirir. Farklı RAGFlow sürümleri için aşağıdaki tabloya bakın. `v0.25.2` dışında bir sürüm indirmek için, `docker compose` ile sunucuyu başlatmadan önce **docker/.env** dosyasındaki `RAGFLOW_IMAGE` değişkenini güncelleyin. +> Aşağıdaki komut RAGFlow Docker imajının `v0.25.5` sürümünü indirir. Farklı RAGFlow sürümleri için aşağıdaki tabloya bakın. `v0.25.5` dışında bir sürüm indirmek için, `docker compose` ile sunucuyu başlatmadan önce **docker/.env** dosyasındaki `RAGFLOW_IMAGE` değişkenini güncelleyin. ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # İsteğe bağlı: Kararlı bir etiket kullanın (sürümler: https://github.com/infiniflow/ragflow/releases) # Bu adım, koddaki **entrypoint.sh** dosyasının Docker imaj sürümüyle eşleşmesini sağlar. diff --git a/README_tzh.md b/README_tzh.md index 172c54a2955..0d6c95af03f 100644 --- a/README_tzh.md +++ b/README_tzh.md @@ -25,7 +25,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Latest Release @@ -191,12 +191,12 @@ > 所有 Docker 映像檔都是為 x86 平台建置的。目前,我們不提供 ARM64 平台的 Docker 映像檔。 > 如果您使用的是 ARM64 平台,請使用 [這份指南](https://ragflow.io/docs/dev/build_docker_image) 來建置適合您系統的 Docker 映像檔。 -> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.25.2`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.25.2` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。 +> 執行以下指令會自動下載 RAGFlow Docker 映像 `v0.25.5`。請參考下表查看不同 Docker 發行版的說明。如需下載不同於 `v0.25.5` 的 Docker 映像,請在執行 `docker compose` 啟動服務之前先更新 **docker/.env** 檔案內的 `RAGFLOW_IMAGE` 變數。 ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # 可選:使用穩定版標籤(查看發佈:https://github.com/infiniflow/ragflow/releases) # 此步驟確保程式碼中的 entrypoint.sh 檔案與 Docker 映像版本一致。 diff --git a/README_zh.md b/README_zh.md index 72de8935d49..c2222a95ab4 100644 --- a/README_zh.md +++ b/README_zh.md @@ -25,7 +25,7 @@ Static Badge - docker pull infiniflow/ragflow:v0.25.2 + docker pull infiniflow/ragflow:v0.25.5 Latest Release @@ -192,12 +192,12 @@ > 请注意,目前官方提供的所有 Docker 镜像均基于 x86 架构构建,并不提供基于 ARM64 的 Docker 镜像。 > 如果你的操作系统是 ARM64 架构,请参考[这篇文档](https://ragflow.io/docs/dev/build_docker_image)自行构建 Docker 镜像。 - > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.25.2`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.25.2` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。 + > 运行以下命令会自动下载 RAGFlow Docker 镜像 `v0.25.5`。请参考下表查看不同 Docker 发行版的描述。如需下载不同于 `v0.25.5` 的 Docker 镜像,请在运行 `docker compose` 启动服务之前先更新 **docker/.env** 文件内的 `RAGFLOW_IMAGE` 变量。 ```bash $ cd ragflow/docker - # git checkout v0.25.2 + # git checkout v0.25.5 # 可选:使用稳定版本标签(查看发布:https://github.com/infiniflow/ragflow/releases) # 这一步确保代码中的 entrypoint.sh 文件与 Docker 镜像的版本保持一致。 diff --git a/admin/client/README.md b/admin/client/README.md index cac7425aad8..964cbcc6fcb 100644 --- a/admin/client/README.md +++ b/admin/client/README.md @@ -48,7 +48,7 @@ It consists of a server-side Service and a command-line client (CLI), both imple 1. Ensure the Admin Service is running. 2. Install ragflow-cli. ```bash - pip install ragflow-cli==0.25.2 + pip install ragflow-cli==0.25.5 ``` 3. Launch the CLI client: ```bash diff --git a/admin/client/parser.py b/admin/client/parser.py index cdb20b491dd..7e668c4e299 100644 --- a/admin/client/parser.py +++ b/admin/client/parser.py @@ -264,7 +264,7 @@ list_keys: LIST KEYS OF quoted_string ";" drop_key: DROP KEY quoted_string OF quoted_string ";" -set_variable: SET VAR identifier identifier ";" +set_variable: SET VAR identifier variable_value ";" show_variable: SHOW VAR identifier ";" list_variables: LIST VARS ";" list_configs: LIST CONFIGS ";" @@ -378,6 +378,7 @@ identifier_list: identifier (COMMA identifier)* identifier: WORD +variable_value: WORD | NUMBER | QUOTED_STRING quoted_string: QUOTED_STRING status: ON | WORD diff --git a/admin/client/pyproject.toml b/admin/client/pyproject.toml index 5f70bb1b188..bb450c330ab 100644 --- a/admin/client/pyproject.toml +++ b/admin/client/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ragflow-cli" -version = "0.25.2" +version = "0.25.5" description = "Admin Service's client of [RAGFlow](https://github.com/infiniflow/ragflow). The Admin Service provides user management and system monitoring. " authors = [{ name = "Lynn", email = "lynn_inf@hotmail.com" }] license = { text = "Apache License, Version 2.0" } diff --git a/admin/client/ragflow_client.py b/admin/client/ragflow_client.py index 148af4b45fe..71a5541bbae 100644 --- a/admin/client/ragflow_client.py +++ b/admin/client/ragflow_client.py @@ -43,6 +43,12 @@ def encrypt(input_string): return base64.b64encode(cipher_text).decode("utf-8") +def _strip_tree_value(value): + if isinstance(value, Tree): + value = value.children[0] + return str(value).strip("'\"") + + class RAGFlowClient: def __init__(self, http_client: HttpClient, server_type: str): self.http_client = http_client @@ -526,10 +532,8 @@ def set_variable(self, command): if self.server_type != "admin": print("This command is only allowed in ADMIN mode") - var_name_tree: Tree = command["var_name"] - var_name = var_name_tree.children[0].strip("'\"") - var_value_tree: Tree = command["var_value"] - var_value = var_value_tree.children[0].strip("'\"") + var_name = _strip_tree_value(command["var_name"]) + var_value = _strip_tree_value(command["var_value"]) response = self.http_client.request("PUT", "/admin/variables", json_body={"var_name": var_name, "var_value": var_value}, use_api_base=True, auth_kind="admin") @@ -544,8 +548,7 @@ def show_variable(self, command): if self.server_type != "admin": print("This command is only allowed in ADMIN mode") - var_name_tree: Tree = command["var_name"] - var_name = var_name_tree.children[0].strip("'\"") + var_name = _strip_tree_value(command["var_name"]) response = self.http_client.request(method="GET", path="/admin/variables", json_body={"var_name": var_name}, use_api_base=True, auth_kind="admin") res_json = response.json() diff --git a/admin/client/uv.lock b/admin/client/uv.lock index 0bf404a2308..17b8554c90e 100644 --- a/admin/client/uv.lock +++ b/admin/client/uv.lock @@ -188,7 +188,7 @@ wheels = [ [[package]] name = "ragflow-cli" -version = "0.25.2" +version = "0.25.5" source = { virtual = "." } dependencies = [ { name = "beartype" }, diff --git a/admin/server/routes.py b/admin/server/routes.py index 658cec48c09..0313d8230be 100644 --- a/admin/server/routes.py +++ b/admin/server/routes.py @@ -421,7 +421,7 @@ def get_user_permission(user_name: str): def set_variable(): try: data = request.get_json() - if not data and "var_name" not in data: + if not data or "var_name" not in data: return error_response("Var name is required", 400) if "var_value" not in data: @@ -449,7 +449,7 @@ def get_variable(): # get var data = request.get_json() - if not data and "var_name" not in data: + if not data or "var_name" not in data: return error_response("Var name is required", 400) var_name: str = data["var_name"] res = SettingsMgr.get_by_name(var_name) diff --git a/admin/server/services.py b/admin/server/services.py index 43646d7918a..dc0a41c6c9d 100644 --- a/admin/server/services.py +++ b/admin/server/services.py @@ -330,36 +330,65 @@ def restart_service(service_id: int): class SettingsMgr: + @staticmethod + def _format_setting(setting): + return { + "data_type": setting.data_type, + "name": setting.name, + "setting_type": "config", + "value": setting.value, + } + + @staticmethod + def _validate_value(name: str, data_type: str, value: str): + data_type = data_type.lower() + value = str(value) + if data_type == "string": + return + if data_type == "integer": + try: + int(value) + except ValueError: + raise AdminException(f"Invalid integer value for {name}: {value}") + return + if data_type in {"bool", "boolean"}: + if value not in {"true", "false"}: + raise AdminException(f"Invalid bool value for {name}: expected true or false") + return + if data_type == "json": + try: + json.loads(value) + except json.JSONDecodeError: + raise AdminException(f"Invalid JSON value for {name}") + return + raise AdminException(f"Unsupported data type for {name}: {data_type}") + + @staticmethod + def _infer_data_type(name: str): + if name.startswith("sandbox."): + return "json" + if name.endswith(".enabled"): + return "bool" + return "string" + @staticmethod def get_all(): - settings = SystemSettingsService.get_all() + settings = SystemSettingsService.get_all(reverse=False, order_by="name") result = [] for setting in settings: - result.append( - { - "name": setting.name, - "source": setting.source, - "data_type": setting.data_type, - "value": setting.value, - } - ) + result.append(SettingsMgr._format_setting(setting)) return result @staticmethod def get_by_name(name: str): settings = SystemSettingsService.get_by_name(name) if len(settings) == 0: - raise AdminException(f"Can't get setting: {name}") + settings = SystemSettingsService.get_by_name_prefix(name) + if len(settings) == 0: + raise AdminException(f"Can't get setting: {name}") result = [] for setting in settings: - result.append( - { - "name": setting.name, - "source": setting.source, - "data_type": setting.data_type, - "value": setting.value, - } - ) + result.append(SettingsMgr._format_setting(setting)) return result @staticmethod @@ -367,6 +396,7 @@ def update_by_name(name: str, value: str): settings = SystemSettingsService.get_by_name(name) if len(settings) == 1: setting = settings[0] + SettingsMgr._validate_value(name, setting.data_type, value) setting.value = value setting_dict = setting.to_dict() SystemSettingsService.update_by_name(name, setting_dict) @@ -376,12 +406,8 @@ def update_by_name(name: str, value: str): # Create new setting if it doesn't exist # Determine data_type based on name and value - if name.startswith("sandbox."): - data_type = "json" - elif name.endswith(".enabled"): - data_type = "boolean" - else: - data_type = "string" + data_type = SettingsMgr._infer_data_type(name) + SettingsMgr._validate_value(name, data_type, value) new_setting = { "name": name, @@ -431,11 +457,21 @@ class SandboxMgr: # Provider registry with metadata PROVIDER_REGISTRY = { + "local": { + "name": "Local", + "description": "Execute code directly on the current host process.", + "tags": ["local", "host", "minimal"], + }, "self_managed": { "name": "Self-Managed", "description": "On-premise deployment using Daytona/Docker", "tags": ["self-hosted", "low-latency", "secure"], }, + "ssh": { + "name": "SSH", + "description": "Execute code on a remote machine over SSH.", + "tags": ["remote", "ssh", "custom-runtime"], + }, "aliyun_codeinterpreter": { "name": "Aliyun Code Interpreter", "description": "Aliyun Function Compute Code Interpreter - Code execution in serverless microVMs", @@ -463,13 +499,17 @@ def list_providers(): def get_provider_config_schema(provider_id: str): """Get configuration schema for a specific provider.""" from agent.sandbox.providers import ( + LocalProvider, SelfManagedProvider, + SSHProvider, AliyunCodeInterpreterProvider, E2BProvider, ) schemas = { + "local": LocalProvider.get_config_schema(), "self_managed": SelfManagedProvider.get_config_schema(), + "ssh": SSHProvider.get_config_schema(), "aliyun_codeinterpreter": AliyunCodeInterpreterProvider.get_config_schema(), "e2b": E2BProvider.get_config_schema(), } @@ -486,7 +526,6 @@ def get_config(): # Get active provider type provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type") if not provider_type_settings: - # Return default config if not set provider_type = "self_managed" else: provider_type = provider_type_settings[0].value @@ -501,6 +540,15 @@ def get_config(): except json.JSONDecodeError: provider_config = {} + if not provider_config: + schema = SandboxMgr.get_provider_config_schema(provider_type) + provider_config = {} + for field_name, field_schema in schema.items(): + if field_schema.get("readonly"): + continue + if field_schema.get("default") is not None: + provider_config[field_name] = field_schema["default"] + return { "provider_type": provider_type, "config": provider_config, @@ -524,7 +572,9 @@ def set_config(provider_type: str, config: dict, set_active: bool = True): Dictionary with updated provider_type and config """ from agent.sandbox.providers import ( + LocalProvider, SelfManagedProvider, + SSHProvider, AliyunCodeInterpreterProvider, E2BProvider, ) @@ -551,7 +601,7 @@ def set_config(provider_type: str, config: dict, set_active: bool = True): elif field_type == "string": if not isinstance(config[field_name], str): raise AdminException(f"Field '{field_name}' must be a string") - elif field_type == "bool": + elif field_type == "boolean": if not isinstance(config[field_name], bool): raise AdminException(f"Field '{field_name}' must be a boolean") @@ -566,7 +616,9 @@ def set_config(provider_type: str, config: dict, set_active: bool = True): # Provider-specific custom validation provider_classes = { + "local": LocalProvider, "self_managed": SelfManagedProvider, + "ssh": SSHProvider, "aliyun_codeinterpreter": AliyunCodeInterpreterProvider, "e2b": E2BProvider, } @@ -582,6 +634,8 @@ def set_config(provider_type: str, config: dict, set_active: bool = True): # Always update the provider config config_json = json.dumps(config) SettingsMgr.update_by_name(f"sandbox.{provider_type}", config_json) + from agent.sandbox.client import reload_provider + reload_provider() return {"provider_type": provider_type, "config": config} except AdminException: @@ -608,14 +662,18 @@ def test_connection(provider_type: str, config: dict): """ try: from agent.sandbox.providers import ( + LocalProvider, SelfManagedProvider, + SSHProvider, AliyunCodeInterpreterProvider, E2BProvider, ) # Instantiate provider based on type provider_classes = { + "local": LocalProvider, "self_managed": SelfManagedProvider, + "ssh": SSHProvider, "aliyun_codeinterpreter": AliyunCodeInterpreterProvider, "e2b": E2BProvider, } @@ -631,59 +689,40 @@ def test_connection(provider_type: str, config: dict): # Create a temporary sandbox instance for testing instance = provider.create_instance(template="python") + if not instance: + raise AdminException("Failed to create sandbox instance.") - if not instance or instance.status != "READY": - raise AdminException(f"Failed to create sandbox instance. Status: {instance.status if instance else 'None'}") - - # Simple test code that exercises basic Python functionality - test_code = """ -# Test basic Python functionality -import sys + try: + # Simple test code that exercises provider wrapping via main(). + test_code = """ import json import math +import sys -print("Python version:", sys.version) -print("Platform:", sys.platform) - -# Test basic calculations -result = 2 + 2 -print(f"2 + 2 = {result}") - -# Test JSON operations -data = {"test": "data", "value": 123} -print(f"JSON dump: {json.dumps(data)}") - -# Test math operations -print(f"Math.sqrt(16) = {math.sqrt(16)}") - -# Test error handling -try: - x = 1 / 1 - print("Division test: OK") -except Exception as e: - print(f"Error: {e}") -# Return success indicator -print("TEST_PASSED") +def main() -> dict: + print("Python version:", sys.version) + print("Platform:", sys.platform) + print(f"2 + 2 = {2 + 2}") + print(f"JSON dump: {json.dumps({'test': 'data', 'value': 123})}") + print(f"Math.sqrt(16) = {math.sqrt(16)}") + print("TEST_PASSED") + return {"ok": True, "provider_test": "TEST_PASSED"} """ - # Execute test code with timeout - execution_result = provider.execute_code( - instance_id=instance.instance_id, - code=test_code, - language="python", - timeout=10 # 10 seconds timeout - ) - - # Clean up the test instance (if provider supports it) - try: - if hasattr(provider, 'terminate_instance'): - provider.terminate_instance(instance.instance_id) + # Execute test code with timeout + execution_result = provider.execute_code( + instance_id=instance.instance_id, + code=test_code, + language="python", + timeout=10, + ) + finally: + try: + provider.destroy_instance(instance.instance_id) logging.info(f"Cleaned up test instance {instance.instance_id}") - else: - logging.warning(f"Provider {provider_type} does not support terminate_instance, test instance may leak") - except Exception as cleanup_error: - logging.warning(f"Failed to cleanup test instance {instance.instance_id}: {cleanup_error}") + except Exception as cleanup_error: + logging.warning(f"Failed to cleanup test instance {instance.instance_id}: {cleanup_error}") # Build detailed result message success = execution_result.exit_code == 0 and "TEST_PASSED" in execution_result.stdout diff --git a/agent/canvas.py b/agent/canvas.py index ab6d0ba9ff1..3421d207ed2 100644 --- a/agent/canvas.py +++ b/agent/canvas.py @@ -17,7 +17,6 @@ import base64 import datetime import inspect -import binascii import json import logging import re @@ -39,6 +38,7 @@ from common.exceptions import TaskCanceledException from rag.prompts.generator import chunks_format from rag.utils.redis_conn import REDIS_CONN +from rag.utils.tts_cache import synthesize_with_cache class Graph: """ @@ -263,7 +263,7 @@ def set_variable_param_value(self, obj: Any, path: str, value) -> Any: keys = path.split('.') if not path: return value - for key in keys: + for key in keys[:-1]: if key not in cur or not isinstance(cur[key], dict): cur[key] = {} cur = cur[key] @@ -714,14 +714,7 @@ def clean_tts_text(text: str) -> str: text = clean_tts_text(text) if not text: return None - bin = b"" - try: - for chunk in tts_mdl.tts(text): - bin += chunk - except Exception as e: - logging.error(f"TTS failed: {e}, text={text!r}") - return None - return binascii.hexlify(bin).decode("utf-8") + return synthesize_with_cache(tts_mdl, text) def get_history(self, window_size): convs = [] diff --git a/agent/component/agent_with_tools.py b/agent/component/agent_with_tools.py index 859064046d6..83c3e27e531 100644 --- a/agent/component/agent_with_tools.py +++ b/agent/component/agent_with_tools.py @@ -32,7 +32,7 @@ from api.db.services.mcp_server_service import MCPServerService from api.db.services.tenant_llm_service import TenantLLMService from common.connection_utils import timeout -from common.mcp_tool_call_conn import MCPToolCallSession, mcp_tool_metadata_to_openai_tool +from common.mcp_tool_call_conn import MCPToolBinding, MCPToolCallSession, mcp_tool_metadata_to_openai_tool from rag.prompts.generator import citation_plus, citation_prompt, full_question, kb_prompt, message_fit_in, structured_output_prompt @@ -97,13 +97,16 @@ def __init__(self, canvas, id, param: LLMParam): indexed_meta["function"]["name"] = indexed_name self.tool_meta.append(indexed_meta) + tool_idx = len(self.tools) for mcp in self._param.mcp: _, mcp_server = MCPServerService.get_by_id(mcp["mcp_id"]) custom_header = self._param.custom_header tool_call_session = MCPToolCallSession(mcp_server, mcp_server.variables, custom_header) for tnm, meta in mcp["tools"].items(): - self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta)) - self.tools[tnm] = tool_call_session + indexed_name = f"{tnm}_{tool_idx}" + tool_idx += 1 + self.tool_meta.append(mcp_tool_metadata_to_openai_tool(meta, function_name=indexed_name)) + self.tools[indexed_name] = MCPToolBinding(tool_call_session, tnm) self.callback = partial(self._canvas.tool_use_callback, id) self.toolcall_session = LLMToolPluginCallSession(self.tools, self.callback) if self.tool_meta: diff --git a/agent/component/base.py b/agent/component/base.py index 9bceb4ce6d9..299adcd4532 100644 --- a/agent/component/base.py +++ b/agent/component/base.py @@ -366,6 +366,7 @@ class ComponentBase(ABC): component_name: str thread_limiter = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT_CHATS", 10))) variable_ref_patt = r"\{* *\{([a-zA-Z:0-9]+@[A-Za-z0-9_.-]+|sys\.[A-Za-z0-9_.]+|env\.[A-Za-z0-9_.]+)\} *\}*" + iteration_alias_patt = r"\{* *\{(item|index|result)\} *\}*" def __str__(self): """ @@ -486,6 +487,10 @@ def get_input(self, key: str = None) -> Union[Any, dict[str, Any]]: continue if isinstance(v, str) and self._canvas.is_reff(v): self.set_input_value(var, self._canvas.get_variable_value(v)) + elif isinstance(v, str) and re.search(self.variable_ref_patt, v): + elements = self.get_input_elements_from_text(v) + kv = {k: e.get('value', '') for k, e in elements.items()} + self.set_input_value(var, self.string_format(v, kv)) else: self.set_input_value(var, v) res[var] = self.get_input_value(var) @@ -497,6 +502,23 @@ def get_input_values(self) -> Union[Any, dict[str, Any]]: return {var: self.get_input_value(var) for var, o in self.get_input_elements().items()} + def _resolve_iteration_alias_ref(self, exp: str) -> str | None: + if exp not in {"item", "index", "result"}: + return None + + parent = self.get_parent() + if not parent or parent.component_name.lower() != "iteration": + return None + + for cid, cpn in self._canvas.components.items(): + if cpn.get("parent_id") != parent._id: + continue + if cpn["obj"].component_name.lower() != "iterationitem": + continue + return f"{cid}@{exp}" + + return None + def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]: res = {} for r in re.finditer(self.variable_ref_patt, txt, flags=re.IGNORECASE | re.DOTALL): @@ -508,6 +530,20 @@ def get_input_elements_from_text(self, txt: str) -> dict[str, dict[str, str]]: "_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references") if cpn_id else None, "_cpn_id": cpn_id } + for r in re.finditer(self.iteration_alias_patt, txt, flags=re.IGNORECASE | re.DOTALL): + exp = r.group(1) + if exp in res: + continue + ref = self._resolve_iteration_alias_ref(exp) + if not ref: + continue + cpn_id, var_nm = ref.split("@", 1) + res[exp] = { + "name": (self._canvas.get_component_name(cpn_id) + f"@{var_nm}"), + "value": self._canvas.get_variable_value(ref), + "_retrieval": self._canvas.get_variable_value(f"{cpn_id}@_references"), + "_cpn_id": cpn_id + } return res def get_input_elements(self) -> dict[str, Any]: diff --git a/agent/component/data_operations.py b/agent/component/data_operations.py index 60e65f88121..9cf5c55335b 100644 --- a/agent/component/data_operations.py +++ b/agent/component/data_operations.py @@ -73,7 +73,7 @@ def _invoke(self, **kwargs): continue if self._param.operations == "select_keys": self._select_keys() - elif self._param.operations == "recursive_eval": + elif self._param.operations == "literal_eval": self._literal_eval() elif self._param.operations == "combine": self._combine() diff --git a/agent/component/iterationitem.py b/agent/component/iterationitem.py index fad4a44e989..c9134e7c777 100644 --- a/agent/component/iterationitem.py +++ b/agent/component/iterationitem.py @@ -54,7 +54,11 @@ def _invoke(self, **kwargs): if self.check_if_canceled("IterationItem processing"): return - self.set_output("item", arr[self._idx]) + current_item = arr[self._idx] + self.set_output("item", current_item) + # Keep `result` as a compatibility alias because existing DSL examples + # and downstream references may still consume IterationItem via `@result`. + self.set_output("result", current_item) self.set_output("index", self._idx) self._idx += 1 diff --git a/agent/component/loop.py b/agent/component/loop.py index 484dfae8256..9558e1001ef 100644 --- a/agent/component/loop.py +++ b/agent/component/loop.py @@ -56,7 +56,7 @@ def _invoke(self, **kwargs): for item in self._param.loop_variables: if any([not item.get("variable"), not item.get("input_mode"), not item.get("value"),not item.get("type")]): - assert "Loop Variable is not complete." + raise ValueError("Loop Variable is not complete.") if item["input_mode"]=="variable": self.set_output(item["variable"],self._canvas.get_variable_value(item["value"])) elif item["input_mode"]=="constant": diff --git a/agent/component/loopitem.py b/agent/component/loopitem.py index b656ea78948..0cfb500850d 100644 --- a/agent/component/loopitem.py +++ b/agent/component/loopitem.py @@ -64,6 +64,16 @@ def evaluate_condition(self,var, operator, value): elif operator == "not empty": return var != "" + elif isinstance(var, bool): + if operator == "is": + return var is value + elif operator == "is not": + return var is not value + elif operator == "empty": + return var is None + elif operator == "not empty": + return var is not None + elif isinstance(var, (int, float)): if operator == "=": return var == value @@ -82,16 +92,6 @@ def evaluate_condition(self,var, operator, value): elif operator == "not empty": return var is not None - elif isinstance(var, bool): - if operator == "is": - return var is value - elif operator == "is not": - return var is not value - elif operator == "empty": - return var is None - elif operator == "not empty": - return var is not None - elif isinstance(var, dict): if operator == "empty": return len(var) == 0 diff --git a/agent/component/message.py b/agent/component/message.py index a52741f6b36..5ab7c6ef526 100644 --- a/agent/component/message.py +++ b/agent/component/message.py @@ -161,7 +161,7 @@ def get_kwargs( if k in kwargs: continue v = v["value"] - if not v: + if v is None: v = "" ans = "" if isinstance(v, partial): diff --git a/agent/component/string_transform.py b/agent/component/string_transform.py index d298e5a1b8a..0b152f8f013 100644 --- a/agent/component/string_transform.py +++ b/agent/component/string_transform.py @@ -105,7 +105,7 @@ def _merge(self, kwargs:dict[str, str] = {}): pass for k,v in kwargs.items(): - if not v: + if v is None: v = "" script = re.sub(k, lambda match: v, script) diff --git a/agent/component/variable_assigner.py b/agent/component/variable_assigner.py index dd6182c7ce0..5b5e39a8259 100644 --- a/agent/component/variable_assigner.py +++ b/agent/component/variable_assigner.py @@ -48,7 +48,7 @@ def _invoke(self, **kwargs): else: for item in self._param.variables: if any([not item.get("variable"), not item.get("operator"), not item.get("parameter")]): - assert "Variable is not complete." + raise ValueError("Variable is not complete.") variable=item["variable"] operator=item["operator"] parameter=item["parameter"] @@ -92,12 +92,12 @@ def _clear(self,variable): return "" elif isinstance(variable,dict): return {} + elif isinstance(variable,bool): + return False elif isinstance(variable,int): return 0 elif isinstance(variable,float): return 0.0 - elif isinstance(variable,bool): - return False else: return None diff --git a/agent/sandbox/client.py b/agent/sandbox/client.py index 9ca51cc8e3a..daafb0d07f1 100644 --- a/agent/sandbox/client.py +++ b/agent/sandbox/client.py @@ -23,7 +23,6 @@ import json import logging -import os from typing import Dict, Any, Optional from api.db.services.system_settings_service import SystemSettingsService @@ -49,7 +48,6 @@ def get_provider_manager() -> ProviderManager: if _provider_manager is not None: return _provider_manager - # Initialize provider manager with system settings _provider_manager = ProviderManager() _load_provider_from_settings() @@ -61,7 +59,7 @@ def _load_provider_from_settings() -> None: Load sandbox provider from system settings and configure the provider manager. This function resolves the active provider type, then loads configuration - from system settings with environment overrides for that provider. + from system settings. """ global _provider_manager @@ -69,7 +67,7 @@ def _load_provider_from_settings() -> None: return try: - provider_type, provider_type_from_env = _resolve_provider_type() + provider_type = _resolve_provider_type() config = _load_provider_config(provider_type) # Import and instantiate the provider @@ -78,6 +76,7 @@ def _load_provider_from_settings() -> None: AliyunCodeInterpreterProvider, E2BProvider, LocalProvider, + SSHProvider, ) provider_classes = { @@ -85,11 +84,10 @@ def _load_provider_from_settings() -> None: "aliyun_codeinterpreter": AliyunCodeInterpreterProvider, "e2b": E2BProvider, "local": LocalProvider, + "ssh": SSHProvider, } if provider_type not in provider_classes: - if provider_type_from_env: - raise SandboxProviderConfigError(f"Unknown sandbox provider type: {provider_type}") logger.error(f"Unknown provider type: {provider_type}") return @@ -99,7 +97,7 @@ def _load_provider_from_settings() -> None: # Initialize the provider if not provider.initialize(config): message = f"Failed to initialize sandbox provider: {provider_type}. Config keys: {list(config.keys())}" - if provider_type == "local" or provider_type_from_env: + if provider_type in {"local", "ssh"}: raise SandboxProviderConfigError(message) logger.error(message) return @@ -114,8 +112,6 @@ def _load_provider_from_settings() -> None: logger.error(f"Failed to load sandbox provider from settings: {e}") import traceback traceback.print_exc() - - def _load_provider_config_from_settings(provider_type: str) -> Dict[str, Any]: provider_config_settings = SystemSettingsService.get_by_name(f"sandbox.{provider_type}") if not provider_config_settings: @@ -129,64 +125,15 @@ def _load_provider_config_from_settings(provider_type: str) -> Dict[str, Any]: return {} -def _resolve_provider_type() -> tuple[str, bool]: - provider_type = os.environ.get("SANDBOX_PROVIDER_TYPE", "").strip() - if provider_type: - return provider_type, True - +def _resolve_provider_type() -> str: provider_type_settings = SystemSettingsService.get_by_name("sandbox.provider_type") if not provider_type_settings: - raise RuntimeError( - "Sandbox provider type not configured. Please set 'sandbox.provider_type' in system settings." - ) - return provider_type_settings[0].value, False + return "self_managed" + return provider_type_settings[0].value def _load_provider_config(provider_type: str) -> Dict[str, Any]: - config = _load_provider_config_from_settings(provider_type) - env_config = _load_provider_config_from_env(provider_type) - if env_config: - config.update(env_config) - return config - - -def _load_provider_config_from_env(provider_type: str) -> Dict[str, Any]: - if provider_type == "local": - return _load_local_provider_config_from_env() - if provider_type == "self_managed": - return _load_self_managed_provider_config_from_env() - return {} - - -def _load_local_provider_config_from_env() -> Dict[str, Any]: - env_to_config = { - "SANDBOX_LOCAL_PYTHON_BIN": "python_bin", - "SANDBOX_LOCAL_NODE_BIN": "node_bin", - "SANDBOX_LOCAL_WORK_DIR": "work_dir", - "SANDBOX_LOCAL_TIMEOUT": "timeout", - "SANDBOX_LOCAL_MAX_MEMORY_MB": "max_memory_mb", - "SANDBOX_LOCAL_MAX_OUTPUT_BYTES": "max_output_bytes", - "SANDBOX_LOCAL_MAX_ARTIFACTS": "max_artifacts", - "SANDBOX_LOCAL_MAX_ARTIFACT_BYTES": "max_artifact_bytes", - } - config = {} - for env_name, config_name in env_to_config.items(): - if env_name in os.environ: - config[config_name] = os.environ[env_name] - return config - - -def _load_self_managed_provider_config_from_env() -> Dict[str, Any]: - host = os.environ.get("SANDBOX_HOST", "").strip() - port = os.environ.get("SANDBOX_EXECUTOR_MANAGER_PORT", "").strip() - pool_size = os.environ.get("SANDBOX_EXECUTOR_MANAGER_POOL_SIZE", "").strip() - - config = {} - if host: - config["endpoint"] = f"http://{host}:{port or '9385'}" - if pool_size: - config["pool_size"] = pool_size - return config + return _load_provider_config_from_settings(provider_type) def reload_provider() -> None: @@ -231,6 +178,14 @@ def execute_code( ) provider = provider_manager.get_provider() + provider_name = provider_manager.get_provider_name() or getattr(provider, "__class__", type(provider)).__name__ + + logger.info( + "CodeExec using sandbox provider '%s' (language=%s, timeout=%ss)", + provider_name, + language, + timeout, + ) # Create a sandbox instance instance = provider.create_instance(template=language) diff --git a/agent/sandbox/executor_manager/Dockerfile b/agent/sandbox/executor_manager/Dockerfile index 9444a848763..56c83384018 100644 --- a/agent/sandbox/executor_manager/Dockerfile +++ b/agent/sandbox/executor_manager/Dockerfile @@ -1,6 +1,10 @@ FROM python:3.11-slim-bookworm -RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.tuna.tsinghua.edu.cn|g' && \ +ARG NEED_MIRROR=1 + +RUN if [ "$NEED_MIRROR" = 1 ]; then \ + grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.tuna.tsinghua.edu.cn|g'; \ + fi; \ apt-get update && \ apt-get install -y curl gcc && \ rm -rf /var/lib/apt/lists/* @@ -27,11 +31,11 @@ RUN set -eux; \ ln -sf /usr/local/bin/docker /usr/bin/docker COPY --from=ghcr.io/astral-sh/uv:0.7.5 /uv /uvx /bin/ -ENV UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple WORKDIR /app COPY . . -RUN uv pip install --system -r requirements.txt +RUN if [ "$NEED_MIRROR" = 1 ]; then export UV_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple"; else export UV_INDEX_URL="https://pypi.org/simple"; fi && \ + uv pip install --system -r requirements.txt CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "9385"] diff --git a/agent/sandbox/executor_manager/services/security.py b/agent/sandbox/executor_manager/services/security.py index 13a02ced2eb..f0323e747a2 100644 --- a/agent/sandbox/executor_manager/services/security.py +++ b/agent/sandbox/executor_manager/services/security.py @@ -26,7 +26,7 @@ class SecurePythonAnalyzer(ast.NodeVisitor): An AST-based analyzer for detecting unsafe Python code patterns. """ - DANGEROUS_IMPORTS = {"os", "subprocess", "sys", "shutil", "socket", "ctypes", "pickle", "threading", "multiprocessing", "asyncio", "http.client", "ftplib", "telnetlib"} + DANGEROUS_IMPORTS = {"os", "subprocess", "sys", "shutil", "socket", "ctypes", "pickle", "threading", "multiprocessing", "asyncio", "http.client", "ftplib", "telnetlib", "builtins"} DANGEROUS_CALLS = { "eval", @@ -77,6 +77,16 @@ def visit_Call(self, node: ast.Call): """Check for dangerous function calls.""" if isinstance(node.func, ast.Name) and node.func.id in self.DANGEROUS_CALLS: self.unsafe_items.append((f"Call: {node.func.id}", node.lineno)) + elif isinstance(node.func, ast.Attribute) and node.func.attr in self.DANGEROUS_CALLS: + # Surface the attribute-style match in the analyzer log so that + # incident response can grep for it just like the other unsafe-item + # findings; the bare append is invisible to operators. + logger.warning( + "[SafeCheck] Attribute-style dangerous call detected: %s (line %s)", + node.func.attr, + node.lineno, + ) + self.unsafe_items.append((f"Call: {node.func.attr}", node.lineno)) self.generic_visit(node) def visit_Attribute(self, node: ast.Attribute): @@ -154,9 +164,9 @@ def visit_Yield(self, node: ast.Yield): class SecureJavaScriptAnalyzer: DANGEROUS_PATTERNS = [ - (re.compile(r"""require\s*\(\s*['"]child_process['"]\s*\)"""), "Require: child_process"), - (re.compile(r"""require\s*\(\s*['"]fs['"]\s*\)"""), "Require: fs"), - (re.compile(r"""require\s*\(\s*['"]worker_threads['"]\s*\)"""), "Require: worker_threads"), + (re.compile(r"""require\s*\(\s*['"`]child_process['"`]\s*\)"""), "Require: child_process"), + (re.compile(r"""require\s*\(\s*['"`]fs['"`]\s*\)"""), "Require: fs"), + (re.compile(r"""require\s*\(\s*['"`]worker_threads['"`]\s*\)"""), "Require: worker_threads"), (re.compile(r"""\beval\s*\("""), "Call: eval"), (re.compile(r"""\bFunction\s*\("""), "Call: Function"), (re.compile(r"""\bprocess\s*\.\s*binding\s*\("""), "Call: process.binding"), diff --git a/agent/sandbox/providers/__init__.py b/agent/sandbox/providers/__init__.py index e7cfc2ddc9c..b67a982f3ec 100644 --- a/agent/sandbox/providers/__init__.py +++ b/agent/sandbox/providers/__init__.py @@ -25,6 +25,7 @@ Official Documentation: https://help.aliyun.com/zh/functioncompute/fc/sandbox-sandbox-code-interepreter - e2b.py: E2B provider implementation - local.py: Local process provider implementation +- ssh.py: Remote SSH provider implementation """ from .base import SandboxProvider, SandboxInstance, ExecutionResult, SandboxProviderConfigError @@ -33,6 +34,7 @@ from .aliyun_codeinterpreter import AliyunCodeInterpreterProvider from .e2b import E2BProvider from .local import LocalProvider +from .ssh import SSHProvider __all__ = [ "SandboxProvider", @@ -44,4 +46,5 @@ "AliyunCodeInterpreterProvider", "E2BProvider", "LocalProvider", + "SSHProvider", ] diff --git a/agent/sandbox/providers/local.py b/agent/sandbox/providers/local.py index b8057fa5b43..ed37cc57d00 100644 --- a/agent/sandbox/providers/local.py +++ b/agent/sandbox/providers/local.py @@ -41,11 +41,14 @@ ".svg", } - -def _env_enabled(name: str) -> bool: - return os.environ.get(name, "").strip().lower() in {"1", "true", "yes", "on"} - - +LOCAL_PYTHON_THREAD_ENV_VARS = ( + "OPENBLAS_NUM_THREADS", + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + "BLIS_NUM_THREADS", + "VECLIB_MAXIMUM_THREADS", +) class LocalProvider(SandboxProvider): """ Execute code as a local child process. @@ -67,17 +70,14 @@ def __init__(self): self._instances: dict[str, Path] = {} def initialize(self, config: Dict[str, Any]) -> bool: - if not _env_enabled("SANDBOX_LOCAL_ENABLED"): - raise SandboxProviderConfigError("Local code execution is disabled. Set SANDBOX_LOCAL_ENABLED=true to enable it.") - - self.python_bin = str(self._resolve_config_value(config, "python_bin", "SANDBOX_LOCAL_PYTHON_BIN", "python3")) - self.node_bin = str(self._resolve_config_value(config, "node_bin", "SANDBOX_LOCAL_NODE_BIN", "node")) - self.work_dir = Path(self._resolve_config_value(config, "work_dir", "SANDBOX_LOCAL_WORK_DIR", "/tmp/ragflow-codeexec")).resolve() - self.timeout = int(self._resolve_config_value(config, "timeout", "SANDBOX_LOCAL_TIMEOUT", 30)) - self.max_memory_mb = int(self._resolve_config_value(config, "max_memory_mb", "SANDBOX_LOCAL_MAX_MEMORY_MB", 512)) - self.max_output_bytes = int(self._resolve_config_value(config, "max_output_bytes", "SANDBOX_LOCAL_MAX_OUTPUT_BYTES", 1024 * 1024)) - self.max_artifacts = int(self._resolve_config_value(config, "max_artifacts", "SANDBOX_LOCAL_MAX_ARTIFACTS", 20)) - self.max_artifact_bytes = int(self._resolve_config_value(config, "max_artifact_bytes", "SANDBOX_LOCAL_MAX_ARTIFACT_BYTES", 10 * 1024 * 1024)) + self.python_bin = str(config.get("python_bin", "python3")) + self.node_bin = str(config.get("node_bin", "node")) + self.work_dir = Path(str(config.get("work_dir", "/tmp/ragflow-codeexec"))).resolve() + self.timeout = int(config.get("timeout", 30)) + self.max_memory_mb = int(config.get("max_memory_mb", 512)) + self.max_output_bytes = int(config.get("max_output_bytes", 1024 * 1024)) + self.max_artifacts = int(config.get("max_artifacts", 20)) + self.max_artifact_bytes = int(config.get("max_artifact_bytes", 10 * 1024 * 1024)) self._validate_limits() self.work_dir.mkdir(parents=True, exist_ok=True, mode=0o700) @@ -185,14 +185,72 @@ def get_supported_languages(self) -> List[str]: @staticmethod def get_config_schema() -> Dict[str, Dict]: return { - "python_bin": {"type": "string", "required": False, "default": "python3"}, - "node_bin": {"type": "string", "required": False, "default": "node"}, - "work_dir": {"type": "string", "required": False, "default": "/tmp/ragflow-codeexec"}, - "timeout": {"type": "integer", "required": False, "default": 30}, - "max_memory_mb": {"type": "integer", "required": False, "default": 512}, - "max_output_bytes": {"type": "integer", "required": False, "default": 1048576}, - "max_artifacts": {"type": "integer", "required": False, "default": 20}, - "max_artifact_bytes": {"type": "integer", "required": False, "default": 10485760}, + "python_bin": { + "type": "string", + "required": False, + "default": "python3", + "label": "Python Binary", + "description": "Python executable used for local code execution.", + }, + "node_bin": { + "type": "string", + "required": False, + "default": "node", + "label": "Node.js Binary", + "description": "Node.js executable used for local JavaScript execution.", + }, + "work_dir": { + "type": "string", + "required": False, + "default": "/tmp/ragflow-codeexec", + "label": "Working Directory", + "description": "Directory used to store temporary scripts and artifacts on the current host.", + }, + "timeout": { + "type": "integer", + "required": False, + "default": 30, + "label": "Timeout (seconds)", + "description": "Maximum execution time for each local run. Unit: seconds.", + "min": 1, + "max": 600, + }, + "max_memory_mb": { + "type": "integer", + "required": False, + "default": 512, + "label": "Max Memory (MB)", + "description": "Address-space memory limit for the local child process. Unit: MB.", + "min": 1, + "max": 65536, + }, + "max_output_bytes": { + "type": "integer", + "required": False, + "default": 1048576, + "label": "Max Output (bytes)", + "description": "Maximum combined stdout and stderr size. Unit: bytes.", + "min": 1024, + "max": 10485760, + }, + "max_artifacts": { + "type": "integer", + "required": False, + "default": 20, + "label": "Max Artifacts", + "description": "Maximum number of files collected from the artifacts directory.", + "min": 0, + "max": 100, + }, + "max_artifact_bytes": { + "type": "integer", + "required": False, + "default": 10485760, + "label": "Max Artifact Size (bytes)", + "description": "Maximum size of a single artifact file. Unit: bytes.", + "min": 1024, + "max": 104857600, + }, } def _validate_limits(self) -> None: @@ -218,21 +276,19 @@ def _prepare_script(self, instance_dir: Path, language: str, code: str, args_jso return [self.node_bin, str(script_path)], script_path raise RuntimeError(f"Unsupported language for local provider: {language}") - @staticmethod - def _resolve_config_value(config: Dict[str, Any], key: str, env_name: str, default: Any) -> Any: - value = config.get(key) - if value is not None: - return value - return os.environ.get(env_name, default) - def _build_child_env(self, instance_dir: Path) -> dict[str, str]: - return { + env = { "HOME": str(instance_dir), "MPLBACKEND": "Agg", "PATH": os.environ.get("PATH", ""), "PYTHONUNBUFFERED": "1", "TMPDIR": str(instance_dir), } + for name in LOCAL_PYTHON_THREAD_ENV_VARS: + value = os.environ.get(name) + if value is not None: + env[name] = value + return env def _limit_child_process(self) -> None: import resource diff --git a/agent/sandbox/providers/self_managed.py b/agent/sandbox/providers/self_managed.py index 0e73e2f9e17..8b92d0b2c45 100644 --- a/agent/sandbox/providers/self_managed.py +++ b/agent/sandbox/providers/self_managed.py @@ -22,6 +22,7 @@ """ import base64 +import os import time import uuid from typing import Dict, Any, List, Optional @@ -40,10 +41,10 @@ class SelfManagedProvider(SandboxProvider): """ def __init__(self): - self.endpoint: str = "http://localhost:9385" + self.endpoint: str = "http://sandbox-executor-manager:9385" self.timeout: int = 30 self.max_retries: int = 3 - self.pool_size: int = 10 + self.pool_size: int = 3 self._initialized: bool = False def initialize(self, config: Dict[str, Any]) -> bool: @@ -52,7 +53,7 @@ def initialize(self, config: Dict[str, Any]) -> bool: Args: config: Configuration dictionary with keys: - - endpoint: HTTP endpoint (default: "http://localhost:9385") + - endpoint: HTTP endpoint (default: "http://sandbox-executor-manager:9385") - timeout: Request timeout in seconds (default: 30) - max_retries: Maximum retry attempts (default: 3) - pool_size: Container pool size for info (default: 10) @@ -60,30 +61,13 @@ def initialize(self, config: Dict[str, Any]) -> bool: Returns: True if initialization successful, False otherwise """ - self.endpoint = config.get("endpoint", "http://localhost:9385") + self.endpoint = config.get("endpoint", "http://sandbox-executor-manager:9385") self.timeout = config.get("timeout", 30) self.max_retries = config.get("max_retries", 3) - self.pool_size = config.get("pool_size", 10) + self.pool_size = config.get("executor_manager_pool_size", config.get("pool_size", 3)) # Validate endpoint is accessible if not self.health_check(): - # Try to fall back to SANDBOX_HOST from settings if we are using localhost - if "localhost" in self.endpoint or "127.0.0.1" in self.endpoint: - try: - from common import settings - if settings.SANDBOX_HOST and settings.SANDBOX_HOST not in self.endpoint: - original_endpoint = self.endpoint - self.endpoint = f"http://{settings.SANDBOX_HOST}:9385" - if self.health_check(): - import logging - logging.warning(f"Sandbox self_managed: Connected using settings.SANDBOX_HOST fallback: {self.endpoint} (original: {original_endpoint})") - self._initialized = True - return True - else: - self.endpoint = original_endpoint # Restore if fallback also fails - except ImportError: - pass - return False self._initialized = True @@ -270,9 +254,11 @@ def get_config_schema() -> Dict[str, Dict]: "type": "string", "required": True, "label": "Executor Manager Endpoint", - "placeholder": "http://localhost:9385", - "default": "http://localhost:9385", - "description": "HTTP endpoint of the executor_manager service" + "placeholder": "http://sandbox-executor-manager:9385", + "default": "http://sandbox-executor-manager:9385", + "description": "HTTP endpoint used by RAGFlow to call sandbox-executor-manager.", + "scope": "runtime", + "readonly": False, }, "timeout": { "type": "integer", @@ -281,26 +267,86 @@ def get_config_schema() -> Dict[str, Dict]: "default": 30, "min": 5, "max": 300, - "description": "HTTP request timeout for code execution" + "description": "Maximum request time for a single code execution call. Unit: seconds.", + "scope": "runtime", + "readonly": False, }, - "max_retries": { - "type": "integer", + "executor_manager_image": { + "type": "string", "required": False, - "label": "Max Retries", - "default": 3, - "min": 0, - "max": 10, - "description": "Maximum number of retry attempts for failed requests" + "label": "Executor Manager Image", + "default": os.getenv("SANDBOX_EXECUTOR_MANAGER_IMAGE", "infiniflow/sandbox-executor-manager:latest"), + "description": "Docker image used by sandbox-executor-manager.", + "scope": "deployment", + "readonly": True, }, - "pool_size": { + "executor_manager_pool_size": { "type": "integer", "required": False, "label": "Container Pool Size", - "default": 10, + "default": int(os.getenv("SANDBOX_EXECUTOR_MANAGER_POOL_SIZE", "3")), "min": 1, "max": 100, - "description": "Size of the container pool (configured in executor_manager)" - } + "description": "Container pool size used by sandbox-executor-manager.", + "scope": "deployment", + "readonly": True, + }, + "base_python_image": { + "type": "string", + "required": False, + "label": "Base Python Image", + "default": os.getenv("SANDBOX_BASE_PYTHON_IMAGE", "infiniflow/sandbox-base-python:latest"), + "description": "Python runtime image used by executor-managed containers.", + "scope": "deployment", + "readonly": True, + }, + "base_nodejs_image": { + "type": "string", + "required": False, + "label": "Base Node.js Image", + "default": os.getenv("SANDBOX_BASE_NODEJS_IMAGE", "infiniflow/sandbox-base-nodejs:latest"), + "description": "Node.js runtime image used by executor-managed containers.", + "scope": "deployment", + "readonly": True, + }, + "executor_manager_port": { + "type": "integer", + "required": False, + "label": "Executor Manager Port", + "default": int(os.getenv("SANDBOX_EXECUTOR_MANAGER_PORT", "9385")), + "min": 1, + "max": 65535, + "description": "Host port exposed by sandbox-executor-manager.", + "scope": "deployment", + "readonly": True, + }, + "enable_seccomp": { + "type": "boolean", + "required": False, + "label": "Enable Seccomp", + "default": os.getenv("SANDBOX_ENABLE_SECCOMP", "false").lower() == "true", + "description": "Whether sandbox-executor-manager starts containers with seccomp enabled.", + "scope": "deployment", + "readonly": True, + }, + "max_memory": { + "type": "string", + "required": False, + "label": "Max Memory", + "default": os.getenv("SANDBOX_MAX_MEMORY", "256m"), + "description": "Memory limit applied to each sandbox container. Common format: 256m or 1g.", + "scope": "deployment", + "readonly": True, + }, + "sandbox_timeout": { + "type": "string", + "required": False, + "label": "Sandbox Timeout", + "default": os.getenv("SANDBOX_TIMEOUT", "10s"), + "description": "Executor-manager container timeout for each sandbox run. Common format: 10s or 1m.", + "scope": "deployment", + "readonly": True, + }, } def _normalize_language(self, language: str) -> str: @@ -347,7 +393,7 @@ def validate_config(self, config: dict) -> tuple[bool, Optional[str]]: return False, f"Invalid endpoint format: {endpoint}. Must start with http:// or https://" # Validate pool_size is positive - pool_size = config.get("pool_size", 10) + pool_size = config.get("executor_manager_pool_size", config.get("pool_size", 3)) if isinstance(pool_size, int) and pool_size <= 0: return False, "Pool size must be greater than 0" diff --git a/agent/sandbox/providers/ssh.py b/agent/sandbox/providers/ssh.py new file mode 100644 index 00000000000..131e4ae8c05 --- /dev/null +++ b/agent/sandbox/providers/ssh.py @@ -0,0 +1,664 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import annotations + +import base64 +import io +import json +import mimetypes +import os +import posixpath +import shlex +import stat +import time +import uuid +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +from agent.sandbox.result_protocol import ( + build_javascript_wrapper, + build_python_wrapper, + extract_structured_result, +) +from .base import ( + ExecutionResult, + SandboxInstance, + SandboxProvider, + SandboxProviderConfigError, +) + +if TYPE_CHECKING: + import paramiko + + +ALLOWED_ARTIFACT_EXTENSIONS = { + ".csv", + ".html", + ".jpeg", + ".jpg", + ".json", + ".pdf", + ".png", + ".svg", +} + + +class SSHProvider(SandboxProvider): + """Execute code on a remote host through SSH.""" + + def __init__(self): + self.host = "" + self.port = 22 + self.username = "" + self.password = "" + self.private_key = "" + self.passphrase = "" + self.python_bin = "python3" + self.node_bin = "node" + self.work_dir = "/tmp" + self.timeout = 30 + self.max_output_bytes = 1024 * 1024 + self.max_artifacts = 20 + self.max_artifact_bytes = 10 * 1024 * 1024 + self._initialized = False + self._instances: dict[str, dict[str, Any]] = {} + + def initialize(self, config: Dict[str, Any]) -> bool: + self.host = str(config.get("host", "")).strip() + self.port = int(config.get("port", 22) or 22) + self.username = str(config.get("username", "")).strip() + self.password = str(config.get("password", "") or "") + self.private_key = str(config.get("private_key", "") or "") + self.passphrase = str(config.get("passphrase", "") or "") + self.python_bin = str(config.get("python_bin", "python3") or "python3").strip() or "python3" + self.node_bin = str(config.get("node_bin", "node") or "node").strip() or "node" + self.work_dir = str(config.get("work_dir", "/tmp") or "/tmp").strip() or "/tmp" + self.timeout = int(config.get("timeout", 30) or 30) + self.max_output_bytes = int(config.get("max_output_bytes", 1024 * 1024) or 1024 * 1024) + self.max_artifacts = int(config.get("max_artifacts", 20) or 20) + self.max_artifact_bytes = int(config.get("max_artifact_bytes", 10 * 1024 * 1024) or 10 * 1024 * 1024) + + is_valid, error_message = self.validate_config( + { + "host": self.host, + "port": self.port, + "username": self.username, + "password": self.password, + "private_key": self.private_key, + "passphrase": self.passphrase, + "python_bin": self.python_bin, + "node_bin": self.node_bin, + "work_dir": self.work_dir, + "timeout": self.timeout, + "max_output_bytes": self.max_output_bytes, + "max_artifacts": self.max_artifacts, + "max_artifact_bytes": self.max_artifact_bytes, + } + ) + if not is_valid: + raise SandboxProviderConfigError(error_message or "Invalid SSH provider configuration.") + + self._assert_connectivity() + + self._initialized = True + return True + + def create_instance(self, template: str = "python") -> SandboxInstance: + if not self._initialized: + raise RuntimeError("Provider not initialized. Call initialize() first.") + + language = self._normalize_language(template) + client = self._create_ssh_client() + sftp = client.open_sftp() + + try: + remote_work_dir = self._create_remote_workspace(client) + stdout, stderr, exit_code = self._run_remote_command( + client, + f"mkdir -p {shlex.quote(posixpath.join(remote_work_dir, 'artifacts'))}", + timeout=min(self.timeout, 10), + ) + if exit_code != 0: + raise RuntimeError( + f"Failed to create remote artifacts directory: {stderr or stdout or 'unknown error'}" + ) + except Exception: + sftp.close() + client.close() + raise + + instance_id = str(uuid.uuid4()) + self._instances[instance_id] = { + "client": client, + "sftp": sftp, + "remote_work_dir": remote_work_dir, + "language": language, + } + + return SandboxInstance( + instance_id=instance_id, + provider="ssh", + status="running", + metadata={"language": language, "remote_work_dir": remote_work_dir}, + ) + + def execute_code( + self, + instance_id: str, + code: str, + language: str, + timeout: int = 10, + arguments: Optional[Dict[str, Any]] = None, + ) -> ExecutionResult: + if not self._initialized: + raise RuntimeError("Provider not initialized. Call initialize() first.") + if instance_id not in self._instances: + raise RuntimeError(f"Unknown SSH sandbox instance: {instance_id}") + + normalized_lang = self._normalize_language(language) + instance = self._instances[instance_id] + client: paramiko.SSHClient = instance["client"] + sftp: paramiko.SFTPClient = instance["sftp"] + remote_work_dir: str = instance["remote_work_dir"] + + args_json = json.dumps(arguments or {}, ensure_ascii=False) + remote_script_path, command = self._upload_script( + sftp=sftp, + remote_work_dir=remote_work_dir, + language=normalized_lang, + code=code, + args_json=args_json, + ) + + requested_timeout = self.timeout if timeout is None else int(timeout) + if requested_timeout <= 0: + raise RuntimeError(f"Execution timeout must be greater than 0 seconds, got {requested_timeout}.") + exec_timeout = min(requested_timeout, self.timeout) + + start_time = time.time() + stdout, stderr, exit_code = self._run_remote_command(client, command, timeout=exec_timeout) + execution_time = time.time() - start_time + + self._validate_output_size(stdout, stderr) + stdout, structured_result = extract_structured_result(stdout) + + return ExecutionResult( + stdout=stdout, + stderr=stderr, + exit_code=exit_code, + execution_time=execution_time, + metadata={ + "instance_id": instance_id, + "language": normalized_lang, + "script_path": remote_script_path, + "remote_work_dir": remote_work_dir, + "status": "ok" if exit_code == 0 else "error", + "timeout": exec_timeout, + "command": command, + "artifacts": self._collect_artifacts( + sftp, posixpath.join(remote_work_dir, "artifacts") + ), + "result_present": structured_result.get("present", False), + "result_value": structured_result.get("value"), + "result_type": structured_result.get("type"), + }, + ) + + def destroy_instance(self, instance_id: str) -> bool: + if not self._initialized: + raise RuntimeError("Provider not initialized. Call initialize() first.") + if instance_id not in self._instances: + return True + + instance = self._instances.pop(instance_id) + client: paramiko.SSHClient = instance["client"] + sftp: paramiko.SFTPClient = instance["sftp"] + remote_work_dir: str = instance["remote_work_dir"] + + cleanup_error: Optional[Exception] = None + try: + stdout, stderr, exit_code = self._run_remote_command( + client, + f"rm -rf {shlex.quote(remote_work_dir)}", + timeout=min(self.timeout, 10), + ) + if exit_code != 0: + raise RuntimeError(stderr or stdout or "unknown error") + except Exception as exc: + cleanup_error = exc + finally: + try: + sftp.close() + finally: + client.close() + + if cleanup_error is not None: + raise RuntimeError(f"Failed to clean remote workspace {remote_work_dir}: {cleanup_error}") + return True + + def health_check(self) -> bool: + try: + self._assert_connectivity() + return True + except Exception: + return False + + def _assert_connectivity(self) -> None: + try: + client = self._create_ssh_client() + try: + _, stderr, exit_code = self._run_remote_command( + client, + "true", + timeout=min(self.timeout, 10), + ) + if exit_code != 0: + raise SandboxProviderConfigError( + f"SSH connectivity check failed on {self.username}@{self.host}:{self.port}: " + f"{stderr or 'remote command returned non-zero exit status'}" + ) + finally: + client.close() + except SandboxProviderConfigError: + raise + except Exception as exc: + raise SandboxProviderConfigError( + f"Failed to connect to SSH host {self.username}@{self.host}:{self.port}: {exc}" + ) from exc + + def get_supported_languages(self) -> List[str]: + return ["python", "javascript", "nodejs"] + + @staticmethod + def get_config_schema() -> Dict[str, Dict]: + return { + "host": { + "type": "string", + "required": True, + "label": "SSH Host", + "placeholder": "192.168.1.10", + "description": "Remote host that will execute generated code.", + }, + "port": { + "type": "integer", + "required": True, + "label": "SSH Port", + "default": 22, + "min": 1, + "max": 65535, + "description": "SSH port on the remote host.", + }, + "username": { + "type": "string", + "required": True, + "label": "SSH Username", + "placeholder": "ragflow", + "description": "Username used to connect to the remote host.", + }, + "password": { + "type": "string", + "required": False, + "label": "SSH Password", + "secret": True, + "placeholder": "Optional when using a private key", + "description": "Password-based SSH authentication.", + }, + "private_key": { + "type": "string", + "required": False, + "label": "SSH Private Key", + "secret": True, + "multiline": True, + "placeholder": "Paste PEM content or enter a local file path", + "description": "Private key PEM content or a readable private key path on the RAGFlow host.", + }, + "passphrase": { + "type": "string", + "required": False, + "label": "Private Key Passphrase", + "secret": True, + "placeholder": "Optional", + "description": "Passphrase for the private key if it is encrypted.", + }, + "python_bin": { + "type": "string", + "required": False, + "default": "python3", + "label": "Python Binary", + "description": "Python executable used for remote code execution.", + }, + "node_bin": { + "type": "string", + "required": False, + "default": "node", + "label": "Node.js Binary", + "description": "Node.js executable used for remote JavaScript execution.", + }, + "work_dir": { + "type": "string", + "required": False, + "label": "Remote Workspace Root", + "default": "/tmp", + "placeholder": "/tmp", + "description": "Writable remote directory used to create a temporary workspace.", + }, + "timeout": { + "type": "integer", + "required": False, + "label": "Timeout (seconds)", + "default": 30, + "min": 1, + "max": 600, + "description": "Maximum SSH execution time for a single run.", + }, + "max_output_bytes": { + "type": "integer", + "required": False, + "label": "Max Output Bytes", + "default": 1048576, + "min": 1024, + "max": 10485760, + "description": "Maximum combined stdout and stderr size.", + }, + "max_artifacts": { + "type": "integer", + "required": False, + "label": "Max Artifacts", + "default": 20, + "min": 0, + "max": 100, + "description": "Maximum number of files collected from the remote artifacts directory.", + }, + "max_artifact_bytes": { + "type": "integer", + "required": False, + "label": "Max Artifact Bytes", + "default": 10485760, + "min": 1024, + "max": 104857600, + "description": "Maximum size of a single artifact file in bytes.", + }, + } + + def validate_config(self, config: Dict[str, Any]) -> tuple[bool, Optional[str]]: + host = str(config.get("host", "") or "").strip() + username = str(config.get("username", "") or "").strip() + password = str(config.get("password", "") or "") + private_key = str(config.get("private_key", "") or "") + python_bin = str(config.get("python_bin", "python3") or "python3").strip() + node_bin = str(config.get("node_bin", "node") or "node").strip() + + if not host: + return False, "SSH host is required" + if not username: + return False, "SSH username is required" + if not password and not private_key: + return False, "Either password or private_key must be provided" + if not python_bin: + return False, "Python binary is required" + if not node_bin: + return False, "Node.js binary is required" + + try: + port = int(config.get("port", 22) or 22) + except (TypeError, ValueError): + return False, "SSH port must be an integer" + if port <= 0 or port > 65535: + return False, "SSH port must be between 1 and 65535" + + for key in ("timeout", "max_output_bytes", "max_artifacts", "max_artifact_bytes"): + try: + value = int(config.get(key, 0) or 0) + except (TypeError, ValueError): + return False, f"{key} must be an integer" + if key == "max_artifacts": + if value < 0: + return False, "max_artifacts must be greater than or equal to 0" + elif value <= 0: + return False, f"{key} must be greater than 0" + + return True, None + + def _create_ssh_client(self) -> paramiko.SSHClient: + paramiko = _get_paramiko_module() + client = paramiko.SSHClient() + client.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + + connect_kwargs: dict[str, Any] = { + "hostname": self.host, + "port": self.port, + "username": self.username, + "timeout": self.timeout, + "banner_timeout": self.timeout, + "auth_timeout": self.timeout, + "look_for_keys": False, + "allow_agent": False, + } + if self.private_key: + connect_kwargs["pkey"] = self._load_private_key() + if self.password: + connect_kwargs["password"] = self.password + + client.connect(**connect_kwargs) + return client + + def _load_private_key(self) -> paramiko.PKey: + paramiko = _get_paramiko_module() + loaders = ( + paramiko.RSAKey, + paramiko.Ed25519Key, + paramiko.ECDSAKey, + paramiko.DSSKey, + ) + errors: list[str] = [] + private_key_value = self.private_key.strip() + passphrase = self.passphrase or None + + if os.path.exists(private_key_value): + for key_cls in loaders: + try: + return key_cls.from_private_key_file(private_key_value, password=passphrase) + except Exception as exc: + errors.append(str(exc)) + else: + for key_cls in loaders: + try: + return key_cls.from_private_key(io.StringIO(private_key_value), password=passphrase) + except Exception as exc: + errors.append(str(exc)) + + raise SandboxProviderConfigError( + "Failed to load SSH private key. " + "; ".join(error for error in errors if error) + ) + + def _create_remote_workspace(self, client: paramiko.SSHClient) -> str: + base_dir = self.work_dir.rstrip("/") or "/tmp" + template = posixpath.join(base_dir, "ragflow-codeexec.XXXXXX") + stdout, stderr, exit_code = self._run_remote_command( + client, + f"mkdir -p {shlex.quote(base_dir)} && mktemp -d {shlex.quote(template)}", + timeout=min(self.timeout, 10), + ) + if exit_code != 0: + raise RuntimeError( + f"Failed to create remote workspace on {self.host}: {stderr or stdout or 'unknown error'}" + ) + + remote_work_dir = stdout.strip().splitlines()[-1] if stdout.strip() else "" + if not remote_work_dir: + raise RuntimeError("Remote workspace creation did not return a path.") + return remote_work_dir + + def _upload_script( + self, + sftp: paramiko.SFTPClient, + remote_work_dir: str, + language: str, + code: str, + args_json: str, + ) -> tuple[str, str]: + if language == "python": + script_name = "main.py" + script_content = build_python_wrapper(code, args_json) + elif language in {"javascript", "nodejs"}: + script_name = "main.js" + script_content = build_javascript_wrapper(code, args_json) + else: + raise RuntimeError(f"Unsupported language for SSH provider: {language}") + + remote_script_path = posixpath.join(remote_work_dir, script_name) + with sftp.file(remote_script_path, "w") as remote_file: + remote_file.write(script_content) + + command = self._build_execution_command(remote_work_dir, remote_script_path, language) + return remote_script_path, command + + def _build_execution_command(self, remote_work_dir: str, remote_script_path: str, language: str) -> str: + normalized_lang = self._normalize_language(language) + if normalized_lang == "python": + executable = self.python_bin + elif normalized_lang == "nodejs": + executable = self.node_bin + else: + raise RuntimeError(f"Unsupported language for SSH provider: {language}") + + return ( + f"cd {shlex.quote(remote_work_dir)} && " + f"{shlex.quote(executable)} {shlex.quote(remote_script_path)}" + ) + + def _run_remote_command( + self, + client: paramiko.SSHClient, + command: str, + timeout: int, + ) -> tuple[str, str, int]: + stdin, stdout_stream, stderr_stream = client.exec_command(command, timeout=timeout) + stdin.close() + channel = stdout_stream.channel + + stdout_chunks: list[bytes] = [] + stderr_chunks: list[bytes] = [] + deadline = time.time() + timeout + + while True: + while channel.recv_ready(): + stdout_chunks.append(channel.recv(65536)) + while channel.recv_stderr_ready(): + stderr_chunks.append(channel.recv_stderr(65536)) + + if channel.exit_status_ready(): + break + if time.time() > deadline: + channel.close() + raise TimeoutError(f"Execution timed out after {timeout} seconds") + time.sleep(0.1) + + while channel.recv_ready(): + stdout_chunks.append(channel.recv(65536)) + while channel.recv_stderr_ready(): + stderr_chunks.append(channel.recv_stderr(65536)) + + exit_code = channel.recv_exit_status() + stdout = b"".join(stdout_chunks).decode("utf-8", errors="replace") + stderr = b"".join(stderr_chunks).decode("utf-8", errors="replace") + return stdout, stderr, exit_code + + def _validate_output_size(self, stdout: str, stderr: str) -> None: + output_size = len((stdout or "").encode("utf-8")) + len((stderr or "").encode("utf-8")) + if output_size > self.max_output_bytes: + raise RuntimeError(f"SSH execution output exceeded {self.max_output_bytes} bytes.") + + def _collect_artifacts( + self, + sftp: paramiko.SFTPClient, + artifacts_dir: str, + ) -> list[dict[str, Any]]: + artifacts: list[dict[str, Any]] = [] + self._collect_artifacts_recursive(sftp, artifacts_dir, "", artifacts) + return artifacts + + def _collect_artifacts_recursive( + self, + sftp: paramiko.SFTPClient, + current_dir: str, + relative_dir: str, + artifacts: list[dict[str, Any]], + ) -> None: + try: + entries = sftp.listdir_attr(current_dir) + except FileNotFoundError: + return + + for entry in sorted(entries, key=lambda item: item.filename): + name = entry.filename + remote_path = posixpath.join(current_dir, name) + relative_path = posixpath.join(relative_dir, name) if relative_dir else name + mode = entry.st_mode + if mode is None: + mode = sftp.lstat(remote_path).st_mode + if mode is None: + raise RuntimeError(f"Unable to determine artifact entry type: {relative_path}") + + if stat.S_ISLNK(mode): + raise RuntimeError(f"Artifact symlinks are not allowed: {relative_path}") + if stat.S_ISDIR(mode): + self._collect_artifacts_recursive(sftp, remote_path, relative_path, artifacts) + continue + if not stat.S_ISREG(mode): + raise RuntimeError(f"Unsupported artifact entry: {relative_path}") + + if len(artifacts) >= self.max_artifacts: + raise RuntimeError(f"SSH execution produced more than {self.max_artifacts} artifacts.") + + size = int(entry.st_size or 0) + if size > self.max_artifact_bytes: + raise RuntimeError(f"Artifact exceeds {self.max_artifact_bytes} bytes: {relative_path}") + + ext = os.path.splitext(name)[1].lower() + if ext not in ALLOWED_ARTIFACT_EXTENSIONS: + raise RuntimeError(f"Unsupported artifact type: {relative_path}") + + with sftp.file(remote_path, "rb") as artifact_file: + content = artifact_file.read() + + artifacts.append( + { + "name": relative_path, + "content_b64": base64.b64encode(content).decode("ascii"), + "mime_type": mimetypes.guess_type(name)[0] or "application/octet-stream", + "size": size, + } + ) + + @staticmethod + def _normalize_language(language: str) -> str: + lang_lower = (language or "python").lower() + if lang_lower in {"python", "python3"}: + return "python" + if lang_lower in {"javascript", "nodejs"}: + return "nodejs" + return lang_lower + + +def _get_paramiko_module(): + try: + import paramiko + except ImportError as exc: + raise SandboxProviderConfigError( + "paramiko is required for the SSH sandbox provider. Install the project dependencies to enable it." + ) from exc + return paramiko diff --git a/agent/sandbox/sandbox_base_image/nodejs/Dockerfile b/agent/sandbox/sandbox_base_image/nodejs/Dockerfile index fe7b19f7733..21432b818aa 100644 --- a/agent/sandbox/sandbox_base_image/nodejs/Dockerfile +++ b/agent/sandbox/sandbox_base_image/nodejs/Dockerfile @@ -1,6 +1,12 @@ FROM node:24.13-bookworm-slim -RUN npm config set registry https://registry.npmmirror.com +ARG NEED_MIRROR=1 + +RUN if [ "$NEED_MIRROR" = 1 ]; then \ + npm config set registry https://registry.npmmirror.com; \ + else \ + npm config set registry https://registry.npmjs.org; \ + fi # RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.ustc.edu.cn|g' && \ # apt-get update && \ diff --git a/agent/sandbox/sandbox_base_image/python/Dockerfile b/agent/sandbox/sandbox_base_image/python/Dockerfile index 410aad8d15a..585d5c26768 100644 --- a/agent/sandbox/sandbox_base_image/python/Dockerfile +++ b/agent/sandbox/sandbox_base_image/python/Dockerfile @@ -1,7 +1,8 @@ FROM python:3.11-slim-bookworm +ARG NEED_MIRROR=1 + COPY --from=ghcr.io/astral-sh/uv:0.7.5 /uv /uvx /bin/ -ENV UV_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple ENV MPLBACKEND=Agg ENV MPLCONFIGDIR=/tmp/matplotlib ENV MATPLOTLIBRC=/usr/local/etc/matplotlibrc @@ -9,12 +10,18 @@ ENV MATPLOTLIBRC=/usr/local/etc/matplotlibrc COPY requirements.txt . COPY matplotlibrc /usr/local/etc/matplotlibrc -RUN grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.tuna.tsinghua.edu.cn|g' && \ +RUN if [ "$NEED_MIRROR" = 1 ]; then \ + grep -rl 'deb.debian.org' /etc/apt/ | xargs sed -i 's|http[s]*://deb.debian.org|https://mirrors.tuna.tsinghua.edu.cn|g'; \ + export UV_INDEX_URL="https://pypi.tuna.tsinghua.edu.cn/simple"; \ + else \ + export UV_INDEX_URL="https://pypi.org/simple"; \ + fi; \ apt-get update && \ - apt-get install -y curl gcc && \ + apt-get install -y --no-install-recommends curl gcc && \ mkdir -p /tmp/matplotlib && \ - uv pip install --system -r requirements.txt + uv pip install --system -r requirements.txt && \ + rm -rf /var/lib/apt/lists/* WORKDIR /workspace -CMD ["sleep", "infinity"] +CMD ["sleep", "infinity"] \ No newline at end of file diff --git a/agent/sandbox/tests/test_security.py b/agent/sandbox/tests/test_security.py index ed096894e44..dc8d9f80630 100644 --- a/agent/sandbox/tests/test_security.py +++ b/agent/sandbox/tests/test_security.py @@ -45,6 +45,60 @@ def test_javascript_eval_is_rejected(): assert any("eval" in issue.lower() for issue, _ in issues) +def test_javascript_child_process_template_literal_is_rejected(): + """Template literal backticks bypass single/double-quote regex patterns.""" + is_safe, issues = analyze_code_security( + "const cp = require(`child_process`); async function main() { return 'ok'; }", + SupportLanguage.NODEJS, + ) + + assert is_safe is False + assert any("child_process" in issue for issue, _ in issues) + + +def test_javascript_fs_template_literal_is_rejected(): + is_safe, issues = analyze_code_security( + "const fs = require(`fs`); async function main() { return fs.readFileSync('/etc/passwd', 'utf8'); }", + SupportLanguage.NODEJS, + ) + + assert is_safe is False + assert any("fs" in issue for issue, _ in issues) + + +def test_python_builtins_import_is_rejected(): + """builtins module gives access to eval/exec and must be blocked.""" + is_safe, issues = analyze_code_security( + "import builtins\ndef main():\n builtins.eval('1+1')", + SupportLanguage.PYTHON, + ) + + assert is_safe is False + # Pin the specific reason: rejection must come from the new ``builtins`` + # entry in ``DANGEROUS_IMPORTS``, not from some unrelated parse error. + assert any("builtins" in issue for issue, _ in issues), ( + f"expected an issue mentioning 'builtins', got {issues!r}" + ) + + +def test_python_attribute_eval_call_is_rejected(): + """Attribute-style dangerous calls (builtins.eval) must be caught.""" + is_safe, issues = analyze_code_security( + "import builtins\ndef main():\n builtins.exec('import os')", + SupportLanguage.PYTHON, + ) + + assert is_safe is False + # Pin the specific reason: rejection must come from the new + # ``ast.Attribute`` branch in ``visit_Call`` flagging the ``exec`` call, + # not from the ``import builtins`` line above. We assert ``exec`` is in at + # least one finding so the test fails if visit_Call's attribute branch is + # ever reverted. + assert any("exec" in issue for issue, _ in issues), ( + f"expected an issue mentioning 'exec', got {issues!r}" + ) + + def test_javascript_safe_code_still_passes(): is_safe, issues = analyze_code_security( "async function main(args) { return { answer: args.value ?? null }; }", diff --git a/agent/sandbox/uv.lock b/agent/sandbox/uv.lock index 77e39f36ae3..10ceb268a23 100644 --- a/agent/sandbox/uv.lock +++ b/agent/sandbox/uv.lock @@ -383,11 +383,11 @@ wheels = [ [[package]] name = "urllib3" -version = "2.6.3" +version = "2.7.0" source = { registry = "https://pypi.tuna.tsinghua.edu.cn/simple" } -sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/c7/24/5f1b3bdffd70275f6661c76461e25f024d5a38a46f04aaca912426a2b1d3/urllib3-2.6.3.tar.gz", hash = "sha256:1b62b6884944a57dbe321509ab94fd4d3b307075e0c2eae991ac71ee15ad38ed", size = 435556, upload-time = "2026-01-07T16:24:43.925Z" } +sdist = { url = "https://pypi.tuna.tsinghua.edu.cn/packages/53/0c/06f8b233b8fd13b9e5ee11424ef85419ba0d8ba0b3138bf360be2ff56953/urllib3-2.7.0.tar.gz", hash = "sha256:231e0ec3b63ceb14667c67be60f2f2c40a518cb38b03af60abc813da26505f4c", size = 433602, upload-time = "2026-05-07T16:13:18.596Z" } wheels = [ - { url = "https://pypi.tuna.tsinghua.edu.cn/packages/39/08/aaaad47bc4e9dc8c725e68f9d04865dbcb2052843ff09c97b08904852d84/urllib3-2.6.3-py3-none-any.whl", hash = "sha256:bf272323e553dfb2e87d9bfd225ca7b0f467b919d7bbd355436d3fd37cb0acd4", size = 131584, upload-time = "2026-01-07T16:24:42.685Z" }, + { url = "https://pypi.tuna.tsinghua.edu.cn/packages/7f/3e/5db95bcf282c52709639744ca2a8b149baccf648e39c8cc87553df9eae0c/urllib3-2.7.0-py3-none-any.whl", hash = "sha256:9fb4c81ebbb1ce9531cce37674bbc6f1360472bc18ca9a553ede278ef7276897", size = 131087, upload-time = "2026-05-07T16:13:17.151Z" }, ] [[package]] diff --git a/agent/tools/base.py b/agent/tools/base.py index 194b47fceec..71cf2c593e9 100644 --- a/agent/tools/base.py +++ b/agent/tools/base.py @@ -19,11 +19,12 @@ from copy import deepcopy import asyncio from functools import partial +from collections.abc import Mapping from typing import TypedDict, List, Any from agent.component.base import ComponentParamBase, ComponentBase from common.misc_utils import hash_str2int from rag.prompts.generator import kb_prompt -from common.mcp_tool_call_conn import MCPToolCallSession, ToolCallSession +from common.mcp_tool_call_conn import MCPToolBinding, MCPToolCallSession, ToolCallSession from timeit import default_timer as timer @@ -52,16 +53,20 @@ def __init__(self, tools_map: dict[str, object], callback: partial): self.tools_map = tools_map self.callback = callback - def tool_call(self, name: str, arguments: dict[str, Any]) -> Any: - return asyncio.run(self.tool_call_async(name, arguments)) + def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> Any: + return asyncio.run(self.tool_call_async(name, arguments, request_timeout=timeout)) - async def tool_call_async(self, name: str, arguments: dict[str, Any]) -> Any: + async def tool_call_async(self, name: str, arguments: dict[str, Any], request_timeout: float | int = 10) -> Any: assert name in self.tools_map, f"LLM tool {name} does not exist" logging.info(f"[ToolCall] invoke name={name} arguments={str(arguments)[:200]}") + if not isinstance(arguments, Mapping): + raise TypeError(f"Tool arguments for {name} must be an object, got {type(arguments).__name__}") st = timer() tool_obj = self.tools_map[name] - if isinstance(tool_obj, MCPToolCallSession): - resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, 60) + if isinstance(tool_obj, MCPToolBinding): + resp = await thread_pool_exec(tool_obj.session.tool_call, tool_obj.original_name, arguments, request_timeout) + elif isinstance(tool_obj, MCPToolCallSession): + resp = await thread_pool_exec(tool_obj.tool_call, name, arguments, request_timeout) elif hasattr(tool_obj, "invoke_async") and asyncio.iscoroutinefunction(tool_obj.invoke_async): resp = await tool_obj.invoke_async(**arguments) else: diff --git a/agent/tools/code_exec.py b/agent/tools/code_exec.py index ece67d97fc9..3133784e21c 100644 --- a/agent/tools/code_exec.py +++ b/agent/tools/code_exec.py @@ -24,7 +24,7 @@ from typing import Optional from pydantic import BaseModel, Field, field_validator -from strenum import StrEnum +from enum import StrEnum from agent.tools.base import ToolBase, ToolMeta, ToolParamBase from api.db.services.file_service import FileService @@ -37,6 +37,7 @@ { "content", "actual_type", + "attachments", "_ERROR", "_ARTIFACTS", "_ATTACHMENT_CONTENT", @@ -312,7 +313,10 @@ def main() -> dict: self.lang = Language.PYTHON.value self.script = 'def main(arg1: str, arg2: str) -> dict: return {"result": arg1 + arg2}' self.arguments = {} - self.outputs = {"result": {"value": "", "type": "object"}} + self.outputs = { + "result": {"value": "", "type": "object"}, + "attachments": {"value": [], "type": "Array"}, + } def check(self): self.check_valid_value(self.lang, "Support languages", ["python", "python3", "nodejs", "javascript"]) @@ -357,11 +361,21 @@ def _execute_code(self, language: str, code: str, arguments: dict): # Try using the new sandbox provider system first try: from agent.sandbox.client import execute_code as sandbox_execute_code + from agent.sandbox.client import get_provider_info + from agent.sandbox.client import reload_provider from agent.sandbox.providers.base import SandboxProviderConfigError if self.check_if_canceled("CodeExec execution"): return + reload_provider() + provider_info = get_provider_info() + provider_type = provider_info.get("provider_type") or "unknown" + logging.info( + f"[CodeExec]: dispatching execution to sandbox provider '{provider_type}' " + f"(language={language}, timeout={timeout_seconds}s)" + ) + # Execute code using the provider system result = sandbox_execute_code(code=code, language=language, timeout=timeout_seconds, arguments=arguments) @@ -372,7 +386,7 @@ def _execute_code(self, language: str, code: str, arguments: dict): return self._process_execution_result( result.stdout, result.stderr, - "Provider system", + f"Provider system ({provider_type})", artifacts, execution_metadata=result.metadata, ) @@ -384,10 +398,8 @@ def _execute_code(self, language: str, code: str, arguments: dict): # Provider modules are unavailable, fall back to legacy HTTP sandbox. logging.info(f"[CodeExec]: Provider system not available, using HTTP fallback: {provider_error}") except RuntimeError as provider_error: - if not self._should_fallback_to_http(provider_error): - self.set_output("_ERROR", f"Provider system execution failed: {provider_error}") - return self.output() - logging.info(f"[CodeExec]: Provider system not available, using HTTP fallback: {provider_error}") + self.set_output("_ERROR", f"Provider system execution failed: {provider_error}") + return self.output() # Fallback to direct HTTP request code_b64 = self._encode_code(code) @@ -468,11 +480,13 @@ def _process_execution_result( self.set_output("_ARTIFACTS", artifact_urls or None) attachment_text = self._build_attachment_content(artifacts, artifact_urls) self.set_output("_ATTACHMENT_CONTENT", attachment_text) + self.set_output("attachments", self._build_attachment_markdown_list(artifact_urls)) if attachment_text: content_parts.append(attachment_text) else: self.set_output("_ARTIFACTS", None) self.set_output("_ATTACHMENT_CONTENT", "") + self.set_output("attachments", []) self.set_output("content", "\n\n".join([part for part in content_parts if part]).strip()) @@ -496,15 +510,6 @@ def _resolve_execution_result_value(self, stdout: str, execution_metadata: Mappi return metadata.get("result_value"), False return self._deserialize_stdout(stdout), True - @staticmethod - def _should_fallback_to_http(provider_error: RuntimeError) -> bool: - message = str(provider_error).lower() - fallback_markers = ( - "no sandbox provider configured", - "sandbox provider type not configured", - ) - return any(marker in message for marker in fallback_markers) - @classmethod def _ensure_bucket_lifecycle(cls): if cls._lifecycle_configured: @@ -641,6 +646,23 @@ def _build_attachment_content(self, artifacts: list, artifact_urls: list[dict] | return f"attachment_count: {len(sections)}\n\n" + "\n\n".join(sections) return "attachment_count: 0" + def _build_attachment_markdown_list(self, artifact_urls: list[dict]) -> list[str]: + markdown_items = [] + for art in artifact_urls: + name = _art_field(art, "name") + url = _art_field(art, "url") + mime_type = str(_art_field(art, "mime_type") or "").strip().lower() + if not name: + continue + + if mime_type.startswith("image/") and url: + markdown_items.append(f"![{name}]({url})") + elif url: + markdown_items.append(f"[Download {name}]({url})") + else: + markdown_items.append(name) + return markdown_items + def _normalize_attachment_type(self, name: str, mime_type: str) -> str: mime_type = str(mime_type or "").strip().lower() if mime_type.startswith("image/"): diff --git a/agent/tools/exesql.py b/agent/tools/exesql.py index ea4ca34b837..e1b586af98a 100644 --- a/agent/tools/exesql.py +++ b/agent/tools/exesql.py @@ -64,9 +64,9 @@ def check(self): self.check_positive_integer(self.max_records, "Maximum number of records") if self.database == "rag_flow": if self.host == "ragflow-mysql": - raise ValueError("For the security reason, it dose not support database named rag_flow.") + raise ValueError("For the security reason, it does not support database named rag_flow.") if self.password == "infini_rag_flow": - raise ValueError("For the security reason, it dose not support database named rag_flow.") + raise ValueError("For the security reason, it does not support database named rag_flow.") def get_input_form(self) -> dict[str, dict]: return { diff --git a/agent/tools/retrieval.py b/agent/tools/retrieval.py index 4496f497aef..02cb3e2ce6d 100644 --- a/agent/tools/retrieval.py +++ b/agent/tools/retrieval.py @@ -142,6 +142,11 @@ def _load_metas() -> dict: return DocMetadataService.get_flatted_meta_by_kbs(kb_ids) def _resolve_manual_filter(flt: dict) -> dict: + # Return a new dict instead of mutating `flt` in place. The + # caller passes filters straight out of self._param.meta_data_filter, + # so mutating them would replace the variable reference with its + # resolved value and every subsequent invocation (e.g. inside an + # Iteration component) would reuse that stale value. pat = re.compile(self.variable_ref_patt) s = flt.get("value", "") out_parts = [] @@ -167,8 +172,9 @@ def _resolve_manual_filter(flt: dict) -> dict: last = m.end() out_parts.append(s[last:]) - flt["value"] = "".join(out_parts) - return flt + resolved = dict(flt) + resolved["value"] = "".join(out_parts) + return resolved chat_mdl = None if self._param.meta_data_filter.get("method") in ["auto", "semi_auto"]: @@ -201,6 +207,7 @@ def _resolve_manual_filter(flt: dict) -> dict: self._param.top_n, self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight, + top=self._param.top_k, doc_ids=doc_ids, aggs=True, rerank_mdl=rerank_mdl, diff --git a/api/apps/__init__.py b/api/apps/__init__.py index e05bbb03d42..6df12f47a83 100644 --- a/api/apps/__init__.py +++ b/api/apps/__init__.py @@ -56,6 +56,7 @@ def _unauthorized_message(error): except Exception: return UNAUTHORIZED_MESSAGE + app = Quart(__name__) app = cors(app, allow_origin="*") @@ -92,19 +93,53 @@ def _unauthorized_message(error): P = ParamSpec("P") +def _load_user_from_session(): + """Resolve the current user from the session cookie set by ``login_user()``. + + OAuth/OIDC callbacks call ``login_user(user)`` which writes ``_user_id`` + into the session. The frontend's response interceptor wipes the + Authorization header from localStorage on the first 401, so post-redirect + requests can arrive with no header at all — we still want to honour the + server-side session in that window. + + The same access-token validity rules used by the JWT path are applied + here so that tokens revoked by ``logout`` (which rewrites the column to + ``INVALID_``) or shortened by data corruption can't keep a stale + session authenticated. + """ + user_id = session.get("_user_id") + if not user_id: + return None + try: + users = UserService.query(id=user_id, status=StatusEnum.VALID.value) + except Exception: + logging.exception("load_user from session failed") + return None + if not users: + return None + user = users[0] + access_token = str(user.access_token or "").strip() + if not access_token or len(access_token) < 32 or access_token.startswith("INVALID_"): + return None + logging.debug("Authenticated request via session fallback for user_id=%s", user_id) + g.user = user + return user + + def _load_user(): jwt = Serializer(secret_key=settings.get_secret_key()) authorization = request.headers.get("Authorization") g.user = None + g.auth_via_api_token = False if not authorization: - return None + return _load_user_from_session() # Extract auth_token based on whether Authorization starts with "bearer" (case-insensitive) if authorization.lower().startswith("bearer "): parts = authorization.split(maxsplit=1) if len(parts) < 2: logging.warning("Authorization header has invalid bearer format") - return None + return _load_user_from_session() auth_token = parts[1] else: auth_token = authorization @@ -115,20 +150,20 @@ def _load_user(): if not access_token or not access_token.strip(): logging.warning("Authentication attempt with empty access token") - return None + return _load_user_from_session() if len(access_token.strip()) < 32: logging.warning(f"Authentication attempt with invalid token format: {len(access_token)} chars") - return None + return _load_user_from_session() user = UserService.query(access_token=access_token, status=StatusEnum.VALID.value) if user: if not user[0].access_token or not user[0].access_token.strip(): logging.warning(f"User {user[0].email} has empty access_token in database") - return None + return _load_user_from_session() g.user = user[0] return user[0] - return None + return _load_user_from_session() except Exception as e_jwt: logging.warning(f"load_user from jwt got exception {e_jwt}") @@ -140,7 +175,8 @@ def _load_user(): if user: if not user[0].access_token or not user[0].access_token.strip(): logging.warning(f"User {user[0].email} has empty access_token in database") - return None + return _load_user_from_session() + g.auth_via_api_token = True g.user = user[0] return user[0] logging.warning(f"load_user: No user found for tenant_id={objs[0].tenant_id} from APIToken") @@ -149,7 +185,7 @@ def _load_user(): except Exception as e_api_token: logging.warning(f"load_user from api token got exception {e_api_token}") - return None + return _load_user_from_session() current_user = LocalProxy(_load_user) @@ -251,16 +287,10 @@ def logout_user(): def search_pages_path(page_path): - app_path_list = [ - path for path in page_path.glob("*_app.py") if not path.name.startswith(".") - ] - api_path_list = [ - path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".") - ] + app_path_list = [path for path in page_path.glob("*_app.py") if not path.name.startswith(".")] + api_path_list = [path for path in page_path.glob("*sdk/*.py") if not path.name.startswith(".")] app_path_list.extend(api_path_list) - restful_api_path_list = [ - path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".") - ] + restful_api_path_list = [path for path in page_path.glob("*restful_apis/*.py") if not path.name.startswith(".")] app_path_list.extend(restful_api_path_list) return app_path_list @@ -269,9 +299,7 @@ def register_page(page_path): path = f"{page_path}" page_name = page_path.stem.removesuffix("_app") - module_name = ".".join( - page_path.parts[page_path.parts.index("api"): -1] + (page_name,) - ) + module_name = ".".join(page_path.parts[page_path.parts.index("api") : -1] + (page_name,)) spec = spec_from_file_location(module_name, page_path) page = module_from_spec(spec) @@ -282,9 +310,7 @@ def register_page(page_path): page_name = getattr(page, "page_name", page_name) sdk_path = "\\sdk\\" if sys.platform.startswith("win") else "/sdk/" restful_api_path = "\\restful_apis\\" if sys.platform.startswith("win") else "/restful_apis/" - url_prefix = ( - f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}" - ) + url_prefix = f"/api/{API_VERSION}" if sdk_path in path or restful_api_path in path else f"/{API_VERSION}/{page_name}" app.register_blueprint(page.manager, url_prefix=url_prefix) return url_prefix @@ -297,12 +323,11 @@ def register_page(page_path): Path(__file__).parent.parent / "api" / "apps" / "sdk", ] -client_urls_prefix = [ - register_page(path) for directory in pages_dir for path in search_pages_path(directory) -] +client_urls_prefix = [register_page(path) for directory in pages_dir for path in search_pages_path(directory)] # Register backward compatibility routes for deprecated APIs from api.apps.backward_compat import register_backward_compat_routes + register_backward_compat_routes(app) @@ -336,6 +361,7 @@ async def unauthorized_werkzeug(error): logging.warning("Unauthorized request (werkzeug)") return get_json_result(code=error.code, message=error.description), RetCode.UNAUTHORIZED + @app.teardown_request def _db_close(exception): if exception: diff --git a/api/apps/backward_compat.py b/api/apps/backward_compat.py index a2c950158e6..feaedc6d60e 100644 --- a/api/apps/backward_compat.py +++ b/api/apps/backward_compat.py @@ -22,8 +22,15 @@ Deprecated APIs and their replacements: - POST /api/v1/agents/{agent_id}/completions -> POST /api/v1/agents/chat/completion +- POST /api/v1/agents_openai/{agent_id}/chat/completions -> POST /api/v1/agents/chat/completions - POST /api/v1/chats/{chat_id}/completions -> POST /api/v1/chat/completions - POST /api/v1/chats_openai/{chat_id}/chat/completions -> POST /api/v1/openai/{chat_id}/chat/completions +- GET /api/v1/datasets/{dataset_id}/knowledge_graph -> GET /api/v1/datasets/{dataset_id}/graph +- DELETE /api/v1/datasets/{dataset_id}/knowledge_graph -> DELETE /api/v1/datasets/{dataset_id}/graph +- POST /api/v1/datasets/{dataset_id}/run_graphrag -> POST /api/v1/datasets/{dataset_id}/index?type=graph +- GET /api/v1/datasets/{dataset_id}/trace_graphrag -> GET /api/v1/datasets/{dataset_id}/index?type=graph +- POST /api/v1/datasets/{dataset_id}/run_raptor -> POST /api/v1/datasets/{dataset_id}/index?type=raptor +- GET /api/v1/datasets/{dataset_id}/trace_raptor -> GET /api/v1/datasets/{dataset_id}/index?type=raptor - PUT /api/v1/chats/{chat_id}/sessions/{session_id} -> PATCH /api/v1/chats/{chat_id}/sessions/{session_id} - DELETE /api/v1/chats -> DELETE /api/v1/chats/{chat_id} (with body) - POST /api/v1/file/convert -> POST /api/v1/files/link-to-datasets @@ -41,16 +48,21 @@ from quart import Blueprint, jsonify, request from api.apps import login_required -from api.apps.restful_apis import chat_api, file_api, file2document_api, chunk_api, openai_api, document_api +from api.apps.restful_apis import agent_api, chat_api, chunk_api, dataset_api, document_api, file2document_api, file_api, openai_api from api.apps.restful_apis.system_api import run_health_checks -from api.apps.restful_apis import agent_api -from api.apps.services import file_api_service -from api.utils.api_utils import get_data_error_result, get_json_result, add_tenant_id_to_kwargs +from api.apps.services import dataset_api_service, file_api_service +from api.utils.api_utils import add_tenant_id_to_kwargs, get_data_error_result, get_json_result, get_request_json manager = Blueprint("backward_compat", __name__) legacy_v1_manager = Blueprint("backward_compat_legacy_v1", __name__) +def _index_result(success, result): + if success: + return get_json_result(data=result) + return get_data_error_result(message=result) + + # ============================================================================= # System APIs # ============================================================================= @@ -110,6 +122,137 @@ async def deprecated_openai_chat_completions(chat_id): return await openai_api.openai_chat_completions(chat_id) +@manager.route("/agents_openai//chat/completions", methods=["POST"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_agents_openai_chat_completions(agent_id, tenant_id=None): + """ + Deprecated: Use POST /api/v1/agents/chat/completions with openai-compatible=true instead. + + Old path: POST /api/v1/agents_openai/{agent_id}/chat/completions + New path: POST /api/v1/agents/chat/completions + """ + logging.warning( + "API endpoint /api/v1/agents_openai/%s/chat/completions is deprecated. " + "Please use /api/v1/agents/chat/completions with `openai-compatible` instead.", + agent_id, + ) + req = dict(await get_request_json()) + req["openai-compatible"] = True + request._cached_payload = req + return await agent_api.agent_chat_completion(tenant_id=tenant_id, agent_id=agent_id) + + +# ============================================================================= +# Dataset Graph and Index APIs +# ============================================================================= + +@manager.route("/datasets//knowledge_graph", methods=["GET"]) +@login_required +async def deprecated_get_knowledge_graph(dataset_id): + """ + Deprecated: Use GET /api/v1/datasets/{dataset_id}/graph instead. + + Old path: GET /api/v1/datasets/{dataset_id}/knowledge_graph + New path: GET /api/v1/datasets/{dataset_id}/graph + """ + logging.warning( + "API endpoint /api/v1/datasets/%s/knowledge_graph is deprecated. " + "Please use /api/v1/datasets/%s/graph instead.", + dataset_id, dataset_id, + ) + return await dataset_api.get_knowledge_graph(dataset_id=dataset_id) + + +@manager.route("/datasets//knowledge_graph", methods=["DELETE"]) +@login_required +async def deprecated_delete_knowledge_graph(dataset_id): + """ + Deprecated: Use DELETE /api/v1/datasets/{dataset_id}/graph instead. + + Old path: DELETE /api/v1/datasets/{dataset_id}/knowledge_graph + New path: DELETE /api/v1/datasets/{dataset_id}/graph + """ + logging.warning( + "API endpoint DELETE /api/v1/datasets/%s/knowledge_graph is deprecated. " + "Please use DELETE /api/v1/datasets/%s/graph instead.", + dataset_id, dataset_id, + ) + return await dataset_api.delete_knowledge_graph(dataset_id=dataset_id) + + +@manager.route("/datasets//run_graphrag", methods=["POST"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_run_graphrag(dataset_id, tenant_id=None): + """ + Deprecated: Use POST /api/v1/datasets/{dataset_id}/index?type=graph instead. + + Old path: POST /api/v1/datasets/{dataset_id}/run_graphrag + New path: POST /api/v1/datasets/{dataset_id}/index?type=graph + """ + logging.warning( + "API endpoint /api/v1/datasets/%s/run_graphrag is deprecated. " + "Please use /api/v1/datasets/%s/index?type=graph instead.", + dataset_id, dataset_id, + ) + return _index_result(*dataset_api_service.run_index(dataset_id, tenant_id, "graph")) + + +@manager.route("/datasets//trace_graphrag", methods=["GET"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_trace_graphrag(dataset_id, tenant_id=None): + """ + Deprecated: Use GET /api/v1/datasets/{dataset_id}/index?type=graph instead. + + Old path: GET /api/v1/datasets/{dataset_id}/trace_graphrag + New path: GET /api/v1/datasets/{dataset_id}/index?type=graph + """ + logging.warning( + "API endpoint /api/v1/datasets/%s/trace_graphrag is deprecated. " + "Please use /api/v1/datasets/%s/index?type=graph instead.", + dataset_id, dataset_id, + ) + return _index_result(*dataset_api_service.trace_index(dataset_id, tenant_id, "graph")) + + +@manager.route("/datasets//run_raptor", methods=["POST"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_run_raptor(dataset_id, tenant_id=None): + """ + Deprecated: Use POST /api/v1/datasets/{dataset_id}/index?type=raptor instead. + + Old path: POST /api/v1/datasets/{dataset_id}/run_raptor + New path: POST /api/v1/datasets/{dataset_id}/index?type=raptor + """ + logging.warning( + "API endpoint /api/v1/datasets/%s/run_raptor is deprecated. " + "Please use /api/v1/datasets/%s/index?type=raptor instead.", + dataset_id, dataset_id, + ) + return _index_result(*dataset_api_service.run_index(dataset_id, tenant_id, "raptor")) + + +@manager.route("/datasets//trace_raptor", methods=["GET"]) +@login_required +@add_tenant_id_to_kwargs +async def deprecated_trace_raptor(dataset_id, tenant_id=None): + """ + Deprecated: Use GET /api/v1/datasets/{dataset_id}/index?type=raptor instead. + + Old path: GET /api/v1/datasets/{dataset_id}/trace_raptor + New path: GET /api/v1/datasets/{dataset_id}/index?type=raptor + """ + logging.warning( + "API endpoint /api/v1/datasets/%s/trace_raptor is deprecated. " + "Please use /api/v1/datasets/%s/index?type=raptor instead.", + dataset_id, dataset_id, + ) + return _index_result(*dataset_api_service.trace_index(dataset_id, tenant_id, "raptor")) + + # ============================================================================= # Chat Session APIs # ============================================================================= diff --git a/api/apps/llm_app.py b/api/apps/llm_app.py index 583e05af7c9..430b7d8dc36 100644 --- a/api/apps/llm_app.py +++ b/api/apps/llm_app.py @@ -25,8 +25,6 @@ from api.utils.api_utils import get_allowed_llm_factories, get_data_error_result, get_json_result, get_request_json, server_error_response, validate_request from common.constants import StatusEnum, LLMType from api.db.db_models import TenantLLM -from rag.utils.base64_image import test_image -from rag.llm import EmbeddingModel, ChatModel, RerankModel, CvModel, TTSModel, OcrModel, Seq2txtModel def _resolve_my_llm_is_tools(o_dict: dict) -> bool: @@ -78,6 +76,8 @@ def factories(): @validate_request("llm_factory", "api_key") async def set_api_key(): req = await get_request_json() + from rag.llm import ChatModel, EmbeddingModel, RerankModel + # test if api key works chat_passed, embd_passed, rerank_passed = False, False, False factory = req["llm_factory"] @@ -178,21 +178,68 @@ async def check_streamly(): @validate_request("llm_factory") async def add_llm(): req = await get_request_json() + from rag.llm import ChatModel, CvModel, EmbeddingModel, OcrModel, RerankModel, Seq2txtModel, TTSModel + factory = req["llm_factory"] - api_key = req.get("api_key", "x") llm_name = req.get("llm_name") timeout_seconds = int(os.environ.get("LLM_TIMEOUT_SECONDS", 10)) if factory not in [f.name for f in get_allowed_llm_factories()]: return get_data_error_result(message=f"LLM factory {factory} is not allowed") + # When editing an existing model the frontend leaves the api_key input blank + # and strips it from the payload, so req["api_key"] is missing. Without a + # fallback the validation below would run with the "x" placeholder and the + # upstream provider would return "Your API key is invalid" — recover the + # saved key from DB. Use only the *decoded* api_key (never the raw JSON + # payload) so factories that pack extra fields into api_key + # (OpenRouter, Bedrock, …) can rebuild their JSON correctly with whatever + # new fields the user did provide via apikey_json. + if req.get("api_key") is None and llm_name: + _LLM_NAME_SUFFIX = { + "LocalAI": "___LocalAI", + "HuggingFace": "___HuggingFace", + "OpenAI-API-Compatible": "___OpenAI-API", + "VLLM": "___VLLM", + } + saved_llm_name = llm_name + _LLM_NAME_SUFFIX.get(factory, "") + logging.debug( + "add_llm: attempting api_key recovery factory=%s llm_name=%s saved_llm_name=%s tenant_id=%s", + factory, llm_name, saved_llm_name, current_user.id, + ) + existing_llms = TenantLLMService.query( + tenant_id=current_user.id, + llm_factory=factory, + llm_name=saved_llm_name, + ) + logging.debug( + "add_llm: api_key recovery query matched=%d factory=%s saved_llm_name=%s", + len(existing_llms) if existing_llms else 0, factory, saved_llm_name, + ) + if existing_llms: + existing_api_key, _, _ = TenantLLMService._decode_api_key_config( + existing_llms[0].api_key + ) + logging.debug( + "add_llm: api_key recovery decoded=%s factory=%s saved_llm_name=%s", + "present" if existing_api_key else "absent", factory, saved_llm_name, + ) + if existing_api_key: + req["api_key"] = existing_api_key + logging.info( + "add_llm: recovered saved api_key from existing record factory=%s saved_llm_name=%s tenant_id=%s", + factory, saved_llm_name, current_user.id, + ) + + api_key = req.get("api_key", "x") + def apikey_json(keys): nonlocal req return json.dumps({k: req.get(k, "") for k in keys}) if factory == "VolcEngine": # For VolcEngine, due to its special authentication method - # Assemble ark_api_key endpoint_id into api_key + # Assemble ark_api_key model_id into api_key; keep endpoint_id in backend payload for compatibility api_key = apikey_json(["ark_api_key", "endpoint_id"]) elif factory == "Tencent Cloud": @@ -248,19 +295,6 @@ def apikey_json(keys): elif factory == "OpenDataLoader": api_key = apikey_json(["api_key", "provider_order"]) - existing_llm = None - existing_api_key = None - if req.get("api_key") is None: - existing_llms = TenantLLMService.query(tenant_id=current_user.id, llm_factory=factory, llm_name=llm_name) - if existing_llms: - existing_llm = existing_llms[0] - existing_api_key, _, existing_api_key_payload = TenantLLMService._decode_api_key_config(existing_llm.api_key) - if existing_api_key_payload is not None: - existing_api_key = existing_api_key_payload - - if req.get("api_key") is None: - api_key = existing_api_key if existing_api_key is not None else "x" - llm = { "tenant_id": current_user.id, "llm_factory": factory, @@ -326,11 +360,13 @@ async def check_streamly(): if len(arr) == 0: raise Exception("Not known.") except KeyError: - msg += f"{factory} dose not support this model({factory}/{mdl_nm})" + msg += f"{factory} does not support this model({factory}/{mdl_nm})" except Exception as e: msg += f"\nFail to access model({factory}/{mdl_nm})." + str(e) case LLMType.IMAGE2TEXT.value: + from rag.utils.base64_image import test_image + assert factory in CvModel, f"Image to text model from {factory} is not supported yet." mdl = CvModel[factory](key=model_api_key, model_name=mdl_nm, base_url=model_base_url) try: diff --git a/api/apps/restful_apis/agent_api.py b/api/apps/restful_apis/agent_api.py index c0c6c604af7..3e6226158d9 100644 --- a/api/apps/restful_apis/agent_api.py +++ b/api/apps/restful_apis/agent_api.py @@ -29,9 +29,6 @@ import jwt from quart import Response, jsonify, request -from agent.canvas import Canvas -from agent.component import LLM -from agent.dsl_migration import normalize_chunker_dsl from api.apps import current_user, login_required from api.apps.services.canvas_replica_service import CanvasReplicaService from api.db import CanvasCategory @@ -60,12 +57,10 @@ validate_request, ) from common import settings +from common.ssrf_guard import assert_host_is_safe from common.constants import RetCode from common.misc_utils import get_uuid, thread_pool_exec from peewee import MySQLDatabase, PostgresqlDatabase -from rag.flow.pipeline import Pipeline -from rag.nlp import search -from rag.utils.redis_conn import REDIS_CONN def _require_canvas_access_sync(func): @@ -113,9 +108,23 @@ def _build_sse_response(body): return resp +def _normalize_agent_reference_entry(reference): + if not isinstance(reference, dict): + return {"chunks": [], "doc_aggs": []} + if "chunks" in reference or "doc_aggs" in reference: + return { + "chunks": reference.get("chunks", []), + "doc_aggs": reference.get("doc_aggs", []), + } + return { + "chunks": reference.get("reference", reference.get("chunks", [])) or [], + "doc_aggs": reference.get("doc_aggs", []) or [], + } + + def _normalize_agent_session(conv): - conv["messages"] = conv.pop("message") - for info in conv["messages"]: + conv["message"] = conv.get("message", []) + for info in conv["message"]: if "prompt" in info: info.pop("prompt") conv["agent_id"] = conv.pop("dialog_id") @@ -124,11 +133,15 @@ def _normalize_agent_session(conv): conv["reference"] = [conv["reference"]] else: conv["reference"] = [value for _, value in sorted(conv["reference"].items(), key=lambda item: int(item[0]))] + elif isinstance(conv["reference"], list): + conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in conv["reference"]] + else: + conv["reference"] = [] if conv["reference"]: - messages = [message for i, message in enumerate(conv["messages"]) if i != 0 and message["role"] != "user"] + messages = [message for i, message in enumerate(conv["message"]) if i != 0 and message["role"] != "user"] for message, reference in zip(messages, conv["reference"]): - chunks = reference["chunks"] + chunks = reference.get("chunks", []) message["reference"] = [ { "id": chunk.get("chunk_id", chunk.get("id")), @@ -149,6 +162,171 @@ def _agent_session_list_result(data, total): return jsonify({"code": RetCode.SUCCESS, "message": "success", "data": data, "total": total}) +async def _run_workflow_session( + tenant_id, + agent_id, + workflow_conv, + canvas, + query, + files, + inputs, + user_id, + session_id, + custom_header, + canvas_title, + canvas_category, + return_trace, + stream, +): + async def commit_runtime_replica(): + commit_ok = CanvasReplicaService.commit_after_run( + canvas_id=agent_id, + tenant_id=str(tenant_id), + runtime_user_id=user_id, + dsl=json.loads(str(canvas)), + canvas_category=canvas_category, + title=canvas_title, + ) + if not commit_ok: + logging.error( + "Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s", + agent_id, + tenant_id, + user_id, + ) + + workflow_conv.setdefault("message", []) + if isinstance(workflow_conv.get("reference"), dict): + if "chunks" in workflow_conv["reference"]: + workflow_conv["reference"] = [workflow_conv["reference"]] + else: + workflow_conv["reference"] = [ + value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0])) + ] + elif not isinstance(workflow_conv.get("reference"), list): + workflow_conv["reference"] = [] + workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]] + + turn_id = workflow_conv["message"][-1].get("id") if workflow_conv["message"] else get_uuid() + full_content = "" + reference = {} + final_ans = {} + trace_items = [] + structured_output = {} + + async def persist_workflow_session(): + if not final_ans: + return + workflow_conv["message"].append( + { + "role": "assistant", + "content": full_content, + "created_at": time.time(), + "id": turn_id, + } + ) + workflow_conv["reference"].append(_normalize_agent_reference_entry(reference)) + workflow_conv["dsl"] = json.loads(str(canvas)) + workflow_conv["source"] = workflow_conv.get("source") or "workflow" + await thread_pool_exec(API4ConversationService.append_message, session_id, workflow_conv) + await commit_runtime_replica() + + if stream: + + async def sse(): + nonlocal full_content, reference, final_ans, trace_items, structured_output + done_sent = False + try: + async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): + ans["session_id"] = session_id + if ans.get("event") == "message": + full_content += ans.get("data", {}).get("content", "") + if ans.get("data", {}).get("reference", None): + reference.update(ans["data"]["reference"]) + if ans.get("event") == "node_finished": + data = ans.get("data", {}) + node_out = data.get("outputs", {}) + component_id = data.get("component_id") + if component_id is not None and "structured" in node_out: + structured_output[component_id] = copy.deepcopy(node_out["structured"]) + if return_trace: + trace_items.append( + { + "component_id": data.get("component_id"), + "trace": [copy.deepcopy(data)], + } + ) + final_ans = ans + yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" + + if final_ans: + if "data" not in final_ans or not isinstance(final_ans["data"], dict): + final_ans["data"] = {} + final_ans["data"]["content"] = full_content + final_ans["data"]["reference"] = reference + if structured_output: + final_ans["data"]["structured"] = structured_output + if trace_items: + final_ans["data"]["trace"] = trace_items + await persist_workflow_session() + except Exception as exc: + logging.exception(exc) + canvas.cancel_task() + yield ( + "data:" + + json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False) + + "\n\n" + ) + finally: + if not done_sent: + done_sent = True + yield "data:[DONE]\n\n" + + return _build_sse_response(sse()) + + try: + async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): + ans["session_id"] = session_id + if ans.get("event") == "message": + full_content += ans.get("data", {}).get("content", "") + if ans.get("data", {}).get("reference", None): + reference.update(ans["data"]["reference"]) + if ans.get("event") == "node_finished": + data = ans.get("data", {}) + node_out = data.get("outputs", {}) + component_id = data.get("component_id") + if component_id is not None and "structured" in node_out: + structured_output[component_id] = copy.deepcopy(node_out["structured"]) + if return_trace: + trace_items.append( + { + "component_id": data.get("component_id"), + "trace": [copy.deepcopy(data)], + } + ) + final_ans = ans + except Exception as exc: + logging.exception(exc) + canvas.cancel_task() + return get_result(data=f"**ERROR**: {str(exc)}") + + if not final_ans: + await commit_runtime_replica() + return get_result(data={}) + + if "data" not in final_ans or not isinstance(final_ans["data"], dict): + final_ans["data"] = {} + final_ans["data"]["content"] = full_content + final_ans["data"]["reference"] = reference + if structured_output: + final_ans["data"]["structured"] = structured_output + if trace_items: + final_ans["data"]["trace"] = trace_items + + await persist_workflow_session() + return get_result(data=final_ans) + + @manager.route("/agents//sessions", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs @@ -194,6 +372,8 @@ def list_agent_sessions(agent_id, tenant_id): @add_tenant_id_to_kwargs @_require_canvas_access_async async def create_agent_session(agent_id, tenant_id): + from agent.canvas import Canvas + req = await get_request_json() user_id = req.get("user_id") or request.args.get("user_id", tenant_id) release_mode = bool(req.get("release", request.args.get("release", False))) @@ -245,10 +425,12 @@ def delete_agent_session_item(agent_id, session_id, tenant_id): @manager.route("/agents/download", methods=["GET"]) # noqa: F821 -async def download_agent_file(): +@login_required +@add_tenant_id_to_kwargs +async def download_agent_file(tenant_id): id = request.args.get("id") - created_by = request.args.get("created_by") - blob = FileService.get_blob(created_by, id) + logging.info("Agent file download requested: tenant_id=%s file_id=%s", tenant_id, id) + blob = await thread_pool_exec(FileService.get_blob, tenant_id, id) return Response(blob) @@ -316,6 +498,7 @@ def list_agents(tenant_id): keywords = request.args.get("keywords", "") canvas_category = request.args.get("canvas_category") owner_ids = [item for item in request.args.get("owner_ids", "").strip().split(",") if item] + tags = [item for item in request.args.get("tags", "").strip().split(",") if item] page_number = int(request.args.get("page", 0)) items_per_page = int(request.args.get("page_size", 0)) @@ -347,11 +530,71 @@ def list_agents(tenant_id): desc, keywords, canvas_category, + tags, ) return get_json_result(data={"canvas": canvas, "total": total}) +@manager.route("/agents/tags", methods=["GET"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +def list_agent_tags(tenant_id): + """Aggregate tag usage counts across agents visible to the caller.""" + canvas_category = request.args.get("canvas_category") + tenants = TenantService.get_joined_tenants_by_user_id(tenant_id) + joined_ids = list({member["tenant_id"] for member in tenants} | {tenant_id}) + counts = UserCanvasService.list_tags(joined_ids, tenant_id, canvas_category) + logging.info( + "list_agent_tags tenant=%s canvas_category=%s tags_count=%d", + tenant_id, + canvas_category, + len(counts), + ) + return get_json_result(data=[{"tag": k, "count": v} for k, v in sorted(counts.items(), key=lambda x: (-x[1], x[0]))]) + + +@manager.route("/agents//tags", methods=["PUT"]) # noqa: F821 +@login_required +@add_tenant_id_to_kwargs +async def update_agent_tags(tenant_id, canvas_id): + if not UserCanvasService.accessible(canvas_id, tenant_id): + logging.info( + "update_agent_tags denied tenant=%s canvas_id=%s reason=no_permission", + tenant_id, + canvas_id, + ) + return get_json_result( + data=False, + message="Agent not found or no permission.", + code=RetCode.OPERATING_ERROR, + ) + req = await get_request_json() + tags = req.get("tags", "") + incoming = tags if isinstance(tags, (list, tuple)) else [t for t in str(tags).split(",") if t.strip()] + rows_affected = UserCanvasService.update_tags(canvas_id, tags) + if rows_affected == 0: + logging.info( + "update_agent_tags miss tenant=%s canvas_id=%s incoming_count=%d rows=0", + tenant_id, + canvas_id, + len(incoming), + ) + return get_json_result( + data=False, + message="Agent not found or no permission.", + code=RetCode.OPERATING_ERROR, + ) + logging.info( + "update_agent_tags ok tenant=%s canvas_id=%s incoming_count=%d rows=%d", + tenant_id, + canvas_id, + len(incoming), + rows_affected, + ) + return get_json_result(data=True) + + @manager.route("/agents", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs @@ -421,22 +664,34 @@ async def create_agent(tenant_id): @manager.route("/agents//upload", methods=["POST"]) # noqa: F821 -async def upload_agent_file(agent_id): - exists, canvas = UserCanvasService.get_by_canvas_id(agent_id) - if not exists: - return get_data_error_result(message="canvas not found.") - - user_id = canvas["user_id"] +@login_required +@add_tenant_id_to_kwargs +@_require_canvas_access_async +async def upload_agent_file(agent_id, tenant_id): files = await request.files file_objs = files.getlist("file") if files and files.get("file") else [] + logging.info( + "Agent file upload requested: tenant_id=%s agent_id=%s file_count=%s", + tenant_id, + agent_id, + len(file_objs), + ) try: if len(file_objs) == 1: - return get_json_result( - data=FileService.upload_info(user_id, file_objs[0], request.args.get("url")) + uploaded = await thread_pool_exec( + FileService.upload_info, tenant_id, file_objs[0], request.args.get("url") ) - results = [FileService.upload_info(user_id, file_obj) for file_obj in file_objs] + return get_json_result(data=uploaded) + results = await asyncio.gather( + *(thread_pool_exec(FileService.upload_info, tenant_id, file_obj) for file_obj in file_objs) + ) return get_json_result(data=results) except Exception as exc: + logging.exception( + "Agent file upload failed: tenant_id=%s agent_id=%s", + tenant_id, + agent_id, + ) return server_error_response(exc) @@ -446,6 +701,8 @@ async def upload_agent_file(agent_id): @_require_canvas_access_sync def get_agent_component_input_form(agent_id, component_id, tenant_id): try: + from agent.canvas import Canvas + exists, user_canvas = UserCanvasService.get_by_id(agent_id) if not exists: return get_data_error_result(message="canvas not found.") @@ -463,6 +720,9 @@ def get_agent_component_input_form(agent_id, component_id, tenant_id): async def debug_agent_component(agent_id, component_id, tenant_id): req = await get_request_json() try: + from agent.canvas import Canvas + from agent.component import LLM + _, user_canvas = UserCanvasService.get_by_id(agent_id) canvas = Canvas(json.dumps(user_canvas.dsl), tenant_id, canvas_id=user_canvas.id) canvas.reset() @@ -521,6 +781,8 @@ def get_agent(agent_id, tenant_id): released_versions.sort(key=lambda version: version.update_time, reverse=True) last_publish_time = released_versions[0].update_time + from agent.dsl_migration import normalize_chunker_dsl + canvas["dsl"] = normalize_chunker_dsl(canvas.get("dsl", {})) canvas["last_publish_time"] = last_publish_time @@ -563,14 +825,17 @@ def get_agent_version(agent_id, version_id, tenant_id): @manager.route("/agents//logs/", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -@_require_canvas_access_sync -def get_agent_logs(agent_id, message_id, tenant_id): +@_require_canvas_access_async +async def get_agent_logs(agent_id, message_id, tenant_id): try: - binary = REDIS_CONN.get(f"{agent_id}-{message_id}-logs") + from rag.utils.redis_conn import REDIS_CONN + + binary = await thread_pool_exec(REDIS_CONN.get, f"{agent_id}-{message_id}-logs") if not binary: return get_json_result(data={}) - return get_json_result(data=json.loads(binary.encode("utf-8"))) + payload = binary.decode("utf-8") if isinstance(binary, bytes) else binary + return get_json_result(data=json.loads(payload)) except Exception as exc: logging.exception(exc) return server_error_response(exc) @@ -642,6 +907,8 @@ async def update_agent(agent_id, tenant_id): @_require_canvas_access_async async def reset_agent(agent_id, tenant_id): try: + from agent.canvas import Canvas + exists, user_canvas = UserCanvasService.get_by_id(agent_id) if not exists: return get_data_error_result(message="canvas not found.") @@ -670,6 +937,8 @@ async def reset_agent(agent_id, tenant_id): @login_required @add_tenant_id_to_kwargs async def rerun_agent(tenant_id): + from rag.nlp import search + req = await get_request_json() doc = PipelineOperationLogService.get_documents_info(req["id"]) if not doc: @@ -706,52 +975,78 @@ async def rerun_agent(tenant_id): @login_required async def test_db_connection(): req = await get_request_json() + try: + safe_host = assert_host_is_safe(req["host"]) + except ValueError as exc: + logging.warning( + "Rejected test_db_connection: unsafe host %r (db_type=%s, user=%s): %s", + req.get("host"), req.get("db_type"), current_user.id, exc, + ) + return get_data_error_result(message=str(exc)) + except OSError as exc: + logging.warning( + "Rejected test_db_connection: cannot resolve host %r (db_type=%s, user=%s): %s", + req.get("host"), req.get("db_type"), current_user.id, exc, + ) + logging.debug("Full resolver exception for host %r", req.get("host"), exc_info=True) + return get_data_error_result(message=f"Could not resolve host {req.get('host')!r}.") try: if req["db_type"] in ["mysql", "mariadb"]: db = MySQLDatabase( req["database"], user=req["username"], - host=req["host"], + host=safe_host, port=req["port"], password=req["password"], ) + with db.connection_context(): + db.execute_sql("SELECT 1") elif req["db_type"] == "oceanbase": db = MySQLDatabase( req["database"], user=req["username"], - host=req["host"], + host=safe_host, port=req["port"], password=req["password"], charset="utf8mb4", ) + with db.connection_context(): + db.execute_sql("SELECT 1") elif req["db_type"] == "postgres": db = PostgresqlDatabase( req["database"], user=req["username"], - host=req["host"], + host=safe_host, port=req["port"], password=req["password"], ) + with db.connection_context(): + db.execute_sql("SELECT 1") elif req["db_type"] == "mssql": import pyodbc connection_string = ( f"DRIVER={{ODBC Driver 17 for SQL Server}};" - f"SERVER={req['host']},{req['port']};" + f"SERVER={safe_host},{req['port']};" f"DATABASE={req['database']};" f"UID={req['username']};" f"PWD={req['password']};" ) db = pyodbc.connect(connection_string) - cursor = db.cursor() - cursor.execute("SELECT 1") - cursor.close() + try: + cursor = db.cursor() + try: + cursor.execute("SELECT 1") + finally: + cursor.close() + finally: + db.close() elif req["db_type"] == "IBM DB2": import ibm_db conn_str = ( f"DATABASE={req['database']};" - f"HOSTNAME={req['host']};" + f"HOSTNAME={safe_host};" f"PORT={req['port']};" f"PROTOCOL=TCPIP;" f"UID={req['username']};" @@ -760,7 +1055,7 @@ async def test_db_connection(): logging.info( "DATABASE=%s;HOSTNAME=%s;PORT=%s;PROTOCOL=TCPIP;UID=%s;PWD=****;", req["database"], - req["host"], + safe_host, req["port"], req["username"], ) @@ -768,7 +1063,6 @@ async def test_db_connection(): stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1") ibm_db.fetch_assoc(stmt) ibm_db.close(conn) - return get_json_result(data="Database Connection Successful!") elif req["db_type"] == "trino": import os import trino @@ -787,7 +1081,7 @@ async def test_db_connection(): auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"]) conn = trino.dbapi.connect( - host=req["host"], + host=safe_host, port=int(req["port"] or 8080), user=req["username"] or "ragflow", catalog=catalog, @@ -795,18 +1089,18 @@ async def test_db_connection(): http_scheme=http_scheme, auth=auth, ) - cur = conn.cursor() - cur.execute("SELECT 1") - cur.fetchall() - cur.close() - conn.close() - return get_json_result(data="Database Connection Successful!") + try: + cur = conn.cursor() + try: + cur.execute("SELECT 1") + cur.fetchall() + finally: + cur.close() + finally: + conn.close() else: return server_error_response("Unsupported database type.") - if req["db_type"] != "mssql": - db.connect() - db.close() return get_json_result(data="Database Connection Successful!") except Exception as exc: return server_error_response(exc) @@ -846,6 +1140,8 @@ async def agent_chat_completion(tenant_id, agent_id=None): req.pop("agent_id", None) req.pop("openai-compatible", None) session_id = req.get("session_id") + workflow_session = False + workflow_conv = None if session_id: exists, conv = API4ConversationService.get_by_id(session_id) if not exists: @@ -862,6 +1158,9 @@ async def agent_chat_completion(tenant_id, agent_id=None): message="Only authorized users can access this agent session.", code=RetCode.OPERATING_ERROR, ) + workflow_session = getattr(conv, "source", "") == "workflow" + if workflow_session: + workflow_conv = conv.to_dict() if openai_compatible: # OpenAI-compatible mode uses a different wire format, keep it separate from regular agent events. @@ -894,8 +1193,7 @@ async def agent_chat_completion(tenant_id, agent_id=None): return jsonify(response) return None - if not session_id: - # Without session state, run against the runtime replica that tracks draft edits. + if workflow_session: query = req.get("query", "") or req.get("question", "") files = req.get("files", []) inputs = req.get("inputs", {}) @@ -903,6 +1201,64 @@ async def agent_chat_completion(tenant_id, agent_id=None): user_id = str(runtime_user_id) custom_header = req.get("custom_header", "") + _, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id) + if not cvs: + return get_data_error_result(message="canvas not found.") + + if not isinstance(workflow_conv.get("message"), list): + workflow_conv["message"] = [] + if isinstance(workflow_conv.get("reference"), dict): + if "chunks" in workflow_conv["reference"]: + workflow_conv["reference"] = [workflow_conv["reference"]] + else: + workflow_conv["reference"] = [ + value for _, value in sorted(workflow_conv["reference"].items(), key=lambda item: int(item[0])) + ] + elif not isinstance(workflow_conv.get("reference"), list): + workflow_conv["reference"] = [] + workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]] + turn_id = get_uuid() + workflow_conv["message"].append( + { + "role": "user", + "content": query, + "id": turn_id, + "files": files, + "created_at": time.time(), + } + ) + await thread_pool_exec(API4ConversationService.update_by_id, session_id, workflow_conv) + + try: + from agent.canvas import Canvas + + workflow_dsl = workflow_conv.get("dsl", {}) + if isinstance(workflow_dsl, str): + dsl_str = workflow_dsl + else: + dsl_str = json.dumps(workflow_dsl, ensure_ascii=False) + canvas = Canvas(dsl_str, str(tenant_id), canvas_id=agent_id, custom_header=custom_header) + except Exception as exc: + return server_error_response(exc) + + return await _run_workflow_session( + tenant_id=tenant_id, + agent_id=agent_id, + workflow_conv=workflow_conv, + canvas=canvas, + query=query, + files=files, + inputs=inputs, + user_id=user_id, + session_id=session_id, + custom_header=custom_header, + canvas_title=getattr(cvs, "title", ""), + canvas_category=getattr(cvs, "canvas_category", CanvasCategory.Agent), + return_trace=bool(req.get("return_trace", False)), + stream=req.get("stream", True), + ) + + if not session_id: if not UserCanvasService.accessible(agent_id, tenant_id): return get_json_result( data=False, @@ -910,6 +1266,16 @@ async def agent_chat_completion(tenant_id, agent_id=None): code=RetCode.OPERATING_ERROR, ) + # Keep the original workflow execution path, but assign a session_id so the + # response shape stays closer to the older agent completion contract. + query = req.get("query", "") or req.get("question", "") + files = req.get("files", []) + inputs = req.get("inputs", {}) + runtime_user_id = req.get("user_id") or tenant_id + user_id = str(runtime_user_id) + custom_header = req.get("custom_header", "") + session_id = get_uuid() + _, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id) if not cvs: return get_data_error_result(message="canvas not found.") @@ -940,7 +1306,34 @@ async def agent_chat_completion(tenant_id, agent_id=None): dsl_str = json.dumps(replica_dsl, ensure_ascii=False) if cvs.canvas_category == CanvasCategory.DataFlow: + from rag.flow.pipeline import Pipeline + task_id = get_uuid() + workflow_conv = { + "id": session_id, + "dialog_id": cvs.id, + "user_id": user_id, + "exp_user_id": user_id, + "name": req.get("name", ""), + "message": [ + { + "role": "user", + "content": query, + "id": task_id, + "files": files, + "created_at": time.time(), + } + ], + "reference": [], + "source": "workflow", + "dsl": replica_dsl, + "version_title": await thread_pool_exec( + UserCanvasVersionService.get_latest_version_title, + cvs.id, + release_mode=False, + ), + } + await thread_pool_exec(API4ConversationService.save, **workflow_conv) Pipeline( dsl_str, tenant_id=str(tenant_id), @@ -959,94 +1352,57 @@ async def agent_chat_completion(tenant_id, agent_id=None): ) if not ok: return get_data_error_result(message=error_message) - return get_json_result(data={"message_id": task_id}) + return get_json_result(data={"message_id": task_id, "session_id": session_id}) try: + from agent.canvas import Canvas + canvas = Canvas(dsl_str, str(tenant_id), canvas_id=agent_id, custom_header=custom_header) except Exception as exc: return server_error_response(exc) - - async def commit_runtime_replica(): - commit_ok = CanvasReplicaService.commit_after_run( - canvas_id=agent_id, - tenant_id=str(tenant_id), - runtime_user_id=user_id, - dsl=json.loads(str(canvas)), - canvas_category=canvas_category, - title=canvas_title, - ) - if not commit_ok: - logging.error( - "Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s", - agent_id, - tenant_id, - user_id, - ) - - if req.get("stream", True): - async def sse(): - nonlocal canvas - try: - async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): - yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n" - - await commit_runtime_replica() - except Exception as exc: - logging.exception(exc) - canvas.cancel_task() - yield ( - "data:" - + json.dumps({"code": 500, "message": str(exc), "data": False}, ensure_ascii=False) - + "\n\n" - ) - - return _build_sse_response(sse()) - - full_content = "" - reference = {} - final_ans = {} - trace_items = [] - structured_output = {} - try: - async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs): - if ans.get("event") == "message": - full_content += ans.get("data", {}).get("content", "") - if ans.get("data", {}).get("reference", None): - reference.update(ans["data"]["reference"]) - if ans.get("event") == "node_finished": - data = ans.get("data", {}) - node_out = data.get("outputs", {}) - component_id = data.get("component_id") - if component_id is not None and "structured" in node_out: - structured_output[component_id] = copy.deepcopy(node_out["structured"]) - if req.get("return_trace", False): - trace_items.append( - { - "component_id": data.get("component_id"), - "trace": [copy.deepcopy(data)], - } - ) - final_ans = ans - except Exception as exc: - logging.exception(exc) - canvas.cancel_task() - return get_result(data=f"**ERROR**: {str(exc)}") - - if not final_ans: - await commit_runtime_replica() - return get_result(data={}) - - if "data" not in final_ans or not isinstance(final_ans["data"], dict): - final_ans["data"] = {} - final_ans["data"]["content"] = full_content - final_ans["data"]["reference"] = reference - if structured_output: - final_ans["data"]["structured"] = structured_output - if trace_items: - final_ans["data"]["trace"] = trace_items - - await commit_runtime_replica() - return get_result(data=final_ans) + turn_id = get_uuid() + workflow_conv = { + "id": session_id, + "dialog_id": cvs.id, + "user_id": user_id, + "exp_user_id": user_id, + "name": req.get("name", ""), + "message": [ + { + "role": "user", + "content": query, + "id": turn_id, + "files": files, + "created_at": time.time(), + } + ], + "reference": [], + "source": "workflow", + "dsl": replica_dsl, + "version_title": await thread_pool_exec( + UserCanvasVersionService.get_latest_version_title, + cvs.id, + release_mode=False, + ), + } + workflow_conv["reference"] = [_normalize_agent_reference_entry(reference) for reference in workflow_conv["reference"]] + await thread_pool_exec(API4ConversationService.save, **workflow_conv) + return await _run_workflow_session( + tenant_id=tenant_id, + agent_id=agent_id, + workflow_conv=workflow_conv, + canvas=canvas, + query=query, + files=files, + inputs=inputs, + user_id=user_id, + session_id=session_id, + custom_header=custom_header, + canvas_title=canvas_title, + canvas_category=canvas_category, + return_trace=bool(req.get("return_trace", False)), + stream=req.get("stream", True), + ) return_trace = bool(req.get("return_trace", False)) if req.get("stream", True): @@ -1247,6 +1603,8 @@ def _validate_rate_limit(security_cfg): now = time.time() try: + from rag.utils.redis_conn import REDIS_CONN + res = REDIS_CONN.lua_token_bucket( keys=[key], args=[capacity, rate, now, cost], @@ -1354,6 +1712,8 @@ def _validate_jwt_auth(security_cfg): if not isinstance(cvs.dsl, str): dsl = json.dumps(cvs.dsl, ensure_ascii=False) try: + from agent.canvas import Canvas + canvas = Canvas(dsl, cvs.user_id, agent_id, canvas_id=agent_id) except Exception as e: resp=get_data_error_result(code=RetCode.BAD_REQUEST,message=str(e)) @@ -1607,6 +1967,8 @@ def validate_type(value, t): response_cfg = webhook_cfg.get("response", {}) def append_webhook_trace(agent_id: str, start_ts: float,event: dict, ttl=600): + from rag.utils.redis_conn import REDIS_CONN + key = f"webhook-trace-{agent_id}-logs" raw = REDIS_CONN.get(key) @@ -1806,6 +2168,8 @@ def decode_webhook_id(enc_id: str, webhooks: dict) -> str | None: webhook_id = request.args.get("webhook_id") key = f"webhook-trace-{agent_id}-logs" + from rag.utils.redis_conn import REDIS_CONN + raw = REDIS_CONN.get(key) if since_ts is None: diff --git a/api/apps/restful_apis/chat_api.py b/api/apps/restful_apis/chat_api.py index fab74f5c62a..baba98ae288 100644 --- a/api/apps/restful_apis/chat_api.py +++ b/api/apps/restful_apis/chat_api.py @@ -47,7 +47,8 @@ ) from api.utils.tenant_utils import ensure_tenant_model_id_for_params from common.constants import LLMType, RetCode, StatusEnum -from common.misc_utils import get_uuid +from common import settings +from common.misc_utils import get_uuid, thread_pool_exec from rag.prompts.generator import chunks_format from rag.prompts.template import load_prompt @@ -128,8 +129,9 @@ def _build_session_response(conv: dict) -> dict: return conv -def _ensure_owned_chat(chat_id): - return DialogService.query( +async def _ensure_owned_chat(chat_id): + return await thread_pool_exec( + DialogService.query, tenant_id=current_user.id, id=chat_id, status=StatusEnum.VALID.value ) @@ -151,7 +153,7 @@ def _build_default_completion_dialog(): ) -def _create_session_for_completion(chat_id, dialog, user_id): +async def _create_session_for_completion(chat_id, dialog, user_id): conv = { "id": get_uuid(), "dialog_id": chat_id, @@ -160,14 +162,14 @@ def _create_session_for_completion(chat_id, dialog, user_id): "user_id": user_id, "reference": [], } - ConversationService.save(**conv) - ok, conv_obj = ConversationService.get_by_id(conv["id"]) + await thread_pool_exec(ConversationService.save, **conv) + ok, conv_obj = await thread_pool_exec(ConversationService.get_by_id, conv["id"]) if not ok: raise LookupError("Fail to create a session!") return conv_obj -def _validate_llm_id(llm_id, tenant_id, llm_setting=None): +async def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if not llm_id: return None @@ -176,7 +178,8 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None): if model_type not in {"chat", "image2text"}: model_type = "chat" - if not TenantLLMService.query( + if not await thread_pool_exec( + TenantLLMService.query, tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, @@ -186,13 +189,14 @@ def _validate_llm_id(llm_id, tenant_id, llm_setting=None): return None -def _validate_rerank_id(rerank_id, tenant_id): +async def _validate_rerank_id(rerank_id, tenant_id): if not rerank_id: return None llm_name, llm_factory = TenantLLMService.split_model_name_and_factory(rerank_id) if llm_name in _DEFAULT_RERANK_MODELS: return None - if TenantLLMService.query( + if await thread_pool_exec( + TenantLLMService.query, tenant_id=tenant_id, llm_name=llm_name, llm_factory=llm_factory, @@ -211,7 +215,7 @@ def _validate_rerank_id(rerank_id, tenant_id): # return None -def _validate_dataset_ids(dataset_ids, tenant_id): +async def _validate_dataset_ids(dataset_ids, tenant_id): if dataset_ids is None: return [] if not isinstance(dataset_ids, list): @@ -220,9 +224,9 @@ def _validate_dataset_ids(dataset_ids, tenant_id): normalized_ids = [dataset_id for dataset_id in dataset_ids if dataset_id] kbs = [] for dataset_id in normalized_ids: - if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): + if not await thread_pool_exec(KnowledgebaseService.accessible, kb_id=dataset_id, user_id=tenant_id): return f"You don't own the dataset {dataset_id}" - matches = KnowledgebaseService.query(id=dataset_id) + matches = await thread_pool_exec(KnowledgebaseService.query, id=dataset_id) if not matches: return f"You don't own the dataset {dataset_id}" kb = matches[0] @@ -268,19 +272,19 @@ async def create(): req["name"] = name if "dataset_ids" in req: - kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id) + kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id) if isinstance(kb_ids, str): return get_data_error_result(message=kb_ids) req["kb_ids"] = kb_ids req.pop("dataset_ids", None) if "llm_id" in req: - err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) + err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) if err: return get_data_error_result(message=err) if "rerank_id" in req: - err = _validate_rerank_id(req.get("rerank_id"), current_user.id) + err = await _validate_rerank_id(req.get("rerank_id"), current_user.id) if err: return get_data_error_result(message=err) @@ -335,7 +339,7 @@ async def create(): @manager.route("/chats", methods=["GET"]) # noqa: F821 @login_required -def list_chats(): +async def list_chats(): chat_id = request.args.get("id") name = request.args.get("name") keywords = request.args.get("keywords", "") @@ -350,19 +354,32 @@ def list_chats(): page_number = int(request.args.get("page", 0)) items_per_page = int(request.args.get("page_size", 0)) + tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) + authorized_owner_ids = {member["tenant_id"] for member in tenants} + authorized_owner_ids.add(current_user.id) + if owner_ids: - chats, total = DialogService.get_by_tenant_ids( - owner_ids, current_user.id, 0, 0, orderby, desc, keywords, **exact_filters - ) - chats = [chat for chat in chats if chat["tenant_id"] in owner_ids] - total = len(chats) - if page_number and items_per_page: - start = (page_number - 1) * items_per_page - chats = chats[start : start + items_per_page] + requested_owner_ids = set(owner_ids) + unauthorized_owner_ids = requested_owner_ids - authorized_owner_ids + if unauthorized_owner_ids: + logging.warning( + "Rejected list_chats request: user=%s attempted unauthorized owner_ids=%s", + current_user.id, + sorted(unauthorized_owner_ids), + ) + return get_json_result( + data=False, + message="Only authorized owner_ids can be queried.", + code=RetCode.OPERATING_ERROR, + ) + effective_owner_ids = list(requested_owner_ids) else: - chats, total = DialogService.get_by_tenant_ids( - [], current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters - ) + effective_owner_ids = list(authorized_owner_ids) + + chats, total = await thread_pool_exec( + DialogService.get_by_tenant_ids, + effective_owner_ids, current_user.id, page_number, items_per_page, orderby, desc, keywords, **exact_filters, + ) return get_json_result( data={"chats": [_build_chat_response(chat) for chat in chats], "total": total} @@ -373,12 +390,13 @@ def list_chats(): @manager.route("/chats/", methods=["GET"]) # noqa: F821 @login_required -def get_chat(chat_id): +async def get_chat(chat_id): try: - tenants = UserTenantService.query(user_id=current_user.id) + tenants = await thread_pool_exec(UserTenantService.query, user_id=current_user.id) for tenant in tenants: - if DialogService.query( - tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value + if await thread_pool_exec( + DialogService.query, + tenant_id=tenant.tenant_id, id=chat_id, status=StatusEnum.VALID.value, ): break else: @@ -388,7 +406,7 @@ def get_chat(chat_id): code=RetCode.AUTHENTICATION_ERROR, ) - ok, chat = DialogService.get_by_id(chat_id) + ok, chat = await thread_pool_exec(DialogService.get_by_id, chat_id) if not ok: return get_data_error_result(message="Chat not found!") return get_json_result(data=_build_chat_response(chat)) @@ -399,7 +417,7 @@ def get_chat(chat_id): @manager.route("/chats/", methods=["PUT"]) # noqa: F821 @login_required async def update_chat(chat_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR ) @@ -425,19 +443,19 @@ async def update_chat(chat_id): req["name"] = name if "dataset_ids" in req: - kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id) + kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id) if isinstance(kb_ids, str): return get_data_error_result(message=kb_ids) req["kb_ids"] = kb_ids req.pop("dataset_ids", None) if "llm_id" in req: - err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) + err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) if err: return get_data_error_result(message=err) if "rerank_id" in req: - err = _validate_rerank_id(req.get("rerank_id"), current_user.id) + err = await _validate_rerank_id(req.get("rerank_id"), current_user.id) if err: return get_data_error_result(message=err) @@ -485,7 +503,7 @@ async def update_chat(chat_id): @manager.route("/chats/", methods=["PATCH"]) # noqa: F821 @login_required async def patch_chat(chat_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR ) @@ -509,19 +527,19 @@ async def patch_chat(chat_id): req["name"] = name if "dataset_ids" in req: - kb_ids = _validate_dataset_ids(req.get("dataset_ids"), current_user.id) + kb_ids = await _validate_dataset_ids(req.get("dataset_ids"), current_user.id) if isinstance(kb_ids, str): return get_data_error_result(message=kb_ids) req["kb_ids"] = kb_ids req.pop("dataset_ids", None) if "llm_id" in req: - err = _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) + err = await _validate_llm_id(req.get("llm_id"), current_user.id, req.get("llm_setting")) if err: return get_data_error_result(message=err) if "rerank_id" in req: - err = _validate_rerank_id(req.get("rerank_id"), current_user.id) + err = await _validate_rerank_id(req.get("rerank_id"), current_user.id) if err: return get_data_error_result(message=err) @@ -575,8 +593,8 @@ async def patch_chat(chat_id): @manager.route("/chats/", methods=["DELETE"]) # noqa: F821 @login_required -def delete_chat(chat_id): - if not _ensure_owned_chat(chat_id): +async def delete_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR ) @@ -624,7 +642,7 @@ async def bulk_delete_chats(): unique_ids, duplicate_messages = check_duplicate_ids(ids, "chat") for chat_id in unique_ids: - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): errors.append(f"Chat({chat_id}) not found.") continue success_count += DialogService.update_by_id(chat_id, {"status": StatusEnum.INVALID.value}) @@ -644,7 +662,7 @@ async def bulk_delete_chats(): @manager.route("/chats//sessions", methods=["POST"]) # noqa: F821 @login_required async def create_session(chat_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() @@ -674,9 +692,9 @@ async def create_session(chat_id): @manager.route("/chats//sessions", methods=["GET"]) # noqa: F821 @login_required -def list_sessions(chat_id): +async def list_sessions(chat_id): try: - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", @@ -702,15 +720,15 @@ def list_sessions(chat_id): @manager.route("/chats//sessions/", methods=["GET"]) # noqa: F821 @login_required async def get_session(chat_id, session_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: - ok, conv = ConversationService.get_by_id(session_id) + ok, conv = await thread_pool_exec(ConversationService.get_by_id, session_id) if not ok: return get_data_error_result(message="Session not found!") if conv.dialog_id != chat_id: return get_data_error_result(message="Session does not belong to this chat!") - dialog = _ensure_owned_chat(chat_id) + dialog = await _ensure_owned_chat(chat_id) avatar = dialog[0].icon if dialog else "" for ref in conv.reference: if isinstance(ref, list): @@ -726,7 +744,7 @@ async def get_session(chat_id, session_id): @manager.route("/chats//sessions/", methods=["PATCH"]) # noqa: F821 @login_required async def update_session(chat_id, session_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() @@ -755,7 +773,7 @@ async def update_session(chat_id, session_id): @manager.route("/chats//sessions", methods=["DELETE"]) # noqa: F821 @login_required async def delete_sessions(chat_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: req = await get_request_json() @@ -777,6 +795,17 @@ async def delete_sessions(chat_id): if not ConversationService.query(id=sid, dialog_id=chat_id): errors.append(f"The chat doesn't own the session {sid}") continue + ok, conv = ConversationService.get_by_id(sid) + if ok: + for msg in conv.message or []: + for file in msg.get("files") or []: + file_id = file.get("id") + if not file_id: + continue + try: + settings.STORAGE_IMPL.rm(f"{current_user.id}-downloads", file_id) + except Exception: + logging.warning("Failed to delete chat upload blob %s/%s", current_user.id, file_id) ConversationService.delete_by_id(sid) success_count += 1 all_errors = errors + duplicate_messages @@ -795,7 +824,7 @@ async def delete_sessions(chat_id): @manager.route("/chats//sessions//messages/", methods=["DELETE"]) # noqa: F821 @login_required async def delete_session_message(chat_id, session_id, msg_id): - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: ok, conv = ConversationService.get_by_id(session_id) @@ -819,7 +848,7 @@ async def delete_session_message(chat_id, session_id, msg_id): @manager.route("/chats//sessions//messages//feedback", methods=["PUT"]) # noqa: F821 @login_required async def update_message_feedback(chat_id, session_id, msg_id): - owned = _ensure_owned_chat(chat_id) + owned = await _ensure_owned_chat(chat_id) if not owned: return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) try: @@ -857,12 +886,14 @@ async def update_message_feedback(chat_id, session_id, msg_id): reference = conv_dict["reference"][ref_index] if reference: if isinstance(prior_thumb, bool) and prior_thumb != thumb_raw: - ChunkFeedbackService.apply_feedback( + await thread_pool_exec( + ChunkFeedbackService.apply_feedback, tenant_id=current_user.id, reference=reference, is_positive=not prior_thumb, ) - feedback_result = ChunkFeedbackService.apply_feedback( + feedback_result = await thread_pool_exec( + ChunkFeedbackService.apply_feedback, tenant_id=current_user.id, reference=reference, is_positive=thumb_raw is True, @@ -875,7 +906,7 @@ async def update_message_feedback(chat_id, session_id, msg_id): except Exception as e: logging.warning("Failed to apply chunk feedback: %s", e) - ConversationService.update_by_id(conv_dict["id"], conv_dict) + await thread_pool_exec(ConversationService.update_by_id, conv_dict["id"], conv_dict) return get_json_result(data=_build_session_response(conv_dict)) except Exception as ex: return server_error_response(ex) @@ -1053,23 +1084,23 @@ async def session_completion(chat_id_in_arg=""): return get_data_error_result(message="`chat_id` is required when `session_id` is provided.") if chat_id: - if not _ensure_owned_chat(chat_id): + if not await _ensure_owned_chat(chat_id): return get_json_result( data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR, ) - e, dia = DialogService.get_by_id(chat_id) + e, dia = await thread_pool_exec(DialogService.get_by_id, chat_id) if not e: return get_data_error_result(message="Chat not found!") if session_id: - e, conv = ConversationService.get_by_id(session_id) + e, conv = await thread_pool_exec(ConversationService.get_by_id, session_id) if not e: return get_data_error_result(message="Session not found!") if conv.dialog_id != chat_id: return get_data_error_result(message="Session does not belong to this chat!") else: - conv = _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id)) + conv = await _create_session_for_completion(chat_id, dia, req.get("user_id", current_user.id)) session_id = conv.id conv.message = deepcopy(req["messages"]) else: @@ -1085,7 +1116,7 @@ async def session_completion(chat_id_in_arg=""): conv.reference.append({"chunks": [], "doc_aggs": []}) if chat_model_id: - if not TenantLLMService.get_api_key(tenant_id=dia.tenant_id, model_name=chat_model_id): + if not await thread_pool_exec(TenantLLMService.get_api_key, tenant_id=dia.tenant_id, model_name=chat_model_id): return get_data_error_result(message=f"Cannot use specified model {chat_model_id}.") dia.llm_id = chat_model_id dia.llm_setting = chat_model_config @@ -1105,7 +1136,7 @@ async def stream(): ans = _format_answer(ans) yield "data:" + json.dumps({"code": 0, "message": "", "data": ans}, ensure_ascii=False) + "\n\n" if conv is not None: - ConversationService.update_by_id(conv.id, conv.to_dict()) + await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict()) except Exception as ex: logging.exception(ex) yield "data:" + json.dumps({"code": 500, "message": str(ex), "data": {"answer": "**ERROR**: " + str(ex), "reference": []}}, ensure_ascii=False) + "\n\n" @@ -1123,7 +1154,7 @@ async def stream(): async for ans in async_chat(dia, msg, **req): answer = _format_answer(ans) if conv is not None: - ConversationService.update_by_id(conv.id, conv.to_dict()) + await thread_pool_exec(ConversationService.update_by_id, conv.id, conv.to_dict()) break return get_json_result(data=answer) except Exception as ex: diff --git a/api/apps/restful_apis/chunk_api.py b/api/apps/restful_apis/chunk_api.py index 13b5cb5801e..fe45209dd01 100644 --- a/api/apps/restful_apis/chunk_api.py +++ b/api/apps/restful_apis/chunk_api.py @@ -43,8 +43,6 @@ from common.misc_utils import thread_pool_exec from common.string_utils import is_content_empty, remove_redundant_spaces from common.tag_feature_utils import validate_tag_features -from rag.app.qa import beAdoc, rmPrefix -from rag.nlp import rag_tokenizer, search class Chunk(BaseModel): @@ -96,12 +94,24 @@ def _strip_chunk_runtime_fields(chunk): return chunk +def _get_dataset_tenant_id(dataset_id): + ok, kb = KnowledgebaseService.get_by_id(dataset_id) + if not ok: + return None + return kb.tenant_id + + @manager.route("/datasets//documents//chunks", methods=["GET"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs async def list_chunks(tenant_id, dataset_id, document_id): + from rag.nlp import search + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + dataset_tenant_id = _get_dataset_tenant_id(dataset_id) + if not dataset_tenant_id: + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") @@ -122,7 +132,7 @@ async def list_chunks(tenant_id, dataset_id, document_id): res = {"total": 0, "chunks": [], "doc": _map_doc(doc)} if req.get("id"): - chunk = settings.docStoreConn.get(req.get("id"), search.index_name(tenant_id), [dataset_id]) + chunk = settings.docStoreConn.get(req.get("id"), search.index_name(dataset_tenant_id), [dataset_id]) if not chunk: return get_result(message=f"Chunk not found: {dataset_id}/{req.get('id')}", code=RetCode.DATA_ERROR) if str(chunk.get("doc_id", chunk.get("document_id"))) != str(document_id): @@ -145,10 +155,10 @@ async def list_chunks(tenant_id, dataset_id, document_id): } res["chunks"].append(final_chunk) _ = Chunk(**final_chunk) - elif settings.docStoreConn.index_exist(search.index_name(tenant_id), dataset_id): + elif settings.docStoreConn.index_exist(search.index_name(dataset_tenant_id), dataset_id): sres = await settings.retriever.search( query, - search.index_name(tenant_id), + search.index_name(dataset_tenant_id), [dataset_id], emb_mdl=None, highlight=True, @@ -181,13 +191,18 @@ async def list_chunks(tenant_id, dataset_id, document_id): @login_required @add_tenant_id_to_kwargs async def get_chunk(tenant_id, dataset_id, document_id, chunk_id): + from rag.nlp import search + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + dataset_tenant_id = _get_dataset_tenant_id(dataset_id) + if not dataset_tenant_id: + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") try: - chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) + chunk = settings.docStoreConn.get(chunk_id, search.index_name(dataset_tenant_id), [dataset_id]) if chunk is None or str(chunk.get("doc_id", chunk.get("document_id"))) != str(document_id): return get_result(data=False, message="Chunk not found!", code=RetCode.DATA_ERROR) return get_result(data=_strip_chunk_runtime_fields(chunk)) @@ -201,8 +216,13 @@ async def get_chunk(tenant_id, dataset_id, document_id, chunk_id): @login_required @add_tenant_id_to_kwargs async def add_chunk(tenant_id, dataset_id, document_id): + from rag.nlp import rag_tokenizer, search + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + dataset_tenant_id = _get_dataset_tenant_id(dataset_id) + if not dataset_tenant_id: + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") @@ -254,12 +274,12 @@ async def add_chunk(tenant_id, dataset_id, document_id): model_config = get_model_config_by_id(tenant_embd_id) else: embd_id = DocumentService.get_embd_id(document_id) - model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id) + model_config = get_model_config_by_type_and_name(dataset_tenant_id, LLMType.EMBEDDING.value, embd_id) embd_mdl = TenantLLMService.model_instance(model_config) v, c = embd_mdl.encode([doc.name, req["content"] if not d["question_kwd"] else "\n".join(d["question_kwd"])]) v = 0.1 * v[0] + 0.9 * v[1] d[f"q_{len(v)}_vec"] = v.tolist() - settings.docStoreConn.insert([d], search.index_name(tenant_id), dataset_id) + settings.docStoreConn.insert([d], search.index_name(dataset_tenant_id), dataset_id) if image_base64: store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64)) @@ -287,8 +307,13 @@ async def add_chunk(tenant_id, dataset_id, document_id): @login_required @add_tenant_id_to_kwargs async def rm_chunk(tenant_id, dataset_id, document_id): + from rag.nlp import search + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + dataset_tenant_id = _get_dataset_tenant_id(dataset_id) + if not dataset_tenant_id: + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") docs = DocumentService.query(id=document_id, kb_id=dataset_id) if not docs: return get_error_data_result(message=f"You don't own the document {document_id}.") @@ -300,8 +325,8 @@ async def rm_chunk(tenant_id, dataset_id, document_id): if not chunk_ids: if req.get("delete_all") is True: doc = docs[0] - DocumentService.delete_chunk_images(doc, tenant_id) - chunk_number = settings.docStoreConn.delete({"doc_id": document_id}, search.index_name(tenant_id), dataset_id) + DocumentService.delete_chunk_images(doc, dataset_tenant_id) + chunk_number = settings.docStoreConn.delete({"doc_id": document_id}, search.index_name(dataset_tenant_id), dataset_id) if chunk_number != 0: DocumentService.decrement_chunk_num(document_id, dataset_id, 1, chunk_number, 0) return get_result(message=f"deleted {chunk_number} chunks") @@ -310,7 +335,7 @@ async def rm_chunk(tenant_id, dataset_id, document_id): unique_chunk_ids, duplicate_messages = check_duplicate_ids(chunk_ids, "chunk") chunk_number = settings.docStoreConn.delete( {"doc_id": document_id, "id": unique_chunk_ids}, - search.index_name(tenant_id), + search.index_name(dataset_tenant_id), dataset_id, ) if chunk_number != 0: @@ -331,13 +356,19 @@ async def rm_chunk(tenant_id, dataset_id, document_id): @login_required @add_tenant_id_to_kwargs async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): + from rag.app.qa import beAdoc, rmPrefix + from rag.nlp import rag_tokenizer, search + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + dataset_tenant_id = _get_dataset_tenant_id(dataset_id) + if not dataset_tenant_id: + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") doc = DocumentService.query(id=document_id, kb_id=dataset_id) if not doc: return get_error_data_result(message=f"You don't own the document {document_id}.") doc = doc[0] - chunk = settings.docStoreConn.get(chunk_id, search.index_name(tenant_id), [dataset_id]) + chunk = settings.docStoreConn.get(chunk_id, search.index_name(dataset_tenant_id), [dataset_id]) if chunk is None or str(chunk.get("doc_id", chunk.get("document_id"))) != str(document_id): return get_error_data_result(f"Can't find this chunk {chunk_id}") req = await get_request_json() @@ -387,7 +418,7 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): model_config = get_model_config_by_id(tenant_embd_id) else: embd_id = DocumentService.get_embd_id(document_id) - model_config = get_model_config_by_type_and_name(tenant_id, LLMType.EMBEDDING.value, embd_id) + model_config = get_model_config_by_type_and_name(dataset_tenant_id, LLMType.EMBEDDING.value, embd_id) embd_mdl = TenantLLMService.model_instance(model_config) if doc.parser_id == ParserType.QA: arr = [t for t in re.split(r"[\n\t]", d["content_with_weight"]) if len(t) > 1] @@ -404,7 +435,7 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): ) v = 0.1 * v[0] + 0.9 * v[1] if doc.parser_id != ParserType.QA else v[1] d[f"q_{len(v)}_vec"] = v.tolist() - settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(tenant_id), dataset_id) + settings.docStoreConn.update({"id": chunk_id}, d, search.index_name(dataset_tenant_id), dataset_id) if image_base64: store_chunk_image(dataset_id, chunk_id, base64.b64decode(image_base64)) return get_result() @@ -414,8 +445,13 @@ async def update_chunk(tenant_id, dataset_id, document_id, chunk_id): @login_required @add_tenant_id_to_kwargs async def switch_chunks(tenant_id, dataset_id, document_id): + from rag.nlp import search + if not KnowledgebaseService.accessible(kb_id=dataset_id, user_id=tenant_id): return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") + dataset_tenant_id = _get_dataset_tenant_id(dataset_id) + if not dataset_tenant_id: + return get_error_data_result(message=f"You don't own the dataset {dataset_id}.") req = await get_request_json() if not req.get("chunk_ids"): return get_error_data_result(message="`chunk_ids` is required.") @@ -434,7 +470,7 @@ def _switch_sync(): if not settings.docStoreConn.update( {"id": cid}, {"available_int": available_int}, - search.index_name(tenant_id), + search.index_name(dataset_tenant_id), doc.kb_id, ): return get_error_data_result(message="Index updating failure") diff --git a/api/apps/restful_apis/connector_api.py b/api/apps/restful_apis/connector_api.py index 99a58930211..89287a706d0 100644 --- a/api/apps/restful_apis/connector_api.py +++ b/api/apps/restful_apis/connector_api.py @@ -35,21 +35,52 @@ from api.apps import login_required, current_user from box_sdk_gen import BoxOAuth, OAuthConfig, GetAuthorizeUrlOptions + +LOGGER = logging.getLogger(__name__) + + +def _connector_auth_error(connector_id: str, user_id: str): + """Return the connector authorization failure response and log the denial.""" + LOGGER.warning("connector access denied: connector_id=%s user_id=%s", connector_id, user_id) + return get_json_result(data=False, message="No authorization.", code=RetCode.AUTHENTICATION_ERROR) + + @manager.route("/connectors/", methods=["PATCH"]) # noqa: F821 @login_required async def update_connector(connector_id): + """Update an accessible connector's polling configuration.""" + if not ConnectorService.accessible(connector_id, current_user.id): + return _connector_auth_error(connector_id, current_user.id) + req = await get_request_json() + if isinstance(req, dict) and isinstance(req.get("data"), dict): + req = req["data"] + e, conn = ConnectorService.get_by_id(connector_id) if not e: return get_data_error_result(message="Can't find this Connector!") + should_sleep = False if req: - conn = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} - conn["id"] = connector_id - ConnectorService.update_by_id(connector_id, conn) - - await asyncio.sleep(1) + update_fields = {fld: req[fld] for fld in ["prune_freq", "refresh_freq", "config", "timeout_secs"] if fld in req} + if update_fields: + update_fields["id"] = connector_id + ConnectorService.update_by_id(connector_id, update_fields) + should_sleep = True + + if req.get("reschedule"): + ConnectorService.cancel_tasks(connector_id) + ConnectorService.schedule_tasks(connector_id) + elif req.get("status") in [TaskStatus.CANCEL, "CANCEL"]: + ConnectorService.cancel_tasks(connector_id) + elif req.get("status") in [TaskStatus.SCHEDULE, "SCHEDULE"]: + ConnectorService.schedule_tasks(connector_id) + + if should_sleep: + await asyncio.sleep(1) e, conn = ConnectorService.get_by_id(connector_id) + if not e: + return get_data_error_result(message="Can't find this Connector!") return get_json_result(data=conn.to_dict()) @@ -57,6 +88,7 @@ async def update_connector(connector_id): @manager.route("/connectors", methods=["POST"]) # noqa: F821 @login_required async def create_connector(): + """Create a connector owned by the current tenant.""" req = await get_request_json() if req: req["id"] = get_uuid() @@ -68,9 +100,9 @@ async def create_connector(): "input_type": InputType.POLL, "config": req["config"], "refresh_freq": int(req.get("refresh_freq", 5)), - "prune_freq": int(req.get("prune_freq", 720)), + "prune_freq": int(req.get("prune_freq", 5)), "timeout_secs": int(req.get("timeout_secs", 60 * 29)), - "status": TaskStatus.SCHEDULE, + "status": TaskStatus.UNSTART, } ConnectorService.save(**conn) @@ -83,12 +115,17 @@ async def create_connector(): @manager.route("/connectors", methods=["GET"]) # noqa: F821 @login_required def list_connector(): + """List connectors owned by the current tenant.""" return get_json_result(data=ConnectorService.list(current_user.id)) @manager.route("/connectors/", methods=["GET"]) # noqa: F821 @login_required def get_connector(connector_id): + """Return connector details when the current user can access it.""" + if not ConnectorService.accessible(connector_id, current_user.id): + return _connector_auth_error(connector_id, current_user.id) + e, conn = ConnectorService.get_by_id(connector_id) if not e: return get_data_error_result(message="Can't find this Connector!") @@ -98,27 +135,26 @@ def get_connector(connector_id): @manager.route("/connectors//logs", methods=["GET"]) # noqa: F821 @login_required def list_logs(connector_id): + """List sync logs for a connector the current user can access.""" + if not ConnectorService.accessible(connector_id, current_user.id): + return _connector_auth_error(connector_id, current_user.id) + req = request.args.to_dict(flat=True) arr, total = SyncLogsService.list_sync_tasks(connector_id, int(req.get("page", 1)), int(req.get("page_size", 15))) return get_json_result(data={"total": total, "logs": arr}) -@manager.route("/connectors//resume", methods=["POST"]) # noqa: F821 -@login_required -async def resume(connector_id): - req = await get_request_json() - if req.get("resume"): - ConnectorService.resume(connector_id, TaskStatus.SCHEDULE) - else: - ConnectorService.resume(connector_id, TaskStatus.CANCEL) - return get_json_result(data=True) - - @manager.route("/connectors//rebuild", methods=["POST"]) # noqa: F821 @login_required -@validate_request("kb_id") async def rebuild(connector_id): + """Schedule a rebuild for an accessible connector and knowledge base.""" + if not ConnectorService.accessible(connector_id, current_user.id): + return _connector_auth_error(connector_id, current_user.id) + req = await get_request_json() + if "kb_id" not in req: + return get_json_result(code=RetCode.ARGUMENT_ERROR, message="required argument is missing: kb_id") + err = ConnectorService.rebuild(req["kb_id"], connector_id, current_user.id) if err: return get_json_result(data=False, message=err, code=RetCode.SERVER_ERROR) @@ -128,11 +164,66 @@ async def rebuild(connector_id): @manager.route("/connectors/", methods=["DELETE"]) # noqa: F821 @login_required def rm_connector(connector_id): - ConnectorService.resume(connector_id, TaskStatus.CANCEL) + """Delete an accessible connector after canceling its sync tasks.""" + if not ConnectorService.accessible(connector_id, current_user.id): + return _connector_auth_error(connector_id, current_user.id) + + ConnectorService.cancel_tasks(connector_id) ConnectorService.delete_by_id(connector_id) return get_json_result(data=True) +@manager.route("/connectors//test", methods=["POST"]) # noqa: F821 +@login_required +async def test_connector(connector_id): + """Validate connector configuration without persisting changes or triggering sync. + + For the REST API connector, this uses `RestAPIConnector.validate_config` + against the existing saved configuration. + """ + if not ConnectorService.accessible(connector_id, current_user.id): + return _connector_auth_error(connector_id, current_user.id) + + from common.data_source.rest_api_connector import RestAPIConnector + from common.data_source.exceptions import ConnectorMissingCredentialError, ConnectorValidationError + + ok, conn = ConnectorService.get_by_id(connector_id) + if not ok: + return get_data_error_result(message="Can't find this Connector!") + + if conn.source != DocumentSource.REST_API: + return get_json_result( + code=RetCode.ARGUMENT_ERROR, + message="Test endpoint currently supports only REST API connectors.", + data=False, + ) + + config = conn.config or {} + credentials = config.get("credentials") or {} + + try: + await asyncio.to_thread( + RestAPIConnector.validate_config, + config=config, + credentials=credentials, + ) + except (ConnectorValidationError, ConnectorMissingCredentialError) as exc: + return get_json_result( + code=RetCode.DATA_ERROR, + message=str(exc), + data=False, + ) + except Exception as exc: + logging.exception("REST API connector validation failed: %s", exc) + return get_json_result( + code=RetCode.SERVER_ERROR, + message="REST API connector validation failed, please check logs.", + data=False, + ) + + return get_json_result(data=True) + + WEB_FLOW_TTL_SECS = 15 * 60 diff --git a/api/apps/restful_apis/dataset_api.py b/api/apps/restful_apis/dataset_api.py index 55ded90e028..df1862592cf 100644 --- a/api/apps/restful_apis/dataset_api.py +++ b/api/apps/restful_apis/dataset_api.py @@ -559,21 +559,6 @@ async def get_knowledge_graph(tenant_id, dataset_id): return get_error_data_result(message="Internal server error") -@manager.route("/datasets//graph", methods=["DELETE"]) # noqa: F821 -@login_required -@add_tenant_id_to_kwargs -def delete_knowledge_graph(tenant_id, dataset_id): - try: - success, result = dataset_api_service.delete_knowledge_graph(dataset_id, tenant_id) - if success: - return get_result(data=result) - else: - return get_result(data=False, message=result, code=RetCode.AUTHENTICATION_ERROR) - except Exception as e: - logging.exception(e) - return get_error_data_result(message="Internal server error") - - @manager.route("/datasets//index", methods=["POST"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs @@ -613,14 +598,15 @@ def trace_index(tenant_id, dataset_id): @manager.route("/datasets//", methods=["DELETE"]) # noqa: F821 +@manager.route("/datasets//index", methods=["DELETE"]) # noqa: F821 @login_required @add_tenant_id_to_kwargs -def delete_index(tenant_id, dataset_id, index_type): - index_type = index_type.lower() +def delete_index(tenant_id, dataset_id, index_type=None): + index_type = (index_type or request.args.get("type", "")).lower() if index_type not in dataset_api_service._VALID_INDEX_TYPES: return get_error_argument_result(f"Invalid index type '{index_type}'") # `wipe` controls whether the persisted index artefacts (graph rows / - # raptor summaries) are removed. Default true preserves historical + # raptor summaries) are removed. Default true preserves historical # behaviour; pass wipe=false to cancel the running task while keeping # prior progress so it can be resumed later. wipe_arg = (request.args.get("wipe", "true") or "true").strip().lower() diff --git a/api/apps/restful_apis/document_api.py b/api/apps/restful_apis/document_api.py index 7300a55a9f7..57215080d35 100644 --- a/api/apps/restful_apis/document_api.py +++ b/api/apps/restful_apis/document_api.py @@ -995,7 +995,7 @@ def _parse_doc_id_filter_with_metadata(req, kb_id): if not doc_ids_filter: return RetCode.SUCCESS, "", [], return_empty_metadata - return RetCode.SUCCESS, "", list(doc_ids_filter) if doc_ids_filter is not None else [], return_empty_metadata + return RetCode.SUCCESS, "", list(doc_ids_filter) if doc_ids_filter is not None else None, return_empty_metadata @manager.route("/datasets//documents", methods=["DELETE"]) # noqa: F821 @@ -1881,8 +1881,6 @@ async def download_attachment(tenant_id=None, doc_id=None, attachment_id=None): # Keep backward compatibility with older callers and unit tests that still # pass `attachment_id` instead of the route parameter name. doc_id = doc_id or attachment_id - if not DocumentService.accessible(doc_id, current_user.id): - return get_data_error_result(message="Document not found!") ext = request.args.get("ext", "markdown") data = await thread_pool_exec(settings.STORAGE_IMPL.get, tenant_id, doc_id) response = await make_response(data) diff --git a/api/apps/restful_apis/memory_api.py b/api/apps/restful_apis/memory_api.py index c361d816b60..1be67b8a70b 100644 --- a/api/apps/restful_apis/memory_api.py +++ b/api/apps/restful_apis/memory_api.py @@ -17,7 +17,7 @@ import os import time -from quart import request +from quart import request, g from common.constants import LLMType, RetCode from common.exceptions import ArgumentException, NotFoundException from api.apps import login_required, current_user @@ -188,8 +188,18 @@ async def add_message(): req = await get_request_json() memory_ids = req["memory_id"] + # JWT / session users cannot spoof attribution; API-key callers may supply an external subject id. + try: + trust_client_subject = bool(getattr(g, "auth_via_api_token", False)) + except RuntimeError: + trust_client_subject = False + if trust_client_subject: + effective_user_id = req.get("user_id", "") + else: + effective_user_id = current_user.id + message_dict = { - "user_id": req.get("user_id"), + "user_id": effective_user_id, "agent_id": req["agent_id"], "session_id": req["session_id"], "user_input": req["user_input"], diff --git a/api/apps/restful_apis/search_api.py b/api/apps/restful_apis/search_api.py index c56d0ff8344..7755704e4d2 100644 --- a/api/apps/restful_apis/search_api.py +++ b/api/apps/restful_apis/search_api.py @@ -15,6 +15,7 @@ # import json +import logging from quart import Response, request from api.db.services.dialog_service import async_ask @@ -75,15 +76,31 @@ def list_searches(): owner_ids = request.args.getlist("owner_ids") try: - if not owner_ids: - tenants = [] - search_apps, total = SearchService.get_by_tenant_ids(tenants, current_user.id, page_number, items_per_page, orderby, desc, keywords) + tenants = TenantService.get_joined_tenants_by_user_id(current_user.id) + authorized_owner_ids = {member["tenant_id"] for member in tenants} + authorized_owner_ids.add(current_user.id) + + if owner_ids: + requested_owner_ids = set(owner_ids) + unauthorized_owner_ids = requested_owner_ids - authorized_owner_ids + if unauthorized_owner_ids: + logging.warning( + "Rejected list_searches request: user=%s attempted unauthorized owner_ids=%s", + current_user.id, + sorted(unauthorized_owner_ids), + ) + return get_json_result( + data=False, + message="Only authorized owner_ids can be queried.", + code=RetCode.OPERATING_ERROR, + ) + effective_owner_ids = list(requested_owner_ids) else: - search_apps, total = SearchService.get_by_tenant_ids(owner_ids, current_user.id, 0, 0, orderby, desc, keywords) - search_apps = [s for s in search_apps if s["tenant_id"] in owner_ids] - total = len(search_apps) - if page_number and items_per_page: - search_apps = search_apps[(page_number - 1) * items_per_page: page_number * items_per_page] + effective_owner_ids = list(authorized_owner_ids) + + search_apps, total = SearchService.get_by_tenant_ids( + effective_owner_ids, current_user.id, page_number, items_per_page, orderby, desc, keywords + ) return get_json_result(data={"search_apps": search_apps, "total": total}) except Exception as e: return server_error_response(e) diff --git a/api/apps/restful_apis/user_api.py b/api/apps/restful_apis/user_api.py index 714453ac6fa..7ae99163d81 100644 --- a/api/apps/restful_apis/user_api.py +++ b/api/apps/restful_apis/user_api.py @@ -806,15 +806,15 @@ async def forget_reset_password(): new_pwd = req.get("new_password") new_pwd2 = req.get("confirm_new_password") - new_pwd_base64 = decrypt(new_pwd) - new_pwd_string = base64.b64decode(new_pwd_base64).decode('utf-8') - new_pwd2_string = base64.b64decode(decrypt(new_pwd2)).decode('utf-8') + if not all([email, new_pwd, new_pwd2]): + return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and passwords are required") if not REDIS_CONN.get(_verified_key(email)): return get_json_result(data=False, code=RetCode.AUTHENTICATION_ERROR, message="email not verified") - if not all([email, new_pwd, new_pwd2]): - return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="email and passwords are required") + new_pwd_base64 = decrypt(new_pwd) + new_pwd_string = base64.b64decode(new_pwd_base64).decode('utf-8') + new_pwd2_string = base64.b64decode(decrypt(new_pwd2)).decode('utf-8') if new_pwd_string != new_pwd2_string: return get_json_result(data=False, code=RetCode.ARGUMENT_ERROR, message="passwords do not match") diff --git a/api/apps/sdk/dify_retrieval.py b/api/apps/sdk/dify_retrieval.py index e85a1d439c5..05885c380b2 100644 --- a/api/apps/sdk/dify_retrieval.py +++ b/api/apps/sdk/dify_retrieval.py @@ -15,7 +15,13 @@ # import logging -from quart import jsonify +from quart import jsonify, request +from werkzeug.exceptions import BadRequest as WerkzeugBadRequest + +try: + from quart.exceptions import BadRequest as QuartBadRequest +except ImportError: # pragma: no cover - optional dependency + QuartBadRequest = None from api.db.services.document_service import DocumentService from api.db.services.doc_metadata_service import DocMetadataService @@ -23,14 +29,86 @@ from api.db.services.llm_service import LLMBundle from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type from common.metadata_utils import meta_filter, convert_conditions -from api.utils.api_utils import apikey_required, build_error_result, get_request_json, validate_request +from api.utils.api_utils import apikey_required, build_error_result, get_request_json, get_json_result from rag.app.tag import label_question from common.constants import RetCode, LLMType from common import settings -@manager.route('/dify/retrieval', methods=['POST']) # noqa: F821 +logger = logging.getLogger(__name__) + + +async def _read_retrieval_request(): + try: + method = request.method + except RuntimeError: + # Unit tests may call the handler directly without a request context. + method = "POST" + if method == "GET": + query_args = request.args + retrieval_setting = {} + knowledge_id = query_args.get("knowledge_id") + query = query_args.get("query") + use_kg = str(query_args.get("use_kg", "")).lower() in {"1", "true", "yes", "on"} + top_k = query_args.get("top_k") + score_threshold = query_args.get("score_threshold") + try: + if top_k not in (None, ""): + retrieval_setting["top_k"] = int(top_k) + if score_threshold not in (None, ""): + retrieval_setting["score_threshold"] = float(score_threshold) + except (TypeError, ValueError): + raise ValueError("top_k must be integer and score_threshold must be numeric") + safe_query = f"len={len(query)}" if isinstance(query, str) else "len=0" + logger.debug( + "Dify retrieval GET normalization: knowledge_id=%s query=%s use_kg=%s top_k=%s score_threshold=%s", + knowledge_id, + safe_query, + use_kg, + retrieval_setting.get("top_k"), + retrieval_setting.get("score_threshold"), + ) + + req = { + "knowledge_id": knowledge_id, + "query": query, + "use_kg": use_kg, + "retrieval_setting": retrieval_setting, + } + return req + req = await get_request_json() + knowledge_id = req.get("knowledge_id") if isinstance(req, dict) else None + query = req.get("query") if isinstance(req, dict) else None + use_kg = req.get("use_kg", False) if isinstance(req, dict) else False + retrieval_setting = req.get("retrieval_setting", {}) if isinstance(req, dict) else {} + if not isinstance(retrieval_setting, dict): + retrieval_setting = {} + safe_query = f"len={len(query)}" if isinstance(query, str) else "len=0" + logger.debug( + "Dify retrieval GET normalization: knowledge_id=%s query=%s use_kg=%s top_k=%s score_threshold=%s", + knowledge_id, + safe_query, + use_kg, + retrieval_setting.get("top_k"), + retrieval_setting.get("score_threshold"), + ) + return req + + +def _parse_retrieval_options(retrieval_setting): + if retrieval_setting is None: + retrieval_setting = {} + if not isinstance(retrieval_setting, dict): + raise ValueError("retrieval_setting must be an object") + try: + similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0)) + top = int(retrieval_setting.get("top_k", 1024)) + except (TypeError, ValueError): + raise ValueError("top_k must be integer and score_threshold must be numeric") + return retrieval_setting, similarity_threshold, top + + +@manager.route('/dify/retrieval', methods=['POST', 'GET']) # noqa: F821 @apikey_required -@validate_request("knowledge_id", "query") async def retrieval(tenant_id): """ Dify-compatible retrieval API @@ -40,9 +118,34 @@ async def retrieval(tenant_id): security: - ApiKeyAuth: [] parameters: + - in: query + name: knowledge_id + required: false + type: string + description: Knowledge base ID (for GET requests) + - in: query + name: query + required: false + type: string + description: Query text (for GET requests) + - in: query + name: use_kg + required: false + type: boolean + description: Whether to use knowledge graph (for GET requests) + - in: query + name: top_k + required: false + type: integer + description: Number of results to return (for GET requests) + - in: query + name: score_threshold + required: false + type: number + description: Similarity threshold (for GET requests) - in: body name: body - required: true + required: false schema: type: object required: @@ -115,15 +218,32 @@ async def retrieval(tenant_id): 404: description: Knowledge base or document not found """ - req = await get_request_json() + parse_exception_types = (AttributeError, TypeError, ValueError, WerkzeugBadRequest) + if QuartBadRequest is not None: + parse_exception_types = parse_exception_types + (QuartBadRequest,) + try: + req = await _read_retrieval_request() + except parse_exception_types as e: + return build_error_result( + message=f"invalid or malformed arguments: {str(e)}; ", + code=RetCode.ARGUMENT_ERROR, + ) + missing = [field for field in ("knowledge_id", "query") if not req.get(field)] + if missing: + return build_error_result( + message=f"required arguments are missing: {','.join(missing)}; ", + code=RetCode.ARGUMENT_ERROR, + ) question = req["query"] kb_id = req["knowledge_id"] use_kg = req.get("use_kg", False) - retrieval_setting = req.get("retrieval_setting", {}) - similarity_threshold = float(retrieval_setting.get("score_threshold", 0.0)) - top = int(retrieval_setting.get("top_k", 1024)) - if top <= 0: - return build_error_result(message="`top_k` must be greater than 0", code=RetCode.DATA_ERROR) + try: + _, similarity_threshold, top = _parse_retrieval_options(req.get("retrieval_setting", {})) + except ValueError as e: + return build_error_result( + message=f"invalid or malformed arguments: {str(e)}; ", + code=RetCode.ARGUMENT_ERROR, + ) metadata_condition = req.get("metadata_condition", {}) or {} metas = DocMetadataService.get_flatted_meta_by_kbs([kb_id]) @@ -191,3 +311,10 @@ async def retrieval(tenant_id): ) logging.exception(e) return build_error_result(message=str(e), code=RetCode.SERVER_ERROR) + + +@manager.route('/dify/retrieval', methods=['GET']) # noqa: F821 +async def retrieval_health_check(): + """Health check endpoint for Dify external knowledge base connectivity verification.""" + return get_json_result(data=True) + diff --git a/api/apps/sdk/doc.py b/api/apps/sdk/doc.py index cf297c4b250..4498b5f5de9 100644 --- a/api/apps/sdk/doc.py +++ b/api/apps/sdk/doc.py @@ -16,9 +16,10 @@ import logging from io import BytesIO -from quart import request, send_file +from quart import send_file -from api.db.db_models import APIToken, Document, Task +from api.apps import login_required +from api.db.db_models import Document, Task from api.db.joint_services.tenant_model_service import get_model_config_by_id, get_model_config_by_type_and_name, get_tenant_default_model_by_type from api.db.services.doc_metadata_service import DocMetadataService from api.db.services.document_service import DocumentService @@ -51,8 +52,8 @@ def _enrich_chunks_with_document_metadata(chunks: list[dict], metadata_fields=No @manager.route("/datasets//documents/", methods=["GET"]) # noqa: F821 -@token_required -async def download(tenant_id, dataset_id, document_id): +@login_required +async def download(dataset_id, document_id): """ Download a document from a dataset. --- @@ -90,8 +91,6 @@ async def download(tenant_id, dataset_id, document_id): """ if not document_id: return get_error_data_result(message="Specify document_id please.") - if not KnowledgebaseService.query(id=dataset_id, tenant_id=tenant_id): - return get_error_data_result(message=f"You do not own the dataset {dataset_id}.") doc = DocumentService.query(kb_id=dataset_id, id=document_id) if not doc: return get_error_data_result(message=f"The dataset not own the document {document_id}.") @@ -110,36 +109,52 @@ async def download(tenant_id, dataset_id, document_id): ) -@manager.route("/documents/", methods=["GET"]) # noqa: F821 -async def download_doc(document_id): - token = request.headers.get("Authorization").split() - if len(token) != 2: - return get_error_data_result(message="Authorization is not valid!") - token = token[1] - logging.info("Beta API token lookup attempted for document download") - objs = APIToken.query(beta=token) - if not objs: - logging.warning("Beta API token lookup failed for document download: invalid API key") - return get_error_data_result(message='Authentication error: API key is invalid!"') - if len(objs) > 1: - logging.error("Beta API token lookup is ambiguous for document download: matches=%s", len(objs)) - return get_error_data_result(message="Authentication error: API key configuration is ambiguous.") - tenant_id = objs[0].tenant_id - logging.info("Beta API token authorized for document download: tenant_id=%s", tenant_id) +DOC_STOP_PARSING_INVALID_STATE_MESSAGE = "Can't stop parsing document that has not started or already completed" +DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE = "DOC_STOP_PARSING_INVALID_STATE" +@manager.route("/documents/", methods=["GET"]) # noqa: F821 +@login_required +async def download_document(document_id): + """ + Download a document. + --- + tags: + - Documents + security: + - ApiKeyAuth: [] + produces: + - application/octet-stream + parameters: + - in: path + name: dataset_id + type: string + required: true + description: ID of the dataset. + - in: path + name: document_id + type: string + required: true + description: ID of the document to download. + - in: header + name: Authorization + type: string + required: true + description: Bearer token for authentication. + responses: + 200: + description: Document file stream. + schema: + type: file + 400: + description: Error message. + schema: + type: object + """ if not document_id: return get_error_data_result(message="Specify document_id please.") doc = DocumentService.query(id=document_id) if not doc: return get_error_data_result(message=f"The dataset not own the document {document_id}.") - if not KnowledgebaseService.query(id=doc[0].kb_id, tenant_id=tenant_id): - logging.warning( - "cross-tenant access denied for document download: tenant_id=%s kb_id=%s document_id=%s", - tenant_id, - doc[0].kb_id, - document_id, - ) - return get_error_data_result(message="You do not have access to this document.") # The process of downloading doc_id, doc_location = File2DocumentService.get_storage_address(doc_id=document_id) # minio address file_stream = settings.STORAGE_IMPL.get(doc_id, doc_location) @@ -154,11 +169,6 @@ async def download_doc(document_id): mimetype="application/octet-stream", # Set a default MIME type ) - -DOC_STOP_PARSING_INVALID_STATE_MESSAGE = "Can't stop parsing document that has not started or already completed" -DOC_STOP_PARSING_INVALID_STATE_ERROR_CODE = "DOC_STOP_PARSING_INVALID_STATE" - - @manager.route("/datasets//chunks", methods=["POST"]) # noqa: F821 @token_required async def parse(tenant_id, dataset_id): @@ -492,7 +502,12 @@ async def retrieval_test(tenant_id): rerank_mdl = None if req.get("tenant_rerank_id"): - rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) + allowed_rerank_tenant_ids = {tenant_id, *[dataset.tenant_id for dataset in kbs]} + rerank_model_config = get_model_config_by_id( + req["tenant_rerank_id"], + allowed_tenant_ids=allowed_rerank_tenant_ids, + requester_tenant_id=tenant_id, + ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif req.get("rerank_id"): rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK, req["rerank_id"]) diff --git a/api/apps/sdk/session.py b/api/apps/sdk/session.py index 11960dcf65c..394ba71e905 100644 --- a/api/apps/sdk/session.py +++ b/api/apps/sdk/session.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import copy import json import re @@ -36,7 +37,7 @@ from api.db.services.user_service import UserTenantService from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \ get_model_config_by_type_and_name -from common.misc_utils import get_uuid +from common.misc_utils import get_uuid, thread_pool_exec from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \ get_result, get_request_json, server_error_response, token_required, validate_request from rag.app.tag import label_question @@ -58,11 +59,11 @@ async def create_agent_session(tenant_id, agent_id): user_id = req.get("user_id") or request.args.get("user_id", tenant_id) release_mode = bool(req.get("release", request.args.get("release", False))) - if not UserCanvasService.query(user_id=tenant_id, id=agent_id): + if not await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id): return get_error_data_result("You cannot access the agent.") try: - cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode, tenant_id) + cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode, tenant_id) except LookupError: return get_error_data_result("Agent not found.") except PermissionError as e: @@ -74,7 +75,7 @@ async def create_agent_session(tenant_id, agent_id): cvs.dsl = json.loads(str(canvas)) # Get the version title based on release_mode - version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode) + version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode) conv = { "id": session_id, "dialog_id": cvs.id, @@ -84,7 +85,7 @@ async def create_agent_session(tenant_id, agent_id): "dsl": cvs.dsl, "version_title": version_title } - API4ConversationService.save(**conv) + await thread_pool_exec(API4ConversationService.save, **conv) conv["agent_id"] = conv.pop("dialog_id") return get_result(data=conv) @@ -95,7 +96,7 @@ async def delete_agent_session(tenant_id, agent_id): errors = [] success_count = 0 req = await get_request_json() - cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id) + cvs = await thread_pool_exec(UserCanvasService.query, user_id=tenant_id, id=agent_id) if not cvs: return get_error_data_result(f"You don't own the agent {agent_id}") @@ -105,7 +106,7 @@ async def delete_agent_session(tenant_id, agent_id): ids = req.get("ids") if not ids: if req.get("delete_all") is True: - ids = [conv.id for conv in API4ConversationService.query(dialog_id=agent_id)] + ids = [conv.id for conv in await thread_pool_exec(API4ConversationService.query, dialog_id=agent_id)] if not ids: return get_result() else: @@ -117,11 +118,11 @@ async def delete_agent_session(tenant_id, agent_id): conv_list = unique_conv_ids for session_id in conv_list: - conv = API4ConversationService.query(id=session_id, dialog_id=agent_id) + conv = await thread_pool_exec(API4ConversationService.query, id=session_id, dialog_id=agent_id) if not conv: errors.append(f"The agent doesn't own the session {session_id}") continue - API4ConversationService.delete_by_id(session_id) + await thread_pool_exec(API4ConversationService.delete_by_id, session_id) success_count += 1 if errors: @@ -151,7 +152,7 @@ async def chatbot_completions(dialog_id): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') tenant_id = objs[0].tenant_id @@ -226,11 +227,11 @@ async def chatbots_inputs(dialog_id): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') tenant_id = objs[0].tenant_id - exists, dialog = DialogService.get_by_id(dialog_id) + exists, dialog = await thread_pool_exec(DialogService.get_by_id, dialog_id) if (not exists or getattr(dialog, "tenant_id", None) != tenant_id or str(getattr(dialog, "status", "")) != StatusEnum.VALID.value): @@ -264,7 +265,7 @@ async def agent_bot_completions(agent_id): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -293,13 +294,56 @@ async def stream(): return resp try: + full_content = "" + reference = {} + structured_output = {} + final_ans = {} async for answer in agent_completion(objs[0].tenant_id, agent_id, **req): - return get_result(data=answer) + # agent_completion yields SSE-formatted strings. A single yielded + # chunk can contain multiple "data:..." frames separated by "\n\n" + # plus blank or comment lines, so parse line-by-line rather than + # assuming one frame per chunk. + if not isinstance(answer, str): + continue + for line in answer.splitlines(): + line = line.strip() + if not line.startswith("data:"): + continue + payload = line[len("data:"):].strip() + if not payload: + continue + try: + ans = json.loads(payload) + except Exception as e: + logging.debug("agent_bot_completions: skipping malformed SSE frame: %s", e) + continue + event = ans.get("event") + if event == "message": + full_content += ans.get("data", {}).get("content", "") or "" + if ans.get("data", {}).get("reference"): + reference.update(ans["data"]["reference"]) + if event == "node_finished": + data = ans.get("data", {}) + node_out = data.get("outputs") or {} + component_id = data.get("component_id") + if component_id is not None and "structured" in node_out: + structured_output[component_id] = copy.deepcopy(node_out["structured"]) + final_ans = ans + + if not final_ans: + return get_result(data={}) + + if "data" not in final_ans or not isinstance(final_ans["data"], dict): + final_ans["data"] = {} + final_ans["data"]["content"] = full_content + final_ans["data"]["reference"] = reference + if structured_output: + final_ans["data"]["structured"] = structured_output + return get_result(data=final_ans) except Exception as e: logging.exception(e) return get_error_data_result(message=str(e) or "Unknown error") - return None @manager.route("/agentbots//inputs", methods=["GET"]) # noqa: F821 async def begin_inputs(agent_id): @@ -307,11 +351,11 @@ async def begin_inputs(agent_id): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') - e, cvs = UserCanvasService.get_by_id(agent_id) + e, cvs = await thread_pool_exec(UserCanvasService.get_by_id, agent_id) if not e: return get_error_data_result(f"Can't find agent by ID: {agent_id}") @@ -328,7 +372,7 @@ async def ask_about_embedded(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -338,7 +382,7 @@ async def ask_about_embedded(): search_id = req.get("search_id", "") search_config = {} if search_id: - if search_app := SearchService.get_detail(search_id): + if search_app := await thread_pool_exec(SearchService.get_detail, search_id): search_config = search_app.get("search_config", {}) async def stream(): @@ -367,7 +411,7 @@ async def retrieval_test_embedded(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -406,16 +450,16 @@ async def _retrieval(): chat_mdl = None if req.get("search_id", ""): nonlocal search_config - detail = SearchService.get_detail(req.get("search_id", "")) + detail = await thread_pool_exec(SearchService.get_detail, req.get("search_id", "")) if detail: search_config = detail.get("search_config", {}) meta_data_filter = search_config.get("meta_data_filter", {}) if meta_data_filter.get("method") in ["auto", "semi_auto"]: chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id) + chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id) else: - chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) # Apply search_config settings if not explicitly provided in request if not req.get("similarity_threshold"): @@ -429,7 +473,7 @@ async def _retrieval(): else: meta_data_filter = req.get("meta_data_filter") or {} if meta_data_filter.get("method") in ["auto", "semi_auto"]: - chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) if meta_data_filter: @@ -443,38 +487,44 @@ async def _retrieval(): metas_loader=lambda: DocMetadataService.get_flatted_meta_by_kbs(kb_ids), ) - tenants = UserTenantService.query(user_id=tenant_id) + tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id) for kb_id in kb_ids: for tenant in tenants: - if KnowledgebaseService.query(tenant_id=tenant.tenant_id, id=kb_id): + if await thread_pool_exec(KnowledgebaseService.query, tenant_id=tenant.tenant_id, id=kb_id): tenant_ids.append(tenant.tenant_id) break else: return get_json_result(data=False, message="Only owner of dataset authorized for this operation.", code=RetCode.OPERATING_ERROR) - e, kb = KnowledgebaseService.get_by_id(kb_ids[0]) + e, kb = await thread_pool_exec(KnowledgebaseService.get_by_id, kb_ids[0]) if not e: return get_error_data_result(message="Knowledgebase not found!") if langs: _question = await cross_languages(kb.tenant_id, None, _question, langs) if kb.tenant_embd_id: - embd_model_config = get_model_config_by_id(kb.tenant_embd_id) + embd_model_config = await thread_pool_exec(get_model_config_by_id, kb.tenant_embd_id) else: - embd_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) + embd_model_config = await thread_pool_exec(get_model_config_by_type_and_name, kb.tenant_id, LLMType.EMBEDDING, kb.embd_id) embd_mdl = LLMBundle(kb.tenant_id, embd_model_config) rerank_mdl = None if tenant_rerank_id: - rerank_model_config = get_model_config_by_id(tenant_rerank_id) + allowed_rerank_tenant_ids = {tenant_id, *tenant_ids} + rerank_model_config = await thread_pool_exec( + get_model_config_by_id, + tenant_rerank_id, + allowed_rerank_tenant_ids, + tenant_id, + ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif rerank_id: - rerank_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.RERANK, rerank_id) + rerank_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.RERANK, rerank_id) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) if req.get("keyword", False): - default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(kb.tenant_id, default_chat_model) _question += await keyword_extraction(chat_mdl, _question) @@ -484,7 +534,7 @@ async def _retrieval(): local_doc_ids, rerank_mdl=rerank_mdl, highlight=req.get("highlight"), rank_feature=labels ) if use_kg: - default_chat_model = get_tenant_default_model_by_type(kb.tenant_id, LLMType.CHAT) + default_chat_model = await thread_pool_exec(get_tenant_default_model_by_type, kb.tenant_id, LLMType.CHAT) ck = await settings.kg_retriever.retrieval(_question, tenant_ids, kb_ids, embd_mdl, LLMBundle(kb.tenant_id, default_chat_model)) if ck["content_with_weight"]: @@ -517,7 +567,7 @@ async def related_questions_embedded(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -529,16 +579,16 @@ async def related_questions_embedded(): search_id = req.get("search_id", "") search_config = {} if search_id: - if search_app := SearchService.get_detail(search_id): + if search_app := await thread_pool_exec(SearchService.get_detail, search_id): search_config = search_app.get("search_config", {}) question = req["question"] chat_id = search_config.get("chat_id", "") if chat_id: - chat_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, chat_id) + chat_model_config = await thread_pool_exec(get_model_config_by_type_and_name, tenant_id, LLMType.CHAT, chat_id) else: - chat_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.CHAT) + chat_model_config = await thread_pool_exec(get_tenant_default_model_by_type, tenant_id, LLMType.CHAT) chat_mdl = LLMBundle(tenant_id, chat_model_config) gen_conf = search_config.get("llm_setting", {"temperature": 0.9}) @@ -565,7 +615,7 @@ async def detail_share_embedded(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -574,15 +624,15 @@ async def detail_share_embedded(): if not tenant_id: return get_error_data_result(message="permission denined.") try: - tenants = UserTenantService.query(user_id=tenant_id) + tenants = await thread_pool_exec(UserTenantService.query, user_id=tenant_id) for tenant in tenants: - if SearchService.query(tenant_id=tenant.tenant_id, id=search_id): + if await thread_pool_exec(SearchService.query, tenant_id=tenant.tenant_id, id=search_id): break else: return get_json_result(data=False, message="Has no permission for this operation.", code=RetCode.OPERATING_ERROR) - search = SearchService.get_detail(search_id) + search = await thread_pool_exec(SearchService.get_detail, search_id) if not search: return get_error_data_result(message="Can't find this Search App!") return get_json_result(data=search) @@ -597,7 +647,7 @@ async def mindmap(): if len(token) != 2: return get_error_data_result(message='Authorization is not valid!') token = token[1] - objs = APIToken.query(beta=token) + objs = await thread_pool_exec(APIToken.query, beta=token) if not objs: return get_error_data_result(message='Authentication error: API key is invalid!"') @@ -605,7 +655,7 @@ async def mindmap(): req = await get_request_json() search_id = req.get("search_id", "") - search_app = SearchService.get_detail(search_id) if search_id else {} + search_app = await thread_pool_exec(SearchService.get_detail, search_id) if search_id else {} mind_map =await gen_mindmap(req["question"], req["kb_ids"], tenant_id, search_app.get("search_config", {})) if "error" in mind_map: diff --git a/api/apps/services/dataset_api_service.py b/api/apps/services/dataset_api_service.py index 9e49596539c..8d5f512a358 100644 --- a/api/apps/services/dataset_api_service.py +++ b/api/apps/services/dataset_api_service.py @@ -452,6 +452,10 @@ def delete_knowledge_graph(dataset_id: str, tenant_id: str): # Wiping the graph invalidates any phase-completion markers used to # short-circuit resolution / community detection on resume. clear_phase_markers(dataset_id) + KnowledgebaseService.update_by_id( + kb.id, + {"graphrag_task_id": "", "graphrag_task_finish_at": None}, + ) return True, True @@ -594,9 +598,8 @@ def aggregate_tags(dataset_ids: list[str], tenant_id: str): merged = {} for kb_tenant_id, kb_ids in dataset_ids_by_tenant.items(): - for bucket in settings.retriever.all_tags(kb_tenant_id, kb_ids): - tag = bucket["value"] - merged[tag] = merged.get(tag, 0) + bucket["count"] + for tag, count in settings.retriever.all_tags(kb_tenant_id, kb_ids): + merged[tag] = merged.get(tag, 0) + count return True, [{"value": tag, "count": count} for tag, count in merged.items()] @@ -1006,7 +1009,12 @@ async def search(dataset_id: str, tenant_id: str, req: dict): rerank_mdl = None if req.get("tenant_rerank_id"): - rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) + allowed_rerank_tenant_ids = {tenant_id, kb.tenant_id} + rerank_model_config = get_model_config_by_id( + req["tenant_rerank_id"], + allowed_tenant_ids=allowed_rerank_tenant_ids, + requester_tenant_id=tenant_id, + ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif req.get("rerank_id"): rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) @@ -1335,6 +1343,7 @@ async def search_datasets(tenant_id: str, req: dict): chat_mdl = LLMBundle(tenant_id, chat_model_config) if meta_data_filter: + logging.debug(f"Metadata filter: {meta_data_filter}, question: {question}, chat_mdl={'None' if chat_mdl is None else chat_mdl.llm_name}") local_doc_ids = await apply_meta_data_filter( meta_data_filter, None, @@ -1368,7 +1377,12 @@ async def search_datasets(tenant_id: str, req: dict): rerank_mdl = None if req.get("tenant_rerank_id"): - rerank_model_config = get_model_config_by_id(req["tenant_rerank_id"]) + allowed_rerank_tenant_ids = {tenant_id, *[dataset.tenant_id for dataset in kbs]} + rerank_model_config = get_model_config_by_id( + req["tenant_rerank_id"], + allowed_tenant_ids=allowed_rerank_tenant_ids, + requester_tenant_id=tenant_id, + ) rerank_mdl = LLMBundle(kb.tenant_id, rerank_model_config) elif req.get("rerank_id"): rerank_model_config = get_model_config_by_type_and_name(kb.tenant_id, LLMType.RERANK.value, req["rerank_id"]) diff --git a/api/apps/services/document_api_service.py b/api/apps/services/document_api_service.py index 59abbd25072..63c71ff4a24 100644 --- a/api/apps/services/document_api_service.py +++ b/api/apps/services/document_api_service.py @@ -122,13 +122,16 @@ def reset_document_for_reparse(doc, tenant_id, parser_id=None, pipeline_id=None) # Delete chunks from document store if doc.token_num > 0: - e = DocumentService.increment_chunk_num( - doc.id, - doc.kb_id, - doc.token_num * -1, - doc.chunk_num * -1, - doc.process_duration * -1, - ) + try: + e = DocumentService.increment_chunk_num( + doc.id, + doc.kb_id, + doc.token_num * -1, + doc.chunk_num * -1, + doc.process_duration * -1, + ) + except LookupError: + return get_error_data_result(message="Document not found!") if not e: return get_error_data_result(message="Document not found!") settings.docStoreConn.delete({"doc_id": doc.id}, search.index_name(tenant_id), doc.kb_id) diff --git a/api/apps/services/memory_api_service.py b/api/apps/services/memory_api_service.py index 9040f0ce445..53cff623ceb 100644 --- a/api/apps/services/memory_api_service.py +++ b/api/apps/services/memory_api_service.py @@ -21,10 +21,11 @@ from api.db.services.task_service import TaskService from api.db.joint_services.memory_message_service import get_memory_size_cache, judge_system_prompt_is_default, queue_save_to_memory_task, query_message from api.utils.memory_utils import format_ret_data_from_memory, get_memory_type_human +from api.utils.tenant_utils import ensure_tenant_model_id_for_params from api.constants import MEMORY_NAME_LIMIT, MEMORY_SIZE_LIMIT from memory.services.messages import MessageService from memory.utils.prompt_util import PromptAssembler -from common.constants import MemoryType, ForgettingPolicy +from common.constants import MemoryType, ForgettingPolicy, LLMType from common.exceptions import ArgumentException, NotFoundException from common.time_utils import current_timestamp, timestamp_to_date @@ -131,6 +132,9 @@ async def update_memory(memory_id: str, new_memory_setting: dict): "user_prompt": str } """ + current_memory = _require_memory_access(memory_id) + owner_tenant_id = current_memory.tenant_id + update_dict = {} # check name length if "name" in new_memory_setting: @@ -146,14 +150,32 @@ async def update_memory(memory_id: str, new_memory_setting: dict): if new_memory_setting["permissions"] not in [e.value for e in TenantPermission]: raise ArgumentException(f"Unknown permission '{new_memory_setting['permissions']}'.") update_dict["permissions"] = new_memory_setting["permissions"] - if new_memory_setting.get("llm_id"): - update_dict["llm_id"] = new_memory_setting["llm_id"] - if new_memory_setting.get("embd_id"): - update_dict["embd_id"] = new_memory_setting["embd_id"] - if new_memory_setting.get("tenant_llm_id"): - update_dict["tenant_llm_id"] = new_memory_setting["tenant_llm_id"] - if new_memory_setting.get("tenant_embd_id"): - update_dict["tenant_embd_id"] = new_memory_setting["tenant_embd_id"] + if ("tenant_llm_id" in new_memory_setting or "tenant_embd_id" in new_memory_setting) and not ( + new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id") + ): + raise ArgumentException( + "Do not set tenant_llm_id or tenant_embd_id directly; update llm_id and/or embd_id instead." + ) + if new_memory_setting.get("llm_id") or new_memory_setting.get("embd_id"): + merged = { + "llm_id": new_memory_setting.get("llm_id") or current_memory.llm_id, + "embd_id": new_memory_setting.get("embd_id") or current_memory.embd_id, + } + merged = ensure_tenant_model_id_for_params(owner_tenant_id, merged) + if not merged.get("tenant_llm_id"): + raise ArgumentException( + f"Tenant Model with name {merged['llm_id']} and type {LLMType.CHAT.value} not found" + ) + if new_memory_setting.get("embd_id") and not merged.get("tenant_embd_id"): + raise ArgumentException( + f"Tenant Model with name {merged['embd_id']} and type {LLMType.EMBEDDING.value} not found" + ) + if new_memory_setting.get("llm_id"): + update_dict["llm_id"] = merged["llm_id"] + if new_memory_setting.get("embd_id"): + update_dict["embd_id"] = merged["embd_id"] + update_dict["tenant_llm_id"] = merged["tenant_llm_id"] + update_dict["tenant_embd_id"] = merged.get("tenant_embd_id") if new_memory_setting.get("memory_type"): memory_type = set(new_memory_setting["memory_type"]) invalid_type = memory_type - {e.name.lower() for e in MemoryType} @@ -180,7 +202,6 @@ async def update_memory(memory_id: str, new_memory_setting: dict): for field in ["avatar", "description", "system_prompt", "user_prompt"]: if field in new_memory_setting: update_dict[field] = new_memory_setting[field] - current_memory = _require_memory_access(memory_id) memory_dict = current_memory.to_dict() memory_dict.update({"memory_type": get_memory_type_human(current_memory.memory_type)}) diff --git a/api/db/__init__.py b/api/db/__init__.py index 6d7ed9fcb97..6aa7c5bbf07 100644 --- a/api/db/__init__.py +++ b/api/db/__init__.py @@ -15,7 +15,7 @@ # from enum import IntEnum -from strenum import StrEnum +from enum import StrEnum class UserTenantRole(StrEnum): diff --git a/api/db/db_models.py b/api/db/db_models.py index 5fe64586c04..a207b00788f 100644 --- a/api/db/db_models.py +++ b/api/db/db_models.py @@ -1051,6 +1051,7 @@ class UserCanvas(DataBaseModel): description = TextField(null=True, help_text="Canvas description") canvas_type = CharField(max_length=32, null=True, help_text="Canvas type", index=True) canvas_category = CharField(max_length=32, null=False, default="agent_canvas", help_text="Canvas category: agent_canvas|dataflow_canvas", index=True) + tags = CharField(max_length=512, null=False, default="", help_text="Comma-separated tags for organizing agents", index=True) dsl = JSONField(null=True, default={}) class Meta: @@ -1223,6 +1224,7 @@ def python_value(self, value: str|None) -> datetime|None: class SyncLogs(DataBaseModel): id = CharField(max_length=32, primary_key=True) connector_id = CharField(max_length=32, index=True) + task_type = CharField(max_length=32, null=False, default="sync", index=True) status = CharField(max_length=128, null=False, help_text="Processing status", index=True) from_beginning = CharField(max_length=1, null=True, help_text="", default="0", index=False) new_docs_indexed = IntegerField(default=0, index=False) @@ -1631,6 +1633,7 @@ def migrate_db(): alter_db_add_column(migrator, "llm_factories", "rank", IntegerField(default=0, index=False)) alter_db_add_column(migrator, "api_4_conversation", "name", CharField(max_length=255, null=True, help_text="conversation name", index=False)) alter_db_add_column(migrator, "api_4_conversation", "exp_user_id", CharField(max_length=255, null=True, help_text="exp_user_id", index=True)) + alter_db_add_column(migrator, "sync_logs", "task_type", CharField(max_length=32, null=False, default="sync", index=True)) # Migrate system_settings.value from CharField to TextField for longer sandbox configs alter_db_column_type(migrator, "system_settings", "value", TextField(null=False, help_text="Configuration value (JSON, string, etc.)")) alter_db_add_column(migrator, "document", "content_hash", CharField(max_length=32, null=True, help_text="xxhash128 of document content for change detection", default="", index=True)) @@ -1647,6 +1650,7 @@ def migrate_db(): alter_db_add_column(migrator, "memory", "tenant_embd_id", IntegerField(null=True, help_text="id in tenant_llm", index=True)) alter_db_add_column(migrator, "memory", "tenant_llm_id", IntegerField(null=True, help_text="id in tenant_llm", index=True)) alter_db_add_column(migrator, "user_canvas_version", "release", BooleanField(null=False, help_text="is released", default=False, index=True)) + alter_db_add_column(migrator, "user_canvas", "tags", CharField(max_length=512, null=False, default="", help_text="Comma-separated tags for organizing agents", index=True)) alter_db_add_column(migrator, "api_4_conversation", "version_title", CharField(max_length=255, null=True, help_text="canvas version title when session created", index=False)) alter_db_column_type(migrator, "document", "size", BigIntegerField(default=0, index=True)) alter_db_column_type(migrator, "file", "size", BigIntegerField(default=0, index=True)) diff --git a/api/db/joint_services/memory_message_service.py b/api/db/joint_services/memory_message_service.py index 4765b2bdbb6..1a6da3a8d6b 100644 --- a/api/db/joint_services/memory_message_service.py +++ b/api/db/joint_services/memory_message_service.py @@ -154,7 +154,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, else: user_prompts.append({"role": "user", "content": PromptAssembler.assemble_user_prompt(conversation_content, conversation_time, conversation_time)}) if tenant_llm_id: - llm_config = get_model_config_by_id(tenant_llm_id) + llm_config = get_model_config_by_id( + tenant_llm_id, + allowed_tenant_ids=tenant_id, + requester_tenant_id=tenant_id, + ) else: llm_config = get_model_config_by_type_and_name(tenant_id, LLMType.CHAT, llm_id) llm = LLMBundle(tenant_id, llm_config) @@ -174,7 +178,11 @@ async def extract_by_llm(tenant_id: str, tenant_llm_id: int, extract_conf: dict, async def embed_and_save(memory, message_list: list[dict], task_id: str=None): if memory.tenant_embd_id: - embd_model_config = get_model_config_by_id(memory.tenant_embd_id) + embd_model_config = get_model_config_by_id( + memory.tenant_embd_id, + allowed_tenant_ids=memory.tenant_id, + requester_tenant_id=memory.tenant_id, + ) else: embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) embedding_model = LLMBundle(memory.tenant_id, embd_model_config) @@ -248,7 +256,11 @@ def query_message(filter_dict: dict, params: dict): question = question.strip() memory = memory_list[0] if memory.tenant_embd_id: - embd_model_config = get_model_config_by_id(memory.tenant_embd_id) + embd_model_config = get_model_config_by_id( + memory.tenant_embd_id, + allowed_tenant_ids=memory.tenant_id, + requester_tenant_id=memory.tenant_id, + ) else: embd_model_config = get_model_config_by_type_and_name(memory.tenant_id, LLMType.EMBEDDING, memory.embd_id) embd_model = LLMBundle(memory.tenant_id, embd_model_config) diff --git a/api/db/joint_services/tenant_model_service.py b/api/db/joint_services/tenant_model_service.py index 645d7563812..677bfcaaafc 100644 --- a/api/db/joint_services/tenant_model_service.py +++ b/api/db/joint_services/tenant_model_service.py @@ -24,10 +24,29 @@ logger = logging.getLogger(__name__) -def get_model_config_by_id(tenant_model_id: int) -> dict: +def get_model_config_by_id( + tenant_model_id: int, + allowed_tenant_ids: str | list[str] | set[str] | tuple[str, ...] | None = None, + requester_tenant_id: str | None = None, +) -> dict: found, model_config = TenantLLMService.get_by_id(tenant_model_id) if not found: raise LookupError(f"Tenant Model with id {tenant_model_id} not found") + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if str(model_config.tenant_id) not in allowed_tenant_ids: + logger.warning( + "Denied tenant model access: tenant_model_id=%s model_tenant_id=%s " + "allowed_tenant_ids=%s requester_tenant_id=%s", + tenant_model_id, + model_config.tenant_id, + sorted(allowed_tenant_ids), + requester_tenant_id, + ) + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") config_dict = model_config.to_dict() api_key, is_tools, api_key_payload = TenantLLMService._decode_api_key_config(config_dict.get("api_key", "")) config_dict["api_key"] = api_key diff --git a/api/db/services/canvas_service.py b/api/db/services/canvas_service.py index 4a5734e155d..8c7fe4748ff 100644 --- a/api/db/services/canvas_service.py +++ b/api/db/services/canvas_service.py @@ -16,6 +16,8 @@ import json import logging import time +from functools import reduce +from operator import or_ from uuid import uuid4 from agent.canvas import Canvas from api.db import CanvasCategory, TenantPermission @@ -23,7 +25,7 @@ from api.db.services.api_service import API4ConversationService from api.db.services.common_service import CommonService from api.db.services.user_canvas_version import UserCanvasVersionService -from common.misc_utils import get_uuid +from common.misc_utils import get_uuid, thread_pool_exec from api.utils.api_utils import get_data_openai import tiktoken from peewee import fn @@ -149,6 +151,7 @@ def get_by_tenant_ids( desc, keywords, canvas_category=None, + tags=None, ): fields = [ cls.model.id, @@ -161,6 +164,7 @@ def get_by_tenant_ids( User.avatar.alias('tenant_avatar'), cls.model.update_time, cls.model.canvas_category, + cls.model.tags, ] if keywords: agents = cls.model.select(*fields).join(User, on=(cls.model.user_id == User.id)).where( @@ -173,6 +177,13 @@ def get_by_tenant_ids( ) if canvas_category: agents = agents.where(cls.model.canvas_category == canvas_category) + if tags: + tag_list = [t.strip() for t in tags if t and t.strip()] if isinstance(tags, (list, tuple)) else [t.strip() for t in str(tags).split(",") if t.strip()] + if tag_list: + # Wrap value with commas so 'ml' doesn't match 'ml-ops'. + wrapped = fn.CONCAT(",", cls.model.tags, ",") + clauses = [wrapped.contains(f",{t},") for t in tag_list] + agents = agents.where(reduce(or_, clauses)) if desc: agents = agents.order_by(cls.model.getter_by(orderby).desc()) else: @@ -199,6 +210,69 @@ def get_by_tenant_ids( return agents_list, count + @classmethod + @DB.connection_context() + def list_tags(cls, joined_tenant_ids, user_id, canvas_category=None): + """Return {tag: agent_count} aggregated across agents visible to the user.""" + query = cls.model.select(cls.model.tags).where( + ((cls.model.user_id.in_(joined_tenant_ids)) & (cls.model.permission == TenantPermission.TEAM.value)) | (cls.model.user_id == user_id) + ) + if canvas_category: + query = query.where(cls.model.canvas_category == canvas_category) + + counts: dict[str, int] = {} + for row in query.dicts(): + for t in (row.get("tags") or "").split(","): + t = t.strip() + if t: + counts[t] = counts.get(t, 0) + 1 + logging.info( + "UserCanvasService.list_tags user=%s canvas_category=%s tags_count=%d", + user_id, + canvas_category, + len(counts), + ) + return counts + + # Tag storage is a single comma-separated CharField(max_length=512); + # commas inside a tag would corrupt the encoding, so strip them on write. + TAGS_FIELD_MAX = 512 + TAG_MAX_LEN = 64 + + @classmethod + @DB.connection_context() + def update_tags(cls, canvas_id, tags): + """Persist a normalized comma-separated tag string for the given canvas.""" + if isinstance(tags, (list, tuple)): + cleaned = [str(t).replace(",", " ").strip() for t in tags if t and str(t).strip()] + else: + cleaned = [t.strip() for t in str(tags or "").split(",") if t.strip()] + # Dedupe (case-insensitive, preserve order), cap individual tag length, + # then truncate the joined value so it always fits the column. + seen = set() + normalized = [] + used = 0 + for t in cleaned: + t = t[: cls.TAG_MAX_LEN] + key = t.lower() + if key in seen: + continue + extra = len(t) + (1 if normalized else 0) + if used + extra > cls.TAGS_FIELD_MAX: + break + seen.add(key) + normalized.append(t) + used += extra + value = ",".join(normalized) + rows_affected = cls.model.update(tags=value).where(cls.model.id == canvas_id).execute() + logging.info( + "UserCanvasService.update_tags canvas_id=%s tags_count=%d rows=%d", + canvas_id, + len(normalized), + rows_affected, + ) + return rows_affected + @classmethod @DB.connection_context() def accessible(cls, canvas_id, tenant_id): @@ -245,7 +319,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): release_mode = str(kwargs.get("release", "")).strip().lower() if session_id: - e, conv = API4ConversationService.get_by_id(session_id) + e, conv = await thread_pool_exec(API4ConversationService.get_by_id, session_id) if not e: raise LookupError("Session not found!") if not conv.message: @@ -254,15 +328,15 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): conv.dsl = json.dumps(conv.dsl, ensure_ascii=False) canvas = Canvas(conv.dsl, tenant_id, agent_id, canvas_id=agent_id, custom_header=custom_header) else: - cvs, dsl = UserCanvasService.get_agent_dsl_with_release(agent_id, release_mode=release_mode == "true", tenant_id=tenant_id) + cvs, dsl = await thread_pool_exec(UserCanvasService.get_agent_dsl_with_release, agent_id, release_mode=release_mode == "true", tenant_id=tenant_id) session_id = get_uuid() canvas = Canvas(dsl, tenant_id, agent_id, canvas_id=cvs.id, custom_header=custom_header) canvas.reset() # Get the version title based on release_mode - version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=release_mode == "true") + version_title = await thread_pool_exec(UserCanvasVersionService.get_latest_version_title, cvs.id, release_mode=release_mode == "true") conv = {"id": session_id, "dialog_id": cvs.id, "user_id": user_id, "message": [], "source": "agent", "dsl": dsl, "reference": [], "version_title": version_title} - API4ConversationService.save(**conv) + await thread_pool_exec(API4ConversationService.save, **conv) conv = API4Conversation(**conv) message_id = str(uuid4()) @@ -288,7 +362,7 @@ async def completion(tenant_id, agent_id, session_id=None, **kwargs): conv.errors = canvas.error conv.dsl = str(canvas) conv = conv.to_dict() - API4ConversationService.append_message(conv["id"], conv) + await thread_pool_exec(API4ConversationService.append_message, conv["id"], conv) async def completion_openai(tenant_id, agent_id, question, session_id=None, stream=True, **kwargs): diff --git a/api/db/services/connector_service.py b/api/db/services/connector_service.py index 9f7b0e6ded1..9fa868c6038 100644 --- a/api/db/services/connector_service.py +++ b/api/db/services/connector_service.py @@ -16,45 +16,114 @@ import logging from datetime import datetime import os -from typing import Tuple, List +from typing import Optional, Tuple, List from anthropic import BaseModel from peewee import SQL, fn from api.db import InputType -from api.db.db_models import Connector, SyncLogs, Connector2Kb, Knowledgebase +from api.db.db_models import DB, Connector, SyncLogs, Connector2Kb, Knowledgebase from api.db.services.common_service import CommonService from api.db.services.document_service import DocumentService from api.db.services.document_service import DocMetadataService from api.utils.common import hash128 from common.misc_utils import get_uuid -from common.constants import TaskStatus +from common.constants import ConnectorTaskType, TaskStatus from common.settings import TIMEZONE from common.time_utils import current_timestamp, timestamp_to_date +LOGGER = logging.getLogger(__name__) + + class ConnectorService(CommonService): model = Connector @classmethod - def resume(cls, connector_id, status): + def cancel_tasks(cls, connector_id): + e, conn = cls.get_by_id(connector_id) + if not e: + return + + logging.info( + "[Connector] stop connector=%s(%s)", + conn.name, + connector_id, + ) for c2k in Connector2KbService.query(connector_id=connector_id): - task = SyncLogsService.get_latest_task(connector_id, c2k.kb_id) - if not task: - if status == TaskStatus.SCHEDULE: - SyncLogsService.schedule(connector_id, c2k.kb_id) - ConnectorService.update_by_id(connector_id, {"status": status}) - return - - if task.status == TaskStatus.DONE: - if status == TaskStatus.SCHEDULE: - SyncLogsService.schedule(connector_id, c2k.kb_id, task.poll_range_end, total_docs_indexed=task.total_docs_indexed) - ConnectorService.update_by_id(connector_id, {"status": status}) - return - - task = task.to_dict() - task["status"] = status - SyncLogsService.update_by_id(task["id"], task) - ConnectorService.update_by_id(connector_id, {"status": status}) + SyncLogsService.filter_update( + [ + SyncLogs.connector_id == connector_id, + SyncLogs.kb_id == c2k.kb_id, + SyncLogs.status.in_([TaskStatus.SCHEDULE, TaskStatus.RUNNING]), + ], + {"status": TaskStatus.CANCEL}, + ) + ConnectorService.update_by_id(connector_id, {"status": TaskStatus.CANCEL}) + logging.info( + "[Connector] connector=%s status updated to %s", + connector_id, + TaskStatus.CANCEL, + ) + + @classmethod + @DB.connection_context() + def accessible(cls, connector_id: str, user_id: str) -> bool: + """Return whether the user can access the connector's tenant.""" + e, connector = cls.get_by_id(connector_id) + if not e: + LOGGER.warning("connector access denied: connector not found connector_id=%s user_id=%s", connector_id, user_id) + return False + + if connector.tenant_id == user_id: + return True + + from api.db.services.user_service import TenantService + + joined_tenants = TenantService.get_joined_tenants_by_user_id(user_id) + has_access = any(tenant["tenant_id"] == connector.tenant_id for tenant in joined_tenants) + if not has_access: + LOGGER.warning( + "connector access denied: tenant mismatch connector_id=%s user_id=%s tenant_id=%s", + connector_id, + user_id, + connector.tenant_id, + ) + return has_access + + @classmethod + def schedule_tasks(cls, connector_id): + e, conn = cls.get_by_id(connector_id) + if not e: + return + + logging.info("[Connector] schedule connector=%s(%s)", conn.name, connector_id) + prune_enabled = bool((conn.config or {}).get("sync_deleted_files")) + for c2k in Connector2KbService.query(connector_id=connector_id): + sync_task = SyncLogsService.get_latest_task( + connector_id, + c2k.kb_id, + ConnectorTaskType.SYNC, + ) + poll_range_start = None + total_docs_indexed = 0 + if sync_task and sync_task.status == TaskStatus.DONE: + poll_range_start = sync_task.poll_range_end + total_docs_indexed = sync_task.total_docs_indexed + + SyncLogsService.schedule( + connector_id, + c2k.kb_id, + poll_range_start, + total_docs_indexed=total_docs_indexed, + task_type=ConnectorTaskType.SYNC, + ) + + if prune_enabled: + SyncLogsService.schedule( + connector_id, + c2k.kb_id, + task_type=ConnectorTaskType.PRUNE, + ) @classmethod def list(cls, tenant_id): @@ -77,7 +146,9 @@ def rebuild(cls, kb_id:str, connector_id: str, tenant_id:str): SyncLogsService.filter_delete([SyncLogs.connector_id==connector_id, SyncLogs.kb_id==kb_id]) docs = DocumentService.query(source_type=f"{conn.source}/{conn.id}", kb_id=kb_id) err = FileService.delete_docs([d.id for d in docs], tenant_id) - SyncLogsService.schedule(connector_id, kb_id, reindex=True) + SyncLogsService.schedule(connector_id, kb_id, reindex=True, task_type=ConnectorTaskType.SYNC) + if (conn.config or {}).get("sync_deleted_files"): + SyncLogsService.schedule(connector_id, kb_id, task_type=ConnectorTaskType.PRUNE) return err @classmethod @@ -142,30 +213,25 @@ def cleanup_stale_documents_for_task( class SyncLogsService(CommonService): model = SyncLogs + @classmethod def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) -> Tuple[List[dict], int]: fields = [ cls.model.id, cls.model.connector_id, + cls.model.task_type, cls.model.kb_id, cls.model.update_date, - cls.model.poll_range_start, - cls.model.poll_range_end, cls.model.new_docs_indexed, cls.model.total_docs_indexed, + cls.model.docs_removed_from_index, cls.model.error_msg, - cls.model.full_exception_trace, cls.model.error_count, - Connector.name, - Connector.source, - Connector.tenant_id, - Connector.timeout_secs, + cls.model.time_started.alias("time_started"), + Connector.refresh_freq.alias("refresh_freq"), + Connector.prune_freq.alias("prune_freq"), Knowledgebase.name.alias("kb_name"), - Knowledgebase.avatar.alias("kb_avatar"), - Connector2Kb.auto_parse, - cls.model.from_beginning.alias("reindex"), cls.model.status, - cls.model.update_time ] if not connector_id: fields.append(Connector.config) @@ -197,6 +263,80 @@ def list_sync_tasks(cls, connector_id=None, page_number=None, items_per_page=15) return list(query.dicts()), total + @classmethod + def list_due_sync_tasks(cls) -> List[dict]: + return cls._list_due_tasks_for_freq( + ConnectorTaskType.SYNC, + "refresh_freq", + ) + + @classmethod + def list_due_prune_tasks(cls) -> List[dict]: + tasks = cls._list_due_tasks_for_freq( + ConnectorTaskType.PRUNE, + "prune_freq", + ) + return [ + task for task in tasks + # Prune is opt-in at the connector config level; keep the scheduler + # blind to prune_freq until the flag is enabled. + if bool((task.get("config") or {}).get("sync_deleted_files")) + and int(task.get("prune_freq") or 0) > 0 + ] + + @classmethod + def _list_due_tasks_for_freq(cls, task_type: str, freq_field: str) -> List[dict]: + fields = [ + cls.model.id, + cls.model.connector_id, + cls.model.task_type, + cls.model.kb_id, + cls.model.update_date, + cls.model.poll_range_start, + cls.model.poll_range_end, + cls.model.new_docs_indexed, + cls.model.total_docs_indexed, + cls.model.error_msg, + cls.model.full_exception_trace, + cls.model.error_count, + Connector.name, + Connector.source, + Connector.tenant_id, + Connector.timeout_secs, + Connector.config, + Connector.refresh_freq, + Connector.prune_freq, + Knowledgebase.name.alias("kb_name"), + Knowledgebase.avatar.alias("kb_avatar"), + Connector2Kb.auto_parse, + cls.model.from_beginning.alias("reindex"), + cls.model.status, + cls.model.update_time, + ] + + query = cls.model.select(*fields)\ + .join(Connector, on=(cls.model.connector_id==Connector.id))\ + .join(Connector2Kb, on=(cls.model.kb_id==Connector2Kb.kb_id))\ + .join(Knowledgebase, on=(cls.model.kb_id==Knowledgebase.id)) + + query = query.where( + Connector.input_type == InputType.POLL, + Connector.status == TaskStatus.SCHEDULE, + cls.model.status == TaskStatus.SCHEDULE, + cls.model.task_type == task_type, + ) + + database_type = os.getenv("DB_TYPE", "mysql") + if "postgres" in database_type.lower(): + expr = SQL( + f"NOW() AT TIME ZONE '{TIMEZONE}' - make_interval(mins => t2.{freq_field})" + ) + else: + expr = SQL(f"NOW() - INTERVAL `t2`.`{freq_field}` MINUTE") + query = query.where(cls.model.update_date < expr) + + return list(query.distinct().order_by(cls.model.update_time.desc()).dicts()) + @classmethod def start(cls, id, connector_id): cls.update_by_id(id, {"status": TaskStatus.RUNNING, "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') }) @@ -208,7 +348,15 @@ def done(cls, id, connector_id): ConnectorService.update_by_id(connector_id, {"status": TaskStatus.DONE}) @classmethod - def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, total_docs_indexed=0): + def schedule( + cls, + connector_id, + kb_id, + poll_range_start=None, + reindex=False, + total_docs_indexed=0, + task_type=ConnectorTaskType.SYNC, + ): try: if cls.model.select().where(cls.model.kb_id == kb_id, cls.model.connector_id == connector_id).count() > 100: rm_ids = [m.id for m in cls.model.select(cls.model.id).where(cls.model.kb_id == kb_id, cls.model.connector_id == connector_id).order_by(cls.model.update_time.asc()).limit(70)] @@ -218,21 +366,33 @@ def schedule(cls, connector_id, kb_id, poll_range_start=None, reindex=False, tot logging.exception(e) try: - e = cls.query(kb_id=kb_id, connector_id=connector_id, status=TaskStatus.SCHEDULE) + e = cls.query( + kb_id=kb_id, + connector_id=connector_id, + status=TaskStatus.SCHEDULE, + task_type=task_type, + ) if e: - logging.warning(f"{kb_id}--{connector_id} has already had a scheduling sync task which is abnormal.") + logging.warning( + "%s--%s already has a scheduled %s task.", + kb_id, + connector_id, + task_type, + ) return None reindex = "1" if reindex else "0" ConnectorService.update_by_id(connector_id, {"status": TaskStatus.SCHEDULE}) return cls.save(**{ "id": get_uuid(), "kb_id": kb_id, "status": TaskStatus.SCHEDULE, "connector_id": connector_id, + "task_type": task_type, "poll_range_start": poll_range_start, "from_beginning": reindex, - "total_docs_indexed": total_docs_indexed + "total_docs_indexed": total_docs_indexed, + "time_started": datetime.now().strftime('%Y-%m-%d %H:%M:%S') }) except Exception as e: logging.exception(e) - task = cls.get_latest_task(connector_id, kb_id) + task = cls.get_latest_task(connector_id, kb_id, task_type) if task: cls.model.update(status=TaskStatus.SCHEDULE, poll_range_start=poll_range_start, @@ -276,12 +436,13 @@ class FileObj(BaseModel): id: str filename: str blob: bytes + fingerprint: Optional[str] = None def read(self) -> bytes: return self.blob errs = [] - files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"]) for d in docs] + files = [FileObj(id=d["id"], filename=d["semantic_identifier"]+(f"{d['extension']}" if d["semantic_identifier"][::-1].find(d['extension'][::-1])<0 else ""), blob=d["blob"], fingerprint=d.get("fingerprint")) for d in docs] doc_ids = [] err, doc_blob_pairs = FileService.upload_document(kb, files, tenant_id, src) errs.extend(err) @@ -308,11 +469,14 @@ def read(self) -> bytes: return errs, doc_ids @classmethod - def get_latest_task(cls, connector_id, kb_id): - return cls.model.select().where( + def get_latest_task(cls, connector_id, kb_id, task_type=None): + query = cls.model.select().where( cls.model.connector_id==connector_id, cls.model.kb_id == kb_id - ).order_by(cls.model.update_time.desc()).first() + ) + if task_type is not None: + query = query.where(cls.model.task_type == task_type) + return query.order_by(cls.model.update_time.desc()).first() class Connector2KbService(CommonService): @@ -335,7 +499,10 @@ def link_connectors(cls, kb_id:str, connectors: list[dict], tenant_id:str): "kb_id": kb_id, "auto_parse": conn.get("auto_parse", "1") }) - SyncLogsService.schedule(conn_id, kb_id, reindex=True) + SyncLogsService.schedule(conn_id, kb_id, reindex=True, task_type=ConnectorTaskType.SYNC) + e, full_conn = ConnectorService.get_by_id(conn_id) + if e and (full_conn.config or {}).get("sync_deleted_files"): + SyncLogsService.schedule(conn_id, kb_id, task_type=ConnectorTaskType.PRUNE) errs = [] for conn_id in old_conn_ids: @@ -369,4 +536,3 @@ def list_connectors(cls, kb_id): cls.model.kb_id==kb_id ).dicts() ) - diff --git a/api/db/services/dialog_service.py b/api/db/services/dialog_service.py index 6f981efb5e6..66a6060d2a8 100644 --- a/api/db/services/dialog_service.py +++ b/api/db/services/dialog_service.py @@ -14,7 +14,6 @@ # limitations under the License. # import asyncio -import binascii import logging import re import time @@ -51,6 +50,7 @@ from rag.prompts.generator import chunks_format, citation_prompt, cross_languages, full_question, kb_prompt, keyword_extraction, message_fit_in, PROMPT_JINJA_ENV, ASK_SUMMARY from common.token_utils import num_tokens_from_string from rag.utils.tavily_conn import Tavily +from rag.utils.tts_cache import synthesize_with_cache from common.string_utils import remove_redundant_spaces from common import settings @@ -65,6 +65,53 @@ def _chunk_kb_id_for_doc(row_dict, kb_ids, doc_id): return kb_ids[0] return row_dict.get("kb_id") or row_dict.get("kb_id_kwd") + +async def _hydrate_chunk_vectors(retriever, chunks, tenant_ids, kb_ids): + """ + Citation prep: on the ES backend the main retrieval call deliberately + skips fetching the chunk embedding. insert_citations needs it, so we + pull the vectors for just the candidate chunks right before computing + answer-vs-chunk similarity. Chunks without an ES chunk_id (e.g. web + search results) keep whatever placeholder they were given. Other + backends still carry vectors in the chunk, so we skip the round-trip. + """ + if settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE: + return + if not chunks: + return + dim = 0 + for ck in chunks: + v = ck.get("vector") + if isinstance(v, list) and v: + dim = len(v) + break + if not dim: + return + # Skip chunks that already have a non-zero vector (e.g. parent chunks + # produced by retrieval_by_children copy the child vector inline). + chunk_ids = [] + for ck in chunks: + cid = ck.get("chunk_id") + if not cid: + continue + v = ck.get("vector") or [] + if any(x for x in v): + continue + chunk_ids.append(cid) + if not chunk_ids: + return + try: + vectors = await retriever.fetch_chunk_vectors(chunk_ids, tenant_ids, kb_ids, dim) + except Exception as e: # noqa: BLE001 - degrade gracefully on hydrate failure + logger.warning("fetch_chunk_vectors failed; citations will use placeholders: %s", e) + return + if not vectors: + return + for ck in chunks: + cid = ck.get("chunk_id") + if cid and cid in vectors: + ck["vector"] = vectors[cid] + def _normalize_internet_flag(value): if isinstance(value, bool): return value @@ -525,6 +572,7 @@ async def async_chat(dialog, messages, stream=True, **kwargs): check_llm_ts = timer() langfuse_tracer = None + langfuse_generation = None trace_context = {} langfuse_keys = TenantLangfuseService.filter_by_tenant(tenant_id=dialog.tenant_id) if langfuse_keys: @@ -735,8 +783,8 @@ async def callback(msg: str): if "max_tokens" in gen_conf: gen_conf["max_tokens"] = min(gen_conf["max_tokens"], max_tokens - used_token_count) - def decorate_answer(answer): - nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_tracer + async def decorate_answer(answer): + nonlocal embd_mdl, prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions, langfuse_generation refs = [] ans = answer.split("") @@ -749,6 +797,9 @@ def decorate_answer(answer): idx = set([]) normalized_answer = normalize_arabic_digits(answer) or "" if embd_mdl and not CITATION_MARKER_PATTERN.search(normalized_answer): + # Main retrieval no longer ships chunk vectors back from ES. + # Pull them on demand for the chunks we are about to cite. + await _hydrate_chunk_vectors(retriever, kbinfos.get("chunks", []), tenant_ids, dialog.kb_ids) answer, idx = retriever.insert_citations( answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], @@ -805,19 +856,35 @@ def decorate_answer(answer): f" - Token speed: {int(tk_num / (generate_result_time_cost / 1000.0))}/s" ) - # Add a condition check to call the end method only if langfuse_tracer exists - if langfuse_tracer and "langfuse_generation" in locals(): + # Add a condition check to call the end method only if langfuse_generation exists + if langfuse_generation is not None: langfuse_output = "\n" + re.sub(r"^.*?(### Query:.*)", r"\1", prompt, flags=re.DOTALL) langfuse_output = {"time_elapsed:": re.sub(r"\n", " \n", langfuse_output), "created_at": time.time()} - langfuse_generation.update(output=langfuse_output) + langfuse_generation.update( + output=langfuse_output, + usage_details={ + "input": used_token_count, + "output": tk_num, + "total": used_token_count + tk_num, + }, + ) langfuse_generation.end() return {"answer": think + answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()} if langfuse_tracer: - langfuse_generation = langfuse_tracer.start_generation( - trace_context=trace_context, name="chat", model=llm_model_config["llm_name"], input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg} - ) + try: + langfuse_generation = langfuse_tracer.start_observation( + as_type="generation", + trace_context=trace_context, + name="chat", + model=llm_model_config["llm_name"], + input={"prompt": prompt, "prompt4citation": prompt4citation, "messages": msg}, + ) + except Exception as e: # noqa: BLE001 - tracing must not break chat flow + logger.warning("Langfuse start_observation failed; continuing without tracing: %s", e) + langfuse_tracer = None + langfuse_generation = None if stream: if llm_type == "chat": @@ -834,7 +901,7 @@ def decorate_answer(answer): yield {"answer": value, "reference": {}, "audio_binary": tts(tts_mdl, value), "final": False} full_answer = last_state.full_text if last_state else "" if full_answer: - final = decorate_answer(_extract_visible_answer(thought + full_answer)) + final = await decorate_answer(_extract_visible_answer(thought + full_answer)) final["final"] = True final["audio_binary"] = None yield final @@ -845,7 +912,7 @@ def decorate_answer(answer): answer = await chat_mdl.async_chat(prompt + prompt4citation, msg[1:], gen_conf, images=image_files) user_content = msg[-1].get("content", "[content not available]") logging.debug("User: {}|Assistant: {}".format(user_content, answer)) - res = decorate_answer(answer) + res = await decorate_answer(answer) res["audio_binary"] = tts(tts_mdl, answer) yield res @@ -1370,14 +1437,7 @@ def tts(tts_mdl, text): text = clean_tts_text(text) if not text: return None - bin = b"" - try: - for chunk in tts_mdl.tts(text): - bin += chunk - except Exception as e: - logging.error(f"TTS failed: {e}, text={text!r}") - return None - return binascii.hexlify(bin).decode("utf-8") + return synthesize_with_cache(tts_mdl, text) class _ThinkStreamState: @@ -1535,8 +1595,11 @@ async def async_ask(question, kb_ids, tenant_id, chat_llm_name=None, search_conf msg = [{"role": "user", "content": question}] - def decorate_answer(answer): + async def decorate_answer(answer): nonlocal knowledges, kbinfos, sys_prompt + # Main retrieval no longer ships chunk vectors back from ES. Pull + # them on demand for the chunks we are about to cite. + await _hydrate_chunk_vectors(retriever, kbinfos.get("chunks", []), tenant_ids, kb_ids) answer, idx = retriever.insert_citations(answer, [ck["content_ltks"] for ck in kbinfos["chunks"]], [ck["vector"] for ck in kbinfos["chunks"]], embd_mdl, tkweight=0.7, vtweight=0.3) idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx]) recall_docs = [d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx] @@ -1563,7 +1626,7 @@ def decorate_answer(answer): continue yield {"answer": value, "reference": {}, "final": False} full_answer = last_state.full_text if last_state else "" - final = decorate_answer(_extract_visible_answer(full_answer)) + final = await decorate_answer(_extract_visible_answer(full_answer)) final["final"] = True yield final diff --git a/api/db/services/doc_metadata_service.py b/api/db/services/doc_metadata_service.py index 1cf887c2d3f..fbe32f9e5b7 100644 --- a/api/db/services/doc_metadata_service.py +++ b/api/db/services/doc_metadata_service.py @@ -385,14 +385,26 @@ def insert_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool: if result: logging.error(f"Failed to insert metadata for document {doc_id}: {result}") return False - # Force ES refresh to make metadata immediately available for search + # Force refresh so metadata is immediately searchable. + # Both Elasticsearch and OpenSearch backends expose refresh_idx; + # Infinity does not need a manual refresh. if not settings.DOC_ENGINE_INFINITY: - try: - settings.docStoreConn.es.indices.refresh(index=index_name) - logging.debug(f"Refreshed metadata index: {index_name}") - except Exception as e: - logging.warning(f"Failed to refresh metadata index {index_name}: {e}") - + refresh_idx = getattr(settings.docStoreConn, "refresh_idx", None) + if callable(refresh_idx): + if refresh_idx(index_name): + logging.debug(f"Refreshed metadata index: {index_name}") + else: + # A failed refresh can leave just-inserted metadata + # invisible to subsequent reads; surface it so operators + # can correlate stale-read complaints with the cause. + logging.warning( + f"Failed to refresh metadata index {index_name} on backend " + f"{type(settings.docStoreConn).__name__}; " + f"metadata may not be immediately searchable" + ) + else: + logging.debug(f"Backend {type(settings.docStoreConn).__name__} has no refresh_idx; skipping") + logging.debug(f"Successfully inserted metadata for document {doc_id}") return True @@ -436,7 +448,8 @@ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool: # Post-process to split combined values processed_meta = cls._split_combined_values(meta_fields) - logging.debug(f"[update_document_metadata] Updating doc_id: {doc_id}, kb_id: {kb_id}, meta_fields: {processed_meta}") + logging.debug( + f"[update_document_metadata] Updating doc_id: {doc_id}, kb_id: {kb_id}, meta_fields: {processed_meta}") # For Elasticsearch, use efficient partial update if not settings.DOC_ENGINE_INFINITY and not settings.DOC_ENGINE_OCEANBASE: @@ -444,7 +457,8 @@ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool: index_exists = settings.docStoreConn.index_exist(index_name, "") if not index_exists: # Index doesn't exist - create it and insert directly - logging.debug(f"[update_document_metadata] Index {index_name} does not exist, creating and inserting") + logging.debug( + f"[update_document_metadata] Index {index_name} does not exist, creating and inserting") result = settings.docStoreConn.create_doc_meta_idx(index_name) if result is False: logging.error(f"Failed to create metadata index {index_name}") @@ -459,23 +473,24 @@ def update_document_metadata(cls, doc_id: str, meta_fields: Dict) -> bool: [kb_id] ) if doc_exists: - # Document exists - replace meta_fields entirely - # Use upsert to fully replace the meta_fields field - # (ES update with doc parameter does deep merge on object fields, - # which would retain old keys that should be removed) - settings.docStoreConn.es.update( - index=index_name, - id=doc_id, - refresh=True, - body={ - "script": { - "source": "ctx._source.meta_fields = params.meta_fields", - "params": {"meta_fields": processed_meta} - } - } + # Document exists - replace meta_fields entirely. + # Using update with a `doc` body would deep-merge the meta_fields + # object and retain old keys that should be removed, so we delegate + # to a backend-provided scripted assignment that fully overwrites it. + replace_meta_fields = getattr(settings.docStoreConn, "replace_meta_fields", None) + if callable(replace_meta_fields) and replace_meta_fields(index_name, doc_id, processed_meta): + logging.debug( + f"Successfully updated metadata for document {doc_id} via {type(settings.docStoreConn).__name__}.replace_meta_fields") + return True + logging.warning( + f"replace_meta_fields unavailable or failed on backend " + f"{type(settings.docStoreConn).__name__}; falling back to delete+insert" ) - logging.debug(f"Successfully updated metadata for document {doc_id} using ES script update") - return True + # Mirror the Infinity fallback below so a failed scripted + # replace still guarantees full overwrite semantics rather + # than leaking through the "document not found" branch. + cls.delete_document_metadata(doc_id, kb_id, tenant_id) + return cls.insert_document_metadata(doc_id, processed_meta) except Exception as e: logging.debug(f"Document {doc_id} not found in index, will insert: {e}") @@ -525,7 +540,8 @@ def delete_document_metadata(cls, doc_id: str, kb_id: str, tenant_id: str = None # Check if metadata table exists before attempting deletion # This is the key optimization - no table = no metadata = nothing to delete if not settings.docStoreConn.index_exist(index_name, ""): - logging.debug(f"Metadata table {index_name} does not exist, skipping metadata deletion for document {doc_id}") + logging.debug( + f"Metadata table {index_name} does not exist, skipping metadata deletion for document {doc_id}") return True # No metadata to delete is considered success # Try to get the metadata to confirm it exists before deleting @@ -582,13 +598,18 @@ def _drop_empty_metadata_table(cls, index_name: str, tenant_id: str) -> None: logging.debug(f"[DROP EMPTY TABLE] Table {index_name} exists, checking if empty...") - # Use ES count API for accurate count - # Note: No need to refresh since delete operation already uses refresh=True + # Use the backend-native count primitive when available (ES + OS). + # No need to refresh since delete operation already uses refresh=True. + # The invocation lives inside the try/except so a future backend + # whose count_idx raises (instead of returning the -1 sentinel) + # still falls through to the search-based empty-table check. + count_idx = getattr(settings.docStoreConn, "count_idx", None) try: - count_response = settings.docStoreConn.es.count(index=index_name) - total_count = count_response['count'] - logging.debug(f"[DROP EMPTY TABLE] ES count API result: {total_count} documents") - is_empty = (total_count == 0) + count_value = count_idx(index_name) if callable(count_idx) else -1 + if count_value < 0: + raise RuntimeError("native count_idx unavailable or failed") + logging.debug(f"[DROP EMPTY TABLE] count_idx API result: {count_value} documents") + is_empty = (count_value == 0) except Exception as e: logging.warning(f"[DROP EMPTY TABLE] Count API failed, falling back to search: {e}") # Fallback to search if count fails @@ -610,7 +631,8 @@ def _drop_empty_metadata_table(cls, index_name: str, tenant_id: str) -> None: if isinstance(results, tuple) and len(results) == 2: # Infinity returns (DataFrame, int) df, total = results - logging.debug(f"[DROP EMPTY TABLE] Infinity format - total: {total}, df length: {len(df) if hasattr(df, '__len__') else 'N/A'}") + logging.debug( + f"[DROP EMPTY TABLE] Infinity format - total: {total}, df length: {len(df) if hasattr(df, '__len__') else 'N/A'}") is_empty = (total == 0 or (hasattr(df, '__len__') and len(df) == 0)) elif hasattr(results, 'get') and 'hits' in results: # ES format - MUST check this before hasattr(results, '__len__') @@ -774,52 +796,33 @@ def get_flatted_meta_by_kbs(cls, kb_ids: List[str]) -> Dict: @classmethod def filter_doc_ids_by_meta_pushdown( - cls, - kb_ids: List[str], - filters: List[Dict], - logic: str = "and", - limit: int = 10000, + cls, + kb_ids: List[str], + filters: List[Dict], + logic: str = "and", + limit: int = 10000, ) -> Optional[List[str]]: - """Run a metadata filter directly against ES, returning matching doc IDs. + """Run a metadata filter directly against ES or Infinity, returning matching doc IDs. Returns ``None`` to signal "push-down not viable, use the in-memory ``meta_filter`` fallback". Reasons for ``None``: - - Active doc store is not Elasticsearch (Infinity / OceanBase have - different filter semantics for the JSON ``meta_fields`` column). - - One of the user filters cannot be expressed in ES DSL. - - The ES request itself failed (network, mapping, missing index). + - kb_ids or filters is empty + - One of the user filters cannot be expressed in ES DSL or Infinity SQL + - The request itself failed (network, mapping, missing index) On success returns the deduplicated, ordered list of document IDs the - ES query matched. Callers can union or intersect this with their own + query matched. Callers can union or intersect this with their own base ``doc_ids`` rather than fetching the entire metadata table. """ - from common.metadata_es_filter import ( - UnsupportedMetaFilter, - build_meta_filter_query, - extract_doc_ids, - is_pushdown_supported, - ) - - if not kb_ids: - return [] - - if settings.DOC_ENGINE_INFINITY: - # Infinity stores ``meta_fields`` as a JSON column without dotted - # field access; the in-memory path is still the reliable answer. - return None - - es_client = getattr(settings.docStoreConn, "es", None) - if es_client is None: - return None - - if not is_pushdown_supported(filters): + if not kb_ids or not filters: + logging.debug("Metadata filter skipped: empty kb_ids or filters") return None try: kb = Knowledgebase.get_by_id(kb_ids[0]) except Exception as e: - logging.warning(f"[meta_pushdown] cannot resolve tenant for kb {kb_ids[0]}: {e}") + logging.warning(f"Metadata filter cannot resolve tenant for kb {kb_ids[0]}: {e}") return None if not kb: return None @@ -827,24 +830,48 @@ def filter_doc_ids_by_meta_pushdown( tenant_id = kb.tenant_id index_name = cls._get_doc_meta_index_name(tenant_id) - try: - if not settings.docStoreConn.index_exist(index_name, ""): - # No metadata index → no metadata-filtered docs. Returning an - # empty list (rather than ``None``) so callers don't bounce - # back to the in-memory path and re-query MySQL for nothing. - return [] - except Exception as e: - logging.warning(f"[meta_pushdown] index_exist check failed for {index_name}: {e}") + if not settings.docStoreConn.index_exist(index_name, ""): + return [] + + if settings.DOC_ENGINE_INFINITY: + return cls._filter_doc_ids_by_metadata_infinity( + index_name, kb_ids, filters, logic + ) + else: + return cls._filter_doc_ids_by_metadata_es( + index_name, kb_ids, filters, logic, limit + ) + + @classmethod + def _filter_doc_ids_by_metadata_es( + cls, + index_name: str, + kb_ids: List[str], + filters: List[Dict], + logic: str, + limit: int, + ) -> Optional[List[str]]: + """ES push-down path for metadata filtering.""" + from common.metadata_es_filter import ( + UnsupportedMetaFilter, + build_meta_filter_query, + extract_doc_ids, + is_pushdown_supported, + ) + + es_client = getattr(settings.docStoreConn, "es", None) + if es_client is None: + return None + + if not is_pushdown_supported(filters): return None try: query_body = build_meta_filter_query(filters, logic, kb_ids) except UnsupportedMetaFilter as e: - logging.debug(f"[meta_pushdown] falling back to in-memory: {e.reason}") + logging.error(f"ES build query failed: {e.reason}, filters={filters}") return None - # Only the doc id is needed downstream; trimming ``_source`` keeps the - # response small when the metadata blob is large. request_body = { **query_body, "size": limit, @@ -854,12 +881,10 @@ def filter_doc_ids_by_meta_pushdown( try: response = es_client.search(index=index_name, body=request_body) except Exception as e: - logging.warning(f"[meta_pushdown] ES query failed for {index_name}: {e}") + logging.error(f"ES metadata filter failed for {index_name}: {e}") return None doc_ids = extract_doc_ids(response if isinstance(response, dict) else dict(response)) - # Preserve order while removing duplicates so caller-side de-dupe stays - # cheap. seen: set[str] = set() unique: List[str] = [] for did in doc_ids: @@ -870,12 +895,52 @@ def filter_doc_ids_by_meta_pushdown( if len(unique) >= limit: logging.warning( - f"[meta_pushdown] hit limit {limit} for KBs {kb_ids}; some matches may be missing" + f"ES metadata filter hit limit {limit} for KBs {kb_ids}" ) - logging.debug(f"[meta_pushdown] {len(unique)} matches for KBs {kb_ids}") + logging.debug(f"ES metadata filter returned {len(unique)} matches for KBs {kb_ids}") return unique + @classmethod + def _filter_doc_ids_by_metadata_infinity( + cls, + index_name: str, + kb_ids: List[str], + filters: List[Dict], + logic: str, + ) -> Optional[List[str]]: + """Infinity push-down path for metadata filtering.""" + from common.metadata_infinity_filter import ( + build_infinity_filter, + extract_doc_ids, + is_pushdown_supported, + ) + + if not is_pushdown_supported(filters): + return None + + try: + sql_filter = build_infinity_filter(filters, logic) + escaped_kb_ids = [k.replace("'", "''") for k in kb_ids] + kb_filter = "kb_id IN (" + ", ".join([f"'{k}'" for k in escaped_kb_ids]) + ")" + where_clause = f"{kb_filter} AND {sql_filter}" + logging.debug(f"Infinity metadata filter: {where_clause}") + + inf_conn = settings.docStoreConn.connPool.get_conn() + try: + db_instance = inf_conn.get_database(settings.docStoreConn.dbName) + table_instance = db_instance.get_table(index_name) + df, _ = table_instance.output(["id"]).filter(where_clause).to_df() + doc_ids = extract_doc_ids(df) + logging.debug( + f"Infinity metadata filter returned {len(doc_ids)} doc IDs for kb_ids={kb_ids}, logic={logic}") + return doc_ids + finally: + settings.docStoreConn.connPool.release_conn(inf_conn) + except Exception: + logging.warning("Metadata filter push-down failed; falling back to in-memory filter", exc_info=True) + return None + @classmethod def get_metadata_keys_by_kbs(cls, kb_ids: List[str]) -> List[str]: """ @@ -938,7 +1003,8 @@ def get_metadata_for_documents(cls, doc_ids: Optional[List[str]], kb_id: str) -> if doc_meta: meta_mapping[doc_id] = doc_meta - logging.debug(f"[get_metadata_for_documents] Found metadata for {len(meta_mapping)}/{len(doc_ids) if doc_ids else 'all'} documents") + logging.debug( + f"[get_metadata_for_documents] Found metadata for {len(meta_mapping)}/{len(doc_ids) if doc_ids else 'all'} documents") return meta_mapping except Exception as e: @@ -964,6 +1030,7 @@ def get_metadata_summary(cls, kb_id: str, doc_ids=None) -> Dict: } } """ + def _is_time_string(value: str) -> bool: """Check if a string value is an ISO 8601 datetime (e.g., '2026-02-03T00:00:00').""" if not isinstance(value, str): @@ -1203,7 +1270,8 @@ def _apply_deletes(meta): doc_ids_set = set(doc_ids) missing_doc_ids = doc_ids_set - found_doc_ids if missing_doc_ids and updates: - logging.debug(f"[batch_update_metadata] Inserting new metadata for documents without metadata rows: {missing_doc_ids}") + logging.debug( + f"[batch_update_metadata] Inserting new metadata for documents without metadata rows: {missing_doc_ids}") for doc_id in missing_doc_ids: # Apply updates to create new metadata meta = {} diff --git a/api/db/services/document_service.py b/api/db/services/document_service.py index 7992cdb6105..348d8a3a604 100644 --- a/api/db/services/document_service.py +++ b/api/db/services/document_service.py @@ -88,7 +88,7 @@ def get_list(cls, kb_id, page_number, items_per_page, orderby, desc, keywords, i docs = docs.where(cls.model.name == name) if keywords: docs = docs.where(fn.LOWER(cls.model.name).contains(keywords.lower())) - if doc_ids: + if doc_ids is not None: docs = docs.where(cls.model.id.in_(doc_ids)) if suffix: docs = docs.where(cls.model.suffix.in_(suffix)) @@ -143,7 +143,7 @@ def get_by_kb_id(cls, kb_id, page_number, items_per_page, orderby, desc, keyword .join(User, on=(cls.model.created_by == User.id), join_type=JOIN.LEFT_OUTER) .where(cls.model.kb_id == kb_id) ) - if doc_ids: + if doc_ids is not None: docs = docs.where(cls.model.id.in_(doc_ids)) if run_status: docs = docs.where(cls.model.run.in_(run_status)) @@ -388,6 +388,35 @@ def list_doc_headers_by_kb_and_source_type(cls, kb_id, source_type, page_size=50 offset += page_size return res + @classmethod + @DB.connection_context() + def list_id_content_hash_map_by_kb_and_source_type(cls, kb_id, source_type, page_size=500): + """Return {doc_id: content_hash} for the connector's existing docs. + + Used by the fingerprint-bypass path to decide which keys can skip a + re-fetch -- if the connector's listing fingerprint equals content_hash, + the body hasn't changed since the last sync. + + Ordered by create_time so LIMIT/OFFSET pagination is stable under + concurrent writes; without this, page boundaries can drop or duplicate + rows and the resulting map would silently miss entries. + """ + fields = [cls.model.id, cls.model.content_hash] + docs = cls.model.select(*fields).where( + cls.model.kb_id == kb_id, + cls.model.source_type == source_type, + ).order_by(cls.model.create_time.asc()) + offset = 0 + result: dict[str, str] = {} + while True: + batch = list(docs.offset(offset).limit(page_size).dicts()) + if not batch: + break + for row in batch: + result[row["id"]] = row.get("content_hash") or "" + offset += page_size + return result + @classmethod @DB.connection_context() def get_all_docs_by_creator_id(cls, creator_id): @@ -426,7 +455,7 @@ def remove_document(cls, doc, tenant_id): chunk_index_name = search.index_name(tenant_id) chunk_index_exists = settings.docStoreConn.index_exist(chunk_index_name, doc.kb_id) - # Cancel all running tasks first Using preset function in task_service.py --- set cancel flag in Redis + # Cancel all running tasks first using preset function in task_service.py --- set cancel flag in Redis try: cancel_all_task_of(doc.id) logging.info(f"Cancelled all tasks for document {doc.id}") @@ -562,27 +591,84 @@ def get_unfinished_docs(cls): @classmethod @DB.connection_context() def increment_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration): - num = ( - cls.model.update(token_num=cls.model.token_num + token_num, chunk_num=cls.model.chunk_num + chunk_num, process_duration=cls.model.process_duration + duration) - .where(cls.model.id == doc_id) - .execute() - ) - if num == 0: - logging.warning("Document not found which is supposed to be there") - num = Knowledgebase.update(token_num=Knowledgebase.token_num + token_num, chunk_num=Knowledgebase.chunk_num + chunk_num).where(Knowledgebase.id == kb_id).execute() + """Atomically add chunk/token counters on the document and its knowledge base.""" + with DB.atomic(): + num = ( + cls.model.update( + token_num=cls.model.token_num + token_num, + chunk_num=cls.model.chunk_num + chunk_num, + process_duration=cls.model.process_duration + duration, + ) + .where((cls.model.id == doc_id) & (cls.model.kb_id == kb_id)) + .execute() + ) + if num == 0: + logging.error( + "increment_chunk_num: no document matched doc_id=%s kb_id=%s " + "token_num=%s chunk_num=%s duration=%s", + doc_id, + kb_id, + token_num, + chunk_num, + duration, + ) + raise LookupError("Document not found which is supposed to be there") + num = ( + Knowledgebase.update( + token_num=Knowledgebase.token_num + token_num, + chunk_num=Knowledgebase.chunk_num + chunk_num, + ) + .where(Knowledgebase.id == kb_id) + .execute() + ) + if num == 0: + logging.error( + "increment_chunk_num: no knowledgebase matched kb_id=%s for doc_id=%s " + "token_num=%s chunk_num=%s duration=%s", + kb_id, + doc_id, + token_num, + chunk_num, + duration, + ) + raise LookupError("Knowledgebase not found which is supposed to be there") return num @classmethod @DB.connection_context() def decrement_chunk_num(cls, doc_id, kb_id, token_num, chunk_num, duration): - num = ( - cls.model.update(token_num=cls.model.token_num - token_num, chunk_num=cls.model.chunk_num - chunk_num, process_duration=cls.model.process_duration + duration) - .where(cls.model.id == doc_id) - .execute() - ) - if num == 0: - raise LookupError("Document not found which is supposed to be there") - num = Knowledgebase.update(token_num=Knowledgebase.token_num - token_num, chunk_num=Knowledgebase.chunk_num - chunk_num).where(Knowledgebase.id == kb_id).execute() + """Atomically subtract chunk/token counters on the document and its knowledge base.""" + with DB.atomic(): + num = ( + cls.model.update( + token_num=cls.model.token_num - token_num, + chunk_num=cls.model.chunk_num - chunk_num, + process_duration=cls.model.process_duration + duration, + ) + .where((cls.model.id == doc_id) & (cls.model.kb_id == kb_id)) + .execute() + ) + if num == 0: + raise LookupError("Document not found which is supposed to be there") + num = ( + Knowledgebase.update( + token_num=Knowledgebase.token_num - token_num, + chunk_num=Knowledgebase.chunk_num - chunk_num, + ) + .where(Knowledgebase.id == kb_id) + .execute() + ) + if num == 0: + logging.error( + "decrement_chunk_num: no knowledgebase matched kb_id=%s for doc_id=%s " + "token_num=%s chunk_num=%s duration=%s", + kb_id, + doc_id, + token_num, + chunk_num, + duration, + ) + raise LookupError("Knowledgebase not found which is supposed to be there") return num @classmethod @@ -623,7 +709,7 @@ def delete_document_and_update_kb_counts(cls, doc_id) -> bool: def clear_chunk_num(cls, doc_id): """Deprecated: use delete_document_and_update_kb_counts instead.""" doc = cls.model.get_by_id(doc_id) - assert doc, "Can't fine document in database." + assert doc, "Can't find document in database." num = ( Knowledgebase.update(token_num=Knowledgebase.token_num - doc.token_num, chunk_num=Knowledgebase.chunk_num - doc.chunk_num, doc_num=Knowledgebase.doc_num - 1) @@ -636,7 +722,7 @@ def clear_chunk_num(cls, doc_id): @DB.connection_context() def clear_chunk_num_when_rerun(cls, doc_id): doc = cls.model.get_by_id(doc_id) - assert doc, "Can't fine document in database." + assert doc, "Can't find document in database." num = ( Knowledgebase.update( @@ -978,11 +1064,13 @@ def run(cls, tenant_id: str, doc: dict, kb_table_num_map: dict): queue_tasks(doc, bucket, name, 0) -def queue_raptor_o_graphrag_tasks(sample_doc, ty, priority, fake_doc_id="", doc_ids=[]): +def queue_raptor_o_graphrag_tasks(sample_doc, ty, priority, fake_doc_id="", doc_ids=None): """ You can provide a fake_doc_id to bypass the restriction of tasks at the knowledgebase level. Optionally, specify a list of doc_ids to determine which documents participate in the task. """ + if doc_ids is None: + doc_ids = [] assert ty in ["graphrag", "raptor", "mindmap"], "type should be graphrag, raptor or mindmap" chunking_config = DocumentService.get_chunking_config(sample_doc["id"]) diff --git a/api/db/services/file_service.py b/api/db/services/file_service.py index db8ae4b72f5..7c5945d8afd 100644 --- a/api/db/services/file_service.py +++ b/api/db/services/file_service.py @@ -482,7 +482,12 @@ def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str err.append(file.filename + ": " + user_msg) continue blob = file.read() - new_hash = xxhash.xxh128(blob).hexdigest() + # Connector-supplied fingerprint (e.g. xxhash128(S3 ETag)) + # takes precedence: for connector-sourced docs the bypass + # path uses the fingerprint as content_hash, so reverting + # to xxhash128(blob) here would defeat it. + incoming_fp = getattr(file, "fingerprint", None) + new_hash = incoming_fp or xxhash.xxh128(blob).hexdigest() old_hash = doc.content_hash or "" settings.STORAGE_IMPL.put(kb.id, doc.location, blob, kb.tenant_id) doc.size = len(blob) @@ -518,6 +523,7 @@ def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str thumbnail_location = f"thumbnail_{doc_id}.png" settings.STORAGE_IMPL.put(kb.id, thumbnail_location, img) + incoming_fp = getattr(file, "fingerprint", None) doc = { "id": doc_id, "kb_id": kb.id, @@ -532,7 +538,7 @@ def upload_document(self, kb, file_objs, user_id, src="local", parent_path: str "location": location, "size": len(blob), "thumbnail": thumbnail_location, - "content_hash": xxhash.xxh128(blob).hexdigest(), + "content_hash": incoming_fp or xxhash.xxh128(blob).hexdigest(), } DocumentService.insert(doc) @@ -555,14 +561,14 @@ def list_all_files_by_parent_id(cls, parent_id): @staticmethod def parse_docs(file_objs, user_id): - exe = ThreadPoolExecutor(max_workers=12) - threads = [] - for file in file_objs: - threads.append(exe.submit(FileService.parse, file.filename, file.read(), False)) + with ThreadPoolExecutor(max_workers=12) as exe: + threads = [] + for file in file_objs: + threads.append(exe.submit(FileService.parse, file.filename, file.read(), False)) - res = [] - for th in threads: - res.append(th.result()) + res = [] + for th in threads: + res.append(th.result()) return "\n\n".join(res) @@ -699,7 +705,7 @@ def structured(filename, filetype, blob, content_type): # Pre-resolve the full redirect chain so that AsyncWebCrawler never # follows a server-sent redirect to an unvalidated (potentially - # internal) host. Each hop is SSRF-checked before being followed; + # internal) host. Each hop is SSRF-checked before being followed; # the validated (hostname, ip) pairs are pinned via Chromium's # --host-resolver-rules so the browser cannot re-resolve any of them # through a fresh DNS query. @@ -735,7 +741,7 @@ def structured(filename, filetype, blob, content_type): ) # Build a single MAP rule string covering every validated hostname - # in the redirect chain. Chromium uses the pinned IP for each, + # in the redirect chain. Chromium uses the pinned IP for each, # skipping DNS entirely and eliminating the rebinding window. _map_rules = ",".join(f"MAP {h} {ip}" for h, ip in host_pins.items()) @@ -787,19 +793,19 @@ def get_files(files: Union[None, list[dict]], raw: bool = False, layout_recogniz def image_to_base64(file): return "data:{};base64,{}".format(file["mime_type"], base64.b64encode(FileService.get_blob(file["created_by"], file["id"])).decode("utf-8")) - exe = ThreadPoolExecutor(max_workers=5) - threads = [] - imgs = [] - for file in files: - if file["mime_type"].find("image") >=0: - if raw: - imgs.append(FileService.get_blob(file["created_by"], file["id"])) - else: - threads.append(exe.submit(image_to_base64, file)) - continue - threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"], layout_recognize)) - - if raw: - return [th.result() for th in threads], imgs - else: - return [th.result() for th in threads] + with ThreadPoolExecutor(max_workers=5) as exe: + threads = [] + imgs = [] + for file in files: + if file["mime_type"].find("image") >=0: + if raw: + imgs.append(FileService.get_blob(file["created_by"], file["id"])) + else: + threads.append(exe.submit(image_to_base64, file)) + continue + threads.append(exe.submit(FileService.parse, file["name"], FileService.get_blob(file["created_by"], file["id"]), True, file["created_by"], layout_recognize)) + + if raw: + return [th.result() for th in threads], imgs + else: + return [th.result() for th in threads] diff --git a/api/db/services/knowledgebase_service.py b/api/db/services/knowledgebase_service.py index a164287fa4e..d6bb9e1db13 100644 --- a/api/db/services/knowledgebase_service.py +++ b/api/db/services/knowledgebase_service.py @@ -48,6 +48,25 @@ class KnowledgebaseService(CommonService): """ model = Knowledgebase + @classmethod + def _visibility_and_status_filter(cls, joined_tenant_ids, user_id): + """ + Build a Peewee filter expression representing knowledgebase visibility + for a given user, combined with a valid-status constraint. + + Visibility rules: + - Team KBs (`permission == TenantPermission.TEAM`) owned by any tenant in `joined_tenant_ids` + - KBs owned by the current user (`tenant_id == user_id`) + Always constrained to `StatusEnum.VALID`. + """ + return ( + ( + (cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == TenantPermission.TEAM.value)) + | (cls.model.tenant_id == user_id) + ) + & (cls.model.status == StatusEnum.VALID.value) + ) + @classmethod @DB.connection_context() def accessible4deletion(cls, kb_id, user_id): @@ -169,18 +188,12 @@ def get_by_tenant_ids(cls, joined_tenant_ids, user_id, ] if keywords: kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where( - ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == - TenantPermission.TEAM.value)) | ( - cls.model.tenant_id == user_id)) - & (cls.model.status == StatusEnum.VALID.value), - (fn.LOWER(cls.model.name).contains(keywords.lower())) + cls._visibility_and_status_filter(joined_tenant_ids, user_id), + fn.LOWER(cls.model.name).contains(keywords.lower()), ) else: kbs = cls.model.select(*fields).join(User, on=(cls.model.tenant_id == User.id)).where( - ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == - TenantPermission.TEAM.value)) | ( - cls.model.tenant_id == user_id)) - & (cls.model.status == StatusEnum.VALID.value) + cls._visibility_and_status_filter(joined_tenant_ids, user_id), ) if parser_id: kbs = kbs.where(cls.model.parser_id == parser_id) @@ -213,11 +226,7 @@ def get_all_kb_by_tenant_ids(cls, tenant_ids, user_id): cls.model.update_date ] # find team kb and owned kb - kbs = cls.model.select(*fields).where( - (cls.model.tenant_id.in_(tenant_ids) & (cls.model.permission ==TenantPermission.TEAM.value)) | ( - cls.model.tenant_id == user_id - ) - ) + kbs = cls.model.select(*fields).where(cls._visibility_and_status_filter(tenant_ids, user_id)) # sort by create_time asc kbs.order_by(cls.model.create_time.asc()) # maybe cause slow query by deep paginate, optimize later. @@ -459,12 +468,7 @@ def get_list(cls, joined_tenant_ids, user_id, if parser_id: kbs = kbs.where(cls.model.parser_id == parser_id) - kbs = kbs.where( - ((cls.model.tenant_id.in_(joined_tenant_ids) & (cls.model.permission == - TenantPermission.TEAM.value)) | ( - cls.model.tenant_id == user_id)) - & (cls.model.status == StatusEnum.VALID.value) - ) + kbs = kbs.where(cls._visibility_and_status_filter(joined_tenant_ids, user_id)) if desc: kbs = kbs.order_by(cls.model.getter_by(orderby).desc()) diff --git a/api/db/services/llm_service.py b/api/db/services/llm_service.py index 60090bb0409..9b6b5bd4f1e 100644 --- a/api/db/services/llm_service.py +++ b/api/db/services/llm_service.py @@ -97,7 +97,24 @@ def encode(self, texts: list): generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode", model=self.model_config["llm_name"], input={"texts": texts}) safe_texts = [] - for text in texts: + for idx, text in enumerate(texts): + # Embedding APIs (OpenAI-compatible, Zhipu, etc.) reject empty or + # whitespace-only inputs with errors like "Input at index N cannot + # be empty or whitespace only". Upstream parsers can produce such + # chunks — e.g. when OCR/vision on an embedded DOCX image returns + # nothing, or a table has only empty cells — so coerce to a safe + # placeholder here, at the single boundary every embedding path + # funnels through. + if text is None or not str(text).strip(): + marker = "None" if text is None else "whitespace-only" + logging.warning( + "LLMBundle.encode: empty input at index %d (%s) coerced to placeholder 'None' for model %s", + idx, + marker, + self.model_config["llm_name"], + ) + safe_texts.append("None") + continue token_size = num_tokens_from_string(text) if token_size > self.max_length: target_len = int(self.max_length * 0.95) @@ -121,6 +138,14 @@ def encode_queries(self, query: str): if self.langfuse: generation = self.langfuse.start_observation(trace_context=self.trace_context, as_type="generation", name="encode_queries", model=self.model_config["llm_name"], input={"query": query}) + if query is None or not str(query).strip(): + marker = "None" if query is None else "whitespace-only" + logging.warning( + "LLMBundle.encode_queries: empty query (%s) coerced to placeholder 'None' for model %s", + marker, + self.model_config["llm_name"], + ) + query = "None" emd, used_tokens = self.mdl.encode_queries(query) if self.model_config["llm_factory"] == "Builtin": logging.info("LLMBundle.encode_queries query: {}, emd len: {}, used_tokens: {}. Builtin model don't need to update token usage".format(query, len(emd), used_tokens)) diff --git a/api/db/services/system_settings_service.py b/api/db/services/system_settings_service.py index eac7019e6a1..0b0bde80242 100644 --- a/api/db/services/system_settings_service.py +++ b/api/db/services/system_settings_service.py @@ -26,7 +26,13 @@ class SystemSettingsService(CommonService): @classmethod @DB.connection_context() def get_by_name(cls, name): - objs = cls.model.select().where(cls.model.name == name) + objs = cls.model.select().where(cls.model.name == name).order_by(cls.model.name.asc()) + return objs + + @classmethod + @DB.connection_context() + def get_by_name_prefix(cls, name_prefix): + objs = cls.model.select().where(cls.model.name.startswith(name_prefix)).order_by(cls.model.name.asc()) return objs @classmethod diff --git a/api/db/services/tenant_llm_service.py b/api/db/services/tenant_llm_service.py index ee2eab6648a..f14f97fcef6 100644 --- a/api/db/services/tenant_llm_service.py +++ b/api/db/services/tenant_llm_service.py @@ -24,7 +24,6 @@ from api.db.services.common_service import CommonService from api.db.services.langfuse_service import TenantLangfuseService from api.db.services.user_service import TenantService -from rag.llm import ChatModel, CvModel, EmbeddingModel, OcrModel, RerankModel, Seq2txtModel, TTSModel class LLMFactoriesService(CommonService): @@ -183,6 +182,8 @@ def get_model_config(cls, tenant_id, llm_type, llm_name=None): def model_instance(cls, model_config: dict, lang="Chinese", **kwargs): if not model_config: raise LookupError("Model config is required") + from rag.llm import ChatModel, CvModel, EmbeddingModel, OcrModel, RerankModel, Seq2txtModel, TTSModel + kwargs.update({"provider": model_config["llm_factory"]}) api_key = model_config.get("api_key_payload", model_config["api_key"]) if model_config["model_type"] == LLMType.EMBEDDING.value: diff --git a/api/ragflow_server.py b/api/ragflow_server.py index af4720218fc..5c850092f37 100644 --- a/api/ragflow_server.py +++ b/api/ragflow_server.py @@ -19,8 +19,13 @@ import time start_ts = time.time() -import logging import os + +# LiteLLM fetches a model cost map from GitHub during import unless this is set. +# The API server should not block startup on external network access. +os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") + +import logging import signal import sys import threading diff --git a/api/utils/api_utils.py b/api/utils/api_utils.py index a041ee0819f..5e034f0c509 100644 --- a/api/utils/api_utils.py +++ b/api/utils/api_utils.py @@ -439,6 +439,7 @@ def get_parser_config(chunk_method, parser_config): "category", ], "method": "light", + "batch_chunk_token_size": 4096, }, "parent_child": { "use_parent_child": False, diff --git a/api/utils/configs.py b/api/utils/configs.py index 91baa28e36e..c3abc13c37f 100644 --- a/api/utils/configs.py +++ b/api/utils/configs.py @@ -18,7 +18,6 @@ import base64 import pickle from api.utils.common import bytes_to_string, string_to_bytes -from common.config_utils import get_base_config safe_module = { 'numpy', @@ -54,8 +53,4 @@ def deserialize_b64(src): src = base64.b64decode( string_to_bytes(src) if isinstance( src, str) else src) - use_deserialize_safe_module = get_base_config( - 'use_deserialize_safe_module', False) - if use_deserialize_safe_module: - return restricted_loads(src) - return pickle.loads(src) + return restricted_loads(src) diff --git a/api/utils/file_utils.py b/api/utils/file_utils.py index 857cf17381d..21b746f8f18 100644 --- a/api/utils/file_utils.py +++ b/api/utils/file_utils.py @@ -107,23 +107,21 @@ def thumbnail_img(filename, blob): if re.match(r".*\.pdf$", filename): try: with sys.modules[LOCK_KEY_pdfplumber]: - pdf = pdfplumber.open(BytesIO(blob)) - if not pdf.pages: - pdf.close() - return None - buffered = BytesIO() - resolution = 32 - img = None - for _ in range(10): - pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png") - img = buffered.getvalue() - if len(img) >= 64000 and resolution >= 2: - resolution = resolution / 2 - buffered = BytesIO() - else: - break - pdf.close() - return img + with pdfplumber.open(BytesIO(blob)) as pdf: + if not pdf.pages: + return None + buffered = BytesIO() + resolution = 32 + img = None + for _ in range(10): + pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png") + img = buffered.getvalue() + if len(img) >= 64000 and resolution >= 2: + resolution = resolution / 2 + buffered = BytesIO() + else: + break + return img except Exception: return None diff --git a/api/utils/validation_utils.py b/api/utils/validation_utils.py index 94e0fa2ab83..861f94ee228 100644 --- a/api/utils/validation_utils.py +++ b/api/utils/validation_utils.py @@ -327,10 +327,14 @@ def validate_uuid1_hex(v: Any) -> str: class Base(BaseModel): + """Strict base model that rejects unknown request fields.""" + model_config = ConfigDict(extra="forbid", strict=True) class RaptorConfig(Base): + """Dataset parser configuration for RAPTOR summary generation.""" + use_raptor: Annotated[bool, Field(default=False)] prompt: Annotated[ str, @@ -344,19 +348,26 @@ class RaptorConfig(Base): max_cluster: Annotated[int, Field(default=64, ge=1, le=1024)] random_seed: Annotated[int, Field(default=0, ge=0)] scope: Annotated[Literal["file", "dataset"], Field(default="file")] + clustering_method: Annotated[Literal["gmm", "ahc"], Field(default="gmm")] + tree_builder: Annotated[Literal["raptor", "psi"], Field(default="raptor")] auto_disable_for_structured_data: Annotated[bool, Field(default=True)] ext: Annotated[dict, Field(default={})] class GraphragConfig(Base): + """Dataset parser configuration for GraphRAG generation.""" + use_graphrag: Annotated[bool, Field(default=False)] entity_types: Annotated[list[str], Field(default_factory=lambda: ["organization", "person", "geo", "event", "category"])] - method: Annotated[Literal["light", "general"], Field(default="light")] + method: Annotated[Literal["light", "general", "ner"], Field(default="light")] community: Annotated[bool, Field(default=False)] resolution: Annotated[bool, Field(default=False)] + batch_chunk_token_size: Annotated[int, Field(default=4096, ge=512, le=8196)] class ParentChildConfig(Base): + """Dataset parser configuration for parent-child chunking.""" + use_parent_child: Annotated[bool, Field(default=False)] children_delimiter: Annotated[str, Field(default=r"\n", min_length=1)] @@ -377,7 +388,12 @@ class AutoMetadataConfig(Base): built_in_metadata: Annotated[list[AutoMetadataField], Field(default_factory=list)] +TableColumnRole = Literal["indexing", "metadata", "both"] + + class ParserConfig(Base): + """Complete parser configuration accepted by dataset APIs.""" + auto_keywords: Annotated[int, Field(default=0, ge=0, le=32)] auto_questions: Annotated[int, Field(default=0, ge=0, le=10)] chunk_token_num: Annotated[int, Field(default=512, ge=1, le=2048)] @@ -393,6 +409,25 @@ class ParserConfig(Base): task_page_size: Annotated[int | None, Field(default=None, ge=1)] pages: Annotated[list[list[int]] | None, Field(default=None)] ext: Annotated[dict, Field(default={})] + # Table parser: column name -> "indexing" | "metadata" | "both". Absence => all columns "both". + # Table parser: "auto" = all columns both (default), "manual" = use table_column_roles. None → treated as "auto". + table_column_mode: Annotated[Literal["auto", "manual"] | None, Field(default=None)] + # Table parser: column name -> "indexing" | "metadata" | "both". Used only when table_column_mode == "manual". + table_column_roles: Annotated[dict[str, TableColumnRole] | None, Field(default=None)] + # Table parser: list of column names (set by backend after first parse; used by frontend for role selector). + table_column_names: Annotated[list[str] | None, Field(default=None)] + + @field_validator("table_column_roles", mode="before") + @classmethod + def legacy_vectorize_table_column_role(cls, v: Any) -> Any: + """Normalize legacy role value *vectorize* to *indexing* (chunk text + full-text search).""" + if v is None or not isinstance(v, dict): + return v + out: dict[str, Any] = {} + for key, val in v.items(): + k = key if isinstance(key, str) else str(key) + out[k] = "indexing" if val == "vectorize" else val + return out class UpdateDocumentReq(Base): @@ -417,6 +452,7 @@ class UpdateDocumentReq(Base): @field_validator("chunk_method", mode="after") @classmethod def validate_document_chunk_method(cls, chunk_method: str | None): + """Validate an optional document parser method.""" if chunk_method: # Validate chunk method if present valid_chunk_method = {"naive", "manual", "qa", "table", "paper", "book", "laws", "presentation", "picture", "one", "knowledge_graph", "email", "tag"} @@ -428,6 +464,7 @@ def validate_document_chunk_method(cls, chunk_method: str | None): @field_validator("enabled", mode="after") @classmethod def validate_document_enabled(cls, enabled: str | None): + """Validate the optional enabled flag.""" if enabled: converted = int(enabled) if converted < 0 or converted > 1: @@ -438,6 +475,7 @@ def validate_document_enabled(cls, enabled: str | None): @field_validator("meta_fields", mode="after") @classmethod def validate_document_meta_fields(cls, meta_fields: dict | None): + """Validate user-provided document metadata values.""" if meta_fields is None: return None @@ -453,6 +491,8 @@ def validate_document_meta_fields(cls, meta_fields: dict | None): class CreateDatasetReq(Base): + """Request model for creating a dataset.""" + name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(...)] avatar: Annotated[str | None, Field(default=None, max_length=65535)] description: Annotated[str | None, Field(default=None, max_length=65535)] @@ -468,6 +508,7 @@ class CreateDatasetReq(Base): @field_validator("pipeline_id", mode="before") @classmethod def handle_pipeline_id(cls, v: str | None, info: ValidationInfo): + """Drop pipeline_id when parse_type selects direct parser mode.""" if v is None: return v if info.data.get("parse_type", 0) == 1: @@ -721,6 +762,8 @@ def validate_chunk_method(cls, v: Any, handler, info: ValidationInfo) -> Any: class UpdateDatasetReq(CreateDatasetReq): + """Request model for updating a dataset.""" + dataset_id: Annotated[str, Field(...)] name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=DATASET_NAME_LIMIT), Field(default="")] pagerank: Annotated[int, Field(default=0, ge=0, le=100)] @@ -730,10 +773,13 @@ class UpdateDatasetReq(CreateDatasetReq): @field_validator("dataset_id", mode="before") @classmethod def validate_dataset_id(cls, v: Any) -> str: + """Validate and normalize the dataset id.""" return validate_uuid1_hex(v) class DeleteReq(Base): + """Base request model for batch delete APIs.""" + ids: Annotated[list[str] | None, Field(default=None)] delete_all: Annotated[bool, Field(default=False)] @@ -811,10 +857,15 @@ def validate_ids(cls, v_list: list[str] | None) -> list[str] | None: return ids_list -class DeleteDatasetReq(DeleteReq): ... +class DeleteDatasetReq(DeleteReq): + """Request model for deleting datasets.""" + + ... class DeleteDocumentReq(DeleteReq): + """Request model for deleting documents.""" + @field_validator("ids", mode="after") @classmethod def validate_ids(cls, v_list: list[str] | None) -> list[str] | None: @@ -840,6 +891,8 @@ def validate_ids(cls, v_list: list[str] | None) -> list[str] | None: class SearchDatasetReq(BaseModel): + """Request model for searching one dataset.""" + model_config = ConfigDict(extra="ignore") question: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1), Field(...)] @@ -859,6 +912,8 @@ class SearchDatasetReq(BaseModel): class SearchDatasetsReq(BaseModel): + """Request model for searching multiple datasets.""" + model_config = ConfigDict(extra="ignore") dataset_ids: Annotated[list[str], Field(..., min_length=1)] @@ -874,11 +929,13 @@ class SearchDatasetsReq(BaseModel): keyword: Annotated[bool, Field(default=False)] search_id: Annotated[str | None, Field(default=None)] rerank_id: Annotated[str | None, Field(default=None)] - tenant_rerank_id: Annotated[str | None, Field(default=None)] + tenant_rerank_id: Annotated[int | None, Field(default=None)] meta_data_filter: Annotated[dict | None, Field(default=None)] class BaseListReq(BaseModel): + """Shared pagination and sorting fields for list APIs.""" + model_config = ConfigDict(extra="forbid") id: Annotated[str | None, Field(default=None)] @@ -891,10 +948,13 @@ class BaseListReq(BaseModel): @field_validator("id", mode="before") @classmethod def validate_id(cls, v: Any) -> str: + """Validate and normalize an optional list filter id.""" return validate_uuid1_hex(v) class ListDatasetReq(BaseListReq): + """Request model for listing datasets.""" + include_parsing_status: Annotated[bool, Field(default=False)] ext: Annotated[dict, Field(default={})] @@ -903,22 +963,29 @@ class ListDatasetReq(BaseListReq): class CreateFolderReq(Base): + """Request model for creating a folder.""" + name: Annotated[str, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(...)] parent_id: Annotated[str | None, Field(default=None)] type: Annotated[str | None, Field(default=None)] class DeleteFileReq(Base): + """Request model for deleting files.""" + ids: Annotated[list[str], Field(min_length=1)] class MoveFileReq(Base): + """Request model for moving or renaming files.""" + src_file_ids: Annotated[list[str], Field(min_length=1)] dest_file_id: Annotated[str | None, Field(default=None)] new_name: Annotated[str | None, StringConstraints(strip_whitespace=True, min_length=1, max_length=255), Field(default=None)] @model_validator(mode="after") def check_operation(self): + """Require either a destination folder or a new file name.""" if not self.dest_file_id and not self.new_name: raise ValueError("At least one of dest_file_id or new_name must be provided") if self.new_name and len(self.src_file_ids) > 1: @@ -927,6 +994,8 @@ def check_operation(self): class ListFileReq(BaseModel): + """Request model for listing files.""" + model_config = ConfigDict(extra="forbid") parent_id: Annotated[str | None, Field(default=None)] diff --git a/api/utils/web_utils.py b/api/utils/web_utils.py index 23d2421862d..e7c1b48f513 100644 --- a/api/utils/web_utils.py +++ b/api/utils/web_utils.py @@ -173,6 +173,9 @@ def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_opt try: WebDriverWait(driver, timeout).until(staleness_of(driver.find_element(by=By.TAG_NAME, value="html"))) except TimeoutException: + pass + + try: calculated_print_options = { "landscape": False, "displayHeaderFooter": False, @@ -181,8 +184,9 @@ def __get_pdf_from_html(path: str, timeout: int, install_driver: bool, print_opt } calculated_print_options.update(print_options) result = __send_devtools(driver, "Page.printToPDF", calculated_print_options) - driver.quit() return base64.b64decode(result["data"]) + finally: + driver.quit() def is_valid_url(url: str) -> bool: diff --git a/build.sh b/build.sh index 13cbb263431..349ac645fa1 100755 --- a/build.sh +++ b/build.sh @@ -16,6 +16,7 @@ CPP_DIR="$PROJECT_ROOT/internal/cpp" BUILD_DIR="$CPP_DIR/cmake-build-release" RAGFLOW_SERVER_BINARY="$PROJECT_ROOT/bin/server_main" ADMIN_SERVER_BINARY="$PROJECT_ROOT/bin/admin_server" +RAGFLOW_CLI_BINARY="$PROJECT_ROOT/bin/ragflow_cli" echo -e "${GREEN}=== RAGFlow Go Server Build Script ===${NC}" @@ -73,7 +74,7 @@ build_cpp() { # Build Go server build_go() { - print_section "Building Go server" + print_section "Building RAGFlow go" cd "$PROJECT_ROOT" @@ -91,9 +92,10 @@ build_go() { sudo apt -y install libpcre2-dev fi - echo "Building API server binary: $RAGFLOW_SERVER_BINARY and $ADMIN_SERVER_BINARY" - GOPROXY=${GOPROXY:-https://goproxy.cn,https://proxy.golang.org,direct} CGO_ENABLED=1 go build -o "$RAGFLOW_SERVER_BINARY" ./cmd/server_main.go - GOPROXY=${GOPROXY:-https://goproxy.cn,https://proxy.golang.org,direct} CGO_ENABLED=1 go build -o "$ADMIN_SERVER_BINARY" ./cmd/admin_server.go + echo "Building RAGFlow binary: $RAGFLOW_SERVER_BINARY, $ADMIN_SERVER_BINARY, and $RAGFLOW_CLI_BINARY" + GOPROXY=${GOPROXY:-https://goproxy.cn,https://proxy.golang.org,direct} CGO_ENABLED=1 go build -o "$RAGFLOW_SERVER_BINARY" cmd/server_main.go + GOPROXY=${GOPROXY:-https://goproxy.cn,https://proxy.golang.org,direct} CGO_ENABLED=1 go build -o "$ADMIN_SERVER_BINARY" cmd/admin_server.go + GOPROXY=${GOPROXY:-https://goproxy.cn,https://proxy.golang.org,direct} CGO_ENABLED=1 go build -o "$RAGFLOW_CLI_BINARY" cmd/ragflow_cli.go if [ ! -f "$RAGFLOW_SERVER_BINARY" ]; then echo -e "${RED}Error: Failed to build RAGFlow server binary${NC}" @@ -105,8 +107,9 @@ build_go() { exit 1 fi - echo -e "${GREEN}✓ Go server_main built successfully: $RAGFLOW_SERVER_BINARY${NC}" + echo -e "${GREEN}✓ Go ragflow_server built successfully: $RAGFLOW_SERVER_BINARY${NC}" echo -e "${GREEN}✓ Go admin_server built successfully: $ADMIN_SERVER_BINARY${NC}" + echo -e "${GREEN}✓ Go ragflow_cli built successfully: $RAGFLOW_CLI_BINARY${NC}" } # Clean build artifacts diff --git a/cmd/server_main.go b/cmd/server_main.go index e4a634e72af..ab14dacd873 100644 --- a/cmd/server_main.go +++ b/cmd/server_main.go @@ -166,8 +166,7 @@ func startServer(config *server.Config) { // Initialize service layer userService := service.NewUserService() documentService := service.NewDocumentService() - datasetsService := service.NewDatasetsService() - kbService := service.NewKnowledgebaseService() + datasetsService := service.NewDatasetService() chunkService := service.NewChunkService() llmService := service.NewLLMService() tenantService := service.NewTenantService() @@ -187,10 +186,9 @@ func startServer(config *server.Config) { authHandler := handler.NewAuthHandler() userHandler := handler.NewUserHandler(userService) tenantHandler := handler.NewTenantHandler(tenantService, userService) - documentHandler := handler.NewDocumentHandler(documentService) + documentHandler := handler.NewDocumentHandler(documentService, datasetsService) datasetsHandler := handler.NewDatasetsHandler(datasetsService) systemHandler := handler.NewSystemHandler(systemService) - kbHandler := handler.NewKnowledgebaseHandler(kbService, userService, documentService) chunkHandler := handler.NewChunkHandler(chunkService, userService) llmHandler := handler.NewLLMHandler(llmService, userService) chatHandler := handler.NewChatHandler(chatService, userService) @@ -203,7 +201,7 @@ func startServer(config *server.Config) { providerHandler := handler.NewProviderHandler(userService, modelProviderService) // Initialize router - r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, kbHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, skillSearchHandler, providerHandler) + r := router.NewRouter(authHandler, userHandler, tenantHandler, documentHandler, datasetsHandler, systemHandler, chunkHandler, llmHandler, chatHandler, chatSessionHandler, connectorHandler, searchHandler, fileHandler, memoryHandler, skillSearchHandler, providerHandler) // Create Gin engine ginEngine := gin.New() diff --git a/common/connection_utils.py b/common/connection_utils.py index 86ebc371d8c..0218d99a281 100644 --- a/common/connection_utils.py +++ b/common/connection_utils.py @@ -115,7 +115,6 @@ async def construct_response(code=RetCode.SUCCESS, message="success", data=None, response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Method"] = "*" response.headers["Access-Control-Allow-Headers"] = "*" - response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Expose-Headers"] = "Authorization" return response @@ -135,6 +134,5 @@ def sync_construct_response(code=RetCode.SUCCESS, message="success", data=None, response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Method"] = "*" response.headers["Access-Control-Allow-Headers"] = "*" - response.headers["Access-Control-Allow-Headers"] = "*" response.headers["Access-Control-Expose-Headers"] = "Authorization" return response diff --git a/common/constants.py b/common/constants.py index 5ab9acaa502..c76dcdbb099 100644 --- a/common/constants.py +++ b/common/constants.py @@ -16,7 +16,7 @@ import os from enum import Enum, IntEnum -from strenum import StrEnum +from enum import StrEnum SERVICE_CONF = "service_conf.yaml" RAG_FLOW_SERVICE_NAME = "ragflow" @@ -93,6 +93,11 @@ class TaskStatus(StrEnum): VALID_TASK_STATUS = {TaskStatus.UNSTART, TaskStatus.RUNNING, TaskStatus.CANCEL, TaskStatus.DONE, TaskStatus.FAIL, TaskStatus.SCHEDULE} +class ConnectorTaskType(StrEnum): + SYNC = "sync" + PRUNE = "prune" + + class ParserType(StrEnum): PRESENTATION = "presentation" LAWS = "laws" @@ -117,6 +122,7 @@ class FileSource(StrEnum): RSS = "rss" S3 = "s3" NOTION = "notion" + REST_API = "rest_api" DISCORD = "discord" CONFLUENCE = "confluence" GMAIL = "gmail" diff --git a/common/data_source/__init__.py b/common/data_source/__init__.py index 301103652ce..34bb467d9f4 100644 --- a/common/data_source/__init__.py +++ b/common/data_source/__init__.py @@ -44,6 +44,7 @@ from .seafile_connector import SeaFileConnector from .rdbms_connector import RDBMSConnector from .webdav_connector import WebDAVConnector +from .rest_api_connector import RestAPIConnector from .config import BlobType, DocumentSource from .models import Document, TextSection, ImageSection, BasicExpertInfo from .exceptions import ( @@ -87,4 +88,5 @@ "RDBMSConnector", "WebDAVConnector", "DingTalkAITableConnector", + "RestAPIConnector", ] diff --git a/common/data_source/blob_connector.py b/common/data_source/blob_connector.py index 7505b878ba3..e183eb63aac 100644 --- a/common/data_source/blob_connector.py +++ b/common/data_source/blob_connector.py @@ -1,9 +1,12 @@ """Blob storage connector""" import logging import os +from collections.abc import Iterator from datetime import datetime, timezone from typing import Any, Optional +import xxhash + from common.data_source.utils import ( create_s3_client, detect_bucket_region, @@ -18,9 +21,14 @@ CredentialExpiredError, InsufficientPermissionsError ) -from common.data_source.interfaces import LoadConnector, PollConnector +from common.data_source.interfaces import ( + FingerprintConnector, + LoadConnector, + PollConnector, +) from common.data_source.models import ( Document, + KeyRecord, SecondsSinceUnixEpoch, GenerateDocumentsOutput, GenerateSlimDocumentOutput, @@ -28,7 +36,20 @@ ) -class BlobStorageConnector(LoadConnector, PollConnector): +def _normalize_etag(raw_etag: Optional[str]) -> Optional[str]: + """Return a 32-char hex fingerprint derived from an S3 ETag. + + S3 ETags are MD5 (32 hex chars) for single-part uploads and "-" + (34+ chars) for multipart. We always hash so the column format is uniform + regardless of upload type or provider quirks; equality of the hashed value + is sufficient for change detection. + """ + if not raw_etag: + return None + return xxhash.xxh128(raw_etag.strip('"').encode()).hexdigest() + + +class BlobStorageConnector(LoadConnector, PollConnector, FingerprintConnector): """Blob storage connector""" def __init__( @@ -48,6 +69,11 @@ def __init__( self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD self.bucket_region: Optional[str] = None self.european_residency: bool = european_residency + # Populated by list_keys() so a subsequent get_value(key) can find the + # raw S3 object metadata (LastModified, ETag, Key, Size) without a second + # head_object call. Lifetime is one list_keys() pass. + self._listing_cache: dict[str, dict[str, Any]] = {} + self._filename_counts: dict[str, int] = {} def set_allow_images(self, allow_images: bool) -> None: """Set whether to process images""" @@ -122,6 +148,44 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None return None + def _build_document_from_obj( + self, + obj: dict[str, Any], + filename_counts: dict[str, int], + ) -> Optional[Document]: + """Materialize a Document for one S3 object, downloading its body.""" + key = obj["Key"] + file_name = os.path.basename(key) + last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) + + size_bytes = extract_size_bytes(obj) + if ( + self.size_threshold is not None + and isinstance(size_bytes, int) + and size_bytes > self.size_threshold + ): + logging.warning( + f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." + ) + return None + + blob = download_object( + self.s3_client, self.bucket_name, key, self.size_threshold + ) + if blob is None: + return None + + return Document( + id=f"{self.bucket_type}:{self.bucket_name}:{key}", + blob=blob, + source=DocumentSource(self.bucket_type.value), + semantic_identifier=self._get_semantic_id(key, file_name, filename_counts), + extension=get_file_ext(file_name), + doc_updated_at=last_modified, + size_bytes=size_bytes if size_bytes else 0, + fingerprint=_normalize_etag(obj.get("ETag")), + ) + def _yield_blob_objects( self, start: datetime, @@ -132,51 +196,64 @@ def _yield_blob_objects( batch: list[Document] = [] for obj in all_objects: - last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) - file_name = os.path.basename(obj["Key"]) - key = obj["Key"] - - size_bytes = extract_size_bytes(obj) - if ( - self.size_threshold is not None - and isinstance(size_bytes, int) - and size_bytes > self.size_threshold - ): - logging.warning( - f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." - ) - continue - try: - blob = download_object( - self.s3_client, self.bucket_name, key, self.size_threshold - ) - if blob is None: + doc = self._build_document_from_obj(obj, filename_counts) + if doc is None: continue - - semantic_id = self._get_semantic_id(key, file_name, filename_counts) - - batch.append( - Document( - id=f"{self.bucket_type}:{self.bucket_name}:{key}", - blob=blob, - source=DocumentSource(self.bucket_type.value), - semantic_identifier=semantic_id, - extension=get_file_ext(file_name), - doc_updated_at=last_modified, - size_bytes=size_bytes if size_bytes else 0, - ) - ) + batch.append(doc) if len(batch) == self.batch_size: yield batch batch = [] - except Exception: - logging.exception(f"Error decoding object {key}") + logging.exception(f"Error decoding object {obj.get('Key')}") if batch: yield batch + def list_keys(self) -> Iterator[KeyRecord]: + """Enumerate the full bucket keyspace with per-object fingerprints. + + Cheap path: relies on list_objects_v2 which returns ETag in the listing, + so no GetObject call is needed. Caches each object's metadata so a + subsequent get_value(key) call can rebuild the Document without a second + round-trip to S3. + """ + if self.s3_client is None: + raise ConnectorMissingCredentialError("Blob storage") + + all_objects, filename_counts = self._collect_blob_objects( + start=datetime(1970, 1, 1, tzinfo=timezone.utc), + end=datetime.now(timezone.utc), + ) + self._filename_counts = filename_counts + self._listing_cache = {} + + for obj in all_objects: + doc_id = f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}" + self._listing_cache[doc_id] = obj + yield KeyRecord( + key=doc_id, + fingerprint=_normalize_etag(obj.get("ETag")), + ) + + def get_value(self, key: str) -> Document: + """Materialize the Document for a key previously yielded by list_keys(). + + Must be called within the same list_keys() pass that produced the key, + since the metadata cache lives on the connector instance and is reset + each list_keys() call. + """ + obj = self._listing_cache.get(key) + if obj is None: + raise KeyError( + f"get_value({key!r}) called before list_keys() yielded the key, " + "or after a subsequent list_keys() reset the cache" + ) + doc = self._build_document_from_obj(obj, self._filename_counts) + if doc is None: + raise RuntimeError(f"Failed to materialize Document for key {key!r}") + return doc + def _collect_blob_objects( self, start: datetime, diff --git a/common/data_source/config.py b/common/data_source/config.py index 2b512d4ce23..08d0209e03c 100644 --- a/common/data_source/config.py +++ b/common/data_source/config.py @@ -43,6 +43,7 @@ class DocumentSource(str, Enum): RSS = "rss" S3 = "s3" NOTION = "notion" + REST_API = "rest_api" R2 = "r2" GOOGLE_CLOUD_STORAGE = "google_cloud_storage" OCI_STORAGE = "oci_storage" diff --git a/common/data_source/discord_connector.py b/common/data_source/discord_connector.py index 83b2b562f0e..e047148f330 100644 --- a/common/data_source/discord_connector.py +++ b/common/data_source/discord_connector.py @@ -3,7 +3,9 @@ import asyncio import logging import os +from contextlib import suppress from datetime import datetime, timezone +from queue import Queue from typing import Any, AsyncIterable, Iterable from discord import Client, MessageType @@ -151,14 +153,14 @@ async def _fetch_documents_from_channel( yield thread_message -def _manage_async_retrieval( +async def _manage_async_retrieval( token: str, requested_start_date_string: str, channel_names: list[str], server_ids: list[int], start: datetime | None = None, -) -> Iterable[DiscordMessage]: - """Bridge the async Discord client into a synchronous iterator. +) -> AsyncIterable[DiscordMessage]: + """Fetch Discord messages with the async Discord client. `start` is only used as a lower bound for the underlying fetch. Callers that need a narrower time window should apply their own filtering while @@ -173,11 +175,11 @@ def _manage_async_retrieval( if proxy_url: logging.info(f"Using proxy for Discord: {proxy_url}") - async def _async_fetch() -> AsyncIterable[DiscordMessage]: - intents = Intents.default() - intents.message_content = True - async with Client(intents=intents, proxy=proxy_url) as cli: - asyncio.create_task(coro=cli.start(token)) + intents = Intents.default() + intents.message_content = True + async with Client(intents=intents, proxy=proxy_url) as cli: + client_task = asyncio.create_task(cli.start(token)) + try: await cli.wait_until_ready() filtered_channels: list[TextChannel] = await _fetch_filtered_channels( @@ -192,27 +194,41 @@ async def _async_fetch() -> AsyncIterable[DiscordMessage]: start_time=start_time, ): yield message + finally: + await cli.close() + client_task.cancel() + with suppress(asyncio.CancelledError): + await client_task + + +def _iterate_async_messages(async_messages: AsyncIterable[DiscordMessage]) -> Iterable[DiscordMessage]: + """Expose async Discord retrieval to the existing synchronous connector API.""" + item_queue: Queue[DiscordMessage | BaseException | None] = Queue() - def run_and_yield() -> Iterable[DiscordMessage]: - loop = asyncio.new_event_loop() + async def consume_messages() -> None: + async for message in async_messages: + item_queue.put(message) + + def run_consumer() -> None: try: - # Get the async generator - async_gen = _async_fetch() - # Convert to AsyncIterator - async_iter = async_gen.__aiter__() - while True: - try: - # Create a coroutine by calling anext with the async iterator - next_coro = anext(async_iter) - # Run the coroutine to get the next document - doc = loop.run_until_complete(next_coro) - yield doc - except StopAsyncIteration: - break + asyncio.run(consume_messages()) + except BaseException as exc: + item_queue.put(exc) finally: - loop.close() + item_queue.put(None) + + consumer_thread = Thread(target=run_consumer, name="discord-connector-retrieval", daemon=True) + consumer_thread.start() + + while True: + item = item_queue.get() + if item is None: + break + if isinstance(item, BaseException): + raise item + yield item - return run_and_yield() + consumer_thread.join() class DiscordConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): @@ -283,12 +299,14 @@ def merge_batch(): size_bytes=size_bytes, ) - for message in _manage_async_retrieval( - token=self.discord_bot_token, - requested_start_date_string=self.requested_start_date_string, - channel_names=self.channel_names, - server_ids=self.server_ids, - start=start, + for message in _iterate_async_messages( + _manage_async_retrieval( + token=self.discord_bot_token, + requested_start_date_string=self.requested_start_date_string, + channel_names=self.channel_names, + server_ids=self.server_ids, + start=start, + ) ): if not _is_in_window(message): continue @@ -321,7 +339,7 @@ def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None def validate_connector_settings(self) -> None: """Validate Discord connector settings""" - if not self.discord_client: + if not self.discord_bot_token: raise ConnectorMissingCredentialError("Discord") def poll_source(self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch) -> Any: @@ -344,12 +362,14 @@ def retrieve_all_slim_docs_perm_sync( full_scan_batch_size = 0 full_scan_batch_first_id: str | None = None - for message in _manage_async_retrieval( - token=self.discord_bot_token, - requested_start_date_string=self.requested_start_date_string, - channel_names=self.channel_names, - server_ids=self.server_ids, - start=None, + for message in _iterate_async_messages( + _manage_async_retrieval( + token=self.discord_bot_token, + requested_start_date_string=self.requested_start_date_string, + channel_names=self.channel_names, + server_ids=self.server_ids, + start=None, + ) ): if full_scan_batch_first_id is None: full_scan_batch_first_id = f"{_DISCORD_DOC_ID_PREFIX}{message.id}" diff --git a/common/data_source/interfaces.py b/common/data_source/interfaces.py index 324293baaba..fb547d7d928 100644 --- a/common/data_source/interfaces.py +++ b/common/data_source/interfaces.py @@ -2,7 +2,7 @@ import abc import uuid from abc import ABC, abstractmethod -from enum import IntFlag, auto +from enum import IntEnum, IntFlag, auto from types import TracebackType from typing import Any, Dict, Generator, TypeVar, Generic, Callable, TypeAlias from collections.abc import Iterator @@ -10,12 +10,26 @@ from common.data_source.models import ( Document, + KeyRecord, SlimDocument, ConnectorCheckpoint, ConnectorFailure, SecondsSinceUnixEpoch, GenerateSlimDocumentOutput ) + +class IncrementalCapability(IntEnum): + """How a connector handles incremental sync. + + FULL_RESYNC -- every sync re-pulls; no per-key state. + CURSOR -- "give me everything since cursor X"; opaque cursor persisted across syncs. + FINGERPRINT -- list_keys() returns (key, fingerprint) cheaply; bodies fetched lazily. + """ + FULL_RESYNC = 0 + CURSOR = 1 + FINGERPRINT = 2 + + GenerateDocumentsOutput = Iterator[list[Document]] class LoadConnector(ABC): @@ -415,3 +429,39 @@ def progress(self, tag: str, amount: int) -> None: just to act as a keep-alive. """ + +class FingerprintConnector(ABC): + """Tier 1 connector: cheap full listing with per-key fingerprint. + + Sources that can enumerate their entire keyspace via a metadata-only call + (e.g. S3 list_objects_v2 returning ETag + LastModified) implement this to + let the orchestrator skip GetObject for keys whose fingerprint hasn't + changed since the last sync. + + The fingerprint is an opaque equality token: two equal fingerprints mean + the content is unchanged from the orchestrator's point of view. Format is + a 32-char hex string so it fits the existing Document.content_hash column; + connectors are responsible for normalizing whatever the source exposes + (typically by hashing it with xxhash128). + """ + + INCREMENTAL_CAPABILITY: IncrementalCapability = IncrementalCapability.FINGERPRINT + + @abstractmethod + def list_keys(self) -> Iterator[KeyRecord]: + """Yield one KeyRecord per object currently in the source. + + Must enumerate the full current keyspace -- the orchestrator diffs the + result against persisted state to detect adds, updates, and deletes. + """ + raise NotImplementedError + + @abstractmethod + def get_value(self, key: str) -> Document: + """Fetch the body for a single key, returning a fully populated Document. + + Called only when list_keys()'s fingerprint differs from the persisted + content_hash for that key (or when no persisted fingerprint exists). + """ + raise NotImplementedError + diff --git a/common/data_source/models.py b/common/data_source/models.py index 71f8c27242f..29cb6bc251c 100644 --- a/common/data_source/models.py +++ b/common/data_source/models.py @@ -99,6 +99,25 @@ class Document(BaseModel): primary_owners: Optional[list] = None metadata: Optional[dict[str, Any]] = None doc_metadata: Optional[dict[str, Any]] = None + # Opaque, connector-supplied fingerprint stored in Document.content_hash for + # change-detection. 32-char hex string; format is per-source (xxhash128 of + # bytes for local uploads, xxhash128(ETag) for blob storage, etc.). When set + # on a yielded Document, the orchestrator persists it as content_hash and + # skips the post-download xxhash128(blob) recomputation. + fingerprint: Optional[str] = None + + +class KeyRecord(BaseModel): + """One entry returned by a FingerprintConnector.list_keys() call. + + A KeyRecord is the cheap-listing primitive: connector enumerates all keys + it has, attaches a fingerprint when the source exposes one, and the + orchestrator only fetches content when the fingerprint differs from what's + persisted. + """ + key: str + fingerprint: Optional[str] = None + deleted: bool = False class BasicExpertInfo(BaseModel): diff --git a/common/data_source/rest_api_connector.py b/common/data_source/rest_api_connector.py new file mode 100644 index 00000000000..8616be2730d --- /dev/null +++ b/common/data_source/rest_api_connector.py @@ -0,0 +1,1012 @@ +"""Generic, configuration-driven REST API data source connector. + +Connect any REST API as a RAGFlow data source without code changes. +All behaviour — URL, auth, pagination, field mapping — is controlled +via the ``RestAPIConnectorConfig`` schema exposed by the UI. +""" + +from __future__ import annotations + +import json +import logging +import re +import time +from datetime import datetime, timezone +from typing import Any, Dict, Generator, Iterable, List, Mapping, Optional +from urllib.parse import parse_qs, urlparse, urlunparse + +import ipaddress +import socket +import requests +from pydantic import BaseModel, ConfigDict, Field, HttpUrl, ValidationError, field_validator + +logger = logging.getLogger(__name__) + +from api.utils.common import hash128 +from common.data_source.config import INDEX_BATCH_SIZE, DocumentSource +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, +) +from common.data_source.interfaces import ( + LoadConnector, + PollConnector, + SecondsSinceUnixEpoch, +) +from common.data_source.models import Document +from common.data_source.utils import rl_requests, retry_builder + +try: + from jsonpath import jsonpath as _jsonpath # type: ignore[import] +except Exception: # pragma: no cover + _jsonpath = None + +_FIELD_SEGMENT_RE = re.compile(r'^(?P[^\[\]]+)(\[(?P\d+|\*)\])?$') +_DEFAULT_MAX_PAGES = 1000 + + +class AuthType: + NONE = "none" + API_KEY_HEADER = "api_key_header" + BEARER = "bearer" + BASIC = "basic" + + +class PaginationType: + NONE = "none" + PAGE = "page" + OFFSET = "offset" + CURSOR = "cursor" + + +def _text_to_dict(v: Any) -> Dict[str, str]: + """Parse a dict, JSON string, or ``key=value`` text (one per line) into a dict. + + This is module-level because Pydantic ``@field_validator`` classmethods + on ``RestAPIConnectorConfig`` need to call it before any instance exists. + """ + if v is None or v == "": + return {} + if isinstance(v, dict): + return {str(k): str(vv) for k, vv in v.items()} + if isinstance(v, str): + try: + parsed = json.loads(v) + if isinstance(parsed, dict): + return {str(k): str(vv) for k, vv in parsed.items()} + except Exception: + pass + result: Dict[str, str] = {} + for line in v.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + if "=" in line: + k, _, val = line.partition("=") + result[k.strip()] = val.strip() + return result + return {} + + +class RestAPIConnectorConfig(BaseModel): + """Validated schema for the REST API connector configuration.""" + + model_config = ConfigDict(extra="ignore") + + url: HttpUrl + method: str = "GET" + headers: Dict[str, str] = Field(default_factory=dict) + query_params: Dict[str, str] = Field(default_factory=dict) + + auth_type: str = AuthType.NONE + auth_config: Dict[str, Any] = Field(default_factory=dict) + + items_path: Optional[str] = None + id_field: Optional[str] = None + content_fields: List[str] = Field(default_factory=list) + metadata_fields: List[str] = Field(default_factory=list) + + pagination_type: str = PaginationType.NONE + pagination_config: Dict[str, Any] = Field(default_factory=dict) + + poll_timestamp_field: Optional[str] = None + request_body: Optional[Dict[str, Any]] = None + + field_type_hints: Dict[str, str] = Field(default_factory=dict) + field_default_values: Dict[str, Any] = Field(default_factory=dict) + content_template: Optional[str] = None + + batch_size: int = INDEX_BATCH_SIZE + max_pages: int = _DEFAULT_MAX_PAGES + request_delay: float = 0.5 + + @field_validator("headers", mode="before") + @classmethod + def _coerce_headers(cls, v: Any) -> Dict[str, str]: + return _text_to_dict(v) + + @field_validator("query_params", mode="before") + @classmethod + def _coerce_query_params(cls, v: Any) -> Dict[str, str]: + return _text_to_dict(v) + + @field_validator("content_fields", "metadata_fields", mode="before") + @classmethod + def _coerce_field_list(cls, v: Any) -> List[str]: + if v is None or v == "": + return [] + if isinstance(v, str): + return [p.strip() for p in v.split(",") if p.strip()] + if isinstance(v, list): + return [str(p).strip() for p in v if str(p).strip()] + return [] + + def normalized_method(self) -> str: + m = (self.method or "GET").upper() + if m not in {"GET", "POST"}: + raise ConnectorValidationError(f"Unsupported HTTP method '{m}'.") + return m + + def normalized_auth_type(self) -> str: + if self.auth_type not in {AuthType.NONE, AuthType.API_KEY_HEADER, AuthType.BEARER, AuthType.BASIC}: + raise ConnectorValidationError(f"Unsupported auth_type '{self.auth_type}'.") + return self.auth_type + + def normalized_pagination_type(self) -> str: + if self.pagination_type not in {PaginationType.NONE, PaginationType.PAGE, PaginationType.OFFSET, PaginationType.CURSOR}: + raise ConnectorValidationError(f"Unsupported pagination_type '{self.pagination_type}'.") + return self.pagination_type + + def ensure_required_fields(self) -> None: + if not self.content_fields: + raise ConnectorValidationError("At least one content field must be configured (content_fields).") + + +class RestAPIConnector(LoadConnector, PollConnector): + """Configuration-driven REST API connector. + + Implements ``LoadConnector`` and ``PollConnector`` to fetch documents + from any REST API using user-provided configuration (URL, auth, + pagination, field mapping). + """ + + @staticmethod + def _validate_url_for_ssrf(url: str) -> None: + """Validate that the URL does not point to localhost or private/internal networks. + + Raises: + ConnectorValidationError: If the URL is considered unsafe. + """ + parsed = urlparse(str(url)) + + if parsed.scheme not in ("http", "https"): + msg = f"Unsupported URL scheme for REST API connector: {parsed.scheme!r}. Only http/https are allowed." + logger.warning(msg) + raise ConnectorValidationError(msg) + + hostname = parsed.hostname + if not hostname: + msg = "REST API connector URL must include a hostname." + logger.warning(msg) + raise ConnectorValidationError(msg) + + # Quick checks for obvious localhost-style hostnames. + lower_host = hostname.lower() + if lower_host in ("localhost",): + msg = f"REST API connector URL hostname {hostname!r} is not allowed (localhost is blocked)." + logger.warning(msg) + raise ConnectorValidationError(msg) + + try: + addrinfo_list = socket.getaddrinfo(hostname, None) + except OSError as exc: + # If resolution fails, log and let higher-level validation (if any) decide. + # We do not treat this as an SSRF condition by itself. + logger.info("DNS resolution failed for REST API connector URL %r: %s", url, exc) + return + + for family, _, _, _, sockaddr in addrinfo_list: + ip_str = sockaddr[0] + try: + ip_obj = ipaddress.ip_address(ip_str) + except ValueError: + # Not an IP address we understand; skip. + logger.debug("Skipping non-IP address resolved from %r: %r", hostname, ip_str) + continue + + if ( + ip_obj.is_loopback + or ip_obj.is_private + or ip_obj.is_link_local + or ip_obj.is_reserved + or ip_obj.is_multicast + ): + msg = ( + f"REST API connector URL {url!r} resolves to disallowed address {ip_str} " + "(localhost, private, link-local, reserved, or multicast addresses are blocked)." + ) + logger.warning(msg) + raise ConnectorValidationError(msg) + + logger.debug("REST API connector URL %r passed SSRF safety validation.", url) + + def __init__( + self, + url: str, + method: str = "GET", + headers: Optional[Dict[str, str]] = None, + query_params: Optional[Dict[str, str]] = None, + auth_type: str = AuthType.NONE, + auth_config: Optional[Dict[str, Any]] = None, + items_path: Optional[str] = None, + id_field: Optional[str] = None, + content_fields: Optional[List[str]] = None, + metadata_fields: Optional[List[str]] = None, + pagination_type: str = PaginationType.NONE, + pagination_config: Optional[Dict[str, Any]] = None, + poll_timestamp_field: Optional[str] = None, + batch_size: int = INDEX_BATCH_SIZE, + max_pages: int = _DEFAULT_MAX_PAGES, + request_delay: float = 0.5, + request_body: Optional[Dict[str, Any]] = None, + field_type_hints: Optional[Dict[str, str]] = None, + field_default_values: Optional[Dict[str, Any]] = None, + content_template: Optional[str] = None, + ) -> None: + # Validate URL against SSRF-style targets (localhost, private/internal ranges, etc.) + self._validate_url_for_ssrf(url) + + parsed = urlparse(str(url)) + self._base_url = urlunparse((parsed.scheme, parsed.netloc, parsed.path, "", "", "")) + self._url_params: Dict[str, str] = {} + if parsed.query: + for k, v_list in parse_qs(parsed.query, keep_blank_values=True).items(): + self._url_params[k] = v_list[-1] + + self._explicit_query_params: Dict[str, str] = ( + _text_to_dict(query_params) if isinstance(query_params, str) else (query_params or {}) + ) + self.url = self._base_url + self.method = (method or "GET").upper() + self._base_headers: Dict[str, str] = ( + _text_to_dict(headers) if isinstance(headers, str) else (headers or {}) + ) + self.auth_type = auth_type or AuthType.NONE + self.auth_config: Dict[str, Any] = auth_config or {} + self.items_path = items_path + self.id_field = id_field + self.content_fields: List[str] = content_fields or [] + self.metadata_fields: List[str] = metadata_fields or [] + self.pagination_type = pagination_type or PaginationType.NONE + self.pagination_config: Dict[str, Any] = pagination_config or {} + self._static_request_body: Dict[str, Any] = ( + request_body if request_body is not None + else self.pagination_config.get("request_body") or {} + ) + self.poll_timestamp_field = poll_timestamp_field + self.batch_size = batch_size + self.max_pages = max_pages + self.request_delay = max(request_delay, 0.0) + self.field_type_hints: Dict[str, str] = field_type_hints or {} + self.field_default_values: Dict[str, Any] = field_default_values or {} + self.content_template = content_template + + self._credentials: Dict[str, Any] = {} + self._auth_headers: Dict[str, str] = {} + self._basic_auth: Optional[requests.auth.HTTPBasicAuth] = None + + # -- Credentials -------------------------------------------------------- + + def load_credentials(self, credentials: Dict[str, Any]) -> Dict[str, Any] | None: + """Apply authentication credentials (no network call). + + Use ``validate_config()`` to perform a live connectivity check. + """ + self._credentials = credentials or {} + self._build_auth() + return None + + def _build_auth(self) -> None: + """Derive auth headers / basic-auth object from credentials.""" + self._auth_headers = {} + self._basic_auth = None + + if self.auth_type == AuthType.NONE: + logging.info("REST API auth_type=none, no authentication configured.") + return + + if self.auth_type == AuthType.API_KEY_HEADER: + header_name = self.auth_config.get("header_name") + api_key = ( + self._credentials.get("api_key") + or self.auth_config.get("api_key_value") + or self.auth_config.get("api_key") + ) + if not header_name or not api_key: + logging.warning( + "REST API auth setup failed: header_name=%s, api_key present=%s, " + "credentials keys=%s, auth_config keys=%s", + header_name, bool(api_key), + list(self._credentials.keys()), list(self.auth_config.keys()), + ) + raise ConnectorMissingCredentialError( + "REST API (api_key_header) requires 'header_name' in auth_config and 'api_key' in credentials" + ) + self._auth_headers[header_name] = str(api_key) + logging.info("REST API auth configured: header '%s' set.", header_name) + return + + if self.auth_type == AuthType.BEARER: + token = self._credentials.get("token") or self.auth_config.get("token") + if not token: + raise ConnectorMissingCredentialError("REST API (bearer) requires 'token' in credentials") + self._auth_headers["Authorization"] = f"Bearer {token}" + logging.info("REST API auth configured: Bearer token set.") + return + + if self.auth_type == AuthType.BASIC: + username = self._credentials.get("username") or self.auth_config.get("username") + password = self._credentials.get("password") or self.auth_config.get("password") + if not username or password is None: + raise ConnectorMissingCredentialError("REST API (basic) requires 'username' and 'password'") + self._basic_auth = requests.auth.HTTPBasicAuth(str(username), str(password)) + logging.info("REST API auth configured: Basic auth for user '%s'.", username) + return + + raise ConnectorValidationError(f"Unsupported auth_type: {self.auth_type}") + + # -- Config validation (test connection) -------------------------------- + + @classmethod + def parse_storage_config(cls, raw: Dict[str, Any]) -> RestAPIConnectorConfig: + """Parse connector config as stored on the connector row (no network I/O). + + ``credentials`` live under ``raw`` but are excluded from the schema and + must be applied via ``load_credentials`` separately. + """ + body = {k: v for k, v in raw.items() if k != "credentials"} + try: + cfg = RestAPIConnectorConfig(**body) + except ValidationError as exc: + raise ConnectorValidationError(f"Invalid REST API config: {exc}") from exc + cfg.normalized_method() + cfg.normalized_auth_type() + cfg.normalized_pagination_type() + cfg.ensure_required_fields() + return cfg + + @classmethod + def from_parsed_config( + cls, + cfg: RestAPIConnectorConfig, + *, + max_pages: Optional[int] = None, + ) -> RestAPIConnector: + """Build a connector from validated config (``__init__`` runs SSRF validation).""" + return cls( + url=str(cfg.url), + method=cfg.normalized_method(), + headers=cfg.headers, + query_params=cfg.query_params, + auth_type=cfg.normalized_auth_type(), + auth_config=cfg.auth_config, + items_path=cfg.items_path, + id_field=cfg.id_field, + content_fields=cfg.content_fields, + metadata_fields=cfg.metadata_fields, + pagination_type=cfg.normalized_pagination_type(), + pagination_config=cfg.pagination_config, + poll_timestamp_field=cfg.poll_timestamp_field, + batch_size=cfg.batch_size, + max_pages=max_pages if max_pages is not None else cfg.max_pages, + request_delay=cfg.request_delay, + request_body=cfg.request_body, + field_type_hints=cfg.field_type_hints, + field_default_values=cfg.field_default_values, + content_template=cfg.content_template, + ) + + @classmethod + def validate_config( + cls, + config: Dict[str, Any], + credentials: Optional[Dict[str, Any]] = None, + ) -> RestAPIConnectorConfig: + """Validate config schema and optionally perform a live API call. + + Args: + config: Raw config dict from the UI / database. + credentials: Optional credentials dict; when provided a live + connectivity check is performed. + + Returns: + The validated ``RestAPIConnectorConfig`` instance. + + Raises: + ConnectorValidationError: On schema or connectivity failure. + """ + cfg = cls.parse_storage_config(config) + validation_cap = min(cfg.max_pages, 10) + connector = cls.from_parsed_config(cfg, max_pages=validation_cap) + + if credentials is None and cfg.auth_type != AuthType.NONE: + return cfg + + if credentials is not None: + connector.load_credentials(credentials) + else: + connector._credentials = {} + connector._build_auth() + + try: + logging.info("Validating REST API connector by fetching first page") + _ = next(connector._page_iter_for_validation()) + except StopIteration: + pass + + return cfg + + # -- LoadConnector / PollConnector interface ----------------------------- + + def load_from_state(self) -> Generator[List[Document], None, None]: + """Full fetch with pagination.""" + return self._yield_documents(time_window=None) + + def poll_source( + self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch + ) -> Generator[List[Document], None, None]: + """Incremental fetch; filters by ``poll_timestamp_field`` if configured.""" + if not self.poll_timestamp_field: + logging.warning( + "poll_source called without poll_timestamp_field; " + "falling back to full fetch with in-memory filtering." + ) + return self._yield_documents( + time_window=( + datetime.fromtimestamp(start, tz=timezone.utc), + datetime.fromtimestamp(end, tz=timezone.utc), + ) + ) + + # -- Document generation ------------------------------------------------ + + def _yield_documents( + self, + time_window: tuple[datetime, datetime] | None, + ) -> Generator[List[Document], None, None]: + batch: List[Document] = [] + for item in self._iter_items(): + try: + doc = self._item_to_document(item) + except Exception as exc: + logging.warning("Failed to convert REST API item to Document: %s", exc) + continue + + if time_window is not None and not self._doc_in_time_window(doc, *time_window): + continue + + batch.append(doc) + if len(batch) >= self.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + # -- Pagination & page fetching ----------------------------------------- + + def _iter_items(self) -> Iterable[Mapping[str, Any]]: + """Iterate over raw items across all pages.""" + page_count = 0 + + page = int(self.pagination_config.get("start_page", 1)) + per_page = self._resolve_page_size() + + offset = int(self.pagination_config.get("start_offset", 0)) + limit = int(self.pagination_config.get("limit", per_page)) + if limit <= 0: + limit = per_page + + cursor: Optional[str] = self.pagination_config.get("initial_cursor") + + while True: + if page_count >= self.max_pages: + logging.warning("REST API connector reached max_pages=%d, stopping.", self.max_pages) + break + + params: Dict[str, Any] = {} + if self.pagination_type == PaginationType.PAGE: + self._apply_page_pagination(params, page, per_page) + elif self.pagination_type == PaginationType.OFFSET: + self._apply_offset_pagination(params, offset, limit) + elif self.pagination_type == PaginationType.CURSOR and cursor is not None: + self._apply_cursor_pagination(params, cursor) + + if page_count > 0 and self.request_delay > 0: + time.sleep(self.request_delay) + + try: + response_json = self._fetch_page(params) + except (ConnectorValidationError, ConnectorMissingCredentialError): + raise + except Exception as exc: + raise ConnectorValidationError(f"REST API page fetch failed: {exc}") from exc + + items = self._extract_items(response_json) + if not items: + break + + for item in items: + if isinstance(item, Mapping): + yield item + + page_count += 1 + + if self.pagination_type == PaginationType.NONE: + break + elif self.pagination_type == PaginationType.PAGE: + if len(items) < per_page: + break + page += 1 + elif self.pagination_type == PaginationType.OFFSET: + if len(items) < limit: + break + offset += limit + elif self.pagination_type == PaginationType.CURSOR: + next_cursor = self._extract_next_cursor(response_json) + if not next_cursor: + break + cursor = next_cursor + + def _page_iter_for_validation(self) -> Iterable[Mapping[str, Any]]: + """Single-page iterator used for connectivity checks.""" + params: Dict[str, Any] = {} + if self.pagination_type == PaginationType.PAGE: + page = int(self.pagination_config.get("start_page", 1)) + per_page = self._resolve_page_size() + self._apply_page_pagination(params, page, per_page) + elif self.pagination_type == PaginationType.OFFSET: + per_page = self._resolve_page_size() + offset = int(self.pagination_config.get("start_offset", 0)) + limit = int(self.pagination_config.get("limit", per_page)) + if limit <= 0: + limit = per_page + self._apply_offset_pagination(params, offset, limit) + elif self.pagination_type == PaginationType.CURSOR: + cursor = self.pagination_config.get("initial_cursor") + if cursor is not None: + self._apply_cursor_pagination(params, cursor) + + response_json = self._fetch_page(params=params) + for item in self._extract_items(response_json): + yield item + + @retry_builder( + tries=5, delay=1, max_delay=30, backoff=2, + exceptions=(requests.ConnectionError, requests.Timeout, requests.HTTPError), + ) + def _fetch_page(self, params: Dict[str, Any]) -> Any: + """Fetch a single page with retry and exponential backoff.""" + headers = {**self._base_headers, **self._auth_headers} + + merged: Dict[str, Any] = {**self._url_params} + merged.update(self._explicit_query_params) + merged.update(params) + + url, query_params = self._build_url_with_templates(merged) + + sensitive = {"authorization", "apikey", "api-key", "x-api-key"} + logging.debug( + "REST API request: %s %s | params=%s | headers=%s", + self.method, url, + {k: ("***" if k.lower() in sensitive else v) for k, v in query_params.items()}, + {k: ("***" if k.lower() in sensitive else v) for k, v in headers.items()}, + ) + + if self.method == "GET": + resp = rl_requests.get(url, headers=headers, params=query_params, auth=self._basic_auth, timeout=60) + elif self.method == "POST": + resp = rl_requests.post( + url, headers=headers, params=query_params, + json=self._static_request_body or {}, auth=self._basic_auth, timeout=60, + ) + else: + raise ConnectorValidationError(f"Unsupported HTTP method: {self.method}") + + try: + resp.raise_for_status() + except requests.HTTPError as exc: + status = exc.response.status_code if exc.response is not None else None + if status in (401, 403): + sensitive = {"authorization", "apikey", "api-key", "x-api-key"} + logging.warning( + "REST API %d for %s %s | auth_type=%s | " + "request header keys=%s | auth_header keys=%s", + status, self.method, resp.url, + self.auth_type, + [k for k in headers], + [k for k in self._auth_headers], + ) + raise ConnectorMissingCredentialError( + f"REST API authentication failed with status {status}" + ) from exc + if status is not None and 400 <= status < 500 and status != 429: + logging.warning( + "REST API client error %d for %s %s; not retrying.", + status, + self.method, + resp.url, + ) + raise ConnectorValidationError( + f"REST API request failed with non-retriable client error status {status}" + ) from exc + raise + + try: + return resp.json() + except ValueError as exc: + raise ConnectorValidationError("REST API response is not valid JSON") from exc + + def _build_url_with_templates(self, params: Dict[str, Any]) -> tuple[str, Dict[str, Any]]: + """Substitute ``{key}`` placeholders in the URL; return remaining query params.""" + url = self.url + query_params = dict(params) + used_keys: List[str] = [] + for key, value in list(query_params.items()): + placeholder = "{" + key + "}" + if placeholder in url: + url = url.replace(placeholder, str(value)) + used_keys.append(key) + for key in used_keys: + query_params.pop(key, None) + return url, query_params + + # -- Pagination helpers ------------------------------------------------- + + def _resolve_page_size(self) -> int: + """Determine per-page size from config, query params, or batch_size fallback. + + Priority: explicit ``page_size`` in pagination_config > value already + present in user query params for the same param name > batch_size. + """ + explicit = self.pagination_config.get("page_size") + if explicit is not None: + val = int(explicit) + if val > 0: + return val + + size_param = self.pagination_config.get("page_size_param") or self.pagination_config.get("limit_param") + if size_param: + for source in (self._explicit_query_params, self._url_params): + if size_param in source: + try: + val = int(source[size_param]) + if val > 0: + return val + except (ValueError, TypeError): + pass + + return self.batch_size + + def _apply_page_pagination(self, params: Dict[str, Any], page: int, per_page: int) -> None: + params[self.pagination_config.get("page_param", "page")] = page + size_param = self.pagination_config.get("page_size_param") + if size_param: + params[size_param] = per_page + + def _apply_offset_pagination(self, params: Dict[str, Any], offset: int, limit: int) -> None: + params[self.pagination_config.get("offset_param", "offset")] = offset + limit_param = self.pagination_config.get("limit_param") + if limit_param: + params[limit_param] = limit + + def _apply_cursor_pagination(self, params: Dict[str, Any], cursor: str) -> None: + params[self.pagination_config.get("cursor_param", "cursor")] = cursor + + # -- JSON extraction ---------------------------------------------------- + + def _extract_items(self, response_json: Any) -> List[Mapping[str, Any]]: + """Extract the items array from a JSON response.""" + if self.items_path and _jsonpath is not None: + try: + matches = _jsonpath(response_json, self.items_path) + except Exception as exc: + raise ConnectorValidationError( + f"Failed to apply items JSONPath '{self.items_path}': {exc}" + ) from exc + if not matches: + return [] + if len(matches) == 1 and isinstance(matches[0], list): + items = matches[0] + else: + items = matches + elif isinstance(response_json, list): + items = response_json + elif isinstance(response_json, dict): + items = [] + for key in ("items", "results", "data", "records"): + if key in response_json and isinstance(response_json[key], list): + items = response_json[key] + break + else: + for value in response_json.values(): + if isinstance(value, list): + items = value + break + else: + items = [] + + return [it for it in items if isinstance(it, Mapping)] + + def _extract_next_cursor(self, response_json: Any) -> Optional[str]: + """Extract cursor value for cursor-based pagination.""" + cursor_path = self.pagination_config.get("next_cursor_path") + if not cursor_path: + field = self.pagination_config.get("next_cursor_field") + if field and isinstance(response_json, Mapping): + value = response_json.get(field) + return str(value) if value is not None else None + return None + + if _jsonpath is None: + return None + + try: + matches = _jsonpath(response_json, cursor_path) + except Exception: + return None + + if not matches: + return None + return str(matches[0]) if matches[0] is not None else None + + # -- Item → Document mapping -------------------------------------------- + + def _item_to_document(self, item: Mapping[str, Any]) -> Document: + """Map a single API item to a ``Document``.""" + raw_id = self._get_typed_field_value(self.id_field, item) if self.id_field else None + if raw_id is None: + raw_id = hash128(f"rest_api_item:{repr(item)}") + doc_id = hash128(f"rest_api:{raw_id}") + + if self.content_template: + content_text = self._render_content_template(item) + else: + parts = [] + for field in self.content_fields: + val = self._get_typed_field_value(field, item) + if val is not None: + text = self._strip_html(self._coerce_to_text(val)) + if text: + parts.append(text) + content_text = "\n".join(parts) + blob = content_text.encode("utf-8") + + metadata: Dict[str, Any] = {} + for field in self.metadata_fields: + value = self._get_typed_field_value(field, item) + if value is not None: + metadata[field] = self._serialize_metadata_value(value) + + doc_updated_at = self._extract_timestamp(item) or datetime.now(timezone.utc) + + sem = str(self._extract_field(item, self.content_fields[0]) if self.content_fields else raw_id) + sem = self._strip_html(sem).replace("\n", " ").replace("\r", " ").strip()[:100] or str(doc_id) + + return Document( + id=doc_id, + source=DocumentSource.REST_API, + semantic_identifier=sem, + extension=".txt", + blob=blob, + doc_updated_at=doc_updated_at, + size_bytes=len(blob), + metadata=metadata or None, + ) + + # -- Field extraction --------------------------------------------------- + + def _extract_field(self, item: Mapping[str, Any], path: str) -> Any: + """Extract a value using dot-notation with optional array indexing. + + Examples: ``country.name``, ``tags[0].label``, ``tags[*].label`` + """ + values = self._extract_field_values(item, path) + if not values: + return None + return values[0] if len(values) == 1 else values + + def _extract_field_values(self, item: Mapping[str, Any], path: str) -> List[Any]: + """Return all raw values for a dot-notation field path with wildcards.""" + if not path: + return [] + + current_values: List[Any] = [item] + for segment in path.split("."): + if not segment: + return [] + + match = _FIELD_SEGMENT_RE.match(segment) + key = segment + index: Optional[str] = None + if match: + key = match.group("key") + index = match.group("index") + + next_values: List[Any] = [] + for value in current_values: + if not isinstance(value, Mapping): + continue + child = value.get(key) + if child is None: + continue + if index is None: + next_values.append(child) + elif not isinstance(child, list): + continue + elif index == "*": + next_values.extend(child) + else: + try: + idx = int(index) + except ValueError: + continue + if 0 <= idx < len(child): + next_values.append(child[idx]) + + current_values = next_values + if not current_values: + break + + return current_values + + def _get_typed_field_value(self, path: str, item: Mapping[str, Any]) -> Any: + """Extract a field value, applying type hints, defaults, and array joining.""" + values = self._extract_field_values(item, path) + if not values: + return self.field_default_values.get(path) + + hint = self.field_type_hints.get(path) + + def _convert(v: Any) -> Any: + if hint == "string": + return "" if v is None else str(v) + if hint == "number": + if v is None: + return None + try: + num = float(v) + return int(num) if num.is_integer() else num + except Exception: + return None + if hint == "date": + if isinstance(v, datetime): + return v.isoformat() + dt = self._parse_datetime(v) + if dt is not None: + return dt.isoformat() + return str(v) if v is not None else None + return v + + converted = [_convert(v) for v in values] + non_null = [v for v in converted if v is not None] + if not non_null: + return None + if len(non_null) == 1: + return non_null[0] + return ", ".join(self._coerce_to_text(v) for v in non_null) + + # -- Timestamp parsing -------------------------------------------------- + + def _extract_timestamp(self, item: Mapping[str, Any]) -> Optional[datetime]: + """Extract and normalise a timestamp from ``poll_timestamp_field``.""" + if not self.poll_timestamp_field: + return None + + value = self._extract_field(item, self.poll_timestamp_field) + if isinstance(value, list) and value: + value = value[0] + return self._parse_datetime(value) + + @staticmethod + def _parse_datetime(value: Any) -> Optional[datetime]: + """Parse a raw value into a UTC datetime, or return None.""" + if value is None: + return None + + if isinstance(value, datetime): + return (value if value.tzinfo else value.replace(tzinfo=timezone.utc)).astimezone(timezone.utc) + + if isinstance(value, (int, float)): + try: + return datetime.fromtimestamp(float(value), tz=timezone.utc) + except Exception: + return None + + if isinstance(value, str): + ts = value.strip() + for fmt in ("%Y-%m-%dT%H:%M:%S.%fZ", "%Y-%m-%dT%H:%M:%SZ", "%Y-%m-%d %H:%M:%S", "%Y-%m-%d"): + try: + return datetime.strptime(ts, fmt).replace(tzinfo=timezone.utc) + except Exception: + continue + try: + dt = datetime.fromisoformat(ts.replace("Z", "+00:00").replace(" ", "T")) + return (dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)).astimezone(timezone.utc) + except Exception: + return None + + return None + + # -- Content template rendering ----------------------------------------- + + class _SafeDict(dict): + """Dict subclass that returns empty string for missing keys in format_map.""" + def __missing__(self, key: str) -> str: + return "" + + def _render_content_template(self, item: Mapping[str, Any]) -> str: + """Render content using a user-provided template with ``{field}`` placeholders.""" + template = self.content_template or "" + values: Dict[str, str] = {} + for field_path in set(self.content_fields + self.metadata_fields): + val = self._get_typed_field_value(field_path, item) + if val is None: + continue + name = re.sub(r"\[\d+\]|\[\*\]", "", field_path).replace(".", "_") + values[name] = self._coerce_to_text(val) + + try: + rendered = template.format_map(self._SafeDict(values)) + except Exception as exc: + logging.warning("Failed to render content template: %s", exc) + parts = [self._coerce_to_text(self._get_typed_field_value(f, item)) for f in self.content_fields] + rendered = "\n".join(p for p in parts if p) + + return self._strip_html(rendered) + + # -- Static helpers ----------------------------------------------------- + + @staticmethod + def _strip_html(text: str) -> str: + """Remove basic HTML tags and normalise whitespace.""" + if "<" not in text or ">" not in text: + return text + cleaned = re.sub(r"<[^>]+>", " ", text) + return re.sub(r"\s+", " ", cleaned).strip() + + @staticmethod + def _coerce_to_text(value: Any) -> str: + """Convert any value to a plain-text string.""" + if value is None: + return "" + if isinstance(value, str): + return value + if isinstance(value, (int, float, bool)): + return str(value) + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return str(value) + + @staticmethod + def _serialize_metadata_value(value: Any) -> Any: + """Serialise a metadata value for storage.""" + if isinstance(value, datetime): + if value.tzinfo is None: + value = value.replace(tzinfo=timezone.utc) + return value.isoformat() + if isinstance(value, (int, float, bool, str)): + return value + try: + return json.dumps(value, ensure_ascii=False) + except Exception: + return str(value) + + @staticmethod + def _doc_in_time_window(doc: Document, start: datetime, end: datetime) -> bool: + if not doc.doc_updated_at: + return False + dt = doc.doc_updated_at + dt = (dt if dt.tzinfo else dt.replace(tzinfo=timezone.utc)).astimezone(timezone.utc) + return start <= dt < end diff --git a/common/doc_store/es_conn_base.py b/common/doc_store/es_conn_base.py index dccb8a2fe3d..daa5f17770e 100644 --- a/common/doc_store/es_conn_base.py +++ b/common/doc_store/es_conn_base.py @@ -159,6 +159,61 @@ def create_doc_meta_idx(self, index_name: str): except Exception as e: self.logger.exception(f"Error creating document metadata index {index_name}: {e}") + def refresh_idx(self, index_name: str) -> bool: + """ + Refresh an index so that recently inserted documents become searchable. + + Service layers should call this dispatch method instead of reaching + into ``self.es`` directly, so the OpenSearch and Elasticsearch + connections present a uniform abstract API. + """ + try: + self.es.indices.refresh(index=index_name) + return True + except NotFoundError: + return False + except Exception as e: + self.logger.warning(f"ESConnection.refresh_idx({index_name}) failed: {e}") + return False + + def count_idx(self, index_name: str) -> int: + """ + Return the document count for an index, or -1 if the call fails. + Used to decide whether a per-tenant metadata index is empty without + paying a full search. + """ + try: + response = self.es.count(index=index_name) + return int(response.get("count", 0)) + except NotFoundError: + return 0 + except Exception as e: + self.logger.warning(f"ESConnection.count_idx({index_name}) failed: {e}") + return -1 + + def replace_meta_fields(self, index_name: str, doc_id: str, meta_fields: dict) -> bool: + """ + Fully replace the ``meta_fields`` object on a single document. + + Using ES.update with a ``doc`` body would deep-merge object fields, + retaining old keys that should be removed. A scripted update assigns + the new meta_fields outright, matching delete-key semantics. + """ + body = { + "script": { + "source": "ctx._source.meta_fields = params.meta_fields", + "params": {"meta_fields": meta_fields}, + } + } + try: + self.es.update(index=index_name, id=doc_id, refresh=True, body=body) + return True + except NotFoundError: + return False + except Exception as e: + self.logger.warning(f"ESConnection.replace_meta_fields({index_name}, {doc_id}) failed: {e}") + return False + def delete_idx(self, index_name: str, dataset_id: str): if len(dataset_id) > 0: # The index need to be alive after any kb deletion since all kb under this tenant are in one index. @@ -247,6 +302,21 @@ def get_total(self, res): def get_doc_ids(self, res): return [d["_id"] for d in res["hits"]["hits"]] + def get_scores(self, res) -> dict[str, float]: + """ + Map hit `_id` to its raw `_score`. Used to recover the cosine + similarity returned by a KNN-only search without pulling the + chunk vectors out of the index. + """ + out = {} + for d in res.get("hits", {}).get("hits", []): + doc_id = d.get("_id") + if doc_id is None: + continue + score = d.get("_score") + out[doc_id] = float(score) if score is not None else 0.0 + return out + def _get_source(self, res): rr = [] for d in res["hits"]["hits"]: diff --git a/common/mcp_tool_call_conn.py b/common/mcp_tool_call_conn.py index 95e3581bb0b..676978d052e 100644 --- a/common/mcp_tool_call_conn.py +++ b/common/mcp_tool_call_conn.py @@ -20,6 +20,7 @@ import weakref from concurrent.futures import ThreadPoolExecutor from concurrent.futures import TimeoutError as FuturesTimeoutError +from dataclasses import dataclass from string import Template from typing import Any, Literal, Protocol @@ -36,7 +37,13 @@ class ToolCallSession(Protocol): - def tool_call(self, name: str, arguments: dict[str, Any]) -> str: ... + def tool_call(self, name: str, arguments: dict[str, Any], timeout: float | int = 10) -> str: ... + + +@dataclass(frozen=True) +class MCPToolBinding: + session: ToolCallSession + original_name: str class MCPToolCallSession(ToolCallSession): @@ -316,12 +323,12 @@ def shutdown_all_mcp_sessions(): logging.info("All MCPToolCallSession instances have been closed.") -def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]: +def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict, function_name: str | None = None) -> dict[str, Any]: if isinstance(mcp_tool, dict): return { "type": "function", "function": { - "name": mcp_tool["name"], + "name": function_name or mcp_tool["name"], "description": mcp_tool["description"], "parameters": mcp_tool["inputSchema"], }, @@ -330,7 +337,7 @@ def mcp_tool_metadata_to_openai_tool(mcp_tool: Tool | dict) -> dict[str, Any]: return { "type": "function", "function": { - "name": mcp_tool.name, + "name": function_name or mcp_tool.name, "description": mcp_tool.description, "parameters": mcp_tool.inputSchema, }, diff --git a/common/metadata_infinity_filter.py b/common/metadata_infinity_filter.py new file mode 100644 index 00000000000..076cc2e23e1 --- /dev/null +++ b/common/metadata_infinity_filter.py @@ -0,0 +1,296 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Translate RAGflow document-metadata filter lists into Infinity SQL filter expressions. +""" + +from __future__ import annotations + +import ast +import re +from typing import Any, Dict, List, Sequence + +_KEY_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") + + +def _validate_key(key: str, flt: Dict[str, Any]) -> None: + if not _KEY_PATTERN.match(key): + raise ValueError(f"invalid key format (must be identifier-like): {flt}") + +SUPPORTED_OPERATORS: frozenset[str] = frozenset( + { + "=", + "≠", + ">", + "<", + "≥", + "≤", + "in", + "not in", + "contains", + "not contains", + "start with", + "end with", + "empty", + "not empty", + } +) + +_RANGE_OPS: Dict[str, str] = { + ">": ">", + "<": "<", + "≥": ">=", + "≤": "<=", +} + +class MetaFilterTranslator: + """Translate one user filter clause at a time into Infinity SQL filter strings.""" + + def translate(self, flt: Dict[str, Any]) -> str: + op = flt.get("op") + key = flt.get("key") + value = flt.get("value") + + if not key or not isinstance(key, str): + raise ValueError(f"filter is missing a string key: {flt}") + _validate_key(key, flt) + if op not in SUPPORTED_OPERATORS: + raise ValueError(f"unknown operator: {op!r}, filter: {flt}") + + if op == "empty": + return self._translate_empty(key) + if op == "not empty": + return self._translate_not_empty(key) + if op == "=": + return self._translate_equal(key, value, flt) + if op == "≠": + return self._translate_not_equal(key, value, flt) + if op in _RANGE_OPS: + return self._translate_range(key, op, value, flt) + if op == "in": + return self._translate_in(key, value, flt) + if op == "not in": + return self._translate_not_in(key, value, flt) + if op == "contains": + return self._translate_contains(key, value, flt) + if op == "not contains": + return self._translate_not_contains(key, value, flt) + if op == "start with": + return self._translate_start_with(key, value, flt) + if op == "end with": + return self._translate_end_with(key, value, flt) + + raise ValueError(f"no handler for operator: {op!r}, filter: {flt}") + + def _translate_empty(self, key: str) -> str: + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') = '\"\"'" + + def _translate_not_empty(self, key: str) -> str: + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') != '\"\"'" + + def _translate_equal(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + coerced = _coerce_scalar(value, flt) + if isinstance(coerced, str): + escaped = _escape_sql_string(coerced) + return f"JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')" + return f"JSON_CONTAINS(meta_fields, '$.{key}', {coerced})" + + def _translate_not_equal(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + coerced = _coerce_scalar(value, flt) + if isinstance(coerced, str): + escaped = _escape_sql_string(coerced) + return f"NOT JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')" + return f"NOT JSON_CONTAINS(meta_fields, '$.{key}', {coerced})" + + def _translate_range(self, key: str, op: str, value: Any, flt: Dict[str, Any]) -> str: + coerced = _coerce_range_value(value, flt) + sql_op = _RANGE_OPS.get(op, op) + if isinstance(coerced, str): + escaped = _escape_sql_string(coerced) + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') {sql_op} '{escaped}'" + return f"JSON_EXTRACT_DOUBLE(meta_fields, '$.{key}') {sql_op} {coerced}" + + def _translate_in(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + members = _csv_or_list(value, flt) + string_parts = [] + num_parts = [] + for m in members: + # Use same coercion as range operators to detect numeric values + coerced = _coerce_range_value(m, flt) + if isinstance(coerced, (int, float)): + num_parts.append(f"JSON_CONTAINS(meta_fields, '$.{key}', {coerced})") + else: + escaped = _escape_sql_string(coerced) + string_parts.append(f"JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')") + conditions = [] + if string_parts: + conditions.append("(" + " OR ".join(string_parts) + ")") + if num_parts: + conditions.append("(" + " OR ".join(num_parts) + ")") + return "(" + " OR ".join(conditions) + ")" + + def _translate_not_in(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + members = _csv_or_list(value, flt) + string_parts = [] + num_parts = [] + for m in members: + # Use same coercion as range operators to detect numeric values + coerced = _coerce_range_value(m, flt) + if isinstance(coerced, (int, float)): + num_parts.append(f"NOT JSON_CONTAINS(meta_fields, '$.{key}', {coerced})") + else: + escaped = _escape_sql_string(coerced) + string_parts.append(f"NOT JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')") + conditions = [] + if string_parts: + conditions.append("(" + " AND ".join(string_parts) + ")") + if num_parts: + conditions.append("(" + " AND ".join(num_parts) + ")") + return " AND ".join(conditions) + + def _translate_contains(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + if not value and value != 0: + raise ValueError(f"contains value is empty: {flt}") + # Use same coercion as range operators to detect numeric values + coerced = _coerce_range_value(value, flt) + if isinstance(coerced, (int, float)): + return f"JSON_CONTAINS(meta_fields, '$.{key}', {coerced})" + escaped = _escape_sql_string(str(value)) + return f"JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')" + + def _translate_not_contains(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + text = _coerce_string(value, flt) + escaped = _escape_sql_string(text) + # Use Infinity's JSON_CONTAINS to check if value does NOT exist in JSON array + return f"NOT JSON_CONTAINS(meta_fields, '$.{key}', '\"{escaped}\"')" + + def _translate_start_with(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + text = _coerce_string(value, flt) + escaped = _escape_sql_string(_escape_likeWildcards(text)) + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') LIKE '{escaped}%'" + + def _translate_end_with(self, key: str, value: Any, flt: Dict[str, Any]) -> str: + text = _coerce_string(value, flt) + escaped = _escape_sql_string(_escape_likeWildcards(text)) + return f"JSON_EXTRACT_STRING(meta_fields, '$.{key}') LIKE '%{escaped}'" + + +def plan_pushdown(filters: Sequence[Dict[str, Any]], logic: str) -> List[str]: + if logic not in {"and", "or"}: + raise ValueError(f"unknown logic {logic!r}") + translator = MetaFilterTranslator() + return [translator.translate(flt) for flt in filters] + + +def build_infinity_filter(filters: Sequence[Dict[str, Any]], logic: str) -> str: + if not filters: + return "1=1" + fragments = plan_pushdown(filters, logic) + joiner = " AND " if logic == "and" else " OR " + result = "(" + joiner.join(fragments) + ")" + return result + + +def is_pushdown_supported(filters: Sequence[Dict[str, Any]]) -> bool: + for flt in filters: + op = flt.get("op") + if op not in SUPPORTED_OPERATORS: + return False + if not isinstance(flt.get("key"), str) or not flt.get("key"): + return False + return True + + +def extract_doc_ids(df) -> List[str]: + if df is None or not hasattr(df, "iterrows"): + return [] + return [str(row["id"]) for _, row in df.iterrows() if "id" in row] + + +# --------------------------------------------------------------------------- +# Value coercion helpers +# --------------------------------------------------------------------------- + + +def _coerce_scalar(value: Any, flt: Dict[str, Any]) -> Any: + if value is None: + raise ValueError(f"scalar comparison value is None: {flt}") + if isinstance(value, (list, dict)): + raise ValueError(f"scalar comparison value is non-scalar: {flt}") + try: + parsed = ast.literal_eval(str(value).strip()) + if isinstance(parsed, (int, float, bool)): + return parsed + except Exception: + pass + return str(value) + + +def _coerce_range_value(value: Any, flt: Dict[str, Any]) -> Any: + if value is None: + raise ValueError(f"range comparison value is None: {flt}") + try: + parsed = ast.literal_eval(str(value).strip()) + if isinstance(parsed, (int, float)): + return parsed + except Exception: + pass + return str(value) + + +def _coerce_string(value: Any, flt: Dict[str, Any]) -> str: + if value is None: + raise ValueError(f"string-operator value is None: {flt}") + if isinstance(value, (list, dict)): + raise ValueError(f"string-operator value must be a scalar: {flt}") + s = str(value) + if not s: + raise ValueError(f"string-operator value is empty: {flt}") + return s + + +def _csv_or_list(value: Any, flt: Dict[str, Any]) -> List[Any]: + if value is None: + raise ValueError(f"membership value is None: {flt}") + if isinstance(value, (list, tuple)): + members = list(value) + elif isinstance(value, str): + try: + parsed = ast.literal_eval(value) + except Exception: + parsed = value + if isinstance(parsed, (list, tuple)): + members = list(parsed) + else: + members = [m.strip() for m in value.split(",") if m.strip()] + else: + members = [value] + if not members: + raise ValueError(f"membership value resolved to empty list: {flt}") + normalised: List[Any] = [] + for m in members: + if isinstance(m, str): + normalised.append(m.lower().strip()) + else: + normalised.append(m) + return normalised + + +def _escape_sql_string(s: str) -> str: + return s.replace("'", "''") + + +def _escape_likeWildcards(text: str) -> str: + return text.replace("\\", "\\\\").replace("%", "\\%").replace("_", "\\_") \ No newline at end of file diff --git a/common/metadata_utils.py b/common/metadata_utils.py index c2fc90b5414..53af2b4eaf3 100644 --- a/common/metadata_utils.py +++ b/common/metadata_utils.py @@ -19,6 +19,7 @@ import json_repair + def convert_conditions(metadata_condition): if metadata_condition is None: metadata_condition = {} @@ -60,21 +61,21 @@ def filter_out(v2docs, operator, value): # Strict date format detection: YYYY-MM-DD (must be 10 chars with correct format) is_input_date = ( - len(input_str) == 10 and - input_str[4] == '-' and - input_str[7] == '-' and - input_str[:4].isdigit() and - input_str[5:7].isdigit() and - input_str[8:10].isdigit() + len(input_str) == 10 and + input_str[4] == '-' and + input_str[7] == '-' and + input_str[:4].isdigit() and + input_str[5:7].isdigit() and + input_str[8:10].isdigit() ) is_value_date = ( - len(value_str) == 10 and - value_str[4] == '-' and - value_str[7] == '-' and - value_str[:4].isdigit() and - value_str[5:7].isdigit() and - value_str[8:10].isdigit() + len(value_str) == 10 and + value_str[4] == '-' and + value_str[7] == '-' and + value_str[:4].isdigit() and + value_str[5:7].isdigit() and + value_str[8:10].isdigit() ) if is_value_date: @@ -109,17 +110,23 @@ def filter_out(v2docs, operator, value): matched = False try: if operator == "contains": - matched = str(input).find(value) >= 0 if not isinstance(input, list) else any(str(i).find(value) >= 0 for i in input) + matched = str(input).find(value) >= 0 if not isinstance(input, list) else any( + str(i).find(value) >= 0 for i in input) elif operator == "not contains": - matched = str(input).find(value) == -1 if not isinstance(input, list) else all(str(i).find(value) == -1 for i in input) + matched = str(input).find(value) == -1 if not isinstance(input, list) else all( + str(i).find(value) == -1 for i in input) elif operator == "in": matched = input in value if not isinstance(input, list) else all(i in value for i in input) elif operator == "not in": matched = input not in value if not isinstance(input, list) else all(i not in value for i in input) elif operator == "start with": - matched = str(input).lower().startswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).startswith(str(value).lower()) + matched = str(input).lower().startswith(str(value).lower()) if not isinstance(input, + list) else "".join( + [str(i).lower() for i in input]).startswith(str(value).lower()) elif operator == "end with": - matched = str(input).lower().endswith(str(value).lower()) if not isinstance(input, list) else "".join([str(i).lower() for i in input]).endswith(str(value).lower()) + matched = str(input).lower().endswith(str(value).lower()) if not isinstance(input, + list) else "".join( + [str(i).lower() for i in input]).endswith(str(value).lower()) elif operator == "empty": matched = not input elif operator == "not empty": @@ -158,21 +165,23 @@ def filter_out(v2docs, operator, value): if logic == "and": doc_ids = doc_ids & set(ids) if not doc_ids: + logging.debug(f"meta_filter filters={filters}, logic={logic}, early return []") return [] else: doc_ids = doc_ids | set(ids) + logging.debug(f"meta_filter filters={filters}, logic={logic}, returning doc_ids={list(doc_ids)}") return list(doc_ids) async def apply_meta_data_filter( - meta_data_filter: dict | None, - metas: dict | None = None, - question: str = "", - chat_mdl: Any = None, - base_doc_ids: list[str] | None = None, - manual_value_resolver: Callable[[dict], dict] | None = None, - kb_ids: list[str] | None = None, - metas_loader: Callable[[], dict] | None = None, + meta_data_filter: dict | None, + metas: dict | None = None, + question: str = "", + chat_mdl: Any = None, + base_doc_ids: list[str] | None = None, + manual_value_resolver: Callable[[dict], dict] | None = None, + kb_ids: list[str] | None = None, + metas_loader: Callable[[], dict] | None = None, ) -> list[str] | None: """ Apply metadata filtering rules and return the filtered doc_ids. @@ -182,12 +191,11 @@ async def apply_meta_data_filter( - semi_auto: generate conditions using selected metadata keys only - manual: directly filter based on provided conditions - When ``kb_ids`` is supplied and the active doc store is Elasticsearch the - generated filter conditions are pushed down to ES via - ``DocMetadataService.filter_doc_ids_by_meta_pushdown`` instead of being - evaluated in Python over ``metas``. The in-memory ``meta_filter`` path - remains the fallback so callers without a KB scope, or backends without - push-down support, behave exactly as before. + When ``kb_ids`` is supplied, metadata filters are pushed down to the doc metadata + index (ES/Infinity) via ``DocMetadataService.filter_doc_ids_by_metadata`` instead + of being evaluated in Python over ``metas``. The in-memory ``meta_filter`` path + remains the fallback so callers without a KB scope, or backends without push-down + support, behave exactly as before. ``metas`` may be supplied eagerly or via ``metas_loader``. The loader is only invoked when the metadata dict is actually needed — i.e. for the LLM @@ -200,7 +208,7 @@ async def apply_meta_data_filter( list of doc_ids, ["-999"] when manual filters yield no result, or None when auto/semi_auto filters return empty. """ - from rag.prompts.generator import gen_meta_filter # move from the top of the file to avoid circular import + from rag.prompts.generator import gen_meta_filter # move from the top of the file to avoid circular import doc_ids = list(base_doc_ids) if base_doc_ids else [] @@ -220,17 +228,26 @@ def _get_metas() -> dict: cached_metas = metas_loader() if metas_loader else {} return cached_metas - def _evaluate(conditions: list[dict], logic: str) -> list[str]: - """Run conditions through ES push-down when possible, in-memory otherwise.""" + def _run_metadata_filter(conditions: list[dict], logic: str) -> list[str]: + """Run conditions through ES/Infinity push-down when possible, in-memory otherwise.""" if conditions and kb_ids: - pushed = _try_meta_pushdown(kb_ids, conditions, logic) - if pushed is not None: - return pushed + try: + from api.db.services.doc_metadata_service import DocMetadataService + doc_ids = DocMetadataService.filter_doc_ids_by_meta_pushdown(kb_ids, conditions, logic) + logging.debug(f"Doc ids filtered by metadata: {doc_ids}") + if doc_ids is not None: + return doc_ids + except Exception as e: + logging.error(f"Metadata filter push down errored: {e}") + + # In-memory fallback + logging.debug("Metadata filter falls back to in-memory filter") return meta_filter(_get_metas(), conditions, logic) if method == "auto": filters: dict = await gen_meta_filter(chat_mdl, _get_metas(), question) - doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and"))) + logging.debug(f"Metadata filter(auto) generated: {filters}") + doc_ids.extend(_run_metadata_filter(filters["conditions"], filters.get("logic", "and"))) if not doc_ids: return None elif method == "semi_auto": @@ -251,24 +268,27 @@ def _evaluate(conditions: list[dict], logic: str) -> list[str]: filtered_metas = {key: current_metas[key] for key in selected_keys if key in current_metas} if filtered_metas: filters: dict = await gen_meta_filter(chat_mdl, filtered_metas, question, constraints=constraints) - doc_ids.extend(_evaluate(filters["conditions"], filters.get("logic", "and"))) + logging.debug(f"Metadata filter(semi_auto) generated: {filters}") + doc_ids.extend(_run_metadata_filter(filters["conditions"], filters.get("logic", "and"))) if not doc_ids: return None elif method == "manual": filters = meta_data_filter.get("manual", []) if manual_value_resolver: filters = [manual_value_resolver(flt) for flt in filters] - doc_ids.extend(_evaluate(filters, meta_data_filter.get("logic", "and"))) + logging.debug(f"Metadata filter(manual): {filters}") + doc_ids.extend(_run_metadata_filter(filters, meta_data_filter.get("logic", "and"))) if filters and not doc_ids: doc_ids = ["-999"] + logging.debug(f"apply_meta_data_filter meta_filter={meta_data_filter}, returning doc_ids={doc_ids}") return doc_ids def _try_meta_pushdown( - kb_ids: list[str], - conditions: list[dict], - logic: str, + kb_ids: list[str], + conditions: list[dict], + logic: str, ) -> list[str] | None: """Attempt the ES push-down path; return ``None`` to fall back in-memory. @@ -335,7 +355,7 @@ def update_metadata_to(metadata, meta): return metadata -def metadata_schema(metadata: dict|list|None) -> Dict[str, Any]: +def metadata_schema(metadata: dict | list | None) -> Dict[str, Any]: if not metadata: return {} properties = {} diff --git a/common/ssrf_guard.py b/common/ssrf_guard.py index b60bcd4bc99..4f87b94d7b8 100644 --- a/common/ssrf_guard.py +++ b/common/ssrf_guard.py @@ -170,3 +170,42 @@ def assert_url_is_safe( raise ValueError(f"Hostname {hostname!r} resolved to no addresses.") return hostname, resolved_ip + + +def assert_host_is_safe(host: str) -> str: + """Raise ``ValueError`` if *host* resolves to a non-public IP (SSRF guard for raw host/port connections). + + This is the host-level counterpart of :func:`assert_url_is_safe`, intended + for callers that connect via database drivers or other non-HTTP protocols + where there is no URL to parse. + + Returns the first validated public IP string so the caller can pin it if needed. + """ + if not host: + raise ValueError("Host must not be empty.") + + try: + addr_infos = socket.getaddrinfo(host, None) + except socket.gaierror as exc: + logger.warning("SSRF guard could not resolve host=%r reason=%s", host, exc) + raise ValueError(f"Could not resolve host {host!r}: {exc}") from exc + + resolved_ip: str | None = None + for _family, _type, _proto, _canonname, sockaddr in addr_infos: + raw_ip = ipaddress.ip_address(sockaddr[0]) + eff_ip = _effective_ip(raw_ip) + if not eff_ip.is_global: + logger.warning( + "SSRF guard blocked host: host=%r resolved to non-public address=%s", + host, + raw_ip, + ) + raise ValueError(f"Host resolves to a non-public address ({raw_ip}), which is not allowed.") + if resolved_ip is None: + resolved_ip = str(raw_ip) + + if resolved_ip is None: + logger.warning("SSRF guard blocked host: host=%r resolved to no addresses", host) + raise ValueError(f"Host {host!r} resolved to no addresses.") + + return resolved_ip diff --git a/conf/llm_factories.json b/conf/llm_factories.json index 2fc12803d78..4a98f2ccc5f 100644 --- a/conf/llm_factories.json +++ b/conf/llm_factories.json @@ -8,6 +8,34 @@ "rank": "999", "url": "https://api.openai.com/v1", "llm": [ + { + "llm_name": "gpt-5.5", + "tags": "LLM,CHAT,400k,IMAGE2TEXT", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-5.4", + "tags": "LLM,CHAT,400k,IMAGE2TEXT", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-5.4-mini", + "tags": "LLM,CHAT,400k,IMAGE2TEXT", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, + { + "llm_name": "gpt-5.4-nano", + "tags": "LLM,CHAT,400k,IMAGE2TEXT", + "max_tokens": 400000, + "model_type": "chat", + "is_tools": true + }, { "llm_name": "gpt-5.2-pro", "tags": "LLM,CHAT,400k,IMAGE2TEXT", @@ -2234,13 +2262,6 @@ "model_type": "chat", "is_tools": true }, - { - "llm_name": "qwen/qwen2.5-coder-32b-instruct", - "tags": "LLM,CHAT,32K", - "max_tokens": 32768, - "model_type": "chat", - "is_tools": true - }, { "llm_name": "rakuten/rakutenai-7b-chat", "tags": "LLM,CHAT,4K", @@ -2815,13 +2836,6 @@ "rank": "780", "url": "https://api.siliconflow.cn/v1", "llm": [ - { - "llm_name": "THUDM/GLM-4.1V-9B-Thinking", - "tags": "LLM,CHAT,IMAGE2TEXT, 64k", - "max_tokens": 64000, - "model_type": "chat", - "is_tools": false - }, { "llm_name": "Qwen/Qwen3-Embedding-8B", "tags": "TEXT EMBEDDING,TEXT RE-RANK,32k", @@ -2878,13 +2892,6 @@ "model_type": "chat", "is_tools": true }, - { - "llm_name": "Qwen/QVQ-72B-Preview", - "tags": "LLM,CHAT,IMAGE2TEXT,32k", - "max_tokens": 32000, - "model_type": "image2text", - "is_tools": false - }, { "llm_name": "Pro/deepseek-ai/DeepSeek-R1", "tags": "LLM,CHAT,64k", @@ -2955,20 +2962,6 @@ "model_type": "chat", "is_tools": true }, - { - "llm_name": "Qwen/Qwen2.5-VL-72B-Instruct", - "tags": "LLM,CHAT,IMAGE2TEXT,128k", - "max_tokens": 128000, - "model_type": "image2text", - "is_tools": true - }, - { - "llm_name": "Pro/Qwen/Qwen2.5-VL-7B-Instruct", - "tags": "LLM,CHAT,IMAGE2TEXT,32k", - "max_tokens": 32000, - "model_type": "image2text", - "is_tools": false - }, { "llm_name": "THUDM/GLM-Z1-32B-0414", "tags": "LLM,CHAT,32k", @@ -3018,20 +3011,6 @@ "model_type": "chat", "is_tools": true }, - { - "llm_name": "Qwen/Qwen2.5-Coder-32B-Instruct", - "tags": "LLM,CHAT,32k", - "max_tokens": 32000, - "model_type": "chat", - "is_tools": false - }, - { - "llm_name": "Qwen/Qwen2-VL-72B-Instruct", - "tags": "LLM,IMAGE2TEXT,32k", - "max_tokens": 32000, - "model_type": "image2text", - "is_tools": false - }, { "llm_name": "Qwen/Qwen2.5-72B-Instruct-128Kt", "tags": "LLM,IMAGE2TEXT,128k", @@ -3636,13 +3615,6 @@ "model_type": "chat", "is_tools": true }, - { - "llm_name": "Qwen/Qwen2.5-VL-32B-Instruct", - "tags": "LLM,CHAT,131k", - "max_tokens": 131000, - "model_type": "chat", - "is_tools": true - }, { "llm_name": "Qwen/QwQ-32B", "tags": "LLM,CHAT,131k", @@ -3650,20 +3622,6 @@ "model_type": "chat", "is_tools": true }, - { - "llm_name": "Qwen/Qwen2.5-VL-72B-Instruct", - "tags": "LLM,CHAT,131k", - "max_tokens": 131000, - "model_type": "chat", - "is_tools": true - }, - { - "llm_name": "Qwen/Qwen2.5-VL-7B-Instruct", - "tags": "LLM,CHAT,33k", - "max_tokens": 33000, - "model_type": "chat", - "is_tools": false - }, { "llm_name": "deepseek-ai/DeepSeek-R1-Distill-Qwen-32B", "tags": "LLM,CHAT,131k", diff --git a/conf/models/302ai.json b/conf/models/302ai.json new file mode 100644 index 00000000000..8edb8d10c42 --- /dev/null +++ b/conf/models/302ai.json @@ -0,0 +1,222 @@ +{ + "name": "302.AI", + "url": { + "default": "https://api.302.ai" + }, + "url_suffix": { + "chat": "v1/chat/completions", + "models": "v1/models", + "embedding": "jina/v1/embeddings", + "rerank": "jina/v1/rerank", + "asr": "v1/audio/transcriptions", + "doc_parse": "mineru/api/v4/extract/task", + "task": "mineru/api/v4/extract/task", + "ocr": "mistral/v1/ocr" + }, + "class": "302.ai", + "models": [ + { + "name": "kimi-k2.6", + "max_tokens": 262144, + "model_types": [ + "chat", + "vision" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "gpt-5.5", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.4", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.4-mini", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.4-nano", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.2-pro", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.2", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.1", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.1-chat-latest", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5-mini", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5-nano", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5-chat-latest", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-4.1", + "max_tokens": 1047576, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-4.1-mini", + "max_tokens": 1047576, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-4.1-nano", + "max_tokens": 1047576, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-4.5-preview", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "gpt-4o-mini", + "max_tokens": 128000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-4o", + "max_tokens": 128000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-3.5-turbo", + "max_tokens": 4096, + "model_types": [ + "chat" + ] + }, + { + "name": "gpt-3.5-turbo-16k-0613", + "max_tokens": 16385, + "model_types": [ + "chat" + ] + }, + { + "name": "whisper-v3-turbo", + "max_tokens": 8192, + "model_types": [ + "asr" + ] + }, + { + "name": "mistral-ocr-latest", + "max_tokens": 8192, + "model_types": [ + "ocr" + ] + }, + { + "name": "vlm", + "model_types": [ + "doc_parse" + ] + }, + { + "name": "jina-embeddings-v3", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "jina-reranker-v2-base-multilingual", + "max_tokens": 134144, + "model_types": [ + "rerank" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/anthropic.json b/conf/models/anthropic.json new file mode 100644 index 00000000000..93cd6b27975 --- /dev/null +++ b/conf/models/anthropic.json @@ -0,0 +1,85 @@ +{ + "name": "Anthropic", + "url": { + "default": "https://api.anthropic.com" + }, + "url_suffix": { + "chat": "v1/messages", + "models": "v1/models" + }, + "class": "anthropic", + "models": [ + { + "name": "claude-opus-4-5-20251101", + "max_tokens": 128000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "claude-opus-4-1-20250805", + "max_tokens": 128000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "claude-opus-4-20250514", + "max_tokens": 128000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "claude-sonnet-4-5-20250929", + "max_tokens": 64000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "claude-sonnet-4-20250514", + "max_tokens": 64000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "claude-haiku-4-5-20251001", + "max_tokens": 64000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "claude-3-7-sonnet-20250219", + "max_tokens": 64000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "claude-3-5-sonnet-20241022", + "max_tokens": 8192, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "claude-3-5-haiku-20241022", + "max_tokens": 8192, + "model_types": [ + "chat", + "vision" + ] + } + ] +} diff --git a/conf/models/baichuan.json b/conf/models/baichuan.json new file mode 100644 index 00000000000..c7bc5f1c0d0 --- /dev/null +++ b/conf/models/baichuan.json @@ -0,0 +1,90 @@ +{ + "name": "Baichuan", + "url": { + "default": "https://api.baichuan-ai.com/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "embedding": "embeddings" + }, + "class": "baichuan", + "models": [ + { + "name": "Baichuan4", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan4-Air", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan4-Turbo", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan-M3", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan-M3-plus", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan-M2-plus", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan-M2", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan3-Turbo", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan3-Turbo-128k", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan2-Turbo", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "Baichuan-Text-Embedding", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/baidu.json b/conf/models/baidu.json new file mode 100644 index 00000000000..b1654652697 --- /dev/null +++ b/conf/models/baidu.json @@ -0,0 +1,87 @@ +{ + "name": "Baidu", + "url": { + "default": "https://qianfan.baidubce.com/v2" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "embedding": "embeddings", + "rerank": "rerank", + "ocr": "ocr/paddleocr" + }, + "class": "baidu", + "models": [ + { + "name": "deepseek-v3.2", + "max_tokens": 98304, + "model_types": [ + "chat" + ] + }, + { + "name": "deepseek-v4-flash", + "max_tokens": 1048576, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "deepseek-v4-pro", + "max_tokens": 1048576, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "qwen3-32b", + "max_tokens": 30720, + "model_types":[ + "chat" + ] + }, + { + "name": "qwen3-4b", + "max_tokens": 30720, + "model_types": [ + "chat" + ] + }, + { + "name": "ernie-5.0", + "max_tokens": 121856, + "model_types": [ + "vision" + ] + }, + { + "name": "embedding-v1", + "max_tokens": 384, + "model_types": [ + "embedding" + ] + }, + { + "name": "qwen3-reranker-4b", + "max_tokens": 32768, + "model_types": [ + "rerank" + ] + }, + { + "name": "paddleocr-vl-0.9b", + "max_tokens": 8192, + "model_types": [ + "ocr" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/cohere.json b/conf/models/cohere.json new file mode 100644 index 00000000000..98242d49a9a --- /dev/null +++ b/conf/models/cohere.json @@ -0,0 +1,51 @@ +{ + "name": "CoHere", + "url": { + "default": "https://api.cohere.com" + }, + "url_suffix": { + "chat": "v2/chat", + "models": "v1/models", + "embeddings": "v2/embed", + "rerank": "v2/rerank", + "asr": "audio/transcriptions" + }, + "class": "cohere", + "models": [ + { + "name": "command-a-03-2025", + "max_tokens": 256000, + "model_types": [ + "chat" + ] + }, + { + "name": "command-a-reasoning-08-2025", + "max_tokens": 256000, + "model_types": [ + "chat" + ] + }, + { + "name": "rerank-v4.0-pro", + "max_tokens": 128000, + "model_types": [ + "rerank" + ] + }, + { + "name": "embed-v4.0", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "cohere-transcribe-03-2026", + "max_tokens": 8192, + "model_types": [ + "asr" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/cometapi.json b/conf/models/cometapi.json new file mode 100644 index 00000000000..06cbbe1daf8 --- /dev/null +++ b/conf/models/cometapi.json @@ -0,0 +1,125 @@ +{ + "name": "CometAPI", + "url": { + "default": "https://api.cometapi.com" + }, + "url_suffix": { + "chat": "v1/chat/completions", + "models": "api/models", + "embedding": "v1/embeddings", + "balance": "https://query.cometapi.com/user/quota", + "tts": "v1/audio/speech", + "asr": "v1/audio/transcriptions" + }, + "class": "cometapi", + "models": [ + { + "name": "gpt-5.5", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "gpt-5.4-mini", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "gpt-5", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "gpt-4o", + "max_tokens": 128000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "claude-sonnet-4-6", + "max_tokens": 200000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gemini-3-pro-preview", + "max_tokens": 1048576, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "deepseek-v3.2", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "qwen3-235b-a22b", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "text-embedding-3-small", + "max_tokens": 8191, + "model_types": [ + "embedding" + ] + }, + { + "name": "text-embedding-3-large", + "max_tokens": 8191, + "model_types": [ + "embedding" + ] + }, + { + "name": "text-embedding-ada-002", + "max_tokens": 8191, + "model_types": [ + "embedding" + ] + }, + { + "name": "whisper-1", + "model_types": [ + "asr" + ] + }, + { + "name": "tts-1", + "max_tokens": 8192, + "model_types": [ + "tts" + ] + } + ] +} diff --git a/conf/models/deepinfra.json b/conf/models/deepinfra.json new file mode 100644 index 00000000000..a9277fc6e72 --- /dev/null +++ b/conf/models/deepinfra.json @@ -0,0 +1,49 @@ +{ + "name": "DeepInfra", + "url": { + "default": "https://api.deepinfra.com" + }, + "url_suffix": { + "chat": "v1/chat/completions", + "models": "models/list", + "balance": "payment/checklist", + "embedding": "v1/embeddings", + "tts": "v1/text-to-speech", + "asr": "v1/audio/transcriptions" + }, + "class": "deepinfra", + "models": [ + { + "name": "deepseek-ai/DeepSeek-V3.2", + "max_tokens": 32768, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "Qwen/Qwen3-Embedding-4B", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "hexgrad/Kokoro-82M", + "max_tokens": 16384, + "model_types": [ + "tts" + ] + }, + { + "name": "bosonai/HiggsAudioV2.5", + "max_tokens": 8192, + "model_types": [ + "asr" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/fishaudio.json b/conf/models/fishaudio.json new file mode 100644 index 00000000000..aa6beda9647 --- /dev/null +++ b/conf/models/fishaudio.json @@ -0,0 +1,36 @@ +{ + "name": "FishAudio", + "url": { + "default": "https://api.fish.audio" + }, + "url_suffix": { + "models": "model", + "balance": "self/package", + "tts": "v1/tts", + "asr": "v1/asr" + }, + "class": "fishaudio", + "models": [ + { + "name": "s2-pro", + "max_tokens": 8192, + "model_types": [ + "tts" + ] + }, + { + "name": "s1", + "max_tokens": 8192, + "model_types": [ + "tts" + ] + }, + { + "name": "transcribe-1", + "max_tokens": 8192, + "model_types": [ + "asr" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/gitee.json b/conf/models/gitee.json index 630106592f2..6b1a0732e85 100644 --- a/conf/models/gitee.json +++ b/conf/models/gitee.json @@ -9,7 +9,11 @@ "status": "", "balance": "tokens/packages/balance", "embedding": "embedding", - "rerank": "rerank" + "rerank": "rerank", + "ocr": "images/ocr", + "doc_parse": "async/documents/parse", + "tasks": "tasks", + "task": "task" }, "models": [ { @@ -39,6 +43,43 @@ "model_types": [ "rerank" ] + }, + { + "name": "BAAI/bge-m3", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "GOT-OCR2_0", + "model_types": [ + "ocr" + ] + }, + { + "name": "DeepSeek-OCR-2", + "model_types": [ + "ocr" + ] + }, + { + "name": "PaddleOCR-VL-1.5", + "model_types": [ + "ocr" + ] + }, + { + "name": "HunyuanOCR", + "model_types": [ + "ocr" + ] + }, + { + "name": "MinerU2.5", + "model_types": [ + "doc_parse" + ] } ] } \ No newline at end of file diff --git a/conf/models/google.json b/conf/models/google.json index 2e4cf30525f..a1d5f129f0b 100644 --- a/conf/models/google.json +++ b/conf/models/google.json @@ -18,6 +18,13 @@ "default_value": true, "clear_thinking": true } + }, + { + "name": "text-embedding-004", + "max_tokens": 2048, + "model_types": [ + "embedding" + ] } ], "features": { diff --git a/conf/models/huggingface.json b/conf/models/huggingface.json index c46ab4a46bd..f1a7d942fb9 100644 --- a/conf/models/huggingface.json +++ b/conf/models/huggingface.json @@ -1,7 +1,7 @@ { "name": "HuggingFace", "url": { - "default": "https://router.huggingface.co/v1/" + "default": "https://router.huggingface.co/v1" }, "url-suffix": { "chat": "chat/completions", diff --git a/conf/models/jiekouai.json b/conf/models/jiekouai.json new file mode 100644 index 00000000000..4591c4e5a99 --- /dev/null +++ b/conf/models/jiekouai.json @@ -0,0 +1,106 @@ +{ + "name": "JieKouAI", + "url": { + "default": "https://api.jiekou.ai" + }, + "url_suffix": { + "chat": "openai/v1/chat/completions", + "embedding": "openai/v1/embeddings", + "rerank": "openai/v1/rerank", + "models": "openai/v1/models" + }, + "class": "jiekouai", + "models": [ + { + "name": "deepseek-v4-flash", + "max_tokens": 1048576, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "deepseek-v4-pro", + "max_tokens": 1048576, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "zai-org/glm-4.5", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "zai-org/glm-4.5v", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "zai-org/glm-4.7", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "zai-org/glm-4.7-flash", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "zai-org/glm-5", + "max_tokens": 131072, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + }, + { + "name": "baai/bge-reranker-v2-m3", + "max_tokens": 8192, + "model_types": [ + "rerank" + ] + }, + { + "name": "text-embedding-3-large", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/jina.json b/conf/models/jina.json new file mode 100644 index 00000000000..97069a4b9a8 --- /dev/null +++ b/conf/models/jina.json @@ -0,0 +1,107 @@ +{ + "name": "Jina", + "url": { + "default": "https://api.jina.ai/v1", + "deepsearch": "https://deepsearch.jina.ai/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "embedding": "embeddings", + "rerank": "rerank" + }, + "class": "jina", + "models": [ + { + "name": "jina-vlm", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "jina-reranker-v3", + "max_tokens": 134144, + "model_types": [ + "rerank" + ] + }, + { + "name": "jina-reranker-m0", + "max_tokens": 134144, + "model_types": [ + "rerank" + ] + }, + { + "name": "jina-colbert-v2", + "max_tokens": 134144, + "model_types": [ + "rerank" + ] + }, + { + "name": "jina-reranker-v2-base-multilingual", + "max_tokens": 134144, + "model_types": [ + "rerank" + ] + }, + { + "name": "jina-embeddings-v3", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "jina-embeddings-v4", + "max_tokens": 32768, + "model_types": [ + "embedding" + ] + }, + { + "name": "jina-embeddings-v5-text-small", + "max_tokens": 32768, + "model_types": [ + "embedding" + ] + }, + { + "name": "jina-embeddings-v5-text-nano", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "jina-embeddings-v5-omni-small", + "max_tokens": 32768, + "model_types": [ + "embedding" + ] + }, + { + "name": "jina-embeddings-v5-omni-nano", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "jina-clip-v2", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "jina-embeddings-v2-base-en", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + } + ] +} diff --git a/conf/models/lmstudio.json b/conf/models/lmstudio.json index a22cbb982fe..a5293ffb9d5 100644 --- a/conf/models/lmstudio.json +++ b/conf/models/lmstudio.json @@ -2,7 +2,8 @@ "name": "lmstudio", "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "embedding": "embeddings" }, "class": "local" } \ No newline at end of file diff --git a/conf/models/localai.json b/conf/models/localai.json new file mode 100644 index 00000000000..9222a95218f --- /dev/null +++ b/conf/models/localai.json @@ -0,0 +1,10 @@ +{ + "name": "localai", + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "embedding": "embeddings", + "rerank": "rerank" + }, + "class": "local" +} diff --git a/conf/models/longcat.json b/conf/models/longcat.json new file mode 100644 index 00000000000..ec3cf06302e --- /dev/null +++ b/conf/models/longcat.json @@ -0,0 +1,47 @@ +{ + "name": "LongCat", + "url": { + "default": "https://api.longcat.chat" + }, + "url_suffix": { + "chat": "openai/v1/chat/completions" + }, + "class": "longcat", + "models": [ + { + "name": "LongCat-Flash-Chat", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "LongCat-Flash-Lite", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "LongCat-Flash-Thinking-2601", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "LongCat-Flash-Omni-2603", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "LongCat-2.0-Preview", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + } + ] +} diff --git a/conf/models/mineru.json b/conf/models/mineru.json new file mode 100644 index 00000000000..f47accd954d --- /dev/null +++ b/conf/models/mineru.json @@ -0,0 +1,25 @@ +{ + "name": "MinerU", + "url": { + "default": "https://mineru.net" + }, + "url_suffix": { + "doc_parse": "v4/extract/task", + "tasks": "" + }, + "class": "mineru", + "models": [ + { + "name": "vlm", + "model_types": [ + "doc_parse" + ] + }, + { + "name": "MinerU-HTML", + "model_types": [ + "doc_parse" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/mineru_local.json b/conf/models/mineru_local.json new file mode 100644 index 00000000000..54bd46e391c --- /dev/null +++ b/conf/models/mineru_local.json @@ -0,0 +1,8 @@ +{ + "name": "mineru_local", + "url_suffix": { + "doc_parse": "file_parse", + "task": "tasks" + }, + "class": "local" +} \ No newline at end of file diff --git a/conf/models/minimax.json b/conf/models/minimax.json index 31760ac2597..49aa6700a2d 100644 --- a/conf/models/minimax.json +++ b/conf/models/minimax.json @@ -99,6 +99,13 @@ "default_value": true, "clear_thinking": true } + }, + { + "name": "speech-2.8-hd", + "max_tokens": 8192, + "model_types": [ + "tts" + ] } ] } \ No newline at end of file diff --git a/conf/models/mistral.json b/conf/models/mistral.json new file mode 100644 index 00000000000..be9cbb18613 --- /dev/null +++ b/conf/models/mistral.json @@ -0,0 +1,107 @@ +{ + "name": "Mistral", + "url": { + "default": "https://api.mistral.ai" + }, + "url_suffix": { + "chat": "v1/chat/completions", + "models": "v1/models", + "embedding": "v1/embeddings", + "ocr": "v1/ocr" + }, + "class": "mistral", + "models": [ + { + "name": "mistral-large-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "mistral-medium-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "mistral-small-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "ministral-8b-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "ministral-3b-latest", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "pixtral-large-latest", + "max_tokens": 128000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "codestral-latest", + "max_tokens": 256000, + "model_types": [ + "chat" + ] + }, + { + "name": "open-mistral-nemo", + "max_tokens": 128000, + "model_types": [ + "chat" + ] + }, + { + "name": "open-mistral-7b", + "max_tokens": 32000, + "model_types": [ + "chat" + ] + }, + { + "name": "open-mixtral-8x7b", + "max_tokens": 32000, + "model_types": [ + "chat" + ] + }, + { + "name": "open-mixtral-8x22b", + "max_tokens": 64000, + "model_types": [ + "chat" + ] + }, + { + "name": "mistral-embed", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + }, + { + "name": "mistral-ocr-2512", + "max_tokens": 8192, + "model_types": [ + "ocr" + ] + } + ] +} diff --git a/conf/models/novita.json b/conf/models/novita.json new file mode 100644 index 00000000000..f95e6849291 --- /dev/null +++ b/conf/models/novita.json @@ -0,0 +1,70 @@ +{ + "name": "Novita", + "url": { + "default": "https://api.novita.ai" + }, + "url_suffix": { + "chat": "openai/v1/chat/completions", + "models": "openai/v1/models", + "embedding": "openai/v1/embeddings" + }, + "class": "novita", + "models": [ + { + "name": "deepseek/deepseek-v4-pro", + "max_tokens": 65536, + "model_types": [ + "chat" + ] + }, + { + "name": "meta-llama/llama-3.3-70b-instruct", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "qwen/qwen3-30b-a3b-fp8", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "qwen/qwen3-235b-a22b-fp8", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "moonshotai/kimi-k2-instruct", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "google/gemma-3-27b-it", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "mistralai/mistral-nemo", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "baai/bge-m3", + "max_tokens": 8192, + "model_types": [ + "embedding" + ] + } + ] +} diff --git a/conf/models/nvidia.json b/conf/models/nvidia.json index 8ba81f1fd3f..b711b76145a 100644 --- a/conf/models/nvidia.json +++ b/conf/models/nvidia.json @@ -5,7 +5,9 @@ }, "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "embedding": "embeddings", + "rerank": "ranking" }, "class": "nvidia", "models": [ @@ -38,26 +40,11 @@ ] }, { - "name": "deepseek-ai/deepseek-v3.2", - "max_tokens": 131072, - "model_types": [ - "chat" - ], - "thinking": { - "default_value": true, - "clear_thinking": true - } - }, - { - "name": "deepseek-ai/deepseek-v3.1", - "max_tokens": 131072, + "name": "nvidia/nv-embed-v1", + "max_tokens": 8192, "model_types": [ - "chat" - ], - "thinking": { - "default_value": true, - "clear_thinking": true - } + "embedding" + ] }, { "name": "google/codegemma-7b", @@ -80,27 +67,6 @@ "chat" ] }, - { - "name": "google/gemma-7b", - "max_tokens": 8192, - "model_types": [ - "chat" - ] - }, - { - "name": "ibm/granite-3.3-8b-instruct", - "max_tokens": 131072, - "model_types": [ - "chat" - ] - }, - { - "name": "meta/llama-3.1-405b-instruct", - "max_tokens": 131072, - "model_types": [ - "chat" - ] - }, { "name": "meta/llama-3.2-90b-vision-instruct", "max_tokens": 131072, @@ -116,24 +82,6 @@ "chat" ] }, - { - "name": "microsoft/phi-4-mini-flash-reasoning", - "max_tokens": 131072, - "model_types": [ - "chat" - ], - "thinking": { - "default_value": true, - "clear_thinking": true - } - }, - { - "name": "minimaxai/minimax-m2.1", - "max_tokens": 204800, - "model_types": [ - "chat" - ] - }, { "name": "minimaxai/minimax-m2.5", "max_tokens": 204800, @@ -148,20 +96,6 @@ "chat" ] }, - { - "name": "mistralai/devstral-2-123b-instruct-2512", - "max_tokens": 131072, - "model_types": [ - "chat" - ] - }, - { - "name": "mistralai/magistral-small-2506", - "max_tokens": 131072, - "model_types": [ - "chat" - ] - }, { "name": "mistralai/mistral-7b-instruct-v0.3", "max_tokens": 32768, @@ -177,7 +111,7 @@ ] }, { - "name": "mistralai/mistral-medium-3-5-128b", + "name": "mistralai/mistral-medium-3.5-128b", "max_tokens": 131072, "model_types": [ "chat", @@ -191,24 +125,6 @@ "chat" ] }, - { - "name": "mistralai/mixtral-8x22b-instruct", - "max_tokens": 65536, - "model_types": [ - "chat" - ] - }, - { - "name": "moonshotai/kimi-k2.5", - "max_tokens": 262144, - "model_types": [ - "chat" - ], - "thinking": { - "default_value": true, - "clear_thinking": true - } - }, { "name": "moonshotai/kimi-k2.6", "max_tokens": 262144, @@ -224,13 +140,6 @@ "chat" ] }, - { - "name": "moonshotai/kimi-k2-instruct-0905", - "max_tokens": 131072, - "model_types": [ - "chat" - ] - }, { "name": "moonshotai/kimi-k2-thinking", "max_tokens": 131072, @@ -313,13 +222,6 @@ "clear_thinking": true } }, - { - "name": "nvidia/nemoguard-jailbreak-detect", - "max_tokens": 4096, - "model_types": [ - "chat" - ] - }, { "name": "nvidia/nemotron-3-nano-30b-a3b", "max_tokens": 131072, @@ -361,57 +263,67 @@ ] }, { - "name": "nvidia/nvidia-nemotron-nano-9b-v2", - "max_tokens": 131072, + "name": "nvidia/nv-embed-v1", + "max_tokens": 32768, "model_types": [ - "chat" + "embedding" + ] + }, + { + "name": "nvidia/nv-embedqa-e5-v5", + "max_tokens": 512, + "model_types": [ + "embedding" + ] + }, + { + "name": "nvidia/nv-embedqa-mistral-7b-v2", + "max_tokens": 512, + "model_types": [ + "embedding" ] }, { - "name": "nvidia/riva-translate-4b-instruct-v1_1", + "name": "nvidia/nv-rerankqa-mistral-4b-v3", "max_tokens": 4096, "model_types": [ - "chat" + "rerank" ] }, { - "name": "nvidia/usdcode", - "max_tokens": 8192, + "name": "nvidia/llama-3.2-nv-rerankqa-1b-v2", + "max_tokens": 4096, "model_types": [ - "chat" + "rerank" ] }, { - "name": "openai/gpt-oss-120b", + "name": "nvidia/nvidia-nemotron-nano-9b-v2", "max_tokens": 131072, "model_types": [ "chat" ] }, { - "name": "qwen/qwen2.5-coder-7b-instruct", - "max_tokens": 32768, + "name": "nvidia/riva-translate-4b-instruct-v1.1", + "max_tokens": 4096, "model_types": [ "chat" ] }, { - "name": "qwen/qwen3-5-122b-a10b", + "name": "openai/gpt-oss-120b", "max_tokens": 131072, "model_types": [ "chat" ] }, { - "name": "qwen/qwen3-235b-a22b", + "name": "qwen/qwen3.5-122b-a10b", "max_tokens": 131072, "model_types": [ "chat" - ], - "thinking": { - "default_value": true, - "clear_thinking": true - } + ] }, { "name": "qwen/qwen3-coder-480b-a35b-instruct", @@ -425,7 +337,7 @@ } }, { - "name": "z-ai/glm-5", + "name": "z-ai/glm5", "max_tokens": 131072, "model_types": [ "chat" @@ -447,7 +359,7 @@ } }, { - "name": "z-ai/glm-4.7", + "name": "z-ai/glm4.7", "max_tokens": 131072, "model_types": [ "chat" diff --git a/conf/models/ollama.json b/conf/models/ollama.json index ed0a1e011b9..58adb17efe9 100644 --- a/conf/models/ollama.json +++ b/conf/models/ollama.json @@ -2,7 +2,8 @@ "name": "ollama", "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "embedding": "embeddings" }, "class": "local" } \ No newline at end of file diff --git a/conf/models/openai.json b/conf/models/openai.json index 696c6f93b3c..33e4a105061 100644 --- a/conf/models/openai.json +++ b/conf/models/openai.json @@ -5,10 +5,43 @@ }, "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "embedding": "embeddings" }, "class": "gpt", "models": [ + { + "name": "gpt-5.5", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.4", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.4-mini", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "gpt-5.4-nano", + "max_tokens": 400000, + "model_types": [ + "chat", + "vision" + ] + }, { "name": "gpt-5.2-pro", "max_tokens": 400000, diff --git a/conf/models/openrouter.json b/conf/models/openrouter.json index 6af1e2d15df..33d0bdadbde 100644 --- a/conf/models/openrouter.json +++ b/conf/models/openrouter.json @@ -8,7 +8,8 @@ "models": "models", "embedding": "embeddings", "rerank": "rerank", - "balance": "credits" + "balance": "credits", + "tts": "audio/speech" }, "class": "openrouter", "models": [ @@ -44,6 +45,13 @@ "default_value": true, "clear_thinking": true } + }, + { + "name": "openai/gpt-audio-mini", + "max_tokens": 131072, + "model_types": [ + "tts" + ] } ] } \ No newline at end of file diff --git a/conf/models/paddleocr.json b/conf/models/paddleocr.json new file mode 100644 index 00000000000..043921b803d --- /dev/null +++ b/conf/models/paddleocr.json @@ -0,0 +1,33 @@ +{ + "name": "PaddleOCR", + "url": { + "default": "https://paddleocr.aistudio-app.com/api" + }, + "url_suffix": { + "ocr": "v2/ocr/jobs" + }, + "class": "paddleocr", + "models": [ + { + "name": "PaddleOCR-VL-1.5", + "max_tokens": 16384, + "model_types": [ + "ocr" + ] + }, + { + "name": "PP-OCRv5", + "max_tokens": 16384, + "model_types": [ + "ocr" + ] + }, + { + "name": "PP-StructureV3", + "max_tokens": 16384, + "model_types": [ + "ocr" + ] + } + ] +} \ No newline at end of file diff --git a/conf/models/replicate.json b/conf/models/replicate.json new file mode 100644 index 00000000000..91111351ad5 --- /dev/null +++ b/conf/models/replicate.json @@ -0,0 +1,27 @@ +{ + "name": "Replicate", + "url": { + "default": "https://api.replicate.com" + }, + "url_suffix": { + "chat": "v1/predictions", + "models": "v1/models" + }, + "class": "replicate", + "models": [ + { + "name": "meta/meta-llama-3-70b-instruct", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "meta/meta-llama-3-8b-instruct", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + } + ] +} diff --git a/conf/models/siliconflow.json b/conf/models/siliconflow.json index 4da3e0dcab8..320d21aba58 100644 --- a/conf/models/siliconflow.json +++ b/conf/models/siliconflow.json @@ -8,7 +8,9 @@ "models": "models", "embedding": "embeddings", "rerank": "rerank", - "balance": "user/info" + "balance": "user/info", + "tts": "audio/speech", + "asr": "audio/transcriptions" }, "models": [ { @@ -45,6 +47,27 @@ "model_types": [ "embedding" ] + }, + { + "name": "fnlp/MOSS-TTSD-v0.5", + "max_tokens": 8192, + "model_types": [ + "tts" + ] + }, + { + "name": "FunAudioLLM/CosyVoice2-0.5B", + "max_tokens": 8192, + "model_types": [ + "tts" + ] + }, + { + "name": "FunAudioLLM/SenseVoiceSmall", + "max_tokens": 8192, + "model_types": [ + "asr" + ] } ] } diff --git a/conf/models/stepfun.json b/conf/models/stepfun.json new file mode 100644 index 00000000000..06ef52d6736 --- /dev/null +++ b/conf/models/stepfun.json @@ -0,0 +1,115 @@ +{ + "name": "StepFun", + "url": { + "default": "https://api.stepfun.ai/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "tts": "audio/speech" + }, + "class": "step", + "models": [ + { + "name": "step-3.5-flash", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "step-3.5-flash-paid", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "step-2-16k", + "max_tokens": 16384, + "model_types": [ + "chat" + ] + }, + { + "name": "step-1-256k", + "max_tokens": 262144, + "model_types": [ + "chat" + ] + }, + { + "name": "step-1-128k", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "step-1-32k", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "step-1-8k", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "step-1-flash", + "max_tokens": 8192, + "model_types": [ + "chat" + ] + }, + { + "name": "step-1v-32k", + "max_tokens": 32768, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "step-1v-8k", + "max_tokens": 8192, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "step-1o-vision-32k", + "max_tokens": 32768, + "model_types": [ + "chat", + "vision" + ] + }, + { + "name": "step-tts-2", + "max_tokens": 8192, + "model_types": [ + "tts" + ] + }, + { + "name": "stepaudio-2.5-tts", + "max_tokens": 8192, + "model_types": [ + "tts" + ] + }, + { + "name": "step-tts-mini", + "max_tokens": 8192, + "model_types": [ + "tts" + ] + } + ] +} diff --git a/conf/models/togetherai.json b/conf/models/togetherai.json new file mode 100644 index 00000000000..f8660e059bf --- /dev/null +++ b/conf/models/togetherai.json @@ -0,0 +1,34 @@ +{ + "name": "TogetherAI", + "url": { + "default": "https://api.together.ai/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models" + }, + "class": "together", + "models": [ + { + "name": "openai/gpt-oss-20b", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "meta-llama/Llama-3.3-70B-Instruct-Turbo", + "max_tokens": 131072, + "model_types": [ + "chat" + ] + }, + { + "name": "Qwen/Qwen3-Coder-480B-A35B-Instruct-FP8", + "max_tokens": 262144, + "model_types": [ + "chat" + ] + } + ] +} diff --git a/conf/models/upstage.json b/conf/models/upstage.json new file mode 100644 index 00000000000..045bcaf6930 --- /dev/null +++ b/conf/models/upstage.json @@ -0,0 +1,56 @@ +{ + "name": "Upstage", + "url": { + "default": "https://api.upstage.ai/v1" + }, + "url_suffix": { + "chat": "chat/completions", + "models": "models", + "embedding": "embeddings" + }, + "class": "solar", + "models": [ + { + "name": "solar-pro3", + "max_tokens": 65536, + "model_types": [ + "chat" + ] + }, + { + "name": "solar-pro2", + "max_tokens": 65536, + "model_types": [ + "chat" + ] + }, + { + "name": "solar-pro", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "solar-mini", + "max_tokens": 32768, + "model_types": [ + "chat" + ] + }, + { + "name": "solar-embedding-1-large-query", + "max_tokens": 2000, + "model_types": [ + "embedding" + ] + }, + { + "name": "solar-embedding-1-large-passage", + "max_tokens": 2000, + "model_types": [ + "embedding" + ] + } + ] +} diff --git a/conf/models/vllm.json b/conf/models/vllm.json index 96ec1a2403b..d4f074330b3 100644 --- a/conf/models/vllm.json +++ b/conf/models/vllm.json @@ -2,7 +2,9 @@ "name": "vllm", "url_suffix": { "chat": "chat/completions", - "models": "models" + "models": "models", + "embedding": "embeddings", + "rerank": "rerank" }, "class": "local" } \ No newline at end of file diff --git a/conf/models/volcengine.json b/conf/models/volcengine.json index 96a6004097a..82535493703 100644 --- a/conf/models/volcengine.json +++ b/conf/models/volcengine.json @@ -22,7 +22,7 @@ } }, { - "name": "doubao-embedding-vision-250615", + "name": "doubao-embedding-vision-251215", "max_tokens": 131072, "model_types": [ "embedding" diff --git a/conf/models/voyage.json b/conf/models/voyage.json new file mode 100644 index 00000000000..65c2272d934 --- /dev/null +++ b/conf/models/voyage.json @@ -0,0 +1,69 @@ +{ + "name": "Voyage", + "url": { + "default": "https://api.voyageai.com" + }, + "url_suffix": { + "embedding": "v1/embeddings", + "rerank": "v1/rerank" + }, + "class": "voyage", + "models": [ + { + "name": "voyage-3.5", + "max_tokens": 327680, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-3.5-lite", + "max_tokens": 1048576, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-3-large", + "max_tokens": 122880, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-code-3", + "max_tokens": 122880, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-law-2", + "max_tokens": 122880, + "model_types": [ + "embedding" + ] + }, + { + "name": "voyage-finance-2", + "max_tokens": 122880, + "model_types": [ + "embedding" + ] + }, + { + "name": "rerank-2", + "max_tokens": 4000, + "model_types": [ + "rerank" + ] + }, + { + "name": "rerank-2-lite", + "max_tokens": 2000, + "model_types": [ + "rerank" + ] + } + ] +} diff --git a/conf/models/xinference.json b/conf/models/xinference.json new file mode 100644 index 00000000000..cf50dbc7313 --- /dev/null +++ b/conf/models/xinference.json @@ -0,0 +1,8 @@ +{ + "name": "xinference", + "url_suffix": { + "chat": "v1/chat/completions", + "models": "v1/models" + }, + "class": "local" +} diff --git a/conf/models/xunfei.json b/conf/models/xunfei.json new file mode 100644 index 00000000000..6ab0385b55f --- /dev/null +++ b/conf/models/xunfei.json @@ -0,0 +1,23 @@ +{ + "name": "XunFei", + "url": { + "default": "https://spark-api-open.xf-yun.com" + }, + "url_suffix": { + "chat": "v2/chat/completions" + }, + "class": "xunfei", + "models": [ + { + "name": "spark-x", + "max_tokens": 134144, + "model_types": [ + "chat" + ], + "thinking": { + "default_value": true, + "clear_thinking": true + } + } + ] +} \ No newline at end of file diff --git a/conf/models/zhipu-ai.json b/conf/models/zhipu-ai.json index d1bbac649fd..b10f18b5d44 100644 --- a/conf/models/zhipu-ai.json +++ b/conf/models/zhipu-ai.json @@ -9,7 +9,8 @@ "async_result": "async-result", "embedding": "embeddings", "rerank": "rerank", - "files": "files" + "files": "files", + "models": "models" }, "class": "glm", "models": [ diff --git a/deepdoc/parser/mineru_parser.py b/deepdoc/parser/mineru_parser.py index 2c3f63ae3fd..2c35ead98c6 100644 --- a/deepdoc/parser/mineru_parser.py +++ b/deepdoc/parser/mineru_parser.py @@ -32,7 +32,7 @@ import pdfplumber import requests from PIL import Image -from strenum import StrEnum +from enum import StrEnum from deepdoc.parser.pdf_parser import RAGFlowPdfParser from deepdoc.parser.utils import extract_pdf_outlines @@ -539,6 +539,21 @@ def _sanitize_filename(name: str) -> str: if nested_alt.exists(): subdir = nested_alt.parent json_file = nested_alt + else: + # Try vlm subdirectory (for vlm-http-client backend) + vlm_path = output_dir / "vlm" / f"{file_stem}_content_list.json" + self.logger.info(f"[MinerU] Trying vlm subdirectory: {vlm_path}") + attempted.append(vlm_path) + if vlm_path.exists(): + subdir = vlm_path.parent + json_file = vlm_path + else: + vlm_safe = output_dir / "vlm" / f"{safe_stem}_content_list.json" + self.logger.info(f"[MinerU] Trying vlm subdirectory with sanitized name: {vlm_safe}") + attempted.append(vlm_safe) + if vlm_safe.exists(): + subdir = vlm_safe.parent + json_file = vlm_safe if not json_file: parse_subdir = None @@ -629,6 +644,12 @@ def _transfer_to_sections(self, outputs: list[dict[str, Any]], parse_method: str case MinerUContentType.IMAGE: section = "".join(output.get("image_caption", [])) + "\n" + "".join( output.get("image_footnote", [])) + # If a vision model enriched this image with a semantic + # description (see _enhance_images_with_vlm), embed it in + # the chunk so it becomes searchable / retrievable. + vlm_description = (output.get("vlm_description") or "").strip() + if vlm_description: + section = (section.strip("\n") + "\n" + vlm_description).strip("\n") if section.strip() else vlm_description case MinerUContentType.EQUATION: section = output.get("text", "") case MinerUContentType.CODE: @@ -649,6 +670,49 @@ def _transfer_to_sections(self, outputs: list[dict[str, Any]], parse_method: str def _transfer_to_tables(self, outputs: list[dict[str, Any]]): return [] + def _enhance_images_with_vlm(self, outputs: list[dict[str, Any]], vision_model, callback: Optional[Callable] = None): + """Generate semantic descriptions for image blocks via the tenant's + IMAGE2TEXT model, mirroring deepdoc's VisionFigureParser. Each + IMAGE block with a readable img_path gets a ``vlm_description`` + field that ``_transfer_to_sections`` then folds into the chunk + text — closing issue #14869. + """ + from concurrent.futures import ThreadPoolExecutor, as_completed + from rag.app.picture import vision_llm_chunk + from rag.prompts.generator import vision_llm_figure_describe_prompt + + image_jobs = [ + (idx, item) + for idx, item in enumerate(outputs) + if item.get("type") == MinerUContentType.IMAGE + and item.get("img_path") + and os.path.exists(item["img_path"]) + ] + if not image_jobs: + return + + if callback: + callback(0.78, f"[MinerU] Generating VLM descriptions for {len(image_jobs)} images...") + + prompt = vision_llm_figure_describe_prompt() + + def worker(idx, item): + try: + with Image.open(item["img_path"]) as img: + img.load() + desc = vision_llm_chunk(binary=img, vision_model=vision_model, prompt=prompt) + return idx, (desc or "").strip() + except Exception as e: + logging.warning(f"[MinerU] VLM description failed for image #{idx}: {e}") + return idx, "" + + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, idx, item) for idx, item in image_jobs] + for fut in as_completed(futures): + idx, desc = fut.result() + if desc: + outputs[idx]["vlm_description"] = desc + def parse_pdf( self, filepath: str | PathLike[str], @@ -729,6 +793,13 @@ def parse_pdf( if callback: callback(0.75, f"[MinerU] Parsed {len(outputs)} blocks from PDF.") + vision_model = kwargs.get("vision_model") + if vision_model is not None: + try: + self._enhance_images_with_vlm(outputs, vision_model, callback=callback) + except Exception as e: + self.logger.warning(f"[MinerU] VLM image enhancement failed: {e}. Continuing without descriptions.") + return self._transfer_to_sections(outputs, parse_method), self._transfer_to_tables(outputs) finally: if temp_pdf and temp_pdf.exists(): diff --git a/deepdoc/parser/pdf_parser.py b/deepdoc/parser/pdf_parser.py index 3a5bd16627b..e409d5556bd 100644 --- a/deepdoc/parser/pdf_parser.py +++ b/deepdoc/parser/pdf_parser.py @@ -77,8 +77,8 @@ def __init__(self, **kwargs): if layout_recognizer_type not in ["onnx", "ascend"]: raise RuntimeError("Unsupported layout recognizer type.") - if hasattr(self, "model_speciess"): - recognizer_domain = "layout." + self.model_speciess + if hasattr(self, "model_species"): + recognizer_domain = "layout." + self.model_species else: recognizer_domain = "layout" diff --git a/docker/.env b/docker/.env index da469287954..e4787d9ab2a 100644 --- a/docker/.env +++ b/docker/.env @@ -159,11 +159,11 @@ GO_ADMIN_PORT=9383 API_PROXY_SCHEME=python # use pure python server deployment # The RAGFlow Docker image to download. v0.22+ doesn't include embedding models. -RAGFLOW_IMAGE=infiniflow/ragflow:v0.25.2 +RAGFLOW_IMAGE=infiniflow/ragflow:v0.25.5 # If you cannot download the RAGFlow Docker image: -# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.25.2 -# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.25.2 +# RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:v0.25.5 +# RAGFLOW_IMAGE=registry.cn-hangzhou.aliyuncs.com/infiniflow/ragflow:v0.25.5 # # - For the `nightly` edition, uncomment either of the following: # RAGFLOW_IMAGE=swr.cn-north-4.myhuaweicloud.com/infiniflow/ragflow:nightly @@ -242,39 +242,23 @@ REGISTER_ENABLED=1 # ----------------------------------------------------------------------------- # Sandbox # ----------------------------------------------------------------------------- -# Sandbox settings are grouped by provider type. -# 1. Set `SANDBOX_ENABLED=1` to enable sandbox support. -# 2. Set `SANDBOX_PROVIDER_TYPE` to choose the active provider. -# 3. Only edit the section that matches the selected provider type. -# 4. If you do not use `self_managed`, remove `,sandbox` from `COMPOSE_PROFILES`. -# -# Naming convention for future providers: -# - `SANDBOX__*` -# Examples: -# - `SANDBOX_SELF_MANAGED_*` -# - `SANDBOX_LOCAL_*` -# - `SANDBOX_E2B_*` -# - `SANDBOX_ALIYUN_CODEINTERPRETER_*` +# Sandbox provider type and runtime settings are configured in Admin > Sandbox +# Settings. # Enable sandbox support. # SANDBOX_ENABLED=1 # COMPOSE_PROFILES=${COMPOSE_PROFILES},sandbox -# SANDBOX_PROVIDER_TYPE=${SANDBOX_PROVIDER_TYPE:-self_managed} # Shared sandbox settings -# `SANDBOX_HOST` is kept as the common endpoint name for legacy HTTP fallback -# and for the self-managed provider. -# Double check that `sandbox-executor-manager` resolves correctly in your -# Docker network or `/etc/hosts`. -# SANDBOX_HOST=${SANDBOX_HOST:-sandbox-executor-manager} # The MinIO bucket name for storing sandbox-generated artifacts. # SANDBOX_ARTIFACT_BUCKET=sandbox-artifacts + # Number of days before sandbox artifacts are automatically deleted. # SANDBOX_ARTIFACT_EXPIRE_DAYS=7 -# Provider: self_managed -# Use this provider when sandbox executors run as Docker services managed by -# RAGFlow. This is the default provider used by the `sandbox` compose profile. +# Self-managed deployment defaults +# These values are used by the `sandbox` compose profile and shown in Admin as +# deployment defaults for the self-managed provider. # Pull the required base images before running: # docker pull infiniflow/sandbox-base-nodejs:latest # docker pull infiniflow/sandbox-base-python:latest @@ -290,21 +274,9 @@ REGISTER_ENABLED=1 # SANDBOX_MAX_MEMORY=256m # b, k, m, g # SANDBOX_TIMEOUT=10s # s, m, 1m30s -# Provider: local -# Use this provider only in trusted development environments. It executes code -# on the local machine instead of inside Docker-managed sandbox containers. -# When `SANDBOX_PROVIDER_TYPE=local`, you usually do not need the `sandbox` -# compose profile. -# Uncomment and adjust only if you use the local provider. -# SANDBOX_LOCAL_ENABLED=true -# SANDBOX_LOCAL_PYTHON_BIN=python3 -# SANDBOX_LOCAL_NODE_BIN=node -# SANDBOX_LOCAL_WORK_DIR=/tmp/ragflow-codeexec -# SANDBOX_LOCAL_TIMEOUT=30 -# SANDBOX_LOCAL_MAX_MEMORY_MB=1024 -# SANDBOX_LOCAL_MAX_OUTPUT_BYTES=1048576 -# SANDBOX_LOCAL_MAX_ARTIFACTS=20 -# SANDBOX_LOCAL_MAX_ARTIFACT_BYTES=10485760 +# ----------------------------------------------------------------------------- +# Sandbox End +# ----------------------------------------------------------------------------- # Enable DocLing USE_DOCLING=false diff --git a/docker/README.md b/docker/README.md index 6a40db4d2a9..bfabefb9a43 100644 --- a/docker/README.md +++ b/docker/README.md @@ -79,7 +79,7 @@ The [.env](./.env) file contains important environment variables for Docker. - `SVR_HTTP_PORT` The port used to expose RAGFlow's HTTP API service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `9380`. - `RAGFLOW_IMAGE` - The Docker image edition. Defaults to `infiniflow/ragflow:v0.25.2`. The RAGFlow Docker image does not include embedding models. + The Docker image edition. Defaults to `infiniflow/ragflow:v0.25.5`. The RAGFlow Docker image does not include embedding models. > [!TIP] diff --git a/docker/docker-compose-base.yml b/docker/docker-compose-base.yml index 1ceb7fb75ce..1122fe7a7c6 100644 --- a/docker/docker-compose-base.yml +++ b/docker/docker-compose-base.yml @@ -72,7 +72,7 @@ services: infinity: profiles: - infinity - image: infiniflow/infinity:v0.7.0-dev6 + image: infiniflow/infinity:v0.7.0 volumes: - infinity_data:/var/infinity - ./infinity_conf.toml:/infinity_conf.toml diff --git a/docs/administrator/admin/ragflow_cli.md b/docs/administrator/admin/ragflow_cli.md index a4a5d6b376e..7e60121cafd 100644 --- a/docs/administrator/admin/ragflow_cli.md +++ b/docs/administrator/admin/ragflow_cli.md @@ -16,7 +16,7 @@ The RAGFlow CLI is a command-line-based system administration tool that offers a 2. Install ragflow-cli. ```bash - pip install ragflow-cli==0.25.2 + pip install ragflow-cli==0.25.5 ``` 3. Launch the CLI client: @@ -439,7 +439,7 @@ show_version +-----------------------+ | version | +-----------------------+ -| v0.25.2-24-g6f60e9f9e | +| v0.25.4-24-g6f60e9f9e | +-----------------------+ ``` @@ -468,18 +468,18 @@ Revoke successfully! ``` ragflow> list vars; +-----------+---------------------+--------------+-----------+ -| data_type | name | source | value | +| data_type | name | setting_type | value | +-----------+---------------------+--------------+-----------+ -| string | default_role | variable | user | -| bool | enable_whitelist | variable | true | -| string | mail.default_sender | variable | | -| string | mail.password | variable | | -| integer | mail.port | variable | 15 | -| string | mail.server | variable | localhost | -| integer | mail.timeout | variable | 10 | -| bool | mail.use_ssl | variable | true | -| bool | mail.use_tls | variable | false | -| string | mail.username | variable | | +| string | default_role | config | user | +| bool | enable_whitelist | config | true | +| string | mail.default_sender | config | | +| string | mail.password | config | | +| integer | mail.port | config | 15 | +| string | mail.server | config | localhost | +| integer | mail.timeout | config | 10 | +| bool | mail.use_ssl | config | true | +| bool | mail.use_tls | config | false | +| string | mail.username | config | | +-----------+---------------------+--------------+-----------+ ``` @@ -490,9 +490,9 @@ ragflow> list vars; ``` ragflow> show var mail.server; +-----------+-------------+--------------+-----------+ -| data_type | name | source | value | +| data_type | name | setting_type | value | +-----------+-------------+--------------+-----------+ -| string | mail.server | variable | localhost | +| string | mail.server | config | localhost | +-----------+-------------+--------------+-----------+ ``` diff --git a/docs/administrator/configurations/configurations.md b/docs/administrator/configurations/configurations.md index cd9ab94e072..42733e8dd48 100644 --- a/docs/administrator/configurations/configurations.md +++ b/docs/administrator/configurations/configurations.md @@ -103,7 +103,7 @@ RAGFlow utilizes MinIO as its object storage solution, leveraging its scalabilit - `SVR_HTTP_PORT` The port used to expose RAGFlow's HTTP API service to the host machine, allowing **external** access to the service running inside the Docker container. Defaults to `9380`. - `RAGFLOW_IMAGE` - The Docker image edition. Defaults to `infiniflow/ragflow:v0.25.2` (the RAGFlow Docker image without embedding models). + The Docker image edition. Defaults to `infiniflow/ragflow:v0.25.5` (the RAGFlow Docker image without embedding models). :::tip NOTE If you cannot download the RAGFlow Docker image, try the following mirrors. diff --git a/docs/administrator/migration/database_schema_and_migration.md b/docs/administrator/migration/database_schema_and_migration.md index 32ae48c2851..021f8a5b7f0 100644 --- a/docs/administrator/migration/database_schema_and_migration.md +++ b/docs/administrator/migration/database_schema_and_migration.md @@ -43,7 +43,7 @@ The [db_schema_sync.py](https://github.com/infiniflow/ragflow/blob/main/tools/sc ### Key functions - **Change detection**: Compares Python model definitions in `api/db/db_models.py` against the live database to identify new tables, added fields, or type mismatches. -- **Migration generation**: Automatically creates Python migration files (containing `migrate()` and `rollback()` logic) in version-specific directories (e.g., `tools/migrate/v0_25_0/`). +- **Migration generation**: Automatically creates Python migration files (containing `migrate()` and `rollback()` logic) in version-specific directories (e.g., `tools/migrate/v0_25_5/`). - **Schema auditing**: Provides a `--diff` command to view structural discrepancies without applying changes. - **Execution management**: Applies pending migrations to the database to bring it up to date with the current software version. - **Safety controls**: Prevents accidental data loss by requiring an explicit `--drop` flag to generate `DROP COLUMN` statements for removed fields. diff --git a/docs/administrator/upgrade_ragflow.mdx b/docs/administrator/upgrade_ragflow.mdx index 9ecb6427f5d..7fa3fd24f28 100644 --- a/docs/administrator/upgrade_ragflow.mdx +++ b/docs/administrator/upgrade_ragflow.mdx @@ -62,16 +62,16 @@ To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker imag git pull ``` -3. Switch to the latest, officially published release, e.g., `v0.25.2`: +3. Switch to the latest, officially published release, e.g., `v0.25.5`: ```bash - git checkout -f v0.25.2 + git checkout -f v0.25.5 ``` 4. Update **ragflow/docker/.env**: ```bash - RAGFLOW_IMAGE=infiniflow/ragflow:v0.25.2 + RAGFLOW_IMAGE=infiniflow/ragflow:v0.25.5 ``` 5. Update the RAGFlow image and restart RAGFlow: @@ -92,10 +92,10 @@ No, you do not need to. Upgrading RAGFlow in itself will *not* remove your uploa 1. From an environment with Internet access, pull the required Docker image. 2. Save the Docker image to a **.tar** file. ```bash - docker save -o ragflow.v0.25.2.tar infiniflow/ragflow:v0.25.2 + docker save -o ragflow.v0.25.5.tar infiniflow/ragflow:v0.25.5 ``` 3. Copy the **.tar** file to the target server. 4. Load the **.tar** file into Docker: ```bash - docker load -i ragflow.v0.25.2.tar + docker load -i ragflow.v0.25.5.tar ``` diff --git a/docs/develop/build_docker_image.mdx b/docs/develop/build_docker_image.mdx index bc106f57ccd..71578c3e265 100644 --- a/docs/develop/build_docker_image.mdx +++ b/docs/develop/build_docker_image.mdx @@ -49,7 +49,7 @@ After building the infiniflow/ragflow:nightly image, you are ready to launch a f 1. Edit Docker Compose Configuration -Open the `docker/.env` file. Find the `RAGFLOW_IMAGE` setting and change the image reference from `infiniflow/ragflow:v0.25.2` to `infiniflow/ragflow:nightly` to use the pre-built image. +Open the `docker/.env` file. Find the `RAGFLOW_IMAGE` setting and change the image reference from `infiniflow/ragflow:v0.25.5` to `infiniflow/ragflow:nightly` to use the pre-built image. 2. Launch the Service diff --git a/docs/faq.mdx b/docs/faq.mdx index ab2ec1af226..42f81e28244 100644 --- a/docs/faq.mdx +++ b/docs/faq.mdx @@ -147,12 +147,12 @@ When debugging your chat assistant, you can use AI search as a reference to veri --- -### Get a `Request error 404: undefined` when upgrading to v0.25.2 +### Get a `Request error 404: undefined` when upgrading to v0.25.5 To resolve this issue, do either of the following: -- Pull the latest source code from the [main branch](https://github.com/infiniflow/ragflow), then pull and start the v0.25.2 image. -- Update `RAGFLOW_IMAGE` from `infiniflow/ragflow:latest` to `infiniflow/ragflow:v0.25.2` in the [.env file](https://github.com/infiniflow/ragflow/blob/main/docker/.env), then restart the service. +- Pull the latest source code from the [main branch](https://github.com/infiniflow/ragflow), then pull and start the v0.25.5 image. +- Update `RAGFLOW_IMAGE` from `infiniflow/ragflow:latest` to `infiniflow/ragflow:v0.25.5` in the [.env file](https://github.com/infiniflow/ragflow/blob/main/docker/.env), then restart the service. ### How to build the RAGFlow image from scratch? @@ -692,3 +692,26 @@ http://localhost:8080/layout-parsing | `PADDLEOCR_ACCESS_TOKEN` | Access token for official API | `None` | Only when using official API | Environment variables can be used for auto-provisioning, but are not required if configuring via UI. When environment variables are set, these values are used to auto-provision a PaddleOCR model for the tenant on first use. + + +### How do I use Ollama with RAGFlow for local LLM inference? + +RAGFlow supports Ollama as a local model provider for private, offline inference. + +**Step 1: Start Ollama and pull a model** + +```bash +export OLLAMA_HOST=0.0.0.0 +ollama serve +ollama pull llama3 +``` + +**Step 2: Add Ollama in RAGFlow** + +1. Go to **Settings** > **Model providers** > **Ollama**. +2. Set the Base URL to `http://host.docker.internal:11434` (Docker) or `http://localhost:11434` (bare-metal). +3. Enter the model name (e.g., `llama3`) and click **Save**. + +**Step 3: Use Ollama in your assistant** + +- Open an assistant's **Configuration** page and select the Ollama model under **Chat model**. diff --git a/docs/guides/agent/agent_component_reference/code.mdx b/docs/guides/agent/agent_component_reference/code.mdx index d0af92cc184..fa1de1caabf 100644 --- a/docs/guides/agent/agent_component_reference/code.mdx +++ b/docs/guides/agent/agent_component_reference/code.mdx @@ -98,7 +98,49 @@ If you define output variables here, ensure they are also defined in your code i ### Output -The defined output variable(s) will be auto-populated here. +The output is split into two parts: + +- **Business**: the business output defined in **Return Value** +- **System**: runtime fields that are populated automatically, such as `content`, `actual_type`, and `attachments` + +For example, the following code generates a simple line chart: + +```Python +def main() -> dict: + from pathlib import Path + + import matplotlib + matplotlib.use("Agg") + import matplotlib.pyplot as plt + + artifacts_dir = Path("artifacts") + artifacts_dir.mkdir(parents=True, exist_ok=True) + + x = [1, 2, 3, 4, 5] + y = [2, 4, 6, 8, 10] + + output_path = artifacts_dir / "simple_plot.png" + + plt.figure(figsize=(6, 4)) + plt.plot(x, y, marker="o") + plt.title("Simple Line Chart") + plt.xlabel("X") + plt.ylabel("Y") + plt.grid(True) + plt.tight_layout() + plt.savefig(output_path) + plt.close() + + return { + "result": "plot generated successfully", + "file_path": str(output_path), + } +``` +![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/codeexec_output1.jpg) + +Business Output shows the return value you defined, while System Output shows the generated `content`, the inferred `actual_type`, and the collected `attachments`. + +![](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/codeexec_output2.jpg) ## Troubleshooting diff --git a/docs/guides/agent/agent_quickstarts/sandbox_quickstart.md b/docs/guides/agent/agent_quickstarts/sandbox_quickstart.md index 115ffe88823..7c44e9581a6 100644 --- a/docs/guides/agent/agent_quickstarts/sandbox_quickstart.md +++ b/docs/guides/agent/agent_quickstarts/sandbox_quickstart.md @@ -7,19 +7,37 @@ sidebar_custom_props: { --- # Sandbox quickstart -A secure, pluggable code execution backend designed for RAGFlow and other applications requiring isolated code execution environments. +RAGFlow's `CodeExec` agent component needs a sandbox provider to run Python and JavaScript code. -## Features: +The simplest setup flow is: -- Seamless RAGFlow Integration — Works out-of-the-box with the code component of RAGFlow. -- High Security — Uses gVisor for syscall-level sandboxing to isolate execution. -- Customisable Sandboxing — Modify seccomp profiles easily to tailor syscall restrictions. -- Pluggable Runtime Support — Extendable to support any programming language runtime. -- Developer Friendly — Quick setup with a convenient Makefile. +1. Start the required sandbox services. +2. Open the RAGFlow admin page. +3. Go to **Admin > Sandbox Settings**. +4. Choose a provider and save the configuration. +5. Test the connection in the same page. -## Architecture +## Admin page -The architecture consists of isolated Docker base images for each supported language runtime, managed by the executor manager service. The executor manager orchestrates sandboxed code execution using gVisor for syscall interception and optional seccomp profiles for enhanced syscall filtering. +Configure sandbox providers from the admin page: + +- `self_managed`: Uses the executor manager service. +- `local`: Runs code on the current machine. +- `ssh`: Runs code on a remote machine over SSH. +- `aliyun_codeinterpreter` and `e2b`: Cloud providers. + +admin-sandbox-settings + +## Provider options + + +RAGFlow supports multiple sandbox providers. Configure the active provider in +Admin > Sandbox Settings after the services are up. + +- `self_managed`: Runs code inside Docker-managed sandbox containers. This is the default provider. +- `local`: Runs code as local Python or Node.js subprocesses. Use this only in trusted development environments. +- `ssh`: Runs code on a remote machine over SSH. +- `aliyun_codeinterpreter` and `e2b`: Cloud-hosted providers that remain available in the admin provider list. ## Prerequisites @@ -31,14 +49,16 @@ The architecture consists of isolated Docker base images for each supported lang - (Optional) GNU Make for simplified command-line management. :::tip NOTE -The error message `client version 1.43 is too old. Minimum supported API version is 1.44` indicates that your executor manager image's built-in Docker CLI version is lower than `29.1.0` required by the Docker daemon in use. To solve this issue, pull the latest `infiniflow/sandbox-executor-manager:latest` from Docker Hub or rebuild it in `./sandbox/executor_manager`. +The error message `client version 1.43 is too old. Minimum supported API version is 1.44` indicates that your executor manager image's built-in Docker CLI version is lower than `29.1.0` required by the Docker daemon in use. ::: ## Build Docker base images -The sandbox uses isolated base images for secure containerised execution environments. +The sandbox uses isolated base images for secure containerized execution environments. -Build the base images manually: +### Option 1: Build from source + +Build the runtime base images: ```bash docker build -t sandbox-base-python:latest ./sandbox_base_image/python @@ -51,20 +71,43 @@ Alternatively, build all base images at once using the Makefile: make build ``` -Next, build the executor manager image: +Build the executor manager image: ```bash docker build -t sandbox-executor-manager:latest ./executor_manager ``` +### Option 2: Pull base images from Docker Hub + +If you do not need to customize runtime dependencies, pull the published base images and tag them with the names used by standalone Docker Compose: + +```bash +docker pull infiniflow/sandbox-base-python:latest +docker pull infiniflow/sandbox-base-nodejs:latest + +docker tag infiniflow/sandbox-base-python:latest sandbox-base-python:latest +docker tag infiniflow/sandbox-base-nodejs:latest sandbox-base-nodejs:latest +``` + +Then restart the standalone sandbox services: + +```bash +docker compose -f docker-compose.yml down +docker compose -f docker-compose.yml up -d +``` + ## Running with RAGFlow 1. Verify that gVisor is properly installed and operational. 2. Configure the .env file located at docker/.env: -- Uncomment sandbox-related environment variables. -- Enable the sandbox profile at the bottom of the file. +- Set `SANDBOX_ENABLED=1`. +- Include `sandbox` in `COMPOSE_PROFILES` if you want the default + `self_managed` executor-manager service. +- Keep the self-managed deployment defaults in `.env` if you need to change the + sandbox-executor-manager image, pool size, base images, seccomp, memory, or + timeout. 3. Add the following entry to your /etc/hosts file to resolve the executor manager service: @@ -73,6 +116,54 @@ docker build -t sandbox-executor-manager:latest ./executor_manager ``` 4. Start the RAGFlow service as usual. +5. Open **Admin > Sandbox Settings**. +6. Select a provider. +7. Fill in the required fields. +8. Click **Save**. +9. Click **Test Connection** if needed. + +## Environment variables + +The variables in `docker/.env` are grouped by scope. + +### System-level variables + +These variables apply to sandbox support in general: + +- `SANDBOX_ENABLED`: Enables sandbox support in RAGFlow. +- `COMPOSE_PROFILES`: Include `sandbox` to start the default self-managed executor-manager service. +- `SANDBOX_ARTIFACT_BUCKET`: MinIO bucket used for files generated by sandbox code. +- `SANDBOX_ARTIFACT_EXPIRE_DAYS`: Number of days before sandbox artifacts expire. + +### Self-managed deployment defaults + +These variables are shown in Admin as deployment defaults for `self_managed`. +Changing them requires restarting `sandbox-executor-manager`. + +- `SANDBOX_EXECUTOR_MANAGER_IMAGE`: Docker image for the executor manager service. +- `SANDBOX_EXECUTOR_MANAGER_POOL_SIZE`: Number of Python and Node.js sandbox containers kept in the pool. +- `SANDBOX_BASE_PYTHON_IMAGE`: Python runtime image used by executor-managed containers. +- `SANDBOX_BASE_NODEJS_IMAGE`: Node.js runtime image used by executor-managed containers. +- `SANDBOX_EXECUTOR_MANAGER_PORT`: Host port exposed by the executor manager. +- `SANDBOX_ENABLE_SECCOMP`: Enables the optional seccomp profile for sandbox containers. +- `SANDBOX_MAX_MEMORY`: Memory limit for each sandbox runtime container. +- `SANDBOX_TIMEOUT`: Default execution timeout. + +### Admin-managed runtime settings + +Provider selection and runtime settings are configured in **Admin > Sandbox Settings**. + +Examples: + +- Choose the active provider +- Configure `self_managed` runtime settings +- Configure all `local` settings +- Configure all `ssh` settings + +For `self_managed`: + +- Runtime settings are editable in Admin +- Deployment defaults come from `.env` and are shown as read-only values ## Running standalone diff --git a/docs/guides/dataset/add_data_source/add_bitbucket.md b/docs/guides/dataset/add_data_source/add_bitbucket.md new file mode 100644 index 00000000000..1c31ddec3f5 --- /dev/null +++ b/docs/guides/dataset/add_data_source/add_bitbucket.md @@ -0,0 +1,51 @@ +--- +sidebar_position: 16 +slug: /add_confluence +sidebar_custom_props: { + categoryIcon: SiGoogledrive +} +--- + +# Add Bitbucket + +Integrate Bitbucket as a data source. + +--- + +This guide outlines the integration of Bitbucket as a data source for RAGFlow. + +## Prerequisites + +Before starting, ensure you have the following: + +- **Bitbucket API token:** A Personal Access Token (PAT) with the appropriate scopes or permissions. +- **Repository URL:** The full URL of the repository you wish to index. +- **Workspace ID:** The unique identifier for your Bitbucket workspace. + +## Configuration steps + +### Define Bitbucket as an external data source + +Navigate to the **Connectors** or **External Data Source** section in the RAGFlow Admin Panel and select **Bitbucket**. Fill in the connector details in the popup window: + +- **Name**: A descriptive name for this connector. +- **Bitbucket Account Email**: The email address for your Bitbucket account. +- **Bitbucket API Token**: The API token with proper permissions created in the previous step. +- **Workspace** The `WORKSPACE_NAME` from your Bitbucket URL, e.g., `https://bitbucket.org/{WORKSPACE_NAME}/...` +- **Index Mode** + - **Workspace**: (Default) Indexes all repositories in the workspace. + - **Repositories**: Indexes specified repositories in the workspace. + - **Repository Slugs**: A comma-separated list of repository slugs, e.g., `repo2,repo2`. + - **Projects**: Indexes specified projects in the workspace. + - **Projects**: A comma-separated list of project keys, e.g., `PROJ1,PROJ2`. + +*RAGFlow validates the connection immediately and indexes all pull requests from the specified repos or projects.* + +### Link to a dataset + +Credentials alone do not trigger indexing. You must link the data source to a specific dataset: + +1. Navigate to the **Dataset** tab. +2. Select or create the target Dataset. +3. Navigate to the Dataset's **Configuration** page and select **Link data source**. +4. Choose the previously created Bitbucket connector in the popup window. \ No newline at end of file diff --git a/docs/guides/dataset/configure_knowledge_base.md b/docs/guides/dataset/configure_knowledge_base.md index bb8c87c33d0..6b02ca204e4 100644 --- a/docs/guides/dataset/configure_knowledge_base.md +++ b/docs/guides/dataset/configure_knowledge_base.md @@ -135,7 +135,7 @@ See [Run retrieval test](./run_retrieval_test.md) for details. ## Search for dataset -As of RAGFlow v0.25.2, the search feature is still in a rudimentary form, supporting only dataset search by name. +As of RAGFlow v0.25.5, the search feature is still in a rudimentary form, supporting only dataset search by name. ![search dataset](https://raw.githubusercontent.com/infiniflow/ragflow-docs/main/images/search_datasets.jpg) diff --git a/docs/guides/manage_files.md b/docs/guides/manage_files.md index ef53e9f162f..5efaed3c240 100644 --- a/docs/guides/manage_files.md +++ b/docs/guides/manage_files.md @@ -89,4 +89,4 @@ RAGFlow's file management allows you to download an uploaded file: ![download_file](https://github.com/infiniflow/ragflow/assets/93570324/cf3b297f-7d9b-4522-bf5f-4f45743e4ed5) -> As of RAGFlow v0.25.2, bulk download is not supported, nor can you download an entire folder. +> As of RAGFlow v0.25.5, bulk download is not supported, nor can you download an entire folder. diff --git a/docs/quickstart.mdx b/docs/quickstart.mdx index 6d3d7f09525..0beacb2f99c 100644 --- a/docs/quickstart.mdx +++ b/docs/quickstart.mdx @@ -48,7 +48,7 @@ This section provides instructions on setting up the RAGFlow server on Linux. If `vm.max_map_count`. This value sets the maximum number of memory map areas a process may have. Its default value is 65530. While most applications require fewer than a thousand maps, reducing this value can result in abnormal behaviors, and the system will throw out-of-memory errors when a process reaches the limitation. - RAGFlow v0.25.2 uses Elasticsearch or [Infinity](https://github.com/infiniflow/infinity) for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component. + RAGFlow v0.25.5 uses Elasticsearch or [Infinity](https://github.com/infiniflow/infinity) for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component. 1 { return NewAdminException("Can't update more than 1 setting: " + varName) } - // Create new setting if it doesn't exist - // Determine data_type based on name and value - dataType := "string" - if len(varName) >= 7 && varName[:7] == "sandbox" { - dataType = "json" - } else if len(varName) >= 9 && varName[len(varName)-9:] == ".enabled" { - dataType = "boolean" - } - + dataType := inferSystemSettingDataType(varName) newSetting := &entity.SystemSettings{ Name: varName, Value: varValue, Source: "admin", DataType: dataType, } + if err := validateSystemSettingValue(*newSetting, varValue); err != nil { + return err + } return s.systemSettingsDAO.Create(newSetting) } @@ -1730,8 +1718,6 @@ func (s *Service) InitDefaultAdmin() error { } if len(users) == 0 { - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) userID := utility.GenerateToken() accessToken := utility.GenerateToken() status := "1" @@ -1758,12 +1744,6 @@ func (s *Service) InitDefaultAdmin() error { IsAnonymous: "0", LoginChannel: &loginChannel, IsSuperuser: &isSuperuser, - BaseModel: entity.BaseModel{ - CreateTime: &now, - CreateDate: &nowDate, - UpdateTime: &now, - UpdateDate: &nowDate, - }, } if err := dao.DB.Create(user).Error; err != nil { @@ -1806,8 +1786,6 @@ func (s *Service) InitDefaultAdmin() error { // addTenantForAdmin add tenant for admin user func (s *Service) addTenantForAdmin(userID, nickname string) error { - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) status := "1" role := "owner" tenantName := nickname + "'s Kingdom" @@ -1815,12 +1793,6 @@ func (s *Service) addTenantForAdmin(userID, nickname string) error { tenant := &entity.Tenant{ ID: userID, Name: &tenantName, - BaseModel: entity.BaseModel{ - CreateTime: &now, - CreateDate: &nowDate, - UpdateTime: &now, - UpdateDate: &nowDate, - }, } if err := dao.DB.Create(tenant).Error; err != nil { @@ -1833,12 +1805,6 @@ func (s *Service) addTenantForAdmin(userID, nickname string) error { InvitedBy: userID, Role: role, Status: &status, - BaseModel: entity.BaseModel{ - CreateTime: &now, - CreateDate: &nowDate, - UpdateTime: &now, - UpdateDate: &nowDate, - }, } return dao.DB.Create(userTenant).Error diff --git a/internal/admin/service_variables_test.go b/internal/admin/service_variables_test.go new file mode 100644 index 00000000000..2b94a09088e --- /dev/null +++ b/internal/admin/service_variables_test.go @@ -0,0 +1,65 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package admin + +import ( + "ragflow/internal/entity" + "testing" +) + +func TestValidateSystemSettingValue(t *testing.T) { + tests := []struct { + name string + dataType string + value string + wantError bool + }{ + {name: "string accepts arbitrary text", dataType: "string", value: "local host"}, + {name: "integer accepts digits", dataType: "integer", value: "15"}, + {name: "integer rejects text", dataType: "integer", value: "localhost", wantError: true}, + {name: "bool accepts true", dataType: "bool", value: "true"}, + {name: "bool accepts false", dataType: "bool", value: "false"}, + {name: "bool rejects non bool", dataType: "bool", value: "yes", wantError: true}, + {name: "json accepts object", dataType: "json", value: `{"endpoint":"http://localhost:9385"}`}, + {name: "json rejects invalid", dataType: "json", value: "{", wantError: true}, + {name: "unknown type rejects", dataType: "float", value: "1.2", wantError: true}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + setting := entity.SystemSettings{Name: "test.setting", DataType: tt.dataType} + err := validateSystemSettingValue(setting, tt.value) + if (err != nil) != tt.wantError { + t.Fatalf("validateSystemSettingValue() error = %v, wantError %v", err, tt.wantError) + } + }) + } +} + +func TestInferSystemSettingDataType(t *testing.T) { + tests := map[string]string{ + "sandbox.self_managed": "json", + "mail.enabled": "bool", + "mail.server": "string", + } + + for name, want := range tests { + if got := inferSystemSettingDataType(name); got != want { + t.Fatalf("inferSystemSettingDataType(%q) = %q, want %q", name, got, want) + } + } +} diff --git a/internal/cli/admin_command.go b/internal/cli/admin_command.go index f6ab603af5c..d1c37636016 100644 --- a/internal/cli/admin_command.go +++ b/internal/cli/admin_command.go @@ -596,6 +596,142 @@ func (c *RAGFlowClient) ShowService(cmd *Command) (ResponseIf, error) { return &result, nil } +func normalizeVariableRows(rows []map[string]interface{}) { + for _, row := range rows { + if _, ok := row["setting_type"]; ok { + delete(row, "source") + continue + } + if _, ok := row["source"]; ok { + row["setting_type"] = "config" + delete(row, "source") + } + } +} + +// ListVariables lists all system variables (admin mode only). +func (c *RAGFlowClient) ListVariables(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + if iterations > 1 { + return c.HTTPClient.RequestWithIterations("GET", "/admin/variables", "admin", nil, nil, iterations) + } + + resp, err := c.HTTPClient.Request("GET", "/admin/variables", "admin", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list variables: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list variables: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list variables failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + normalizeVariableRows(result.Data) + result.Duration = resp.Duration + return &result, nil +} + +// ShowVariable shows system variables by exact name or name prefix (admin mode only). +func (c *RAGFlowClient) ShowVariable(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + varName, ok := cmd.Params["var_name"].(string) + if !ok { + return nil, fmt.Errorf("var_name not provided") + } + + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + payload := map[string]interface{}{"var_name": varName} + if iterations > 1 { + return c.HTTPClient.RequestWithIterations("GET", "/admin/variables", "admin", nil, payload, iterations) + } + + resp, err := c.HTTPClient.Request("GET", "/admin/variables", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to show variable: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to show variable: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("show variable failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + normalizeVariableRows(result.Data) + result.Duration = resp.Duration + return &result, nil +} + +// SetVariable updates a system variable (admin mode only). +func (c *RAGFlowClient) SetVariable(cmd *Command) (ResponseIf, error) { + if c.ServerType != "admin" { + return nil, fmt.Errorf("this command is only allowed in ADMIN mode") + } + + varName, ok := cmd.Params["var_name"].(string) + if !ok { + return nil, fmt.Errorf("var_name not provided") + } + varValue, ok := cmd.Params["var_value"].(string) + if !ok { + return nil, fmt.Errorf("var_value not provided") + } + + payload := map[string]interface{}{ + "var_name": varName, + "var_value": varValue, + } + resp, err := c.HTTPClient.Request("PUT", "/admin/variables", "admin", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to set variable: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to set variable: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result MessageResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("set variable failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + + result.Duration = resp.Duration + return &result, nil +} + // ListUsers lists all users (admin mode only) // Returns (result_map, error) - result_map is non-nil for benchmark mode func (c *RAGFlowClient) ListUsers(cmd *Command) (ResponseIf, error) { diff --git a/internal/cli/admin_parser.go b/internal/cli/admin_parser.go index c1b2edab5a7..9f4c6228e88 100644 --- a/internal/cli/admin_parser.go +++ b/internal/cli/admin_parser.go @@ -1173,7 +1173,7 @@ func (p *Parser) parseAdminSetVariable() (*Command, error) { } p.nextToken() - varValue, err := p.parseIdentifier() + varValue, err := p.parseVariableValue() if err != nil { return nil, err } diff --git a/internal/cli/admin_variables_test.go b/internal/cli/admin_variables_test.go new file mode 100644 index 00000000000..c871bca18e8 --- /dev/null +++ b/internal/cli/admin_variables_test.go @@ -0,0 +1,235 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package cli + +import ( + "encoding/json" + "io" + "net" + "net/http" + "net/http/httptest" + "net/url" + "strconv" + "testing" +) + +func TestParseAdminVariableCommands(t *testing.T) { + tests := []struct { + name string + input string + command string + varName string + varValue string + hasValue bool + adminMode bool + }{ + { + name: "list variables", + input: "list vars;", + command: "list_variables", + adminMode: true, + }, + { + name: "show variables by prefix", + input: "show var mail;", + command: "show_variable", + varName: "mail", + adminMode: true, + }, + { + name: "set integer variable", + input: "set var mail.port 15;", + command: "set_variable", + varName: "mail.port", + varValue: "15", + hasValue: true, + adminMode: true, + }, + { + name: "set quoted string variable", + input: `set var mail.server "local host";`, + command: "set_variable", + varName: "mail.server", + varValue: "local host", + hasValue: true, + adminMode: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cmd, err := NewParser(tt.input).Parse(tt.adminMode) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if cmd.Type != tt.command { + t.Fatalf("command type = %q, want %q", cmd.Type, tt.command) + } + if tt.varName != "" && cmd.Params["var_name"] != tt.varName { + t.Fatalf("var_name = %v, want %q", cmd.Params["var_name"], tt.varName) + } + if tt.hasValue && cmd.Params["var_value"] != tt.varValue { + t.Fatalf("var_value = %v, want %q", cmd.Params["var_value"], tt.varValue) + } + }) + } +} + +func newAdminTestClient(t *testing.T, handler http.HandlerFunc) (*RAGFlowClient, func()) { + t.Helper() + + server := httptest.NewServer(handler) + serverURL, err := url.Parse(server.URL) + if err != nil { + t.Fatalf("parse test server URL: %v", err) + } + host, portText, err := net.SplitHostPort(serverURL.Host) + if err != nil { + t.Fatalf("split host port: %v", err) + } + port, err := strconv.Atoi(portText) + if err != nil { + t.Fatalf("parse port: %v", err) + } + + client := NewRAGFlowClient("admin") + client.HTTPClient.Host = host + client.HTTPClient.Port = port + client.HTTPClient.client = server.Client() + client.HTTPClient.LoginToken = "test-token" + + return client, server.Close +} + +func TestListVariablesUsesAdminVariablesEndpoint(t *testing.T) { + client, closeServer := newAdminTestClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("method = %s, want GET", r.Method) + return + } + if r.URL.Path != "/api/v1/admin/variables" { + t.Errorf("path = %s, want /api/v1/admin/variables", r.URL.Path) + return + } + if r.Header.Get("Authorization") != "test-token" { + t.Errorf("Authorization header = %q", r.Header.Get("Authorization")) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "code": 0, + "message": "", + "data": []map[string]interface{}{ + {"data_type": "string", "name": "mail.server", "source": "variable", "value": "localhost"}, + }, + }) + }) + defer closeServer() + + resp, err := client.ListVariables(NewCommand("list_variables")) + if err != nil { + t.Fatalf("ListVariables() error = %v", err) + } + result := resp.(*CommonResponse) + if got := result.Data[0]["setting_type"]; got != "config" { + t.Fatalf("setting_type = %v, want config", got) + } + if _, ok := result.Data[0]["source"]; ok { + t.Fatalf("source column should be normalized away: %#v", result.Data[0]) + } +} + +func TestShowVariableSendsRequestedName(t *testing.T) { + client, closeServer := newAdminTestClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Errorf("method = %s, want GET", r.Method) + return + } + body, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var request map[string]string + if err := json.Unmarshal(body, &request); err != nil { + t.Errorf("request body is not JSON: %v", err) + return + } + if request["var_name"] != "mail" { + t.Errorf("var_name = %q, want mail", request["var_name"]) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "code": 0, + "message": "", + "data": []map[string]interface{}{ + {"data_type": "string", "name": "mail.server", "setting_type": "config", "value": "localhost"}, + }, + }) + }) + defer closeServer() + + cmd := NewCommand("show_variable") + cmd.Params["var_name"] = "mail" + resp, err := client.ShowVariable(cmd) + if err != nil { + t.Fatalf("ShowVariable() error = %v", err) + } + result := resp.(*CommonResponse) + if got := result.Data[0]["name"]; got != "mail.server" { + t.Fatalf("name = %v, want mail.server", got) + } +} + +func TestSetVariableReturnsServerConfirmation(t *testing.T) { + client, closeServer := newAdminTestClient(t, func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPut { + t.Errorf("method = %s, want PUT", r.Method) + return + } + var request map[string]string + if err := json.NewDecoder(r.Body).Decode(&request); err != nil { + t.Errorf("request body is not JSON: %v", err) + return + } + if request["var_name"] != "mail.server" { + t.Errorf("var_name = %q, want mail.server", request["var_name"]) + return + } + if request["var_value"] != "localhost" { + t.Errorf("var_value = %q, want localhost", request["var_value"]) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "code": 0, + "message": "Set variable successfully", + "data": nil, + }) + }) + defer closeServer() + + cmd := NewCommand("set_variable") + cmd.Params["var_name"] = "mail.server" + cmd.Params["var_value"] = "localhost" + resp, err := client.SetVariable(cmd) + if err != nil { + t.Fatalf("SetVariable() error = %v", err) + } + result := resp.(*MessageResponse) + if result.Message != "Set variable successfully" { + t.Fatalf("message = %q, want Set variable successfully", result.Message) + } +} diff --git a/internal/cli/client.go b/internal/cli/client.go index 2bd50cb695b..eebfb33d0d9 100644 --- a/internal/cli/client.go +++ b/internal/cli/client.go @@ -155,6 +155,12 @@ func (c *RAGFlowClient) ExecuteAdminCommand(cmd *Command) (ResponseIf, error) { return c.ShowAdminVersion(cmd) case "show_user": return c.ShowUser(cmd) + case "list_variables": + return c.ListVariables(cmd) + case "show_variable": + return c.ShowVariable(cmd) + case "set_variable": + return c.SetVariable(cmd) case "list_user_datasets": return c.ListUserDatasets(cmd) case "list_agents": @@ -203,6 +209,8 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.RunBenchmark(cmd) case "list_datasets": return c.ListDatasets(cmd) + case "list_dataset_documents": + return c.ListDatasetDocumentUserCommand(cmd) case "search_on_datasets": return c.SearchOnDatasets(cmd) case "create_token": @@ -267,6 +275,14 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.EmbedUserText(cmd) case "rarank_user_document": return c.RerankUserDocument(cmd) + case "tts_user_command": + return c.TTSUserCommand(cmd) + case "asr_user_command": + return c.ASRUserCommand(cmd) + case "ocr_user_command": + return c.OCRUserCommand(cmd) + case "parse_file_user_command": + return c.ParseFileUserCommand(cmd) case "check_provider_connection": return c.CheckProviderConnection(cmd) case "use_model": @@ -279,6 +295,10 @@ func (c *RAGFlowClient) ExecuteUserCommand(cmd *Command) (ResponseIf, error) { return c.ResetDefaultModel(cmd) case "list_user_default_models": return c.ListDefaultModels(cmd) + case "list_tasks_user_command": + return c.ListTasksUserCommand(cmd) + case "show_task_user_command": + return c.ShowTaskUserCommand(cmd) // Dataset, metadata commands case "create_dataset_table": return c.CreateDatasetInDocEngine(cmd) diff --git a/internal/cli/lexer.go b/internal/cli/lexer.go index 5f2aadea14f..b3bd1c8b323 100644 --- a/internal/cli/lexer.go +++ b/internal/cli/lexer.go @@ -375,6 +375,8 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenDimension, Value: ident} case "OCR": return Token{Type: TokenOCR, Value: ident} + case "DOC_PARSE": + return Token{Type: TokenDocParse, Value: ident} case "ASYNC": return Token{Type: TokenAsync, Value: ident} case "SYNC": @@ -431,12 +433,16 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenChunks, Value: ident} case "DOCUMENT": return Token{Type: TokenDocument, Value: ident} + case "DOCUMENTS": + return Token{Type: TokenDocuments, Value: ident} case "TAGS": return Token{Type: TokenTag, Value: ident} case "REGION": return Token{Type: TokenRegion, Value: ident} case "URL": return Token{Type: TokenURL, Value: ident} + case "TASK": + return Token{Type: TokenTask, Value: ident} case "TASKS": return Token{Type: TokenTasks, Value: ident} case "LOG": @@ -455,6 +461,14 @@ func (l *Lexer) lookupIdent(ident string) Token { return Token{Type: TokenFatal, Value: ident} case "PANIC": return Token{Type: TokenPanic, Value: ident} + case "PARAM": + return Token{Type: TokenParam, Value: ident} + case "PLAY": + return Token{Type: TokenPlay, Value: ident} + case "FORMAT": + return Token{Type: TokenFormat, Value: ident} + case "SAVE": + return Token{Type: TokenSave, Value: ident} default: return Token{Type: TokenIdentifier, Value: ident} } diff --git a/internal/cli/parser.go b/internal/cli/parser.go index e373c5a8749..02723f5467b 100644 --- a/internal/cli/parser.go +++ b/internal/cli/parser.go @@ -201,6 +201,12 @@ func (p *Parser) parseUserCommand() (*Command, error) { return p.parseEmbedCommand() case TokenRerank: return p.parseRerankCommand() + case TokenASR: + return p.parseASRCommand() + case TokenTTS: + return p.parseTTSCommand() + case TokenOCR: + return p.parseOCRCommand() case TokenCheck: return p.parseCheckCommand() case TokenLS: @@ -213,6 +219,7 @@ func (p *Parser) parseUserCommand() (*Command, error) { return p.parseUpdateCommand() case TokenRemove: return p.parseRemoveCommand() + default: return nil, fmt.Errorf("unknown command: %s", p.curToken.Value) } @@ -278,6 +285,15 @@ func (p *Parser) parseIdentifier() (string, error) { return p.curToken.Value, nil } +func (p *Parser) parseVariableValue() (string, error) { + switch p.curToken.Type { + case TokenIdentifier, TokenQuotedString, TokenInteger, TokenFloat: + return p.curToken.Value, nil + default: + return "", fmt.Errorf("expected variable value, got %s", p.curToken.Value) + } +} + func (p *Parser) parseNumber() (int, error) { if p.curToken.Type != TokenInteger { return 0, fmt.Errorf("expected number, got %s", p.curToken.Value) diff --git a/internal/cli/response.go b/internal/cli/response.go index 4331a76adb2..19daffc9e66 100644 --- a/internal/cli/response.go +++ b/internal/cli/response.go @@ -85,6 +85,42 @@ func (r *CommonDataResponse) PrintOut() { } } +type ListDocumentsResponse struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 + OutputFormat OutputFormat +} + +func (r *ListDocumentsResponse) Type() string { + return "list_documents" +} + +func (r *ListDocumentsResponse) TimeCost() float64 { + return r.Duration +} + +func (r *ListDocumentsResponse) SetOutputFormat(format OutputFormat) { + r.OutputFormat = format +} + +func (r *ListDocumentsResponse) PrintOut() { + if r.Code == 0 { + total := r.Data["total"].(float64) + fmt.Printf("Total: %0.0f\n", total) + docs := r.Data["docs"].([]interface{}) + table := make([]map[string]interface{}, 0) + for _, doc := range docs { + table = append(table, doc.(map[string]interface{})) + } + PrintTableSimpleByFormat(table, r.OutputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + type SimpleResponse struct { Code int `json:"code"` Message string `json:"message"` @@ -113,6 +149,34 @@ func (r *SimpleResponse) PrintOut() { } } +type MessageResponse struct { + Code int `json:"code"` + Message string `json:"message"` + Duration float64 + OutputFormat OutputFormat +} + +func (r *MessageResponse) Type() string { + return "message" +} + +func (r *MessageResponse) TimeCost() float64 { + return r.Duration +} + +func (r *MessageResponse) SetOutputFormat(format OutputFormat) { + r.OutputFormat = format +} + +func (r *MessageResponse) PrintOut() { + if r.Code == 0 { + fmt.Println(r.Message) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + type NonStreamResponse struct { Code int `json:"code"` ReasoningContent string `json:"reasoning_content"` @@ -277,6 +341,86 @@ func (r *KeyValueResponse) PrintOut() { } } +type EmbeddingData struct { + Index int `json:"index"` + Embedding []float64 `json:"embedding"` +} + +type EmbeddingsResponse struct { + Code int `json:"code"` + Data []EmbeddingData `json:"data"` + Message string `json:"message"` + Duration float64 + OutputFormat OutputFormat +} + +func (r *EmbeddingsResponse) Type() string { + return "common" +} + +func (r *EmbeddingsResponse) TimeCost() float64 { + return r.Duration +} + +func (r *EmbeddingsResponse) SetOutputFormat(format OutputFormat) { + r.OutputFormat = format +} + +func (r *EmbeddingsResponse) PrintOut() { + var data []map[string]interface{} + for _, embedding := range r.Data { + data = append(data, map[string]interface{}{ + "index": formatValue(embedding.Index), + "dimension": len(embedding.Embedding), + }) + } + + if r.Code == 0 { + PrintTableSimpleByFormat(data, r.OutputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + +type SegmentResponse struct { + Segments []map[string]interface{} `json:"segments"` +} + +type TaskResponse struct { + Code int `json:"code"` + Data map[string]interface{} `json:"data"` + Message string `json:"message"` + Duration float64 + OutputFormat OutputFormat +} + +func (r *TaskResponse) Type() string { + return "task" +} + +func (r *TaskResponse) TimeCost() float64 { + return r.Duration +} + +func (r *TaskResponse) SetOutputFormat(format OutputFormat) { + r.OutputFormat = format +} + +func (r *TaskResponse) PrintOut() { + if r.Code == 0 { + segmentsRaw := r.Data["segments"].([]interface{}) + segments := make([]map[string]interface{}, len(segmentsRaw)) + for i, v := range segmentsRaw { + segments[i] = v.(map[string]interface{}) + } + PrintTableSimpleByFormat(segments, r.OutputFormat) + } else { + fmt.Println("ERROR") + fmt.Printf("%d, %s\n", r.Code, r.Message) + } +} + // ==================== ContextEngine Commands ==================== // ContextListResponse represents the response for ls command @@ -325,9 +469,9 @@ func (r *ContextSearchResponse) PrintOut() { // ContextCatResponse represents the response for cat command type ContextCatResponse struct { - Code int `json:"code"` - Content string `json:"content"` - Message string `json:"message"` + Code int `json:"code"` + Content string `json:"content"` + Message string `json:"message"` Duration float64 OutputFormat OutputFormat } @@ -343,5 +487,3 @@ func (r *ContextCatResponse) PrintOut() { fmt.Printf("%d, %s\n", r.Code, r.Message) } } - - diff --git a/internal/cli/types.go b/internal/cli/types.go index a30f26c6ad8..3f3ef274259 100644 --- a/internal/cli/types.go +++ b/internal/cli/types.go @@ -102,9 +102,14 @@ const ( TokenASR TokenTTS TokenOCR + TokenDocParse TokenEmbed TokenText TokenQuery + TokenFormat + TokenParam + TokenPlay + TokenSave TokenTop TokenDimension TokenAsync @@ -145,9 +150,11 @@ const ( TokenChunk TokenChunks TokenDocument + TokenDocuments TokenTag TokenRegion TokenURL + TokenTask TokenTasks TokenLog TokenLevel diff --git a/internal/cli/user_command.go b/internal/cli/user_command.go index a8394e40a64..4960dbbb6c9 100644 --- a/internal/cli/user_command.go +++ b/internal/cli/user_command.go @@ -27,6 +27,8 @@ import ( "net" netUrl "net/url" "os" + "os/exec" + "path/filepath" ce "ragflow/internal/cli/filesystem" "strings" "time" @@ -389,7 +391,63 @@ func (c *RAGFlowClient) ListDatasets(cmd *Command) (ResponseIf, error) { var result CommonResponse if err = json.Unmarshal(resp.Body, &result); err != nil { - return nil, fmt.Errorf("list users failed: invalid JSON (%w)", err) + return nil, fmt.Errorf("list datasets failed: invalid JSON (%w)", err) + } + + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + + return &result, nil +} + +// ListDatasetDocumentUserCommand lists dataset documents +func (c *RAGFlowClient) ListDatasetDocumentUserCommand(cmd *Command) (ResponseIf, error) { + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + // Check for benchmark iterations + iterations := 1 + if val, ok := cmd.Params["iterations"].(int); ok && val > 1 { + iterations = val + } + + // Determine auth kind based on whether API token is being used + if c.HTTPClient.LoginToken == "" && !c.HTTPClient.useAPIToken { + return nil, fmt.Errorf("no authorization") + } + + datasetID, ok := cmd.Params["dataset_id"].(string) + if !ok { + return nil, fmt.Errorf("no dataset id") + } + + page := 1 + pageSize := 10 + keywords := "" + returnEmptyMetadata := "true" + url := fmt.Sprintf("/datasets/%s/documents?page=%d&page_size=%d&keywords=%s&return_empty_metadata=%s", datasetID, page, pageSize, keywords, returnEmptyMetadata) + + if iterations > 1 { + // Benchmark mode - return raw result for benchmark stats + return c.HTTPClient.RequestWithIterations("GET", url, "web", nil, nil, iterations) + } + + // Normal mode + resp, err := c.HTTPClient.Request("GET", url, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list documents: %w", err) + } + + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list documents: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var result ListDocumentsResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list documents failed: invalid JSON (%w)", err) } if result.Code != 0 { @@ -1622,16 +1680,35 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) { } } - //audios, ok := cmd.Params["audios"].([]string) - //if !ok { - // return nil, fmt.Errorf("images not provided") - //} + audios, ok := cmd.Params["audios"].([]string) + if !ok { + return nil, fmt.Errorf("images not provided") + } + if len(audios) > 0 { + if len(audios) != 1 { + return nil, fmt.Errorf("only one audio file is supported") + } + audioFile := audios[0] + audioContent, err := os.ReadFile(audioFile) + if err != nil { + return nil, fmt.Errorf("failed to read audio: %w", err) + } + // file type: wav or mp3 + format := filepath.Ext(audioFile) // file type: wav or mp3 + format = strings.TrimPrefix(format, ".") + contents = append(contents, map[string]interface{}{ + "type": "input_audio", + "input_audio": map[string]interface{}{ + "data": base64.StdEncoding.EncodeToString(audioContent), + "format": format, + }, + }) + } files, ok := cmd.Params["files"].([]string) if !ok { return nil, fmt.Errorf("images not provided") } - if len(files) > 0 { for _, file := range files { if isValidURL(file) { @@ -1660,21 +1737,6 @@ func (c *RAGFlowClient) ChatToModel(cmd *Command) (ResponseIf, error) { url := "/chat/completions" - //message = strings.TrimSpace(message) - //var content interface{} = message - //if strings.HasPrefix(message, "[") && strings.HasSuffix(message, "]") { - // var parts []map[string]interface{} - // if err := json.Unmarshal([]byte(message), &parts); err == nil { - // content = parts - // } - //} - //formattedMessage := []map[string]interface{}{ - // { - // "role": "user", - // "content": content, - // }, - //} - payload := map[string]interface{}{ "provider_name": providerName, "instance_name": instanceName, @@ -1838,7 +1900,7 @@ func (c *RAGFlowClient) EmbedUserText(cmd *Command) (ResponseIf, error) { if resp.StatusCode != 200 { return nil, fmt.Errorf("failed to embed text: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) } - var result CommonResponse + var result EmbeddingsResponse if err = json.Unmarshal(resp.Body, &result); err != nil { return nil, fmt.Errorf("embed text failed: invalid JSON (%w)", err) } @@ -1922,6 +1984,545 @@ func (c *RAGFlowClient) RerankUserDocument(cmd *Command) (ResponseIf, error) { return &result, nil } +func (c *RAGFlowClient) TTSUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + var providerName, instanceName, modelName string + + // Check if composite_model_name is provided in command + if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { + names := strings.Split(compositeModelName, "@") + if len(names) != 3 { + return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + } + providerName = names[2] + instanceName = names[1] + modelName = names[0] + } else if c.CurrentModel != nil { + // Use current model if set + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } else { + return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + } + + text, ok := cmd.Params["text"].(string) + if !ok { + return nil, fmt.Errorf("text not provided") + } + + //fileToSave, ok := cmd.Params["file"].(string) + //if !ok { + // return nil, fmt.Errorf("file not provided") + //} + + payload := map[string]interface{}{ + "provider_name": providerName, + "instance_name": instanceName, + "model_name": modelName, + "text": text, + } + + ttsConfigPayload := make(map[string]interface{}) + + explicitFormat, hasExplicitFormat := cmd.Params["format"].(string) + + if paramStr, ok := cmd.Params["param_str"].(string); ok && paramStr != "" { + var dynamicParams map[string]interface{} + if err := json.Unmarshal([]byte(paramStr), &dynamicParams); err != nil { + return nil, fmt.Errorf("param string must be valid JSON. Error: %w", err) + } + + ttsConfigPayload["params"] = dynamicParams + + if !hasExplicitFormat { + var findFormat func(map[string]interface{}) string + findFormat = func(m map[string]interface{}) string { + if val, ok := m["format"]; ok { + return fmt.Sprintf("%v", val) + } + if val, ok := m["response_format"]; ok { + return fmt.Sprintf("%v", val) + } + for _, v := range m { + if subMap, ok := v.(map[string]interface{}); ok { + if res := findFormat(subMap); res != "" { + return res + } + } + } + return "" + } + if ext := findFormat(dynamicParams); ext != "" { + explicitFormat = ext + } + } + } + + if explicitFormat != "" { + ttsConfigPayload["format"] = explicitFormat + } else { + ttsConfigPayload["format"] = "mp3" + } + + if len(ttsConfigPayload) > 0 { + payload["tts_config"] = ttsConfigPayload + } + + url := "/audio/speech" + + resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to TTS document: %w", err) + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to TTS document: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + + var ttsResult struct { + Code int `json:"code"` + Message string `json:"message"` + Data struct { + Audio string `json:"audio"` + } `json:"data"` + } + + if err = json.Unmarshal(resp.Body, &ttsResult); err != nil { + return nil, fmt.Errorf("TTS document failed: invalid JSON (%w)", err) + } + + if ttsResult.Code != 0 { + return nil, fmt.Errorf("%s", ttsResult.Message) + } + + // Convert Base64 back to the original audio byte stream + audioBytes, err := base64.StdEncoding.DecodeString(ttsResult.Data.Audio) + if err != nil { + return nil, fmt.Errorf("failed to decode audio base64: %w", err) + } + + shouldPlay, _ := cmd.Params["play"].(bool) + shouldSave, _ := cmd.Params["save"].(bool) + saveDir, _ := cmd.Params["save_path"].(string) + + // format file name + safeModelName := strings.ReplaceAll(modelName, "/", "_") + safeModelName = strings.ReplaceAll(safeModelName, ":", "-") + fileName := fmt.Sprintf("%s_output.%s", safeModelName, explicitFormat) + + cwd, err := os.Getwd() + if err != nil { + cwd = "." + } + localPath := filepath.Join(cwd, fileName) + + if err := os.WriteFile(localPath, audioBytes, 0644); err != nil { + return nil, fmt.Errorf("failed to write local audio file: %w", err) + } + + if shouldPlay { + cmdExec := exec.Command("aplay", localPath) + if err := cmdExec.Run(); err != nil { + fmt.Printf("Play error: %v (Hint: did you use 'format: wav' in your params?)\n", err) + } + } + + var finalMessage string + if shouldSave { + if saveDir == "" { + saveDir = cwd + } else { + absSaveDir, err := filepath.Abs(saveDir) + if err == nil { + saveDir = absSaveDir + } + + if err := os.MkdirAll(saveDir, 0755); err != nil { + return nil, fmt.Errorf("failed to create save directory: %w", err) + } + + finalPath := filepath.Join(saveDir, fileName) + if err := os.WriteFile(finalPath, audioBytes, 0644); err != nil { + return nil, fmt.Errorf("failed to save file to target directory: %w", err) + } + + if saveDir != cwd { + os.Remove(localPath) + } + + finalMessage = fmt.Sprintf("Saved to directory: %s", finalPath) + } + } else { + defer os.Remove(localPath) + finalMessage = "TTS Task Completed (Audio not saved)" + } + + if finalMessage != "" && shouldSave { + fmt.Println(finalMessage) + } + + var result SimpleResponse + result.Code = 0 + result.Message = "SUCCESS" + result.Duration = resp.Duration + + return &result, nil +} + +func (c *RAGFlowClient) ASRUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + var providerName, instanceName, modelName string + + // Check if composite_model_name is provided in command + if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { + names := strings.Split(compositeModelName, "@") + if len(names) != 3 { + return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + } + providerName = names[2] + instanceName = names[1] + modelName = names[0] + } else if c.CurrentModel != nil { + // Use current model if set + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } else { + return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + } + + audioFile, ok := cmd.Params["audio_file"].(string) + if !ok { + return nil, fmt.Errorf("audio file not provided") + } + + payload := map[string]interface{}{ + "provider_name": providerName, + "instance_name": instanceName, + "model_name": modelName, + "file": audioFile, + } + + asrConfigPayload := make(map[string]interface{}) + if paramStr, ok := cmd.Params["param_str"].(string); ok && paramStr != "" { + var dynamicParams map[string]interface{} + if err := json.Unmarshal([]byte(paramStr), &dynamicParams); err != nil { + return nil, fmt.Errorf("param string must be valid JSON. Error: %w", err) + } + asrConfigPayload["params"] = dynamicParams + } + + if len(asrConfigPayload) > 0 { + payload["asr_config"] = asrConfigPayload + } + + url := "/audio/transcriptions" + + resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to ASR document: %w", err) + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to ASR document: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + var rawResult struct { + Code int `json:"code"` + Message string `json:"message"` + Data map[string]interface{} `json:"data"` + } + + if err = json.Unmarshal(resp.Body, &rawResult); err != nil { + return nil, fmt.Errorf("ASR document failed: invalid JSON (%w)", err) + } + + if rawResult.Code != 0 { + return nil, fmt.Errorf("%s", rawResult.Message) + } + + var result CommonResponse + result.Code = rawResult.Code + result.Data = []map[string]interface{}{ + {"text": rawResult.Data["text"].(string)}, + } + result.Duration = resp.Duration + + return &result, nil +} + +func (c *RAGFlowClient) OCRUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + var providerName, instanceName, modelName string + + // Check if composite_model_name is provided in command + if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { + names := strings.Split(compositeModelName, "@") + if len(names) != 3 { + return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + } + providerName = names[2] + instanceName = names[1] + modelName = names[0] + } else if c.CurrentModel != nil { + // Use current model if set + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } else { + return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + } + + var filename string + var fileURL string + var ok bool + var fileContent []byte + + filename, ok = cmd.Params["file"].(string) + if ok { + // read file and convert to base64 + var err error + fileContent, err = os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + } else { + fileURL, ok = cmd.Params["url"].(string) + if !ok { + return nil, fmt.Errorf("file or url not provided") + } + } + + payload := map[string]interface{}{ + "provider_name": providerName, + "instance_name": instanceName, + "model_name": modelName, + } + + if fileContent != nil { + payload["content"] = fileContent + } else { + payload["url"] = fileURL + } + + url := "/file/ocr" + + resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to OCR document: %w", err) + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to OCR document: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("OCR document failed: invalid JSON (%w)", err) + } + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + + return &result, nil +} + +func (c *RAGFlowClient) ParseFileUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + var providerName, instanceName, modelName string + + // Check if composite_model_name is provided in command + if compositeModelName, ok := cmd.Params["composite_model_name"].(string); ok && compositeModelName != "" { + names := strings.Split(compositeModelName, "@") + if len(names) != 3 { + return nil, fmt.Errorf("model name must be in format 'model@instance@provider'") + } + providerName = names[2] + instanceName = names[1] + modelName = names[0] + } else if c.CurrentModel != nil { + // Use current model if set + providerName = c.CurrentModel.Provider + instanceName = c.CurrentModel.Instance + modelName = c.CurrentModel.Model + } else { + return nil, fmt.Errorf("model name not provided and no current model set. Use 'use model' command first") + } + + var filename string + var fileURL string + var ok bool + var fileContent []byte + + filename, ok = cmd.Params["file"].(string) + if ok { + // For online file + if strings.HasPrefix(filename, "http://") || strings.HasPrefix(filename, "https://") { + fileURL = filename + } else { + // read file and convert to base64 + var err error + fileContent, err = os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("failed to read file: %w", err) + } + } + } else { + fileURL, ok = cmd.Params["url"].(string) + if !ok { + return nil, fmt.Errorf("file or url not provided") + } + } + + payload := map[string]interface{}{ + "provider_name": providerName, + "instance_name": instanceName, + "model_name": modelName, + } + + if fileContent != nil { + payload["content"] = fileContent + } else { + payload["url"] = fileURL + } + + url := "/file/parse" + + resp, err := c.HTTPClient.Request("POST", url, "web", nil, payload) + if err != nil { + return nil, fmt.Errorf("failed to PARSE document: %w", err) + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to PARSE document: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + var result CommonDataResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("PARSE document failed: invalid JSON (%w)", err) + } + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + + return &result, nil +} + +func (c *RAGFlowClient) ListTasksUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + var providerName, instanceName string + + // Check if composite_instance_name is provided in command + if compositeModelName, ok := cmd.Params["composite_instance_name"].(string); ok && compositeModelName != "" { + names := strings.Split(compositeModelName, "@") + if len(names) != 2 { + return nil, fmt.Errorf("model name must be in format 'instance@provider'") + } + providerName = names[1] + instanceName = names[0] + } else { + return nil, fmt.Errorf("no provider name or instance name") + } + + url := fmt.Sprintf("/providers/%s/instances/%s/tasks", providerName, instanceName) + + resp, err := c.HTTPClient.Request("GET", url, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to list tasks: %w", err) + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to list tasks: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + var result CommonResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("list tasks failed: invalid JSON (%w)", err) + } + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + +func (c *RAGFlowClient) ShowTaskUserCommand(cmd *Command) (ResponseIf, error) { + if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { + return nil, fmt.Errorf("API token not set. Please login first") + } + + if c.ServerType != "user" { + return nil, fmt.Errorf("this command is only allowed in USER mode") + } + + var providerName, instanceName string + + // Check if composite_instance_name is provided in command + if compositeModelName, ok := cmd.Params["composite_instance_name"].(string); ok && compositeModelName != "" { + names := strings.Split(compositeModelName, "@") + if len(names) != 2 { + return nil, fmt.Errorf("model name must be in format 'instance@provider'") + } + providerName = names[1] + instanceName = names[0] + } else { + return nil, fmt.Errorf("no provider name or instance name") + } + + taskID, ok := cmd.Params["task_id"].(string) + if !ok { + return nil, fmt.Errorf("task id not provided") + } + + url := fmt.Sprintf("/providers/%s/instances/%s/tasks/%s", providerName, instanceName, taskID) + + resp, err := c.HTTPClient.Request("GET", url, "web", nil, nil) + if err != nil { + return nil, fmt.Errorf("failed to get task: %w", err) + } + if resp.StatusCode != 200 { + return nil, fmt.Errorf("failed to get task: HTTP %d, body: %s", resp.StatusCode, string(resp.Body)) + } + var result TaskResponse + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("get task failed: invalid JSON (%w)", err) + } + if result.Code != 0 { + return nil, fmt.Errorf("%s", result.Message) + } + result.Duration = resp.Duration + return &result, nil +} + func (c *RAGFlowClient) CheckProviderConnection(cmd *Command) (ResponseIf, error) { if c.HTTPClient.APIToken == "" && c.HTTPClient.LoginToken == "" { return nil, fmt.Errorf("API token not set. Please login first") diff --git a/internal/cli/user_parser.go b/internal/cli/user_parser.go index c49eeee11a9..64357abeeac 100644 --- a/internal/cli/user_parser.go +++ b/internal/cli/user_parser.go @@ -136,6 +136,8 @@ func (p *Parser) parseListCommand() (*Command, error) { return NewCommand("list_environments"), nil case TokenDatasets: return p.parseListDatasets() + case TokenDocuments: + return p.parseListDatasetDocuments() case TokenAgents: return p.parseListAgents() case TokenTokens: @@ -163,6 +165,8 @@ func (p *Parser) parseListCommand() (*Command, error) { return NewCommand("list_user_chats"), nil case TokenFiles: return p.parseListFiles() + case TokenQuotedString: + return p.parseListQuotedStringCommand() default: return nil, fmt.Errorf("unknown LIST target: %s", p.curToken.Value) } @@ -179,6 +183,31 @@ func (p *Parser) parseListDatasets() (*Command, error) { return cmd, nil } +func (p *Parser) parseListDatasetDocuments() (*Command, error) { + p.nextToken() // consume DOCUMENTS + + if p.curToken.Type != TokenFrom { + return nil, fmt.Errorf("expected FROM") + } + p.nextToken() + + datasetID, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + cmd := NewCommand("list_dataset_documents") + cmd.Params["dataset_id"] = datasetID + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + return cmd, nil +} + func (p *Parser) parseListAgents() (*Command, error) { p.nextToken() // consume AGENTS @@ -280,9 +309,57 @@ func (p *Parser) parseListFiles() (*Command, error) { return cmd, nil } +func (p *Parser) parseListQuotedStringCommand() (*Command, error) { + str, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() // consume str + switch p.curToken.Type { + case TokenTasks: + p.nextToken() // consume TASKS + cmd := NewCommand("list_tasks_user_command") + cmd.Params["composite_instance_name"] = str + return cmd, nil + default: + return nil, fmt.Errorf("unknown command: %s", str) + } +} + +func (p *Parser) parseShowQuotedStringCommand() (*Command, error) { + str, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() // consume str + switch p.curToken.Type { + case TokenTask: + p.nextToken() // consume TASK + + var taskID string + taskID, err = p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("expected string: %w", err) + } + p.nextToken() + + cmd := NewCommand("show_task_user_command") + cmd.Params["task_id"] = taskID + cmd.Params["composite_instance_name"] = str + p.nextToken() + + // Semicolon is optional + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil + default: + return nil, fmt.Errorf("unknown command: %s", str) + } +} + func (p *Parser) parseShowCommand() (*Command, error) { p.nextToken() // consume SHOW - switch p.curToken.Type { case TokenVersion: p.nextToken() @@ -333,6 +410,10 @@ func (p *Parser) parseShowCommand() (*Command, error) { return p.parseShowInstance() case TokenBalance: return p.parseShowBalance() + case TokenTask: + return p.parseShowTask() + case TokenQuotedString: + return p.parseShowQuotedStringCommand() default: return nil, fmt.Errorf("unknown SHOW target: %s", p.curToken.Value) } @@ -781,6 +862,9 @@ func (p *Parser) parseAddModel() (*Command, error) { case TokenOCR: p.nextToken() modelTypes = append(modelTypes, "ocr") + case TokenDocParse: + p.nextToken() + modelTypes = append(modelTypes, "doc_parse") case TokenTTS: p.nextToken() modelTypes = append(modelTypes, "tts") @@ -1454,6 +1538,27 @@ func (p *Parser) parseShowBalance() (*Command, error) { return cmd, nil } +// parseShowTask parses SHOW TASK +func (p *Parser) parseShowTask() (*Command, error) { + p.nextToken() // consume TASK + + taskID, err := p.parseQuotedString() + if err != nil { + return nil, fmt.Errorf("expected string: %w", err) + } + p.nextToken() + + cmd := NewCommand("show_task_user_command") + cmd.Params["task_id"] = taskID + p.nextToken() + + // Semicolon is optional + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + return cmd, nil +} + // parseAlterInstance parses ALTER INSTANCE NAME FROM PROVIDER command func (p *Parser) parseAlterInstance() (*Command, error) { p.nextToken() // consume INSTANCE @@ -1773,7 +1878,7 @@ func (p *Parser) parseSetVariable() (*Command, error) { } p.nextToken() - varValue, err := p.parseIdentifier() + varValue, err := p.parseVariableValue() if err != nil { return nil, err } @@ -2587,16 +2692,29 @@ func (p *Parser) parseStreamCommand() (*Command, error) { var command *Command var err error - if p.curToken.Type == TokenChat { + switch p.curToken.Type { + case TokenChat: command, err = p.parseChatCommand() if err != nil { return nil, err } - } else if p.curToken.Type == TokenThink { + case TokenThink: command, err = p.parseThinkCommand() if err != nil { return nil, err } + case TokenASR: + command, err = p.parseASRCommand() + if err != nil { + return nil, err + } + case TokenTTS: + command, err = p.parseTTSCommand() + if err != nil { + return nil, err + } + default: + return nil, fmt.Errorf("expected CHAT, THINK, ASR, or TTS after STREAM") } command.Params["stream"] = true @@ -2723,6 +2841,209 @@ documentLoop: return cmd, nil } +func (p *Parser) parseASRCommand() (*Command, error) { + p.nextToken() // consume ASR + + if p.curToken.Type != TokenWith { + return nil, fmt.Errorf("expected WITH after ASR") + } + p.nextToken() // consume WITH + + compositeModelName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + if p.curToken.Type != TokenAudio { + return nil, fmt.Errorf("expected AUDIO to ASR") + } + p.nextToken() // consume AUDIO + + audioFile, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + cmd := NewCommand("asr_user_command") + cmd.Params["composite_model_name"] = compositeModelName + cmd.Params["audio_file"] = audioFile + + for p.curToken.Type != TokenEOF && p.curToken.Type != TokenSemicolon { + switch p.curToken.Type { + case TokenParam: + p.nextToken() + if p.curToken.Type != TokenQuotedString { + return nil, fmt.Errorf("expect quoted string after 'param'") + } + cmd.Params["param_str"] = strings.Trim(p.curToken.Value, "\"'") + p.nextToken() + default: + return nil, fmt.Errorf("unexpected token in asr command: %s", p.curToken.Value) + } + } + + // Semicolon is optional for UNSET TOKEN + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + return cmd, nil +} + +func (p *Parser) parseTTSCommand() (*Command, error) { + p.nextToken() + + cmd := NewCommand("tts_user_command") + + if p.curToken.Type != TokenWith { + return nil, fmt.Errorf("expect 'with' after tts") + } + p.nextToken() + + if p.curToken.Type != TokenQuotedString && p.curToken.Type != TokenIdentifier { + return nil, fmt.Errorf("expect model name after 'with'") + } + cmd.Params["composite_model_name"] = strings.Trim(p.curToken.Value, "\"'") + p.nextToken() + + if p.curToken.Type != TokenText { + return nil, fmt.Errorf("expect 'text' parameter") + } + p.nextToken() + + if p.curToken.Type != TokenQuotedString { + return nil, fmt.Errorf("expect quoted string after 'text'") + } + cmd.Params["text"] = strings.Trim(p.curToken.Value, "\"'") + p.nextToken() + + for p.curToken.Type != TokenEOF && p.curToken.Type != TokenSemicolon { + switch p.curToken.Type { + case TokenPlay: + p.nextToken() + cmd.Params["play"] = true + case TokenParam: + p.nextToken() + if p.curToken.Type != TokenQuotedString { + return nil, fmt.Errorf("expect quoted string after 'param'") + } + cmd.Params["param_str"] = strings.Trim(p.curToken.Value, "\"'") + p.nextToken() + p.nextToken() + case TokenSave: + p.nextToken() + + if p.curToken.Type != TokenQuotedString && p.curToken.Type != TokenIdentifier { + return nil, fmt.Errorf("expect directory path after 'save'") + } + + cmd.Params["save"] = true + cmd.Params["save_path"] = strings.Trim(p.curToken.Value, "\"'") + p.nextToken() + case TokenFormat: + p.nextToken() + if p.curToken.Type != TokenQuotedString && p.curToken.Type != TokenIdentifier { + return nil, fmt.Errorf("expect format string (e.g. 'wav') after 'format'") + } + cmd.Params["format"] = strings.Trim(p.curToken.Value, "\"'") + p.nextToken() + default: + return nil, fmt.Errorf("unexpected token: %s", p.curToken.Value) + } + } + + if p.curToken.Type == TokenSemicolon { + p.nextToken() + } + + return cmd, nil +} + +func (p *Parser) parseOCRCommand() (*Command, error) { + p.nextToken() // consume OCR + + if p.curToken.Type != TokenWith { + return nil, fmt.Errorf("expected WITH after OCR") + } + p.nextToken() // consume WITH + + compositeModelName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + cmd := NewCommand("ocr_user_command") + + switch p.curToken.Type { + case TokenFile: + p.nextToken() + var file string + file, err = p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["file"] = file + p.nextToken() + case TokenURL: + p.nextToken() + var url string + url, err = p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["url"] = url + p.nextToken() + default: + return nil, fmt.Errorf("expected FILE or URL") + } + + cmd.Params["composite_model_name"] = compositeModelName + + return cmd, nil +} + +func (p *Parser) parseModelParseCommand() (*Command, error) { + p.nextToken() // consume WITH + + compositeModelName, err := p.parseQuotedString() + if err != nil { + return nil, err + } + p.nextToken() + + cmd := NewCommand("parse_file_user_command") + + switch p.curToken.Type { + case TokenFile: + p.nextToken() + var file string + file, err = p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["file"] = file + p.nextToken() + case TokenURL: + p.nextToken() + var url string + url, err = p.parseQuotedString() + if err != nil { + return nil, err + } + cmd.Params["url"] = url + p.nextToken() + default: + return nil, fmt.Errorf("expected FILE or URL") + } + + cmd.Params["composite_model_name"] = compositeModelName + + return cmd, nil +} + func (p *Parser) parseCheckCommand() (*Command, error) { p.nextToken() // consume CHECK @@ -2787,11 +3108,14 @@ func (p *Parser) parseUseCommand() (*Command, error) { func (p *Parser) parseParseCommand() (*Command, error) { p.nextToken() // consume PARSE - if p.curToken.Type == TokenDataset { + switch p.curToken.Type { + case TokenDataset: return p.parseParseDataset() + case TokenWith: + return p.parseModelParseCommand() + default: + return p.parseParseDocs() } - - return p.parseParseDocs() } func (p *Parser) parseParseDataset() (*Command, error) { diff --git a/internal/cpp/opencc/config_reader.c b/internal/cpp/opencc/config_reader.c index 06f191e75b0..8271ff48c06 100644 --- a/internal/cpp/opencc/config_reader.c +++ b/internal/cpp/opencc/config_reader.c @@ -170,8 +170,9 @@ static char *parse_trim(char *str) { static int parse(config_desc *config, const char *filename, const char *home_path) { FILE *fp = fopen(filename, "rb"); if (!fp) { - char *pkg_filename = (char *)malloc(sizeof(char) * (strlen(filename) + strlen(home_path) + 2)); - sprintf(pkg_filename, "%s/%s", home_path, filename); + size_t pkg_filename_len = strlen(filename) + strlen(home_path) + 2; + char *pkg_filename = (char *)malloc(sizeof(char) * pkg_filename_len); + snprintf(pkg_filename, pkg_filename_len, "%s/%s", home_path, filename); printf("pkg_filename %s\n", pkg_filename); fp = fopen(pkg_filename, "rb"); if (!fp) { @@ -182,12 +183,26 @@ static int parse(config_desc *config, const char *filename, const char *home_pat free(pkg_filename); } - config->home_dir = (char *)malloc(sizeof(char) * (strlen(home_path) + 1)); - sprintf(config->home_dir, "%s", home_path); + size_t home_dir_len = strlen(home_path) + 1; + config->home_dir = (char *)malloc(sizeof(char) * home_dir_len); + snprintf(config->home_dir, home_dir_len, "%s", home_path); - static char buff[BUFFER_SIZE]; + char buff[BUFFER_SIZE]; while (fgets(buff, BUFFER_SIZE, fp) != NULL) { + /* Detect line truncation: if buffer is full and last char is not newline, + * the line was longer than BUFFER_SIZE-1 bytes. Drain the remainder and + * treat this as a parse error to avoid processing partial config lines. */ + size_t buff_len = strlen(buff); + if (buff_len == BUFFER_SIZE - 1 && buff[buff_len - 1] != '\n') { + int c; + while ((c = fgetc(fp)) != '\n' && c != EOF) + ; + fclose(fp); + errnum = CONFIG_ERROR_PARSE; + return -1; + } + char *trimed_buff = parse_trim(buff); if (*trimed_buff == ';' || *trimed_buff == '#' || *trimed_buff == '\0') { /* Comment Line or empty line */ diff --git a/internal/cpp/opencc/dictionary/text.c b/internal/cpp/opencc/dictionary/text.c index 41bcdbb45af..84a65167a93 100644 --- a/internal/cpp/opencc/dictionary/text.c +++ b/internal/cpp/opencc/dictionary/text.c @@ -20,7 +20,7 @@ #include "../encoding.h" #define INITIAL_DICTIONARY_SIZE 1024 -#define ENTRY_BUFF_SIZE 128 +#define ENTRY_BUFF_SIZE 4096 #define ENTRY_WBUFF_SIZE ENTRY_BUFF_SIZE / sizeof(size_t) struct _text_dictionary { @@ -69,10 +69,14 @@ int parse_entry(const char *buff, entry *entry_i) { if (ucs4_buff == (ucs4_t *)-1) { /* 發生錯誤 回退內存申請 */ ssize_t i; - for (i = value_i - 1; i >= 0; --i) + for (i = value_i - 1; i >= 0; --i) { free(entry_i->value[i]); + entry_i->value[i] = NULL; + } free(entry_i->value); + entry_i->value = NULL; free(entry_i->key); + entry_i->key = NULL; return -1; } @@ -95,7 +99,7 @@ dictionary_t dictionary_text_open(const char *filename) { text_dictionary->lexicon = (entry *)malloc(sizeof(entry) * text_dictionary->entry_count); text_dictionary->word_buff = NULL; - static char buff[ENTRY_BUFF_SIZE]; + char buff[ENTRY_BUFF_SIZE]; FILE *fp = fopen(filename, "rb"); if (fp == NULL) { @@ -105,6 +109,17 @@ dictionary_t dictionary_text_open(const char *filename) { size_t i = 0; while (fgets(buff, ENTRY_BUFF_SIZE, fp)) { + /* Detect line truncation: if buffer is full and last char is not newline, + * the line was longer than ENTRY_BUFF_SIZE-1 bytes. Drain the remainder + * and skip this malformed entry to prevent parsing partial data. */ + size_t buff_len = strlen(buff); + if (buff_len == ENTRY_BUFF_SIZE - 1 && buff[buff_len - 1] != '\n') { + int c; + while ((c = fgetc(fp)) != '\n' && c != EOF) + ; + continue; + } + if (i >= text_dictionary->entry_count) { text_dictionary->entry_count += text_dictionary->entry_count; text_dictionary->lexicon = (entry *)realloc(text_dictionary->lexicon, sizeof(entry) * text_dictionary->entry_count); diff --git a/internal/cpp/opencc/utils.c b/internal/cpp/opencc/utils.c index 9f93aae8f3f..733b1e5d191 100644 --- a/internal/cpp/opencc/utils.c +++ b/internal/cpp/opencc/utils.c @@ -23,8 +23,10 @@ void perr(const char *str) { fputs(str, stderr); } int qsort_int_cmp(const void *a, const void *b) { return *((int *)a) - *((int *)b); } char *mstrcpy(const char *str) { - char *strbuf = (char *)malloc(sizeof(char) * (strlen(str) + 1)); - strcpy(strbuf, str); + size_t len = strlen(str); + char *strbuf = (char *)malloc(sizeof(char) * (len + 1)); + strncpy(strbuf, str, len); + strbuf[len] = '\0'; return strbuf; } diff --git a/internal/dao/connector.go b/internal/dao/connector.go index 2f18e00b306..260e1596a92 100644 --- a/internal/dao/connector.go +++ b/internal/dao/connector.go @@ -36,6 +36,15 @@ type ConnectorListItem struct { Status string `json:"status"` } +// ConnectorDatasetListItem represents a connector linked to a dataset. +type ConnectorDatasetListItem struct { + ID string `json:"id" gorm:"column:id"` + Source string `json:"source" gorm:"column:source"` + Name string `json:"name" gorm:"column:name"` + AutoParse string `json:"auto_parse" gorm:"column:auto_parse"` + Status string `json:"status" gorm:"column:status"` +} + // ListByTenantID list connectors by tenant ID // Only selects id, name, source, status fields (matching Python implementation) func (dao *ConnectorDAO) ListByTenantID(tenantID string) ([]*ConnectorListItem, error) { @@ -53,6 +62,23 @@ func (dao *ConnectorDAO) ListByTenantID(tenantID string) ([]*ConnectorListItem, return connectors, nil } +// ListByDatasetID lists connectors linked to a dataset. +func (dao *ConnectorDAO) ListByDatasetID(datasetID string) ([]*ConnectorDatasetListItem, error) { + var connectors []*ConnectorDatasetListItem + + err := DB.Model(&entity.Connector2Kb{}). + Select("connector.id, connector.source, connector.name, connector2kb.auto_parse, connector.status"). + Joins("JOIN connector ON connector2kb.connector_id = connector.id"). + Where("connector2kb.kb_id = ?", datasetID). + Scan(&connectors).Error + + if err != nil { + return nil, err + } + + return connectors, nil +} + // GetByID get connector by ID func (dao *ConnectorDAO) GetByID(id string) (*entity.Connector, error) { var connector entity.Connector diff --git a/internal/dao/document.go b/internal/dao/document.go index e2e055a1189..2fa2fa5b0a2 100644 --- a/internal/dao/document.go +++ b/internal/dao/document.go @@ -86,15 +86,25 @@ func (dao *DocumentDAO) List(offset, limit int) ([]*entity.Document, int64, erro } // ListByKBID list documents by knowledge base ID -func (dao *DocumentDAO) ListByKBID(kbID string, offset, limit int) ([]*entity.Document, int64, error) { - var documents []*entity.Document +func (dao *DocumentDAO) ListByKBID(kbID string, offset, limit int) ([]*entity.DocumentListItem, int64, error) { + var documents []*entity.DocumentListItem var total int64 if err := DB.Model(&entity.Document{}).Where("kb_id = ?", kbID).Count(&total).Error; err != nil { return nil, 0, err } - err := DB.Where("kb_id = ?", kbID).Offset(offset).Limit(limit).Find(&documents).Error + err := DB.Table("document"). + Select(`document.*, user_canvas.title as pipeline_name, user.nickname`). + Joins("JOIN file2document ON file2document.document_id = document.id"). + Joins("JOIN file ON file.id = file2document.file_id"). + Joins("LEFT JOIN user_canvas ON document.pipeline_id = user_canvas.id"). + Joins("LEFT JOIN user ON document.created_by = user.id"). + Where("document.kb_id = ?", kbID). + Order("document.create_time DESC"). + Offset(offset). + Limit(limit). + Scan(&documents).Error return documents, total, err } @@ -138,3 +148,13 @@ func (dao *DocumentDAO) CountByTenantID(tenantID string) (int64, error) { err := DB.Model(&entity.Document{}).Where("created_by = ?", tenantID).Count(&count).Error return count, err } + +// SumSizeByDatasetID returns the total document size for a dataset. +func (dao *DocumentDAO) SumSizeByDatasetID(datasetID string) (int64, error) { + var total int64 + err := DB.Model(&entity.Document{}). + Select("COALESCE(SUM(size), 0)"). + Where("kb_id = ?", datasetID). + Scan(&total).Error + return total, err +} diff --git a/internal/dao/kb.go b/internal/dao/kb.go index d87051d983c..e025b5e7ce2 100644 --- a/internal/dao/kb.go +++ b/internal/dao/kb.go @@ -22,7 +22,6 @@ import ( "strconv" "strings" - "time" ) // KnowledgebaseDAO knowledge base data access object @@ -314,30 +313,22 @@ func splitNameCounter(name string) (string, int) { // AtomicIncreaseDocNumByID atomically increments the document count // This matches the Python atomic_increase_doc_num_by_id method func (dao *KnowledgebaseDAO) AtomicIncreaseDocNumByID(kbID string) error { - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) return DB.Model(&entity.Knowledgebase{}). Where("id = ?", kbID). Updates(map[string]interface{}{ "doc_num": DB.Raw("doc_num + 1"), - "update_time": now, - "update_date": nowDate, }).Error } // DecreaseDocumentNum decreases document, chunk, and token counts // This matches the Python decrease_document_num_in_delete method func (dao *KnowledgebaseDAO) DecreaseDocumentNum(kbID string, docNum, chunkNum, tokenNum int64) error { - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) return DB.Model(&entity.Knowledgebase{}). Where("id = ?", kbID). Updates(map[string]interface{}{ "doc_num": DB.Raw("doc_num - ?", docNum), "chunk_num": DB.Raw("chunk_num - ?", chunkNum), "token_num": DB.Raw("token_num - ?", tokenNum), - "update_time": now, - "update_date": nowDate, }).Error } diff --git a/internal/dao/migration.go b/internal/dao/migration.go index ca5bd3d06b7..23052f66476 100644 --- a/internal/dao/migration.go +++ b/internal/dao/migration.go @@ -345,7 +345,9 @@ func migrateSkillSearchTables(db *gorm.DB) error { index_version VARCHAR(32) DEFAULT '1.0.0', status VARCHAR(1) DEFAULT '1', create_time BIGINT, - update_time DATETIME, + create_date DATETIME, + update_time BIGINT, + update_date DATETIME, INDEX idx_tenant_id (tenant_id), INDEX idx_space_id (space_id), UNIQUE INDEX idx_tenant_space_embd (tenant_id, space_id, embd_id) @@ -367,6 +369,15 @@ func migrateSkillSearchTables(db *gorm.DB) error { if err := addColumnIfNotExists(db, "skill_search_configs", "space_id", "VARCHAR(128) NOT NULL DEFAULT 'default'"); err != nil { return fmt.Errorf("failed to add space_id column to skill_search_configs: %w", err) } + if err := addColumnIfNotExists(db, "skill_search_configs", "create_date", "DATETIME"); err != nil { + return fmt.Errorf("failed to add create_date column to skill_search_configs: %w", err) + } + if err := addColumnIfNotExists(db, "skill_search_configs", "update_date", "DATETIME"); err != nil { + return fmt.Errorf("failed to add update_date column to skill_search_configs: %w", err) + } + if err := db.Exec(`ALTER TABLE skill_search_configs MODIFY COLUMN update_time BIGINT`).Error; err != nil { + common.Warn("Failed to modify skill_search_configs.update_time", zap.Error(err)) + } // Drop legacy unique index (tenant_id, embd_id) to allow per-space configs. var legacyIndexExists int64 @@ -411,7 +422,9 @@ func migrateSkillSpaceTables(db *gorm.DB) error { top_k INT DEFAULT 10, status VARCHAR(1) DEFAULT '1', create_time BIGINT, - update_time DATETIME, + create_date DATETIME, + update_time BIGINT, + update_date DATETIME, INDEX idx_tenant_id (tenant_id), UNIQUE INDEX idx_tenant_name_status (tenant_id, name, status) ) @@ -433,6 +446,15 @@ func migrateSkillSpaceTables(db *gorm.DB) error { if err := addColumnIfNotExists(db, "skill_spaces", "status", "VARCHAR(1) NOT NULL DEFAULT '1'"); err != nil { return fmt.Errorf("failed to add status column to skill_spaces: %w", err) } + if err := addColumnIfNotExists(db, "skill_spaces", "create_date", "DATETIME"); err != nil { + return fmt.Errorf("failed to add create_date column to skill_spaces: %w", err) + } + if err := addColumnIfNotExists(db, "skill_spaces", "update_date", "DATETIME"); err != nil { + return fmt.Errorf("failed to add update_date column to skill_spaces: %w", err) + } + if err := db.Exec(`ALTER TABLE skill_spaces MODIFY COLUMN update_time BIGINT`).Error; err != nil { + common.Warn("Failed to modify skill_spaces.update_time", zap.Error(err)) + } // Migrate index after status column exists if err := migrateSkillSpaceIndex(db); err != nil { return fmt.Errorf("failed to migrate skill_space index: %w", err) diff --git a/internal/dao/skill_search_config.go b/internal/dao/skill_search_config.go index 6c19964bc21..16c01d56efe 100644 --- a/internal/dao/skill_search_config.go +++ b/internal/dao/skill_search_config.go @@ -19,7 +19,6 @@ package dao import ( "ragflow/internal/entity" "strings" - "time" "github.com/google/uuid" ) @@ -109,7 +108,6 @@ func (dao *SkillSearchConfigDAO) GetOrCreate(tenantID, spaceID, embdID string) ( // CreateWithTenantSpace creates a new config for tenant+space func (dao *SkillSearchConfigDAO) CreateWithTenantSpace(tenantID, spaceID, embdID string) (*entity.SkillSearchConfig, error) { spaceID = normalizeSpaceID(spaceID) - timestamp := time.Now().UnixMilli() defaultFieldConfig := entity.DefaultFieldConfig() fieldConfigMap := entity.JSONMap{ "name": map[string]interface{}{ @@ -140,7 +138,6 @@ func (dao *SkillSearchConfigDAO) CreateWithTenantSpace(tenantID, spaceID, embdID FieldConfig: fieldConfigMap, TopK: 10, Status: "1", - CreateTime: ×tamp, } if err := dao.Create(defaultConfig); err != nil { @@ -167,20 +164,17 @@ func (dao *SkillSearchConfigDAO) DeleteAllByTenantSpaceExceptID(tenantID, spaceI // Update updates a skill search config with the given updates map func (dao *SkillSearchConfigDAO) Update(id string, updates map[string]interface{}) error { - updates["update_time"] = time.Now() return DB.Model(&entity.SkillSearchConfig{}).Where("id = ? AND status = ?", id, "1").Updates(updates).Error } // UpdateByTenantID updates config by tenant ID func (dao *SkillSearchConfigDAO) UpdateByTenantID(tenantID, spaceID string, updates map[string]interface{}) error { - updates["update_time"] = time.Now() result := DB.Model(&entity.SkillSearchConfig{}).Where("tenant_id = ? AND space_id = ? AND status = ?", tenantID, normalizeSpaceID(spaceID), "1").Updates(updates) return result.Error } // UpdateByTenantAndEmbdID updates config by tenant ID and embedding ID func (dao *SkillSearchConfigDAO) UpdateByTenantAndEmbdID(tenantID, spaceID, embdID string, updates map[string]interface{}) error { - updates["update_time"] = time.Now() result := DB.Model(&entity.SkillSearchConfig{}).Where("tenant_id = ? AND space_id = ? AND embd_id = ? AND status = ?", tenantID, normalizeSpaceID(spaceID), embdID, "1").Updates(updates) return result.Error } diff --git a/internal/dao/skill_space.go b/internal/dao/skill_space.go index 2c0596f8a33..c8557521fee 100644 --- a/internal/dao/skill_space.go +++ b/internal/dao/skill_space.go @@ -19,7 +19,6 @@ package dao import ( "ragflow/internal/entity" "strings" - "time" "github.com/google/uuid" ) @@ -101,7 +100,6 @@ func (dao *SkillSpaceDAO) Update(space *entity.SkillSpace) error { // UpdateByID updates skills space by ID func (dao *SkillSpaceDAO) UpdateByID(id string, updates map[string]interface{}) error { - updates["update_time"] = time.Now() return DB.Model(&entity.SkillSpace{}).Where("id = ?", id).Updates(updates).Error } diff --git a/internal/dao/system_settings.go b/internal/dao/system_settings.go index 2e200ac0491..dd7762e42e0 100644 --- a/internal/dao/system_settings.go +++ b/internal/dao/system_settings.go @@ -19,7 +19,6 @@ package dao import ( "errors" "ragflow/internal/entity" - "time" "gorm.io/gorm" ) @@ -36,7 +35,7 @@ func NewSystemSettingsDAO() *SystemSettingsDAO { // Returns all system settings records from database func (d *SystemSettingsDAO) GetAll() ([]entity.SystemSettings, error) { var settings []entity.SystemSettings - err := DB.Find(&settings).Error + err := DB.Order("name ASC").Find(&settings).Error if err != nil { return nil, err } @@ -47,7 +46,18 @@ func (d *SystemSettingsDAO) GetAll() ([]entity.SystemSettings, error) { // Returns settings records that match the given name func (d *SystemSettingsDAO) GetByName(name string) ([]entity.SystemSettings, error) { var settings []entity.SystemSettings - err := DB.Where("name = ?", name).Find(&settings).Error + err := DB.Where("name = ?", name).Order("name ASC").Find(&settings).Error + if err != nil { + return nil, err + } + return settings, nil +} + +// GetByNamePrefix get system settings by name prefix +// Returns settings records whose names start with the given prefix. +func (d *SystemSettingsDAO) GetByNamePrefix(namePrefix string) ([]entity.SystemSettings, error) { + var settings []entity.SystemSettings + err := DB.Where("name LIKE ?", namePrefix+"%").Order("name ASC").Find(&settings).Error if err != nil { return nil, err } @@ -57,31 +67,18 @@ func (d *SystemSettingsDAO) GetByName(name string) ([]entity.SystemSettings, err // UpdateByName update system settings by name // Updates the setting with the given name using the provided data func (d *SystemSettingsDAO) UpdateByName(name string, setting *entity.SystemSettings) error { - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) - return DB.Model(&entity.SystemSettings{}). Where("name = ?", name). Updates(map[string]interface{}{ - "value": setting.Value, - "source": setting.Source, - "data_type": setting.DataType, - "update_time": now, - "update_date": nowDate, + "value": setting.Value, + "source": setting.Source, + "data_type": setting.DataType, }).Error } // Create create a new system setting // Inserts a new system setting record into database func (d *SystemSettingsDAO) Create(setting *entity.SystemSettings) error { - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) - - setting.CreateTime = &now - setting.CreateDate = &nowDate - setting.UpdateTime = &now - setting.UpdateDate = &nowDate - return DB.Create(setting).Error } @@ -96,6 +93,8 @@ func (d *SystemSettingsDAO) SaveOrCreate(name string, value string, source strin if len(settings) == 1 { setting := &settings[0] setting.Value = value + setting.Source = source + setting.DataType = dataType return d.UpdateByName(name, setting) } else if len(settings) > 1 { return errors.New("can't update more than 1 setting: " + name) @@ -159,29 +158,16 @@ func (d *SystemSettingsDAO) Transaction(fn func(tx *gorm.DB) error) error { // CreateWithTx create setting within transaction func (d *SystemSettingsDAO) CreateWithTx(tx *gorm.DB, setting *entity.SystemSettings) error { - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) - - setting.CreateTime = &now - setting.CreateDate = &nowDate - setting.UpdateTime = &now - setting.UpdateDate = &nowDate - return tx.Create(setting).Error } // UpdateByNameWithTx update setting within transaction func (d *SystemSettingsDAO) UpdateByNameWithTx(tx *gorm.DB, name string, setting *entity.SystemSettings) error { - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) - return tx.Model(&entity.SystemSettings{}). Where("name = ?", name). Updates(map[string]interface{}{ - "value": setting.Value, - "source": setting.Source, - "data_type": setting.DataType, - "update_time": now, - "update_date": nowDate, + "value": setting.Value, + "source": setting.Source, + "data_type": setting.DataType, }).Error } diff --git a/internal/development.md b/internal/development.md new file mode 100644 index 00000000000..ae2758cb51f --- /dev/null +++ b/internal/development.md @@ -0,0 +1,385 @@ +# RAGFlow Go Version - Startup Guide + +## 1. Start Dependencies + +```bash +docker compose -f docker/docker-compose-base.yml up -d +``` + +## 2. Build Go Version RAGFlow +- First build (includes C++ dependencies): + +```bash +./build.sh --cpp +``` + +- Subsequent builds (Go only): + +```bash +./build.sh --go +``` + +## 3. Run Go Version RAGFlow +Note: admin_server must be started first; otherwise, ragflow_server will encounter errors when sending heartbeats. + +```bash +# Start admin server +./bin/admin_server +``` + +```bash +# Start RAGFlow server +./bin/ragflow_server +``` +```bash +# Run CLI +./bin/ragflow_cli +``` + +## 4. Start Frontend +```bash +cd web && export API_PROXY_SCHEME=hybrid && npm run dev +``` + +## 5. Service Ports & API Routing +- ragflow_server listens on port 9384 +- admin_server listens on port 9383 + +After updating or implementing an API, update the frontend development environment routes in web/vite.config.ts under proxySchemes. + +### Proxy Schemes + +| Scheme | Description | +|--------|-------------| +| `python` | All API requests from the frontend are routed to the Python server | +| `hybrid` | API requests are partially routed to the Go server and partially to the Python server | +| `go` | All API requests from the frontend are routed to the Go server | + + +## 6. RAGFlow commands + +You can use the following CLI commands to test the corresponding API implementations. + +### 6.1. Run ragflow_cli, register user, login, and logout: + +``` +$ ./ragflow_cli +Welcome to RAGFlow CLI +Type \? for help, \q to quit + +RAGFlow(user)> REGISTER USER 'aaa@aaa.com' AS 'aaa' PASSWORD 'aaa'; +Register successfully +RAGFlow(user)> login user 'aaa@aaa.com'; +password for aaa@aaa.com: Password: +Login user aaa@aaa.com successfully +RAGFlow(user)> logout; +SUCCESS +``` + +### 6.2. List currently supported providers +``` +RAGFlow(user)> list available providers; +``` + +### 6.3. Add or delete a provider for the current tenant +``` +RAGFlow(user)> add provider 'openai'; +``` +``` +RAGFlow(user)> delete provider 'openai'; +``` +### 6.4. Create a model instance for a specific provider +``` +RAGFlow(user)> create provider 'openai' instance 'instance_name' key 'api-key'; +``` + +Note: The api-key is a valid API key that needs to be applied for. You can create multiple instances for the same model provider, each with a different API key. + +For locally deployed models (e.g., ollama, vLLM), use the following command to add a model instance: + +``` +RAGFlow(user)> create provider 'vllm' instance 'instance_name' key '' url 'http://192.168.1.96:8123/v1'; +``` +### 6.5. List and delete an instance +``` +RAGFlow(user)> list instances from 'openai'; +``` +``` +RAGFlow(user)> drop instance 'instance_name' from 'openai'; +``` +### 6.6. List models supported by a model instance +``` +RAGFlow(user)> list models from 'openai' 'instance_name'; +``` +### 6.7. Chat with LLM +- Chat +``` +RAGFlow(user)> chat with 'glm-4.5-flash@test@zhipu-ai' message '20 words introduce LLM'; +Answer: A large language model is an AI trained on vast text data to understand, generate, and refine human-like language. +Time: 1.052269 +``` +- Chat with Thinking (Reasoning) +``` +RAGFlow(user)> think chat with 'glm-4.5-flash@test@zhipu-ai' message '20 words introduce LLM'; +Thinking: I need to create a concise 20-word introduction to LLMs... +Answer: Large Language Models are AI systems trained on vast datasets, enabling human-like text generation, comprehension, and problem-solving across diverse applications. +Time: 11.592358 +``` +- Streaming Chat +``` +RAGFlow(user)> stream chat with 'glm-4.5-flash@test@zhipu-ai' message '20 words introduce LLM'; +Answer: Language Models are advanced AI systems. They process text to learn, generate human-like responses, and perform diverse tasks through machine learning. +Time: 2.615930 +``` +- Streaming Chat with Thinking +``` +RAGFlow(user)> stream think chat with 'glm-4.5-flash@test@zhipu-ai' message '20 words introduce LLM'; +Thinking: The user is asking for a very concise introduction to LLMs... +Answer: language models are AI systems trained on vast text datasets to understand and generate human-like text for diverse tasks. +Time: 11.958035 +``` +- Image Understanding +``` +RAGFlow(user)> chat with 'glm-4.6v-flash@test@zhipu-ai' message 'What are the pics talk about?' image 'https://cdn.bigmodel.cn/static/logo/register.png' 'https://cdn.bigmodel.cn/static/logo/api-key.png' +Answer: The first picture shows a login/register modal... The second picture displays the API keys management page... +Time: 31.600545 +``` +- Video Understanding +``` +RAGFlow(user)> chat with 'glm-4.6v-flash@test@zhipu-ai' message 'What are the video talk about?' video 'https://cdn.bigmodel.cn/agent-demos/lark/113123.mov' +Answer: Based on the sequence of frames provided, the video is a demonstration of a web search and navigation process... +Time: 76.582520 +``` +Note: Both image and video understanding support streaming and thinking modes as well. + +### 6.8. Generate Embeddings +``` +RAGFlow(user)> embed text 'what is rag' 'who are you' with 'embedding-3@test@zhipu-ai' dimension 16; +``` +### 6.9. Document Reranking +``` +RAGFlow(user)> rerank query 'what is rag' document 'rag is retrieval augment generation' 'rag need llm' 'famous rag project includes ragflow' with 'rerank@test@zhipu-ai' top 2; +``` + +### 6.10. Get supported models from provider API + +``` +RAGFlow(user)> list supported models from 'minimax' 'test'; ++------------------------+ +| model_name | ++------------------------+ +| MiniMax-M2.7 | +| MiniMax-M2.7-highspeed | +| MiniMax-M2.5 | +| MiniMax-M2.5-highspeed | +| MiniMax-M2.1 | +| MiniMax-M2.1-highspeed | +| MiniMax-M2 | ++------------------------+ +``` + +### 6.11. Get preset models of a provider + +``` +RAGFlow(user)> list models from 'minimax'; ++------------+-------------+------------------------+ +| max_tokens | model_types | name | ++------------+-------------+------------------------+ +| 204800 | [chat] | minimax-m2.7 | +| 204800 | [chat] | minimax-m2.7-highspeed | +| 204800 | [chat] | minimax-m2.5 | +| 204800 | [chat] | minimax-m2.5-highspeed | +| 204800 | [chat] | minimax-m2.1 | +| 204800 | [chat] | minimax-m2.1-highspeed | +| 204800 | [chat] | minimax-m2 | +| 65536 | [chat] | minimax-m2-her | ++------------+-------------+------------------------+ +``` + +### 6.12. List instances of a provider + +``` +RAGFlow(user)> list instances from 'zhipu-ai'; ++---------+----------------------+----------------------------------+--------------+----------------------------------+--------+ +| apiKey | extra | id | instanceName | providerID | status | ++---------+----------------------+----------------------------------+--------------+----------------------------------+--------+ +| api-key | {"region":"default"} | 19f620e73c7a11f1a51138a74640adcc | test | d21a3758398f11f1ab4838a74640adcc | enable | ++---------+----------------------+----------------------------------+--------------+----------------------------------+--------+ +``` + +### 6.13. Show instance of a provider +``` +RAGFlow(user)> show instance 'test' from 'zhipu-ai'; ++----------------------------------+--------------+----------------------------------+---------+--------+ +| id | instanceName | providerID | region | status | ++----------------------------------+--------------+----------------------------------+---------+--------+ +| 19f620e73c7a11f1a51138a74640adcc | test | d21a3758398f11f1ab4838a74640adcc | default | enable | ++----------------------------------+--------------+----------------------------------+---------+--------+ +``` + +### 6.14. List models of a specific instance + +``` +RAGFlow(user)> list models from 'minimax' 'test'; ++------------+-------------+------------------------+--------+ +| max_tokens | model_types | name | status | ++------------+-------------+------------------------+--------+ +| 204800 | [chat] | minimax-m2.7 | active | +| 204800 | [chat] | minimax-m2.7-highspeed | active | +| 204800 | [chat] | minimax-m2.5 | active | +| 204800 | [chat] | minimax-m2.5-highspeed | active | +| 204800 | [chat] | minimax-m2.1 | active | +| 204800 | [chat] | minimax-m2.1-highspeed | active | +| 204800 | [chat] | minimax-m2 | active | +| 65536 | [chat] | minimax-m2-her | active | ++------------+-------------+------------------------+--------+ +``` + +### 6.15. List added providers +``` +RAGFlow(user)> list providers; ++--------------------------------------------------------------------------+-------------+--------------+ +| base_url | name | total_models | ++--------------------------------------------------------------------------+-------------+--------------+ +| map[default:https://ark.cn-beijing.volces.com/api/v3] | VolcEngine | 2 | +| map[default:https://api.minimaxi.com/ global:https://api.minimax.io/] | MiniMax | 8 | +| map[default:https://api.moark.com/v1] | Gitee | 5 | ++--------------------------------------------------------------------------+-------------+--------------+ +``` + +### 6.16. Deactivate / activate a model + +``` +RAGFlow(user)> disable model 'deepseek-v4-pro' from 'deepseek' 'test'; +SUCCESS +RAGFlow(user)> list models from 'deepseek' 'test'; ++------------+-------------+-------------------+----------+ +| max_tokens | model_types | name | status | ++------------+-------------+-------------------+----------+ +| 1048576 | [chat] | deepseek-v4-flash | active | +| 1048576 | [chat] | deepseek-v4-pro | inactive | ++------------+-------------+-------------------+----------+ +RAGFlow(user)> enable model 'deepseek-v4-pro' from 'deepseek' 'test'; +SUCCESS +``` + +### 6.17. Set current model +``` +RAGFlow(user)> use model 'glm-4.5-flash@test@zhipu-ai'; +SUCCESS +RAGFlow(user)> chat message '20 words introduce LLM'; +Answer: Large language models are advanced AI systems. They process text to understand, generate, and refine human-like language for countless tasks. +Time: 1.680416 +``` + +### 6.18. Set, reset, and list default models +``` +RAGFlow(user)> set default chat model 'zhipu-ai/test/glm-4.5-flash'; +SUCCESS +RAGFlow(user)> set default vision model 'zhipu-ai/test/glm-4.5v'; +SUCCESS +RAGFlow(user)> set default embedding model 'zhipu-ai/test/embedding-2'; +SUCCESS +RAGFlow(user)> set default rerank model 'zhipu-ai/test/rerank'; +SUCCESS +RAGFlow(user)> set default ocr model 'zhipu-ai/test/glm-ocr'; +SUCCESS +RAGFlow(user)> set default tts model 'zhipu-ai/test/glm-tts'; +SUCCESS +RAGFlow(user)> set default asr model 'zhipu-ai/test/glm-asr-2512'; +SUCCESS +RAGFlow(user)> list default models; ++--------+----------------+---------------+----------------+------------+ +| enable | model_instance | model_name | model_provider | model_type | ++--------+----------------+---------------+----------------+------------+ +| true | test | glm-4.5-flash | zhipu-ai | chat | +| true | test | embedding-2 | zhipu-ai | embedding | +| true | test | rerank | zhipu-ai | rerank | +| true | test | glm-asr-2512 | zhipu-ai | asr | +| true | test | glm-4.5v | zhipu-ai | vision | +| true | test | glm-ocr | zhipu-ai | ocr | +| true | test | glm-tts | zhipu-ai | tts | ++--------+----------------+---------------+----------------+------------+ +RAGFlow(user)> reset default embedding model; +SUCCESS +RAGFlow(user)> reset default chat model +SUCCESS +RAGFlow(user)> list default models; ++--------+----------------+--------------+----------------+------------+ +| enable | model_instance | model_name | model_provider | model_type | ++--------+----------------+--------------+----------------+------------+ +| true | test | rerank | zhipu-ai | rerank | +| true | test | glm-asr-2512 | zhipu-ai | asr | +| true | test | glm-4.5v | zhipu-ai | vision | +| true | test | glm-ocr | zhipu-ai | ocr | +| true | test | glm-tts | zhipu-ai | tts | ++--------+----------------+--------------+----------------+------------+ +``` + +### 6.19. Show current balance of a provider instance +``` +RAGFlow(user)> show balance from 'gitee' 'test'; ++-------------+----------+ +| balance | currency | ++-------------+----------+ +| 82.49835029 | CNY | ++-------------+----------+ +``` + +### 6.20. Check provider instance availability +``` +RAGFlow(user)> check instance 'test' from 'zhipu-ai'; +SUCCESS +``` + +### 6.21. Add local model to RAGFlow, only for local deployed inference server, such as ollama +``` +RAGFlow(user)> add model 'Qwen/Qwen2.5-0.5B' to provider 'vllm' instance 'test' with tokens 131072 chat; +SUCCESS +RAGFlow(user)> list models from 'vllm' 'test'; ++-------------------+--------+ +| name | status | ++-------------------+--------+ +| Qwen/Qwen2.5-0.5B | active | ++-------------------+--------+ +RAGFlow(user)> drop model 'Qwen/Qwen2.5-0.5B' from 'vllm' 'test'; +SUCCESS +``` + +### 6.22. List datasets +``` +RAGFlow(user)> list datasets; ++-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+ +| chunk_count | chunk_method | document_count | embedding_model | id | language | name | nickname | permission | tenant_id | token_num | update_time | ++-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+ +| 492 | naive | 1 | embedding-2@ZHIPU-AI | e93ab2c04ad111f1b17438a74640adcc | English | aaa | aaa | me | 2ba4881420fa11f19e9c38a74640adcc | 74278 | 1778245825722 | +| 0 | naive | 1 | embedding-2@ZHIPU-AI | 0abe79f9423311f1ad8d38a74640adcc | English | ccc | aaa | me | 2ba4881420fa11f19e9c38a74640adcc | 0 | 1777375201933 | ++-------------+--------------+----------------+----------------------+----------------------------------+----------+------+----------+------------+----------------------------------+-----------+---------------+ +``` + +### 6.23 Text to Speech +``` +RAGFlow(user)> tts with 'speech-2.8-hd@test@minimax' text 'He who desires but acts not, breeds pestilence.' play format 'wav' save './internal' param '{"voice_setting": {"voice_id": "English_radiant_girl", "speed": 1, "vol": 1, "pitch": 0}, "audio_setting": {"sample_rate": 32000, "bitrate": 128000, "format": "wav", "channel": 1}, "output_format": "hex"}' +Saved to directory: /home/infiniflow/Documents/development/ragflow/internal/speech-2.8-hd_output.wav +SUCCESS +``` + +### 6.24 Audio to Speech +``` +RAGFlow(user)> asr with 'FunAudioLLM/SenseVoiceSmall@test@siliconflow' audio './internal/test.wav' param '' ++----------------------------------------------------------------------------------------------------------------------+ +| text | ++----------------------------------------------------------------------------------------------------------------------+ +| The examination and testimony of the experts enabled the commission to conclude that five shots may have been fired. | ++----------------------------------------------------------------------------------------------------------------------+ +``` + +### 6.25 Optical Character Recognition\ +``` +RAGFlow(user)> ocr with 'paddleocr-vl-0.9b@test@baidu' file './internal/text.jpg' ++------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| text | ++------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +| Parallel to these organizational innovations there were significant complementary technical innovations (e.g., improved methods of manufacturing cast-iron pipe and of coating interiors for pressure maintenance, and newer paving and construction material... | ++------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+ +``` \ No newline at end of file diff --git a/internal/engine/elasticsearch/chunk.go b/internal/engine/elasticsearch/chunk.go new file mode 100644 index 00000000000..26414d7e6a9 --- /dev/null +++ b/internal/engine/elasticsearch/chunk.go @@ -0,0 +1,1216 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "os" + "strings" + + "github.com/elastic/go-elasticsearch/v8/esapi" + "ragflow/internal/common" + "ragflow/internal/engine/types" + + "go.uber.org/zap" +) + +// CreateChunkStore creates an index +func (e *elasticsearchEngine) CreateChunkStore(ctx context.Context, baseName, datasetID string, vectorSize int, parserID string) error { + if baseName == "" { + return fmt.Errorf("index name cannot be empty") + } + + // Check if index already exists + exists, err := e.indexExists(ctx, baseName) + if err != nil { + return fmt.Errorf("failed to check index existence: %w", err) + } + if exists { + return fmt.Errorf("index '%s' already exists", baseName) + } + + // Load mapping based on index type + var mapping map[string]interface{} + if datasetID == "skill" { + // Load skill-specific mapping + skillMapping, err := loadSkillMapping() + if err != nil { + return fmt.Errorf("failed to load skill mapping: %w", err) + } + mapping = skillMapping + } else { + // Default mapping for dataset + mapping = map[string]interface{}{ + "settings": map[string]interface{}{ + "number_of_shards": 1, + "number_of_replicas": 0, + }, + } + } + + // Prepare request body + var body io.Reader + if mapping != nil { + data, err := json.Marshal(mapping) + if err != nil { + return fmt.Errorf("failed to marshal mapping: %w", err) + } + body = bytes.NewReader(data) + } + + // Create index + req := esapi.IndicesCreateRequest{ + Index: baseName, + Body: body, + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to create index: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + bodyBytes, _ := io.ReadAll(res.Body) + reason := extractErrorReason(bodyBytes) + if reason != "" { + return fmt.Errorf("elasticsearch error: %s", reason) + } + return fmt.Errorf("elasticsearch returned error: %s, body: %s", res.Status(), string(bodyBytes)) + } + + // Parse response + var result map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + return fmt.Errorf("failed to parse response: %w", err) + } + + acknowledged, ok := result["acknowledged"].(bool) + if !ok || !acknowledged { + return fmt.Errorf("index creation not acknowledged") + } + + return nil +} + +// InsertChunks inserts documents into a dataset index +func (e *elasticsearchEngine) InsertChunks(ctx context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) { + fullIndexName := fmt.Sprintf("%s_%s", baseName, datasetID) + common.Info("Inserting chunks into Elasticsearch index", zap.String("index_name", fullIndexName), zap.String("dataset_id", datasetID), zap.Int("doc_count", len(chunks))) + + if len(chunks) == 0 { + return []string{}, nil + } + + if fullIndexName == "" { + return nil, fmt.Errorf("index name cannot be empty") + } + + // Build bulk request body + var buf bytes.Buffer + for _, doc := range chunks { + // Action line - index operation + action := map[string]interface{}{ + "index": map[string]interface{}{ + "_index": fullIndexName, + }, + } + actionBytes, err := json.Marshal(action) + if err != nil { + common.Error("Failed to marshal bulk action", err) + return nil, fmt.Errorf("failed to marshal bulk action: %w", err) + } + buf.Write(actionBytes) + buf.WriteByte('\n') + + // Document line + docBytes, err := json.Marshal(doc) + if err != nil { + common.Error("Failed to marshal document", err) + return nil, fmt.Errorf("failed to marshal document: %w", err) + } + buf.Write(docBytes) + buf.WriteByte('\n') + } + + // Execute bulk request + req := esapi.BulkRequest{ + Body: bytes.NewReader(buf.Bytes()), + Refresh: "false", + } + + res, err := req.Do(ctx, e.client) + if err != nil { + common.Error("Failed to execute bulk request", err) + return nil, fmt.Errorf("failed to execute bulk request: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + common.Sugar.Errorw("Elasticsearch bulk request returned error", "status", res.Status()) + return nil, fmt.Errorf("elasticsearch bulk request returned error: %s", res.Status()) + } + + // Parse bulk response to check for errors + var bulkResponse map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&bulkResponse); err != nil { + common.Error("Failed to parse bulk response", err) + return nil, fmt.Errorf("failed to parse bulk response: %w", err) + } + + // Check for errors in bulk response + if errors, ok := bulkResponse["errors"].(bool); ok && errors { + common.Warn("Bulk request had some errors") + // Could iterate through items to find specific errors if needed + } + + common.Info("Successfully inserted chunks into Elasticsearch index", zap.String("index_name", fullIndexName), zap.Int("doc_count", len(chunks))) + return []string{}, nil +} + +// UpdateChunks updates chunks by condition +func (e *elasticsearchEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error { + fullIndexName := fmt.Sprintf("%s_%s", baseName, datasetID) + common.Info("Updating chunks in Elasticsearch index", zap.String("index_name", fullIndexName), zap.Any("condition", condition), zap.Any("new_value", newValue)) + + if fullIndexName == "" { + return fmt.Errorf("index name cannot be empty") + } + + // Check if index exists + exists, err := e.indexExists(ctx, fullIndexName) + if err != nil { + common.Error("Failed to check index existence", err) + return fmt.Errorf("failed to check index existence: %w", err) + } + if !exists { + return fmt.Errorf("index '%s' does not exist", fullIndexName) + } + + // Build query from condition + query := e.buildQueryFromCondition(condition) + if query == nil { + query = map[string]interface{}{"match_all": map[string]interface{}{}} + } + + // Process remove operation if present + var removeOperations []map[string]interface{} + if removeData, ok := newValue["remove"].(map[string]interface{}); ok { + removeOperations = e.buildRemoveOperations(removeData, query, fullIndexName) + } + delete(newValue, "remove") + + // Build update body + updateBody := map[string]interface{}{ + "query": query, + } + + // Handle script-based update if needed (for remove operations or transformations) + if len(removeOperations) > 0 || e.needsScriptUpdate(newValue) { + script := e.buildUpdateScript(newValue, removeOperations) + updateBody["script"] = script + } else { + updateBody["doc"] = newValue + } + + bodyBytes, err := json.Marshal(updateBody) + if err != nil { + common.Error("Failed to marshal update body", err) + return fmt.Errorf("failed to marshal update body: %w", err) + } + + // Execute update by query + req := esapi.UpdateByQueryRequest{ + Index: []string{fullIndexName}, + Body: bytes.NewReader(bodyBytes), + } + + res, err := req.Do(ctx, e.client) + if err != nil { + common.Error("Failed to execute update by query", err) + return fmt.Errorf("failed to execute update by query: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + common.Sugar.Errorw("Elasticsearch update by query returned error", "status", res.Status()) + return fmt.Errorf("elasticsearch update by query returned error: %s", res.Status()) + } + + // Parse response + var result map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + common.Error("Failed to parse update response", err) + return fmt.Errorf("failed to parse update response: %w", err) + } + + if updated, ok := result["updated"].(float64); ok { + common.Info("Successfully updated chunks", zap.String("index_name", fullIndexName), zap.Float64("updated_count", updated)) + } + + return nil +} + +// DeleteChunks deletes chunks from a dataset index by condition +func (e *elasticsearchEngine) DeleteChunks(ctx context.Context, condition map[string]interface{}, indexName string, datasetID string) (int64, error) { + fullIndexName := fmt.Sprintf("%s_%s", indexName, datasetID) + common.Info("Deleting chunks from Elasticsearch index", zap.String("index_name", fullIndexName), zap.Any("condition", condition)) + + // Check if index exists + exists, err := e.indexExists(ctx, fullIndexName) + if err != nil { + return 0, fmt.Errorf("failed to check index existence: %w", err) + } + if !exists { + common.Warn(fmt.Sprintf("Index %s does not exist, skipping delete", fullIndexName)) + return 0, nil + } + + // Build query from condition + query := e.buildQueryFromCondition(condition) + if query == nil { + query = map[string]interface{}{"match_all": map[string]interface{}{}} + } + + // Build delete by query body + deleteBody := map[string]interface{}{ + "query": query, + } + + bodyBytes, err := json.Marshal(deleteBody) + if err != nil { + return 0, fmt.Errorf("failed to marshal delete body: %w", err) + } + + // Execute delete by query + req := esapi.DeleteByQueryRequest{ + Index: []string{fullIndexName}, + Body: bytes.NewReader(bodyBytes), + } + + res, err := req.Do(ctx, e.client) + if err != nil { + common.Error("Failed to execute delete by query", err) + return 0, fmt.Errorf("failed to execute delete by query: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + common.Sugar.Errorw("Elasticsearch delete by query returned error", "status", res.Status()) + return 0, fmt.Errorf("elasticsearch delete by query returned error: %s", res.Status()) + } + + // Parse response + var result map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + common.Error("Failed to parse delete response", err) + return 0, fmt.Errorf("failed to parse delete response: %w", err) + } + + deleted := int64(0) + if d, ok := result["deleted"].(float64); ok { + deleted = int64(d) + } + + common.Info("Successfully deleted chunks", zap.String("index_name", fullIndexName), zap.Int64("deleted_count", deleted)) + return deleted, nil +} + +// SearchResponse Elasticsearch search response +type SearchResponse struct { + Hits struct { + Total struct { + Value int64 `json:"value"` + } `json:"total"` + Hits []struct { + ID string `json:"_id"` + Score float64 `json:"_score"` + Source map[string]interface{} `json:"_source"` + } `json:"hits"` + } `json:"hits"` + Aggregations map[string]interface{} `json:"aggregations"` +} + +// Search executes search with unified types.SearchRequest +func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + return e.searchUnified(ctx, req) +} + +// searchUnified handles the unified types.SearchRequest +func (e *elasticsearchEngine) searchUnified(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + if len(req.IndexNames) == 0 { + return nil, fmt.Errorf("index names cannot be empty") + } + + // Build pagination parameters + offset := req.Offset + limit := req.Limit + if limit <= 0 { + limit = 30 // default ES size + } + + // Check if this is a skill index + isSkillIndex := len(req.IndexNames) > 0 && strings.HasPrefix(req.IndexNames[0], "skill_") + + // Build filter clauses + var filterClauses []map[string]interface{} + if isSkillIndex { + filterClauses = buildSkillFilterClauses() + } else { + filterClauses = buildFilterClauses(req.KbIDs, 1) + } + + // Add filters from req.Filter + if req.Filter != nil && len(req.Filter) > 0 { + filterClauses = append(filterClauses, buildFilterFromMap(req.Filter)...) + } + + // Build search query body + queryBody := make(map[string]interface{}) + + // Determine search type from MatchExprs + var matchText string + var matchDense *types.MatchDenseExpr + var hasVectorMatch bool + + for _, expr := range req.MatchExprs { + if expr == nil { + continue + } + switch e := expr.(type) { + case string: + matchText = e + case *types.MatchTextExpr: + matchText = e.MatchingText + case *types.MatchDenseExpr: + hasVectorMatch = true + matchDense = e + } + } + + var vectorFieldName string + if !hasVectorMatch || matchDense == nil { + // Keyword-only search + if isSkillIndex { + queryBody["query"] = buildSkillKeywordQuery(matchText, filterClauses, 1.0) + } else { + queryBody["query"] = buildESKeywordQuery(matchText, filterClauses, 1.0) + } + } else { + // Hybrid search: keyword + vector + textWeight := 0.7 // default: vector weight = 0.3 + vectorWeight := 0.3 + if matchDense.ExtraOptions != nil { + if vw, ok := matchDense.ExtraOptions["text_weight"].(float64); ok { + textWeight = vw + } + if vw, ok := matchDense.ExtraOptions["vector_weight"].(float64); ok { + vectorWeight = vw + } + } + + // Build boolean query for text match and filters + var boolQuery map[string]interface{} + if isSkillIndex { + boolQuery = buildSkillKeywordQuery(matchText, filterClauses, 1.0) + } else { + boolQuery = buildESKeywordQuery(matchText, filterClauses, 1.0) + } + // Add boost to the bool query (as in Python code) + if boolMap, ok := boolQuery["bool"].(map[string]interface{}); ok { + boolMap["boost"] = textWeight + } + + // Build kNN query + vectorData := matchDense.EmbeddingData + vectorFieldName = matchDense.VectorColumnName + k := matchDense.TopN + if k <= 0 { + k = req.Limit + } + if k <= 0 { + k = 1024 + } + numCandidates := k * 2 + + similarity := 0.0 + if matchDense.ExtraOptions != nil { + if sim, ok := matchDense.ExtraOptions["similarity"].(float64); ok { + similarity = sim + } + } + + knnQuery := map[string]interface{}{ + "field": vectorFieldName, + "query_vector": vectorData, + "k": k, + "num_candidates": numCandidates, + "similarity": similarity, + "boost": vectorWeight, + } + + queryBody["knn"] = knnQuery + queryBody["query"] = boolQuery + + // Add vector column to Source fields (matching Python ES: src.append(f"q_{len(q_vec)}_vec")) + // Only modify Source if it was explicitly set by the caller + if vectorFieldName != "" && len(req.SelectFields) > 0 { + sourceFields := req.SelectFields + found := false + for _, f := range sourceFields { + if f == vectorFieldName { + found = true + break + } + } + if !found { + sourceFields = append(sourceFields, vectorFieldName) + } + req.SelectFields = sourceFields + } + } + + queryBody["size"] = limit + queryBody["from"] = offset + + // Add sorting if specified + if req.OrderBy != nil { + sort := parseOrderByExpr(req.OrderBy) + if len(sort) > 0 { + queryBody["sort"] = sort + } + } + + // Serialize query + var buf bytes.Buffer + if err := json.NewEncoder(&buf).Encode(queryBody); err != nil { + return nil, fmt.Errorf("error encoding query: %w", err) + } + + // Log search details + common.Debug("Elasticsearch searching indices", zap.Strings("indices", req.IndexNames)) + common.Debug("Elasticsearch DSL", zap.Any("dsl", queryBody)) + + // Build search request + reqES := esapi.SearchRequest{ + Index: req.IndexNames, + Body: &buf, + } + + // Execute search + res, err := reqES.Do(ctx, e.client) + if err != nil { + return nil, fmt.Errorf("search failed: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + bodyBytes, err := io.ReadAll(res.Body) + if err != nil { + common.Error("Elasticsearch failed to read error response body", err) + } else { + common.Warn("Elasticsearch error response", zap.String("body", string(bodyBytes))) + } + return nil, fmt.Errorf("Elasticsearch returned error: %s", res.Status()) + } + + // Parse response + var esResp SearchResponse + if err := json.NewDecoder(res.Body).Decode(&esResp); err != nil { + return nil, fmt.Errorf("error parsing response: %w", err) + } + + // Convert to unified response + chunks := convertESResponse(&esResp, vectorFieldName) + return &types.SearchResult{ + Chunks: chunks, + Total: esResp.Hits.Total.Value, + }, nil +} + +// GetChunk gets a chunk by ID +func (e *elasticsearchEngine) GetChunk(ctx context.Context, baseName, chunkID string, datasetIDs []string) (interface{}, error) { + // Build unified search request to get the chunk by ID + searchReq := &types.SearchRequest{ + IndexNames: []string{baseName}, + Limit: 1, + Offset: 0, + Filter: map[string]interface{}{ + "id": chunkID, + }, + } + + // Execute search + searchResp, err := e.Search(ctx, searchReq) + if err != nil { + return nil, fmt.Errorf("failed to search: %w", err) + } + + if len(searchResp.Chunks) == 0 { + return nil, nil + } + + return searchResp.Chunks[0], nil +} + +// GetFields is not implemented for Elasticsearch +func (e *elasticsearchEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} { + common.Warn("GetFields not implemented for Elasticsearch") + return nil +} + +// GetAggregation is not implemented for Elasticsearch +func (e *elasticsearchEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} { + common.Warn("GetAggregation not implemented for Elasticsearch") + return nil +} + +// GetHighlight is not implemented for Elasticsearch +func (e *elasticsearchEngine) GetHighlight(chunks []map[string]interface{}, keywords []string, fieldName string) map[string]string { + common.Warn("GetHighlight not implemented for Elasticsearch") + return nil +} + +// DropChunkStore deletes a chunk index +func (e *elasticsearchEngine) DropChunkStore(ctx context.Context, baseName, datasetID string) error { + return e.dropIndex(ctx, baseName) +} + +// ChunkStoreExists checks if a chunk index exists +func (e *elasticsearchEngine) ChunkStoreExists(ctx context.Context, baseName, datasetID string) (bool, error) { + return e.indexExists(ctx, baseName) +} + +// buildQueryFromCondition builds an ES query from condition map +func (e *elasticsearchEngine) buildQueryFromCondition(condition map[string]interface{}) map[string]interface{} { + if len(condition) == 0 { + return nil + } + + var clauses []map[string]interface{} + + for k, v := range condition { + if v == nil { + continue + } + + switch k { + case "kb_id": + // Handle kb_id as terms query + if listVal, ok := v.([]interface{}); ok { + clauses = append(clauses, map[string]interface{}{ + "terms": map[string]interface{}{k: listVal}, + }) + } else { + clauses = append(clauses, map[string]interface{}{ + "term": map[string]interface{}{k: v}, + }) + } + case "id": + // Handle id as terms or term query + if listVal, ok := v.([]interface{}); ok { + clauses = append(clauses, map[string]interface{}{ + "terms": map[string]interface{}{k: listVal}, + }) + } else { + clauses = append(clauses, map[string]interface{}{ + "term": map[string]interface{}{k: v}, + }) + } + case "available_int": + // Handle available_int as term query + clauses = append(clauses, map[string]interface{}{ + "term": map[string]interface{}{k: v}, + }) + default: + // Default: treat as term query + clauses = append(clauses, map[string]interface{}{ + "term": map[string]interface{}{k: v}, + }) + } + } + + if len(clauses) == 0 { + return nil + } + + if len(clauses) == 1 { + return clauses[0] + } + + return map[string]interface{}{ + "bool": map[string]interface{}{ + "must": clauses, + }, + } +} + +// buildRemoveOperations builds ES script operations for remove +func (e *elasticsearchEngine) buildRemoveOperations(removeData map[string]interface{}, query map[string]interface{}, indexName string) []map[string]interface{} { + // For ES, we handle removals differently - they are typically done via separate update operations + // This is a simplified implementation + return nil +} + +// needsScriptUpdate checks if the update requires a script (more complex operations) +func (e *elasticsearchEngine) needsScriptUpdate(newValue map[string]interface{}) bool { + // Check if any values contain operations that need scripts + return false +} + +// buildUpdateScript builds an ES script for updates +func (e *elasticsearchEngine) buildUpdateScript(newValue map[string]interface{}, removeOperations []map[string]interface{}) map[string]interface{} { + script := map[string]interface{}{ + "source": "ctx._source.putAll(params.doc)", + "params": map[string]interface{}{ + "doc": newValue, + }, + } + return script +} + +// buildMetadataQueryFromCondition builds an ES query for metadata index +func (e *elasticsearchEngine) buildMetadataQueryFromCondition(condition map[string]interface{}) map[string]interface{} { + if len(condition) == 0 { + return nil + } + + var clauses []map[string]interface{} + + for k, v := range condition { + if v == nil { + continue + } + + switch k { + case "kb_id": + if listVal, ok := v.([]interface{}); ok { + clauses = append(clauses, map[string]interface{}{ + "terms": map[string]interface{}{k: listVal}, + }) + } else { + clauses = append(clauses, map[string]interface{}{ + "term": map[string]interface{}{k: v}, + }) + } + case "id": + if listVal, ok := v.([]interface{}); ok { + clauses = append(clauses, map[string]interface{}{ + "terms": map[string]interface{}{k: listVal}, + }) + } else { + clauses = append(clauses, map[string]interface{}{ + "term": map[string]interface{}{k: v}, + }) + } + default: + clauses = append(clauses, map[string]interface{}{ + "term": map[string]interface{}{k: v}, + }) + } + } + + if len(clauses) == 0 { + return nil + } + + if len(clauses) == 1 { + return clauses[0] + } + + return map[string]interface{}{ + "bool": map[string]interface{}{ + "must": clauses, + }, + } +} + +// loadSkillMapping loads the skill index mapping from config file +func loadSkillMapping() (map[string]interface{}, error) { + // Try multiple possible locations for the mapping file + possiblePaths := []string{ + "conf/skill_es_mapping.json", + "../conf/skill_es_mapping.json", + "/app/conf/skill_es_mapping.json", + } + + var data []byte + var err error + for _, path := range possiblePaths { + data, err = os.ReadFile(path) + if err == nil { + break + } + } + + if err != nil { + // Fallback to default skill mapping if file not found + return getDefaultSkillMapping(), nil + } + + var mapping map[string]interface{} + if err := json.Unmarshal(data, &mapping); err != nil { + return nil, fmt.Errorf("failed to parse skill mapping: %w", err) + } + + return mapping, nil +} + +// getDefaultSkillMapping returns the default skill index mapping +func getDefaultSkillMapping() map[string]interface{} { + return map[string]interface{}{ + "settings": map[string]interface{}{ + "index": map[string]interface{}{ + "number_of_shards": 1, + "number_of_replicas": 0, + "refresh_interval": "1000ms", + }, + }, + "mappings": map[string]interface{}{ + "dynamic": false, + "properties": map[string]interface{}{ + "skill_id": map[string]interface{}{ + "type": "keyword", + "store": true, + }, + "name": map[string]interface{}{ + "type": "text", + "index": false, + "store": true, + }, + "name_tks": map[string]interface{}{ + "type": "text", + "analyzer": "whitespace", + "store": true, + }, + "tags": map[string]interface{}{ + "type": "text", + "index": false, + "store": true, + }, + "tags_tks": map[string]interface{}{ + "type": "text", + "analyzer": "whitespace", + "store": true, + }, + "description": map[string]interface{}{ + "type": "text", + "index": false, + "store": true, + }, + "description_tks": map[string]interface{}{ + "type": "text", + "analyzer": "whitespace", + "store": true, + }, + "content": map[string]interface{}{ + "type": "text", + "index": false, + "store": true, + }, + "content_tks": map[string]interface{}{ + "type": "text", + "analyzer": "whitespace", + "store": true, + }, + "q_3072_vec": map[string]interface{}{ + "type": "dense_vector", + "dims": 3072, + "index": true, + "similarity": "cosine", + }, + "q_2560_vec": map[string]interface{}{ + "type": "dense_vector", + "dims": 2560, + "index": true, + "similarity": "cosine", + }, + "q_1536_vec": map[string]interface{}{ + "type": "dense_vector", + "dims": 1536, + "index": true, + "similarity": "cosine", + }, + "q_1024_vec": map[string]interface{}{ + "type": "dense_vector", + "dims": 1024, + "index": true, + "similarity": "cosine", + }, + "q_768_vec": map[string]interface{}{ + "type": "dense_vector", + "dims": 768, + "index": true, + "similarity": "cosine", + }, + "q_512_vec": map[string]interface{}{ + "type": "dense_vector", + "dims": 512, + "index": true, + "similarity": "cosine", + }, + "q_256_vec": map[string]interface{}{ + "type": "dense_vector", + "dims": 256, + "index": true, + "similarity": "cosine", + }, + "version": map[string]interface{}{ + "type": "keyword", + "store": true, + }, + "status": map[string]interface{}{ + "type": "keyword", + "store": true, + }, + "create_time": map[string]interface{}{ + "type": "long", + "store": true, + }, + "update_time": map[string]interface{}{ + "type": "long", + "store": true, + }, + }, + }, + } +} + +// calculatePagination calculates offset and limit based on page, size and topK +func calculatePagination(page, size, topK int) (int, int) { + if page < 1 { + page = 1 + } + if size <= 0 { + size = 30 + } + if topK <= 0 { + topK = 1024 + } + + RERANK_LIMIT := max(30, (64/size)*size) + if RERANK_LIMIT < size { + RERANK_LIMIT = size + } + if RERANK_LIMIT > topK { + RERANK_LIMIT = topK + } + + offset := (page - 1) * RERANK_LIMIT + if offset < 0 { + offset = 0 + } + + return offset, RERANK_LIMIT +} + +// buildFilterClauses builds ES filter clauses from kb_ids and available_int +// Reference: rag/utils/es_conn.py L60-L78 +// When available=0: available_int < 1 +// When available!=0: NOT (available_int < 1) +func buildFilterClauses(datasetIDs []string, available int) []map[string]interface{} { + var filters []map[string]interface{} + + if len(datasetIDs) > 0 { + filters = append(filters, map[string]interface{}{ + "terms": map[string]interface{}{"kb_id": datasetIDs}, + }) + } + + // Add available_int filter + // Reference: rag/utils/es_conn.py L63-L68 + if available == 0 { + // available_int < 1 + filters = append(filters, map[string]interface{}{ + "range": map[string]interface{}{ + "available_int": map[string]interface{}{ + "lt": 1, + }, + }, + }) + } else { + // must_not: available_int < 1 (i.e., available_int >= 1) + filters = append(filters, map[string]interface{}{ + "bool": map[string]interface{}{ + "must_not": []map[string]interface{}{ + { + "range": map[string]interface{}{ + "available_int": map[string]interface{}{ + "lt": 1, + }, + }, + }, + }, + }, + }) + } + + return filters +} + +// buildSkillFilterClauses builds ES filter clauses for skill index +// Skill index uses 'status' field instead of 'available_int' +func buildSkillFilterClauses() []map[string]interface{} { + // Filter for active skills (status = "1") + return []map[string]interface{}{ + { + "term": map[string]interface{}{ + "status": "1", + }, + }, + } +} + +// buildFilterFromMap converts a generic filter map to ES filter clauses +func buildFilterFromMap(filter map[string]interface{}) []map[string]interface{} { + var filters []map[string]interface{} + for field, value := range filter { + switch v := value.(type) { + case []string: + filters = append(filters, map[string]interface{}{ + "terms": map[string]interface{}{field: v}, + }) + case []interface{}: + filters = append(filters, map[string]interface{}{ + "terms": map[string]interface{}{field: v}, + }) + default: + filters = append(filters, map[string]interface{}{ + "term": map[string]interface{}{field: v}, + }) + } + } + return filters +} + +// buildESKeywordQuery builds keyword-only search query for ES +// Uses query_string if matchText is in query_string format, otherwise uses multi_match +// boost is applied to the text match clause (query_string or multi_match) +func buildESKeywordQuery(matchText string, filterClauses []map[string]interface{}, boost float64) map[string]interface{} { + var mustClause map[string]interface{} + + // Handle wildcard query (match all) + if matchText == "*" || matchText == "" { + mustClause = map[string]interface{}{ + "match_all": map[string]interface{}{}, + } + } else { + // Use query_string for complex queries + queryString := map[string]interface{}{ + "query": matchText, + "fields": []string{"title_tks^10", "title_sm_tks^5", "important_kwd^30", "important_tks^20", "question_tks^20", "content_ltks^2", "content_sm_ltks"}, + "type": "best_fields", + "minimum_should_match": "30%", + "boost": boost, + } + mustClause = map[string]interface{}{ + "query_string": queryString, + } + } + + return map[string]interface{}{ + "bool": map[string]interface{}{ + "must": mustClause, + "filter": filterClauses, + }, + } +} + +// buildSkillKeywordQuery builds keyword-only search query for skill index +// Skill index uses different field names: name_tks, tags_tks, description_tks, content_tks +func buildSkillKeywordQuery(matchText string, filterClauses []map[string]interface{}, boost float64) map[string]interface{} { + var mustClause map[string]interface{} + + // Handle wildcard query (match all) + if matchText == "*" || matchText == "" { + mustClause = map[string]interface{}{ + "match_all": map[string]interface{}{}, + } + } else { + // Use query_string for complex queries with skill-specific fields + queryString := map[string]interface{}{ + "query": matchText, + "fields": []string{"name_tks^10", "tags_tks^5", "description_tks^3", "content_tks^1"}, + "type": "best_fields", + "minimum_should_match": "30%", + "boost": boost, + } + mustClause = map[string]interface{}{ + "query_string": queryString, + } + } + + return map[string]interface{}{ + "bool": map[string]interface{}{ + "must": mustClause, + "filter": filterClauses, + }, + } +} + +// convertESResponse converts ES SearchResponse to unified chunks format +func convertESResponse(esResp *SearchResponse, vectorFieldName string) []map[string]interface{} { + if esResp == nil || esResp.Hits.Hits == nil { + return []map[string]interface{}{} + } + + chunks := make([]map[string]interface{}, len(esResp.Hits.Hits)) + for i, hit := range esResp.Hits.Hits { + chunks[i] = hit.Source + chunks[i]["_score"] = hit.Score + chunks[i]["_id"] = hit.ID + } + return chunks +} + +// parseOrderByExpr parses the OrderBy expression into ES sort format +func parseOrderByExpr(orderBy *types.OrderByExpr) []map[string]interface{} { + if orderBy == nil || len(orderBy.Fields) == 0 { + return nil + } + + var result []map[string]interface{} + for _, field := range orderBy.Fields { + direction := "asc" + if field.Type == types.SortDesc { + direction = "desc" + } + + if field.Field == "_score" || field.Field == "score" { + result = append(result, map[string]interface{}{ + "_score": direction, + }) + } else { + result = append(result, map[string]interface{}{ + field.Field: direction, + }) + } + } + + return result +} + +// Helper query builder functions (legacy) + +// BuildMatchTextQuery builds a text match query +func BuildMatchTextQuery(fields []string, text string, fuzziness string) map[string]interface{} { + query := map[string]interface{}{ + "multi_match": map[string]interface{}{ + "query": text, + "fields": fields, + }, + } + + if fuzziness != "" { + if multiMatch, ok := query["multi_match"].(map[string]interface{}); ok { + multiMatch["fuzziness"] = fuzziness + } + } + + return query +} + +// BuildTermQuery builds a term query +func BuildTermQuery(field string, value interface{}) map[string]interface{} { + return map[string]interface{}{ + "term": map[string]interface{}{ + field: value, + }, + } +} + +// BuildRangeQuery builds a range query +func BuildRangeQuery(field string, from, to interface{}) map[string]interface{} { + rangeQuery := make(map[string]interface{}) + if from != nil { + rangeQuery["gte"] = from + } + if to != nil { + rangeQuery["lte"] = to + } + + return map[string]interface{}{ + "range": map[string]interface{}{ + field: rangeQuery, + }, + } +} + +// BuildBoolQuery builds a bool query +func BuildBoolQuery() map[string]interface{} { + return map[string]interface{}{ + "bool": make(map[string]interface{}), + } +} + +// AddMust adds must clause to bool query +func AddMust(query map[string]interface{}, clauses ...map[string]interface{}) { + if boolQuery, ok := query["bool"].(map[string]interface{}); ok { + if _, exists := boolQuery["must"]; !exists { + boolQuery["must"] = []map[string]interface{}{} + } + if must, ok := boolQuery["must"].([]map[string]interface{}); ok { + boolQuery["must"] = append(must, clauses...) + } + } +} + +// AddShould adds should clause to bool query +func AddShould(query map[string]interface{}, clauses ...map[string]interface{}) { + if boolQuery, ok := query["bool"].(map[string]interface{}); ok { + if _, exists := boolQuery["should"]; !exists { + boolQuery["should"] = []map[string]interface{}{} + } + if should, ok := boolQuery["should"].([]map[string]interface{}); ok { + boolQuery["should"] = append(should, clauses...) + } + } +} + +// AddFilter adds filter clause to bool query +func AddFilter(query map[string]interface{}, clauses ...map[string]interface{}) { + if boolQuery, ok := query["bool"].(map[string]interface{}); ok { + if _, exists := boolQuery["filter"]; !exists { + boolQuery["filter"] = []map[string]interface{}{} + } + if filter, ok := boolQuery["filter"].([]map[string]interface{}); ok { + boolQuery["filter"] = append(filter, clauses...) + } + } +} + +// AddMustNot adds must_not clause to bool query +func AddMustNot(query map[string]interface{}, clauses ...map[string]interface{}) { + if boolQuery, ok := query["bool"].(map[string]interface{}); ok { + if _, exists := boolQuery["must_not"]; !exists { + boolQuery["must_not"] = []map[string]interface{}{} + } + if mustNot, ok := boolQuery["must_not"].([]map[string]interface{}); ok { + boolQuery["must_not"] = append(mustNot, clauses...) + } + } +} + +// GetDocIDs is not implemented for Elasticsearch +func (e *elasticsearchEngine) GetDocIDs(chunks []map[string]interface{}) []string { + common.Warn("GetDocIDs not implemented for Elasticsearch") + return nil +} \ No newline at end of file diff --git a/internal/engine/elasticsearch/common.go b/internal/engine/elasticsearch/common.go new file mode 100644 index 00000000000..e4bf5e1bed5 --- /dev/null +++ b/internal/engine/elasticsearch/common.go @@ -0,0 +1,98 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "context" + "fmt" + "io" + + "github.com/elastic/go-elasticsearch/v8/esapi" +) + +// dropIndex deletes an index +func (e *elasticsearchEngine) dropIndex(ctx context.Context, indexName string) error { + if indexName == "" { + return fmt.Errorf("index name cannot be empty") + } + + // Check if index exists + exists, err := e.indexExists(ctx, indexName) + if err != nil { + return fmt.Errorf("failed to check index existence: %w", err) + } + if !exists { + return fmt.Errorf("index '%s' does not exist", indexName) + } + + // Delete index + req := esapi.IndicesDeleteRequest{ + Index: []string{indexName}, + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to delete index: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + bodyBytes, _ := io.ReadAll(res.Body) + reason := extractErrorReason(bodyBytes) + if reason != "" { + return fmt.Errorf("elasticsearch error: %s", reason) + } + return fmt.Errorf("elasticsearch returned error: %s", res.Status()) + } + + return nil +} + +// indexExists checks if index exists +func (e *elasticsearchEngine) indexExists(ctx context.Context, indexName string) (bool, error) { + if indexName == "" { + return false, fmt.Errorf("index name cannot be empty") + } + + req := esapi.IndicesExistsRequest{ + Index: []string{indexName}, + } + + res, err := req.Do(ctx, e.client) + if err != nil { + return false, fmt.Errorf("failed to check index existence: %w", err) + } + defer res.Body.Close() + + if res.StatusCode == 200 { + return true, nil + } else if res.StatusCode == 404 { + return false, nil + } + + bodyBytes, _ := io.ReadAll(res.Body) + reason := extractErrorReason(bodyBytes) + if reason != "" { + return false, fmt.Errorf("elasticsearch error: %s", reason) + } + return false, fmt.Errorf("elasticsearch returned error: %s", res.Status()) +} + +// buildMetadataIndexName returns the metadata index name for a tenant +func buildMetadataIndexName(tenantID string) string { + return fmt.Sprintf("ragflow_doc_meta_%s", tenantID) +} \ No newline at end of file diff --git a/internal/engine/elasticsearch/get.go b/internal/engine/elasticsearch/get.go deleted file mode 100644 index 625bacdda70..00000000000 --- a/internal/engine/elasticsearch/get.go +++ /dev/null @@ -1,49 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package elasticsearch - -import ( - "context" - "fmt" - - "ragflow/internal/engine/types" -) - -// GetChunk gets a chunk by ID -func (e *elasticsearchEngine) GetChunk(ctx context.Context, indexName, chunkID string, kbIDs []string) (interface{}, error) { - // Build unified search request to get the chunk by ID - searchReq := &types.SearchRequest{ - IndexNames: []string{indexName}, - Limit: 1, - Offset: 0, - Filter: map[string]interface{}{ - "id": chunkID, - }, - } - - // Execute search - searchResp, err := e.Search(ctx, searchReq) - if err != nil { - return nil, fmt.Errorf("failed to search: %w", err) - } - - if len(searchResp.Chunks) == 0 { - return nil, nil - } - - return searchResp.Chunks[0], nil -} \ No newline at end of file diff --git a/internal/engine/elasticsearch/index.go b/internal/engine/elasticsearch/index.go deleted file mode 100644 index 7e601acae3f..00000000000 --- a/internal/engine/elasticsearch/index.go +++ /dev/null @@ -1,363 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package elasticsearch - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "os" - - "github.com/elastic/go-elasticsearch/v8/esapi" -) - -// CreateDataset creates an index -func (e *elasticsearchEngine) CreateDataset(ctx context.Context, indexName, datasetID string, vectorSize int, parserID string) error { - if indexName == "" { - return fmt.Errorf("index name cannot be empty") - } - - // Check if index already exists - exists, err := e.TableExists(ctx, indexName) - if err != nil { - return fmt.Errorf("failed to check index existence: %w", err) - } - if exists { - return fmt.Errorf("index '%s' already exists", indexName) - } - - // Load mapping based on index type - var mapping map[string]interface{} - if datasetID == "skill" { - // Load skill-specific mapping - skillMapping, err := loadSkillMapping() - if err != nil { - return fmt.Errorf("failed to load skill mapping: %w", err) - } - mapping = skillMapping - } else { - // Default mapping for dataset - mapping = map[string]interface{}{ - "settings": map[string]interface{}{ - "number_of_shards": 1, - "number_of_replicas": 0, - }, - } - } - - // Prepare request body - var body io.Reader - if mapping != nil { - data, err := json.Marshal(mapping) - if err != nil { - return fmt.Errorf("failed to marshal mapping: %w", err) - } - body = bytes.NewReader(data) - } - - // Create index - req := esapi.IndicesCreateRequest{ - Index: indexName, - Body: body, - } - - res, err := req.Do(ctx, e.client) - if err != nil { - return fmt.Errorf("failed to create index: %w", err) - } - defer res.Body.Close() - - if res.IsError() { - bodyBytes, _ := io.ReadAll(res.Body) - reason := extractErrorReason(bodyBytes) - if reason != "" { - return fmt.Errorf("elasticsearch error: %s", reason) - } - return fmt.Errorf("elasticsearch returned error: %s, body: %s", res.Status(), string(bodyBytes)) - } - - // Parse response - var result map[string]interface{} - if err := json.NewDecoder(res.Body).Decode(&result); err != nil { - return fmt.Errorf("failed to parse response: %w", err) - } - - acknowledged, ok := result["acknowledged"].(bool) - if !ok || !acknowledged { - return fmt.Errorf("index creation not acknowledged") - } - - return nil -} - -// loadSkillMapping loads the skill index mapping from config file -func loadSkillMapping() (map[string]interface{}, error) { - // Try multiple possible locations for the mapping file - possiblePaths := []string{ - "conf/skill_es_mapping.json", - "../conf/skill_es_mapping.json", - "/app/conf/skill_es_mapping.json", - } - - var data []byte - var err error - for _, path := range possiblePaths { - data, err = os.ReadFile(path) - if err == nil { - break - } - } - - if err != nil { - // Fallback to default skill mapping if file not found - return getDefaultSkillMapping(), nil - } - - var mapping map[string]interface{} - if err := json.Unmarshal(data, &mapping); err != nil { - return nil, fmt.Errorf("failed to parse skill mapping: %w", err) - } - - return mapping, nil -} - -// getDefaultSkillMapping returns the default skill index mapping -func getDefaultSkillMapping() map[string]interface{} { - return map[string]interface{}{ - "settings": map[string]interface{}{ - "index": map[string]interface{}{ - "number_of_shards": 1, - "number_of_replicas": 0, - "refresh_interval": "1000ms", - }, - }, - "mappings": map[string]interface{}{ - "dynamic": false, - "properties": map[string]interface{}{ - "skill_id": map[string]interface{}{ - "type": "keyword", - "store": true, - }, - "name": map[string]interface{}{ - "type": "text", - "index": false, - "store": true, - }, - "name_tks": map[string]interface{}{ - "type": "text", - "analyzer": "whitespace", - "store": true, - }, - "tags": map[string]interface{}{ - "type": "text", - "index": false, - "store": true, - }, - "tags_tks": map[string]interface{}{ - "type": "text", - "analyzer": "whitespace", - "store": true, - }, - "description": map[string]interface{}{ - "type": "text", - "index": false, - "store": true, - }, - "description_tks": map[string]interface{}{ - "type": "text", - "analyzer": "whitespace", - "store": true, - }, - "content": map[string]interface{}{ - "type": "text", - "index": false, - "store": true, - }, - "content_tks": map[string]interface{}{ - "type": "text", - "analyzer": "whitespace", - "store": true, - }, - "q_3072_vec": map[string]interface{}{ - "type": "dense_vector", - "dims": 3072, - "index": true, - "similarity": "cosine", - }, - "q_2560_vec": map[string]interface{}{ - "type": "dense_vector", - "dims": 2560, - "index": true, - "similarity": "cosine", - }, - "q_1536_vec": map[string]interface{}{ - "type": "dense_vector", - "dims": 1536, - "index": true, - "similarity": "cosine", - }, - "q_1024_vec": map[string]interface{}{ - "type": "dense_vector", - "dims": 1024, - "index": true, - "similarity": "cosine", - }, - "q_768_vec": map[string]interface{}{ - "type": "dense_vector", - "dims": 768, - "index": true, - "similarity": "cosine", - }, - "q_512_vec": map[string]interface{}{ - "type": "dense_vector", - "dims": 512, - "index": true, - "similarity": "cosine", - }, - "q_256_vec": map[string]interface{}{ - "type": "dense_vector", - "dims": 256, - "index": true, - "similarity": "cosine", - }, - "version": map[string]interface{}{ - "type": "keyword", - "store": true, - }, - "status": map[string]interface{}{ - "type": "keyword", - "store": true, - }, - "create_time": map[string]interface{}{ - "type": "long", - "store": true, - }, - "update_time": map[string]interface{}{ - "type": "long", - "store": true, - }, - }, - }, - } -} - -// DropTable deletes an index -func (e *elasticsearchEngine) DropTable(ctx context.Context, indexName string) error { - if indexName == "" { - return fmt.Errorf("index name cannot be empty") - } - - // Check if index exists - exists, err := e.TableExists(ctx, indexName) - if err != nil { - return fmt.Errorf("failed to check index existence: %w", err) - } - if !exists { - return fmt.Errorf("index '%s' does not exist", indexName) - } - - // Delete index - req := esapi.IndicesDeleteRequest{ - Index: []string{indexName}, - } - - res, err := req.Do(ctx, e.client) - if err != nil { - return fmt.Errorf("failed to delete index: %w", err) - } - defer res.Body.Close() - - if res.IsError() { - bodyBytes, _ := io.ReadAll(res.Body) - reason := extractErrorReason(bodyBytes) - if reason != "" { - return fmt.Errorf("elasticsearch error: %s", reason) - } - return fmt.Errorf("elasticsearch returned error: %s", res.Status()) - } - - return nil -} - -// TableExists checks if index exists -func (e *elasticsearchEngine) TableExists(ctx context.Context, indexName string) (bool, error) { - if indexName == "" { - return false, fmt.Errorf("index name cannot be empty") - } - - req := esapi.IndicesExistsRequest{ - Index: []string{indexName}, - } - - res, err := req.Do(ctx, e.client) - if err != nil { - return false, fmt.Errorf("failed to check index existence: %w", err) - } - defer res.Body.Close() - - if res.StatusCode == 200 { - return true, nil - } else if res.StatusCode == 404 { - return false, nil - } - - bodyBytes, _ := io.ReadAll(res.Body) - reason := extractErrorReason(bodyBytes) - if reason != "" { - return false, fmt.Errorf("elasticsearch error: %s", reason) - } - return false, fmt.Errorf("elasticsearch returned error: %s", res.Status()) -} - -// CreateMetadata creates the document metadata index -func (e *elasticsearchEngine) CreateMetadata(ctx context.Context, indexName string) error { - // TODO - return nil -} - -// InsertDataset inserts documents into a dataset index -func (e *elasticsearchEngine) InsertDataset(ctx context.Context, documents []map[string]interface{}, indexName string, knowledgebaseID string) ([]string, error) { - // TODO - return []string{}, nil -} - -// InsertMetadata inserts documents into tenant's metadata index -func (e *elasticsearchEngine) InsertMetadata(ctx context.Context, documents []map[string]interface{}, tenantID string) ([]string, error) { - // TODO - return []string{}, nil -} - - -// UpdateDataset updates a chunk by condition -func (e *elasticsearchEngine) UpdateDataset(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, tableNamePrefix string, knowledgebaseID string) error { - // TODO - return nil -} - -// UpdateMetadata updates document metadata in tenant's metadata index -func (e *elasticsearchEngine) UpdateMetadata(ctx context.Context, docID string, kbID string, metaFields map[string]interface{}, tenantID string) error { - // TODO - return nil -} - -// Delete deletes rows from either a dataset index or metadata index -func (e *elasticsearchEngine) Delete(ctx context.Context, condition map[string]interface{}, indexName string, datasetID string) (int64, error) { - // TODO - return 0, nil -} diff --git a/internal/engine/elasticsearch/metadata.go b/internal/engine/elasticsearch/metadata.go new file mode 100644 index 00000000000..27270868082 --- /dev/null +++ b/internal/engine/elasticsearch/metadata.go @@ -0,0 +1,275 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package elasticsearch + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/elastic/go-elasticsearch/v8/esapi" + "ragflow/internal/common" + + "go.uber.org/zap" +) + +// CreateMetadataStore creates the document metadata index +func (e *elasticsearchEngine) CreateMetadataStore(ctx context.Context, tenantID string) error { + indexName := buildMetadataIndexName(tenantID) + req := esapi.IndicesCreateRequest{ + Index: indexName, + } + res, err := req.Do(ctx, e.client) + if err != nil { + return fmt.Errorf("failed to create metadata index: %w", err) + } + defer res.Body.Close() + if res.IsError() { + return fmt.Errorf("elasticsearch returned error: %s", res.Status()) + } + return nil +} + +// InsertMetadata inserts documents into tenant's metadata index +func (e *elasticsearchEngine) InsertMetadata(ctx context.Context, metadata []map[string]interface{}, tenantID string) ([]string, error) { + indexName := buildMetadataIndexName(tenantID) + common.Info("Inserting metadata into Elasticsearch index", zap.String("index_name", indexName), zap.String("tenant_id", tenantID), zap.Int("doc_count", len(metadata))) + + if len(metadata) == 0 { + return []string{}, nil + } + + if indexName == "" { + return nil, fmt.Errorf("index name cannot be empty") + } + + // Check if index exists, create if not + exists, err := e.indexExists(ctx, indexName) + if err != nil { + common.Error("Failed to check index existence", err) + return nil, fmt.Errorf("failed to check index existence: %w", err) + } + if !exists { + // Create metadata index + if createErr := e.CreateMetadataStore(ctx, tenantID); createErr != nil { + return nil, fmt.Errorf("failed to create metadata index: %w", createErr) + } + } + + // Build bulk request body + var buf bytes.Buffer + for _, doc := range metadata { + // Action line - index operation + action := map[string]interface{}{ + "index": map[string]interface{}{ + "_index": indexName, + }, + } + actionBytes, err := json.Marshal(action) + if err != nil { + common.Error("Failed to marshal bulk action", err) + return nil, fmt.Errorf("failed to marshal bulk action: %w", err) + } + buf.Write(actionBytes) + buf.WriteByte('\n') + + // Document line - meta_fields is stored as-is (ES can handle nested objects) + docBytes, err := json.Marshal(doc) + if err != nil { + common.Error("Failed to marshal document", err) + return nil, fmt.Errorf("failed to marshal document: %w", err) + } + buf.Write(docBytes) + buf.WriteByte('\n') + } + + // Execute bulk request + req := esapi.BulkRequest{ + Body: bytes.NewReader(buf.Bytes()), + Refresh: "false", + } + + res, err := req.Do(ctx, e.client) + if err != nil { + common.Error("Failed to execute bulk request", err) + return nil, fmt.Errorf("failed to execute bulk request: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + common.Sugar.Errorw("Elasticsearch bulk request returned error", "status", res.Status()) + return nil, fmt.Errorf("elasticsearch bulk request returned error: %s", res.Status()) + } + + // Parse bulk response to check for errors + var bulkResponse map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&bulkResponse); err != nil { + common.Error("Failed to parse bulk response", err) + return nil, fmt.Errorf("failed to parse bulk response: %w", err) + } + + // Check for errors in bulk response + if errors, ok := bulkResponse["errors"].(bool); ok && errors { + common.Warn("Bulk request had some errors") + } + + common.Info("Successfully inserted metadata into Elasticsearch index", zap.String("index_name", indexName), zap.Int("doc_count", len(metadata))) + return []string{}, nil +} + +// UpdateMetadata updates document metadata in tenant's metadata index +func (e *elasticsearchEngine) UpdateMetadata(ctx context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error { + indexName := buildMetadataIndexName(tenantID) + common.Info("Updating metadata in Elasticsearch index", zap.String("index_name", indexName), zap.String("docID", docID), zap.String("datasetID", datasetID)) + + // Check if index exists + exists, err := e.indexExists(ctx, indexName) + if err != nil { + return fmt.Errorf("failed to check index existence: %w", err) + } + if !exists { + return fmt.Errorf("index '%s' does not exist", indexName) + } + + // Build the document ID for update + docID = strings.ReplaceAll(docID, "'", "''") + datasetIDStr := strings.ReplaceAll(datasetID, "'", "''") + + // Build update body - merge meta_fields with existing + query := map[string]interface{}{ + "bool": map[string]interface{}{ + "must": []map[string]interface{}{ + {"term": map[string]interface{}{"id": docID}}, + {"term": map[string]interface{}{"kb_id": datasetIDStr}}, + }, + }, + } + + updateReq := map[string]interface{}{ + "query": query, + "script": map[string]interface{}{ + "source": "ctx._source.meta_fields = params.meta_fields", + "params": map[string]interface{}{ + "meta_fields": metaFields, + }, + }, + } + + updateBytes, err := json.Marshal(updateReq) + if err != nil { + return fmt.Errorf("failed to marshal update request: %w", err) + } + + req := esapi.UpdateByQueryRequest{ + Index: []string{indexName}, + Body: bytes.NewReader(updateBytes), + } + + res, err := req.Do(ctx, e.client) + if err != nil { + common.Error("Failed to execute update by query", err) + return fmt.Errorf("failed to execute update by query: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + common.Sugar.Errorw("Elasticsearch update by query returned error", "status", res.Status()) + return fmt.Errorf("elasticsearch update by query returned error: %s", res.Status()) + } + + common.Info("Successfully updated metadata in Elasticsearch index", zap.String("index_name", indexName), zap.String("docID", docID)) + return nil +} + +// DeleteMetadata deletes metadata from tenant's metadata index by condition +func (e *elasticsearchEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) { + indexName := buildMetadataIndexName(tenantID) + common.Info("Deleting metadata from Elasticsearch index", zap.String("index_name", indexName), zap.Any("condition", condition)) + + // Check if index exists + exists, err := e.indexExists(ctx, indexName) + if err != nil { + return 0, fmt.Errorf("failed to check index existence: %w", err) + } + if !exists { + common.Warn(fmt.Sprintf("Index %s does not exist, skipping delete", indexName)) + return 0, nil + } + + // Build query from condition + query := e.buildMetadataQueryFromCondition(condition) + if query == nil { + query = map[string]interface{}{"match_all": map[string]interface{}{}} + } + + // Build delete by query body + deleteBody := map[string]interface{}{ + "query": query, + } + + bodyBytes, err := json.Marshal(deleteBody) + if err != nil { + return 0, fmt.Errorf("failed to marshal delete body: %w", err) + } + + // Execute delete by query + req := esapi.DeleteByQueryRequest{ + Index: []string{indexName}, + Body: bytes.NewReader(bodyBytes), + } + + res, err := req.Do(ctx, e.client) + if err != nil { + common.Error("Failed to execute delete by query", err) + return 0, fmt.Errorf("failed to execute delete by query: %w", err) + } + defer res.Body.Close() + + if res.IsError() { + common.Sugar.Errorw("Elasticsearch delete by query returned error", "status", res.Status()) + return 0, fmt.Errorf("elasticsearch delete by query returned error: %s", res.Status()) + } + + // Parse response + var result map[string]interface{} + if err := json.NewDecoder(res.Body).Decode(&result); err != nil { + common.Error("Failed to parse delete response", err) + return 0, fmt.Errorf("failed to parse delete response: %w", err) + } + + deleted := int64(0) + if d, ok := result["deleted"].(float64); ok { + deleted = int64(d) + } + + common.Info("Successfully deleted metadata", zap.String("index_name", indexName), zap.Int64("deleted_count", deleted)) + return deleted, nil +} + +// DropMetadataStore drops a metadata index from Elasticsearch +func (e *elasticsearchEngine) DropMetadataStore(ctx context.Context, tenantID string) error { + indexName := buildMetadataIndexName(tenantID) + return e.dropIndex(ctx, indexName) +} + +// MetadataStoreExists checks if a metadata index exists in Elasticsearch +func (e *elasticsearchEngine) MetadataStoreExists(ctx context.Context, tenantID string) (bool, error) { + indexName := buildMetadataIndexName(tenantID) + return e.indexExists(ctx, indexName) +} \ No newline at end of file diff --git a/internal/engine/elasticsearch/search.go b/internal/engine/elasticsearch/search.go deleted file mode 100644 index b3c68fbc11b..00000000000 --- a/internal/engine/elasticsearch/search.go +++ /dev/null @@ -1,583 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package elasticsearch - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io" - "ragflow/internal/common" - "strings" - - "github.com/elastic/go-elasticsearch/v8/esapi" - "go.uber.org/zap" - - "ragflow/internal/engine/types" -) - -// SearchResponse Elasticsearch search response -type SearchResponse struct { - Hits struct { - Total struct { - Value int64 `json:"value"` - } `json:"total"` - Hits []struct { - ID string `json:"_id"` - Score float64 `json:"_score"` - Source map[string]interface{} `json:"_source"` - } `json:"hits"` - } `json:"hits"` - Aggregations map[string]interface{} `json:"aggregations"` -} - -// Search executes search with unified types.SearchRequest -func (e *elasticsearchEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { - return e.searchUnified(ctx, req) -} - -// searchUnified handles the unified types.SearchRequest -func (e *elasticsearchEngine) searchUnified(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { - if len(req.IndexNames) == 0 { - return nil, fmt.Errorf("index names cannot be empty") - } - - // Build pagination parameters - offset := req.Offset - limit := req.Limit - if limit <= 0 { - limit = 30 // default ES size - } - - // Check if this is a skill index - isSkillIndex := len(req.IndexNames) > 0 && strings.HasPrefix(req.IndexNames[0], "skill_") - - // Build filter clauses - var filterClauses []map[string]interface{} - if isSkillIndex { - filterClauses = buildSkillFilterClauses() - } else { - filterClauses = buildFilterClauses(req.KbIDs, 1) - } - - // Add filters from req.Filter - if req.Filter != nil && len(req.Filter) > 0 { - filterClauses = append(filterClauses, buildFilterFromMap(req.Filter)...) - } - - // Build search query body - queryBody := make(map[string]interface{}) - - // Determine search type from MatchExprs - var matchText string - var matchDense *types.MatchDenseExpr - var hasVectorMatch bool - - for _, expr := range req.MatchExprs { - if expr == nil { - continue - } - switch e := expr.(type) { - case string: - matchText = e - case *types.MatchTextExpr: - matchText = e.MatchingText - case *types.MatchDenseExpr: - hasVectorMatch = true - matchDense = e - } - } - - var vectorFieldName string - if !hasVectorMatch || matchDense == nil { - // Keyword-only search - if isSkillIndex { - queryBody["query"] = buildSkillKeywordQuery(matchText, filterClauses, 1.0) - } else { - queryBody["query"] = buildESKeywordQuery(matchText, filterClauses, 1.0) - } - } else { - // Hybrid search: keyword + vector - textWeight := 0.7 // default: vector weight = 0.3 - vectorWeight := 0.3 - if matchDense.ExtraOptions != nil { - if vw, ok := matchDense.ExtraOptions["text_weight"].(float64); ok { - textWeight = vw - } - if vw, ok := matchDense.ExtraOptions["vector_weight"].(float64); ok { - vectorWeight = vw - } - } - - // Build boolean query for text match and filters - var boolQuery map[string]interface{} - if isSkillIndex { - boolQuery = buildSkillKeywordQuery(matchText, filterClauses, 1.0) - } else { - boolQuery = buildESKeywordQuery(matchText, filterClauses, 1.0) - } - // Add boost to the bool query (as in Python code) - if boolMap, ok := boolQuery["bool"].(map[string]interface{}); ok { - boolMap["boost"] = textWeight - } - - // Build kNN query - vectorData := matchDense.EmbeddingData - vectorFieldName = matchDense.VectorColumnName - k := matchDense.TopN - if k <= 0 { - k = req.Limit - } - if k <= 0 { - k = 1024 - } - numCandidates := k * 2 - - similarity := 0.0 - if matchDense.ExtraOptions != nil { - if sim, ok := matchDense.ExtraOptions["similarity"].(float64); ok { - similarity = sim - } - } - - knnQuery := map[string]interface{}{ - "field": vectorFieldName, - "query_vector": vectorData, - "k": k, - "num_candidates": numCandidates, - "similarity": similarity, - "boost": vectorWeight, - } - - queryBody["knn"] = knnQuery - queryBody["query"] = boolQuery - - // Add vector column to Source fields (matching Python ES: src.append(f"q_{len(q_vec)}_vec")) - // Only modify Source if it was explicitly set by the caller - if vectorFieldName != "" && len(req.SelectFields) > 0 { - sourceFields := req.SelectFields - found := false - for _, f := range sourceFields { - if f == vectorFieldName { - found = true - break - } - } - if !found { - sourceFields = append(sourceFields, vectorFieldName) - } - req.SelectFields = sourceFields - } - } - - queryBody["size"] = limit - queryBody["from"] = offset - - // Add sorting if specified - if req.OrderBy != nil { - sort := parseOrderByExpr(req.OrderBy) - if len(sort) > 0 { - queryBody["sort"] = sort - } - } - - // Serialize query - var buf bytes.Buffer - if err := json.NewEncoder(&buf).Encode(queryBody); err != nil { - return nil, fmt.Errorf("error encoding query: %w", err) - } - - // Log search details - common.Debug("Elasticsearch searching indices", zap.Strings("indices", req.IndexNames)) - common.Debug("Elasticsearch DSL", zap.Any("dsl", queryBody)) - - // Build search request - reqES := esapi.SearchRequest{ - Index: req.IndexNames, - Body: &buf, - } - - // Execute search - res, err := reqES.Do(ctx, e.client) - if err != nil { - return nil, fmt.Errorf("search failed: %w", err) - } - defer res.Body.Close() - - if res.IsError() { - bodyBytes, err := io.ReadAll(res.Body) - if err != nil { - common.Error("Elasticsearch failed to read error response body", err) - } else { - common.Warn("Elasticsearch error response", zap.String("body", string(bodyBytes))) - } - return nil, fmt.Errorf("Elasticsearch returned error: %s", res.Status()) - } - - // Parse response - var esResp SearchResponse - if err := json.NewDecoder(res.Body).Decode(&esResp); err != nil { - return nil, fmt.Errorf("error parsing response: %w", err) - } - - // Convert to unified response - chunks := convertESResponse(&esResp, vectorFieldName) - return &types.SearchResult{ - Chunks: chunks, - Total: esResp.Hits.Total.Value, - }, nil -} - -// calculatePagination calculates offset and limit based on page, size and topK -func calculatePagination(page, size, topK int) (int, int) { - if page < 1 { - page = 1 - } - if size <= 0 { - size = 30 - } - if topK <= 0 { - topK = 1024 - } - - RERANK_LIMIT := max(30, (64/size)*size) - if RERANK_LIMIT < size { - RERANK_LIMIT = size - } - if RERANK_LIMIT > topK { - RERANK_LIMIT = topK - } - - offset := (page - 1) * RERANK_LIMIT - if offset < 0 { - offset = 0 - } - - return offset, RERANK_LIMIT -} - -// buildFilterClauses builds ES filter clauses from kb_ids and available_int -// Reference: rag/utils/es_conn.py L60-L78 -// When available=0: available_int < 1 -// When available!=0: NOT (available_int < 1) -func buildFilterClauses(kbIDs []string, available int) []map[string]interface{} { - var filters []map[string]interface{} - - if len(kbIDs) > 0 { - filters = append(filters, map[string]interface{}{ - "terms": map[string]interface{}{"kb_id": kbIDs}, - }) - } - - // Add available_int filter - // Reference: rag/utils/es_conn.py L63-L68 - if available == 0 { - // available_int < 1 - filters = append(filters, map[string]interface{}{ - "range": map[string]interface{}{ - "available_int": map[string]interface{}{ - "lt": 1, - }, - }, - }) - } else { - // must_not: available_int < 1 (i.e., available_int >= 1) - filters = append(filters, map[string]interface{}{ - "bool": map[string]interface{}{ - "must_not": []map[string]interface{}{ - { - "range": map[string]interface{}{ - "available_int": map[string]interface{}{ - "lt": 1, - }, - }, - }, - }, - }, - }) - } - - return filters -} - -// buildSkillFilterClauses builds ES filter clauses for skill index -// Skill index uses 'status' field instead of 'available_int' -func buildSkillFilterClauses() []map[string]interface{} { - // Filter for active skills (status = "1") - return []map[string]interface{}{ - { - "term": map[string]interface{}{ - "status": "1", - }, - }, - } -} - -// buildFilterFromMap converts a generic filter map to ES filter clauses -func buildFilterFromMap(filter map[string]interface{}) []map[string]interface{} { - var filters []map[string]interface{} - for field, value := range filter { - switch v := value.(type) { - case []string: - filters = append(filters, map[string]interface{}{ - "terms": map[string]interface{}{field: v}, - }) - case []interface{}: - filters = append(filters, map[string]interface{}{ - "terms": map[string]interface{}{field: v}, - }) - default: - filters = append(filters, map[string]interface{}{ - "term": map[string]interface{}{field: v}, - }) - } - } - return filters -} - -// buildESKeywordQuery builds keyword-only search query for ES -// Uses query_string if matchText is in query_string format, otherwise uses multi_match -// boost is applied to the text match clause (query_string or multi_match) -func buildESKeywordQuery(matchText string, filterClauses []map[string]interface{}, boost float64) map[string]interface{} { - var mustClause map[string]interface{} - - // Handle wildcard query (match all) - if matchText == "*" || matchText == "" { - mustClause = map[string]interface{}{ - "match_all": map[string]interface{}{}, - } - } else { - // Use query_string for complex queries - queryString := map[string]interface{}{ - "query": matchText, - "fields": []string{"title_tks^10", "title_sm_tks^5", "important_kwd^30", "important_tks^20", "question_tks^20", "content_ltks^2", "content_sm_ltks"}, - "type": "best_fields", - "minimum_should_match": "30%", - "boost": boost, - } - mustClause = map[string]interface{}{ - "query_string": queryString, - } - } - - return map[string]interface{}{ - "bool": map[string]interface{}{ - "must": mustClause, - "filter": filterClauses, - }, - } -} - -// buildSkillKeywordQuery builds keyword-only search query for skill index -// Skill index uses different field names: name_tks, tags_tks, description_tks, content_tks -func buildSkillKeywordQuery(matchText string, filterClauses []map[string]interface{}, boost float64) map[string]interface{} { - var mustClause map[string]interface{} - - // Handle wildcard query (match all) - if matchText == "*" || matchText == "" { - mustClause = map[string]interface{}{ - "match_all": map[string]interface{}{}, - } - } else { - // Use query_string for complex queries with skill-specific fields - queryString := map[string]interface{}{ - "query": matchText, - "fields": []string{"name_tks^10", "tags_tks^5", "description_tks^3", "content_tks^1"}, - "type": "best_fields", - "minimum_should_match": "30%", - "boost": boost, - } - mustClause = map[string]interface{}{ - "query_string": queryString, - } - } - - return map[string]interface{}{ - "bool": map[string]interface{}{ - "must": mustClause, - "filter": filterClauses, - }, - } -} - -// convertESResponse converts ES SearchResponse to unified chunks format -func convertESResponse(esResp *SearchResponse, vectorFieldName string) []map[string]interface{} { - if esResp == nil || esResp.Hits.Hits == nil { - return []map[string]interface{}{} - } - - chunks := make([]map[string]interface{}, len(esResp.Hits.Hits)) - for i, hit := range esResp.Hits.Hits { - chunks[i] = hit.Source - chunks[i]["_score"] = hit.Score - chunks[i]["_id"] = hit.ID - } - return chunks -} - -// parseOrderByExpr parses the OrderBy expression into ES sort format -func parseOrderByExpr(orderBy *types.OrderByExpr) []map[string]interface{} { - if orderBy == nil || len(orderBy.Fields) == 0 { - return nil - } - - var result []map[string]interface{} - for _, field := range orderBy.Fields { - direction := "asc" - if field.Type == types.SortDesc { - direction = "desc" - } - - if field.Field == "_score" || field.Field == "score" { - result = append(result, map[string]interface{}{ - "_score": direction, - }) - } else { - result = append(result, map[string]interface{}{ - field.Field: direction, - }) - } - } - - return result -} - -// Helper query builder functions (legacy) - -// BuildMatchTextQuery builds a text match query -func BuildMatchTextQuery(fields []string, text string, fuzziness string) map[string]interface{} { - query := map[string]interface{}{ - "multi_match": map[string]interface{}{ - "query": text, - "fields": fields, - }, - } - - if fuzziness != "" { - if multiMatch, ok := query["multi_match"].(map[string]interface{}); ok { - multiMatch["fuzziness"] = fuzziness - } - } - - return query -} - -// BuildTermQuery builds a term query -func BuildTermQuery(field string, value interface{}) map[string]interface{} { - return map[string]interface{}{ - "term": map[string]interface{}{ - field: value, - }, - } -} - -// BuildRangeQuery builds a range query -func BuildRangeQuery(field string, from, to interface{}) map[string]interface{} { - rangeQuery := make(map[string]interface{}) - if from != nil { - rangeQuery["gte"] = from - } - if to != nil { - rangeQuery["lte"] = to - } - - return map[string]interface{}{ - "range": map[string]interface{}{ - field: rangeQuery, - }, - } -} - -// BuildBoolQuery builds a bool query -func BuildBoolQuery() map[string]interface{} { - return map[string]interface{}{ - "bool": make(map[string]interface{}), - } -} - -// AddMust adds must clause to bool query -func AddMust(query map[string]interface{}, clauses ...map[string]interface{}) { - if boolQuery, ok := query["bool"].(map[string]interface{}); ok { - if _, exists := boolQuery["must"]; !exists { - boolQuery["must"] = []map[string]interface{}{} - } - if must, ok := boolQuery["must"].([]map[string]interface{}); ok { - boolQuery["must"] = append(must, clauses...) - } - } -} - -// AddShould adds should clause to bool query -func AddShould(query map[string]interface{}, clauses ...map[string]interface{}) { - if boolQuery, ok := query["bool"].(map[string]interface{}); ok { - if _, exists := boolQuery["should"]; !exists { - boolQuery["should"] = []map[string]interface{}{} - } - if should, ok := boolQuery["should"].([]map[string]interface{}); ok { - boolQuery["should"] = append(should, clauses...) - } - } -} - -// AddFilter adds filter clause to bool query -func AddFilter(query map[string]interface{}, clauses ...map[string]interface{}) { - if boolQuery, ok := query["bool"].(map[string]interface{}); ok { - if _, exists := boolQuery["filter"]; !exists { - boolQuery["filter"] = []map[string]interface{}{} - } - if filter, ok := boolQuery["filter"].([]map[string]interface{}); ok { - boolQuery["filter"] = append(filter, clauses...) - } - } -} - -// AddMustNot adds must_not clause to bool query -func AddMustNot(query map[string]interface{}, clauses ...map[string]interface{}) { - if boolQuery, ok := query["bool"].(map[string]interface{}); ok { - if _, exists := boolQuery["must_not"]; !exists { - boolQuery["must_not"] = []map[string]interface{}{} - } - if mustNot, ok := boolQuery["must_not"].([]map[string]interface{}); ok { - boolQuery["must_not"] = append(mustNot, clauses...) - } - } -} - -// GetFields is not implemented for Elasticsearch -func (e *elasticsearchEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} { - common.Warn("GetFields not implemented for Elasticsearch") - return nil -} - -// GetAggregation is not implemented for Elasticsearch -func (e *elasticsearchEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} { - common.Warn("GetAggregation not implemented for Elasticsearch") - return nil -} - -// GetHighlight is not implemented for Elasticsearch -func (e *elasticsearchEngine) GetHighlight(chunks []map[string]interface{}, keywords []string, fieldName string) map[string]string { - common.Warn("GetHighlight not implemented for Elasticsearch") - return nil -} - -// GetDocIDs is not implemented for Elasticsearch -func (e *elasticsearchEngine) GetDocIDs(chunks []map[string]interface{}) []string { - common.Warn("GetDocIDs not implemented for Elasticsearch") - return nil -} diff --git a/internal/engine/engine.go b/internal/engine/engine.go index 19112d0dd46..a37b5beaf14 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -32,26 +32,23 @@ const ( // DocEngine document storage engine interface type DocEngine interface { - // Search - Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) - - // Dataset operations - CreateDataset(ctx context.Context, indexName, datasetID string, vectorSize int, parserID string) error - InsertDataset(ctx context.Context, documents []map[string]interface{}, indexName string, knowledgebaseID string) ([]string, error) - UpdateDataset(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, tableNamePrefix string, knowledgebaseID string) error - // Chunk operations - GetChunk(ctx context.Context, indexName, chunkID string, kbIDs []string) (interface{}, error) + CreateChunkStore(ctx context.Context, baseName, datasetID string, vectorSize int, parserID string) error + InsertChunks(ctx context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) + UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error + DeleteChunks(ctx context.Context, condition map[string]interface{}, baseName string, datasetID string) (int64, error) + Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) + GetChunk(ctx context.Context, baseName, chunkID string, datasetIDs []string) (interface{}, error) + DropChunkStore(ctx context.Context, baseName, datasetID string) error + ChunkStoreExists(ctx context.Context, baseName, datasetID string) (bool, error) // Document metadata operations - CreateMetadata(ctx context.Context, indexName string) error - InsertMetadata(ctx context.Context, documents []map[string]interface{}, tenantID string) ([]string, error) - UpdateMetadata(ctx context.Context, docID string, kbID string, metaFields map[string]interface{}, tenantID string) error - - // Operations for both dataset and metadata tables - Delete(ctx context.Context, condition map[string]interface{}, indexName string, datasetID string) (int64, error) - DropTable(ctx context.Context, indexName string) error - TableExists(ctx context.Context, indexName string) (bool, error) + CreateMetadataStore(ctx context.Context, tenantID string) error + InsertMetadata(ctx context.Context, metadata []map[string]interface{}, tenantID string) ([]string, error) + UpdateMetadata(ctx context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error + DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) + DropMetadataStore(ctx context.Context, tenantID string) error + MetadataStoreExists(ctx context.Context, tenantID string) (bool, error) // Document operations (used by skill indexing) IndexDocument(ctx context.Context, indexName, docID string, doc interface{}) error diff --git a/internal/engine/infinity/chunk.go b/internal/engine/infinity/chunk.go new file mode 100644 index 00000000000..2532fef5c5c --- /dev/null +++ b/internal/engine/infinity/chunk.go @@ -0,0 +1,2038 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package infinity + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "ragflow/internal/common" + "ragflow/internal/engine/types" + "ragflow/internal/utility" + "regexp" + "slices" + "sort" + "strconv" + "strings" + "unicode" + + infinity "github.com/infiniflow/infinity-go-sdk" + "go.uber.org/zap" +) + +// CreateChunkStore creates a chunk table in Infinity +// baseName is the table name prefix (e.g., "ragflow_") +// The full table name is built as "{baseName}_{datasetID}" +// For skill index (datasetID="skill"), tableName is just baseName and uses skill_infinity_mapping.json +func (e *infinityEngine) CreateChunkStore(ctx context.Context, baseName, datasetID string, vectorSize int, parserID string) error { + vecSize := vectorSize + + // Determine table name and mapping file based on index type + var tableName string + var mappingFile string + + tableName = buildChunkTableName(baseName, datasetID) + if datasetID == "skill" { + mappingFile = "skill_infinity_mapping.json" + common.Info("Creating skill index table", zap.String("tableName", tableName), zap.String("mappingFile", mappingFile)) + } else { + mappingFile = e.mappingFileName + common.Info("Creating regular index table", zap.String("tableName", tableName), zap.String("mappingFile", mappingFile)) + } + + // Use configured schema + fpMapping := filepath.Join(utility.GetProjectRoot(), "conf", mappingFile) + + schemaData, err := os.ReadFile(fpMapping) + if err != nil { + return fmt.Errorf("Failed to read mapping file: %w", err) + } + + var schema orderedFields + if err := json.Unmarshal(schemaData, &schema); err != nil { + return fmt.Errorf("Failed to parse mapping file: %w", err) + } + + // Get database + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return fmt.Errorf("Failed to get database: %w", err) + } + + // Determine vector column name + vectorColName := fmt.Sprintf("q_%d_vec", vecSize) + + // Check if table already exists + exists, err := e.tableExists(ctx, tableName) + if err != nil { + return fmt.Errorf("Failed to check if table exists: %w", err) + } + + var table *infinity.Table + if exists { + // Table exists, open it and check if vector column needs to be added + common.Info("Table already exists, checking for vector column", zap.String("tableName", tableName)) + table, err = db.GetTable(tableName) + if err != nil { + return fmt.Errorf("Failed to open existing table %s: %w", tableName, err) + } + + // Check if vector column exists (for embedding model changes) + colExists, err := e.columnExists(table, vectorColName) + if err != nil { + common.Warn("Failed to check column existence", zap.String("column", vectorColName), zap.Error(err)) + } + + // Add new vector column if it doesn't exist (handles embedding model change) + if !colExists { + common.Info("Adding new vector column for embedding model change", zap.String("column", vectorColName), zap.Int("size", vecSize)) + addColSchema := infinity.TableSchema{ + &infinity.ColumnDefinition{ + Name: vectorColName, + DataType: fmt.Sprintf("vector,%d,float", vecSize), + }, + } + if _, err := table.AddColumns(addColSchema); err != nil { + common.Error("Failed to add vector column "+vectorColName, err) + return fmt.Errorf("Failed to add vector column %s: %w", vectorColName, err) + } + common.Info("Successfully added vector column", zap.String("column", vectorColName)) + } + } else { + // Table doesn't exist, create it with vector column in the initial schema + common.Info(fmt.Sprintf("Creating table with vector column: %s with dimension %d", vectorColName, vecSize)) + + // Build column definitions (preserving JSON order) + var columns infinity.TableSchema + for _, fieldName := range schema.Keys { + fieldInfo := schema.Fields[fieldName] + col := infinity.ColumnDefinition{ + Name: fieldName, + DataType: fieldInfo.Type, + Default: fieldInfo.Default, + } + columns = append(columns, &col) + } + + // Add vector column + columns = append(columns, &infinity.ColumnDefinition{ + Name: vectorColName, + DataType: fmt.Sprintf("vector,%d,float", vecSize), + }) + + // Add chunk_data column for table parser + if parserID == "table" { + columns = append(columns, &infinity.ColumnDefinition{ + Name: "chunk_data", + DataType: "json", + Default: "{}", + }) + } + + // Create table + table, err = db.CreateTable(tableName, columns, infinity.ConflictTypeIgnore) + if err != nil { + return fmt.Errorf("Failed to create table: %w", err) + } + common.Debug("Infinity created table", zap.String("tableName", tableName)) + } + + // Create HNSW index on vector column with unique name based on vector size + // Use unique index name to avoid conflict when embedding model changes + vectorIndexName := fmt.Sprintf("q_%d_vec_idx", vecSize) + _, err = table.CreateIndex( + vectorIndexName, + infinity.NewIndexInfo(vectorColName, infinity.IndexTypeHnsw, map[string]string{ + "M": "16", + "ef_construction": "50", + "metric": "cosine", + "encode": "lvq", + }), + infinity.ConflictTypeIgnore, + "", + ) + if err != nil { + return fmt.Errorf("Failed to create HNSW index %s: %w", vectorIndexName, err) + } + common.Info("Created vector index", zap.String("indexName", vectorIndexName), zap.String("column", vectorColName)) + + // Create full-text indexes for varchar fields with analyzers + for _, fieldName := range schema.Keys { + fieldInfo := schema.Fields[fieldName] + if fieldInfo.Type != "varchar" || fieldInfo.Analyzer == nil { + continue + } + + analyzers := []string{} + switch a := fieldInfo.Analyzer.(type) { + case string: + analyzers = []string{a} + case []interface{}: + for _, v := range a { + if s, ok := v.(string); ok { + analyzers = append(analyzers, s) + } + } + } + + for _, analyzer := range analyzers { + indexNameFt := fmt.Sprintf("ft_%s_%s", + regexp.MustCompile(`[^a-zA-Z0-9]`).ReplaceAllString(fieldName, "_"), + regexp.MustCompile(`[^a-zA-Z0-9]`).ReplaceAllString(analyzer, "_"), + ) + _, err = table.CreateIndex( + indexNameFt, + infinity.NewIndexInfo(fieldName, infinity.IndexTypeFullText, map[string]string{"ANALYZER": analyzer}), + infinity.ConflictTypeIgnore, + "", + ) + if err != nil { + return fmt.Errorf("Failed to create fulltext index %s: %w", indexNameFt, err) + } + } + } + + // Create secondary indexes for fields with index_type + for _, fieldName := range schema.Keys { + fieldInfo := schema.Fields[fieldName] + if fieldInfo.IndexType == nil { + continue + } + + indexTypeStr := "" + params := map[string]string{} + + switch it := fieldInfo.IndexType.(type) { + case string: + indexTypeStr = it + case map[string]interface{}: + if t, ok := it["type"].(string); ok { + indexTypeStr = t + } + if card, ok := it["cardinality"].(string); ok { + params["cardinality"] = card + } + } + + if indexTypeStr == "secondary" { + indexNameSec := fmt.Sprintf("sec_%s", fieldName) + _, err = table.CreateIndex( + indexNameSec, + infinity.NewIndexInfo(fieldName, infinity.IndexTypeSecondary, params), + infinity.ConflictTypeIgnore, + "", + ) + if err != nil { + return fmt.Errorf("Failed to create secondary index %s: %w", indexNameSec, err) + } + } + } + + return nil +} + +// InsertChunks inserts documents into a dataset table +// Table name format: {baseName}_{datasetID} +// Auto-create the table if it doesn't exist +// Delete existing rows with matching IDs before insert +func (e *infinityEngine) InsertChunks(ctx context.Context, chunks []map[string]interface{}, baseName string, datasetID string) ([]string, error) { + tableName := buildChunkTableName(baseName, datasetID) + common.Info("InfinityConnection.InsertChunks called", zap.String("tableName", tableName), zap.Int("chunkCount", len(chunks))) + + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return nil, fmt.Errorf("Failed to get database: %w", err) + } + + table, err := db.GetTable(tableName) + if err != nil { + // Table doesn't exist, try to create it + errMsg := strings.ToLower(err.Error()) + if !strings.Contains(errMsg, "not found") && !strings.Contains(errMsg, "doesn't exist") { + return nil, fmt.Errorf("Failed to get table %s: %w", tableName, err) + } + + // Infer vector size from chunks + vectorSize := 0 + vectorPattern := regexp.MustCompile(`q_(\d+)_vec`) + for _, chunk := range chunks { + for key := range chunk { + matches := vectorPattern.FindStringSubmatch(key) + if len(matches) >= 2 { + vectorSize, _ = strconv.Atoi(matches[1]) + break + } + } + if vectorSize > 0 { + break + } + } + if vectorSize == 0 { + return nil, fmt.Errorf("cannot infer vector size from chunks") + } + + // Determine parser_id from chunk structure + parserID := "" + if chunkData, ok := chunks[0]["chunk_data"].(map[string]interface{}); ok && chunkData != nil { + parserID = "table" + } + + // Create table + if err := e.CreateChunkStore(ctx, baseName, datasetID, vectorSize, parserID); err != nil { + return nil, fmt.Errorf("Failed to create table: %w", err) + } + + table, err = db.GetTable(tableName) + if err != nil { + return nil, fmt.Errorf("Failed to get table after creation: %w", err) + } + } + + // Get embedding columns and their sizes + var embeddingCols [][2]interface{} + colsResp, err := table.ShowColumns() + if err != nil { + return nil, fmt.Errorf("Failed to get columns: %w", err) + } + result, ok := colsResp.(*infinity.QueryResult) + if !ok { + return nil, fmt.Errorf("unexpected response type: %T", colsResp) + } + + // ShowColumns returns a result set where Data contains arrays of column values + re := regexp.MustCompile(`Embedding\([a-z]+,(\d+)\)`) + if nameArr, ok := result.Data["name"]; ok { + if typeArr, ok := result.Data["type"]; ok { + for i := 0; i < len(nameArr); i++ { + colName, _ := nameArr[i].(string) + colType, _ := typeArr[i].(string) + matches := re.FindStringSubmatch(colType) + if len(matches) >= 2 { + size, _ := strconv.Atoi(matches[1]) + embeddingCols = append(embeddingCols, [2]interface{}{colName, size}) + } + } + } + } + + // Transform chunks using helper function + insertChunks := make([]map[string]interface{}, len(chunks)) + for i, chunk := range chunks { + insertChunks[i] = transformChunkFields(chunk, embeddingCols) + } + + // Delete existing rows with matching IDs + if len(insertChunks) > 0 { + idList := make([]string, len(insertChunks)) + for i, chunk := range insertChunks { + idList[i] = fmt.Sprintf("'%v'", chunk["id"]) + } + filter := fmt.Sprintf("id IN (%s)", strings.Join(idList, ", ")) + common.Debug(fmt.Sprintf("Deleting existing rows with filter: %s", filter)) + delResp, delErr := table.Delete(filter) + if delErr != nil { + common.Warn(fmt.Sprintf("Failed to delete existing rows: %v", delErr)) + } else { + common.Info(fmt.Sprintf("Deleted %d existing rows", delResp.DeletedRows)) + } + } + + // Insert chunks to dataset + _, err = table.Insert(insertChunks) + if err != nil { + return nil, fmt.Errorf("Failed to insert chunks to dataset: %w", err) + } + + common.Info("InfinityConnection.InsertChunks result", zap.String("tableName", tableName), zap.Int("count", len(insertChunks))) + return []string{}, nil +} + +// UpdateChunks updates chunks in a dataset table +// Table name format: {baseName}_{datasetID} +func (e *infinityEngine) UpdateChunks(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, baseName string, datasetID string) error { + tableName := buildChunkTableName(baseName, datasetID) + common.Info("InfinityConnection.UpdateChunks called", zap.String("tableName", tableName), zap.Any("condition", condition)) + + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return fmt.Errorf("Failed to get database: %w", err) + } + + table, err := db.GetTable(tableName) + if err != nil { + return fmt.Errorf("Failed to get table %s: %w", tableName, err) + } + + // Get table columns + clmns := make(map[string]struct { + Type string + Default interface{} + }) + colsResp, err := table.ShowColumns() + if err != nil { + return fmt.Errorf("Failed to get columns: %w", err) + } + result, ok := colsResp.(*infinity.QueryResult) + if ok { + if nameArr, ok := result.Data["name"]; ok { + if typeArr, ok := result.Data["type"]; ok { + if defArr, ok := result.Data["default"]; ok { + for i := 0; i < len(nameArr); i++ { + colName, _ := nameArr[i].(string) + colType, _ := typeArr[i].(string) + var colDefault interface{} + if i < len(defArr) { + colDefault = defArr[i] + } + clmns[colName] = struct { + Type string + Default interface{} + }{colType, colDefault} + } + } + } + } + } + + // Build filter string from condition + filter := buildFilterFromCondition(condition, clmns) + + // Process remove operation first + removeValue := make(map[string]interface{}) + if removeData, ok := newValue["remove"].(map[string]interface{}); ok { + removeValue = removeData + } + delete(newValue, "remove") + + // Transform new_value fields using helper function (no embeddings needed for update) + transformed := transformChunkFields(newValue, nil) + for k, v := range transformed { + newValue[k] = v + } + + // Remove original fields that were transformed (they're now in transformed with new names/types) + // Also remove intermediate token fields that shouldn't be stored in Infinity + // This must match Python's delete list in infinity_conn.py + for _, key := range []string{"docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", + "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", + "question_kwd", "question_tks"} { + delete(newValue, key) + } + + // Handle remove operations if any + if len(removeValue) > 0 { + colToRemove := make([]string, 0, len(removeValue)) + for k := range removeValue { + colToRemove = append(colToRemove, k) + } + colToRemove = append(colToRemove, "id") + + // Query rows to be updated + queryResult, err := table.Output(colToRemove).Filter(filter).ToResult() + if err != nil { + common.Warn(fmt.Sprintf("Failed to query rows for remove operation: %v", err)) + } else { + qr, ok := queryResult.(*infinity.QueryResult) + if ok && len(qr.Data) > 0 { + // Get the id column and columns to remove + idCol := qr.Data["id"] + removeOpt := make(map[string]map[string][]string) // column -> value -> [ids] + + for colName, colData := range qr.Data { + if colName == "id" { + continue + } + removeVal := removeValue[colName] + for i, id := range idCol { + if i < len(colData) { + existingVal := colData[i] + if removeStr, ok := removeVal.(string); ok { + // Split existing value by ### and remove the target value + if existingStr, ok := existingVal.(string); ok { + parts := strings.Split(existingStr, "###") + var newParts []string + for _, p := range parts { + if p != removeStr { + newParts = append(newParts, p) + } + } + if len(newParts) != len(parts) { + idStr := fmt.Sprintf("'%s'", escapeFilterValue(fmt.Sprintf("%v", id))) + if removeOpt[colName] == nil { + removeOpt[colName] = make(map[string][]string) + } + removeOpt[colName][strings.Join(newParts, "###")] = append(removeOpt[colName][strings.Join(newParts, "###")], idStr) + } + } + } + } + } + } + + // Execute remove updates + for colName, valueToIDs := range removeOpt { + for newVal, ids := range valueToIDs { + idFilter := filter + " AND id IN (" + strings.Join(ids, ", ") + ")" + common.Info(fmt.Sprintf("INFINITY remove update: table=%s, idFilter=%s, column=%s, newValue=%v", tableName, idFilter, colName, newVal)) + _, err := table.Update(idFilter, map[string]interface{}{colName: newVal}) + if err != nil { + common.Warn(fmt.Sprintf("Failed to remove value from column %s: %v", colName, err)) + } + } + } + } + } + } + + // Execute the main update + common.Info(fmt.Sprintf("INFINITY update: table=%s, filter=%s, newValue=%v", tableName, filter, newValue)) + _, err = table.Update(filter, newValue) + if err != nil { + return fmt.Errorf("Failed to update chunks: %w", err) + } + + common.Info("InfinityConnection.UpdateChunks completes", zap.String("tableName", tableName)) + return nil +} + +// DeleteChunks deletes chunks from a dataset table +// Table name format: {baseName}_{datasetID} +// condition specifies which chunks to delete +func (e *infinityEngine) DeleteChunks(ctx context.Context, condition map[string]interface{}, baseName string, datasetID string) (int64, error) { + tableName := buildChunkTableName(baseName, datasetID) + + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return 0, fmt.Errorf("failed to get database: %w", err) + } + + table, err := db.GetTable(tableName) + if err != nil { + common.Warn(fmt.Sprintf("Table %s does not exist, skipping delete", tableName)) + return 0, nil + } + + // Get table columns for building filter + clmns := make(map[string]struct { + Type string + Default interface{} + }) + colsResp, err := table.ShowColumns() + if err != nil { + return 0, fmt.Errorf("failed to get columns: %w", err) + } + result, ok := colsResp.(*infinity.QueryResult) + if ok { + if nameArr, ok := result.Data["name"]; ok { + if typeArr, ok := result.Data["type"]; ok { + if defArr, ok := result.Data["default"]; ok { + for i := 0; i < len(nameArr); i++ { + colName, _ := nameArr[i].(string) + colType, _ := typeArr[i].(string) + var colDefault interface{} + if i < len(defArr) { + colDefault = defArr[i] + } + clmns[colName] = struct { + Type string + Default interface{} + }{colType, colDefault} + } + } + } + } + } + + // Build filter from condition + filter := buildFilterFromCondition(condition, clmns) + + delResp, err := table.Delete(filter) + if err != nil { + return 0, fmt.Errorf("failed to delete: %w", err) + } + + return delResp.DeletedRows, nil +} + +// Search searches the Infinity engine for matching chunks. +// It supports three matching types: MatchTextExpr (full-text), MatchDenseExpr (vector), and FusionExpr (combined). +// If no match expressions are provided, Search relies solely on filter (e.g., doc_id, available_int) to find results. +func (e *infinityEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { + common.Debug("Search in Infinity started", zap.Any("indexNames", req.IndexNames)) + if common.IsDebugEnabled() { + // Format match expressions for logging + var matchExprsStr string + for i, expr := range req.MatchExprs { + switch e := expr.(type) { + case *types.MatchTextExpr: + matchExprsStr += fmt.Sprintf(" [%d] MatchTextExpr: fields=%v, matchingText=%s, topN=%d, extraOptions=%v\n", i, e.Fields, e.MatchingText, e.TopN, e.ExtraOptions) + case *types.MatchDenseExpr: + matchExprsStr += fmt.Sprintf(" [%d] MatchDenseExpr: vectorColumn=%s, vectorSize=%d, topN=%d, extraOptions=%v\n", i, e.VectorColumnName, len(e.EmbeddingData), e.TopN, e.ExtraOptions) + case *types.FusionExpr: + matchExprsStr += fmt.Sprintf(" [%d] FusionExpr: method=%s, topN=%d, fusionParams=%v\n", i, e.Method, e.TopN, e.FusionParams) + default: + matchExprsStr += fmt.Sprintf(" [%d] unknown type\n", i) + } + } + common.Debug(fmt.Sprintf("Search request:\n"+ + " indexNames=%v\n"+ + " KbIDs=%v\n"+ + " offset=%d, limit=%d\n"+ + " SelectFields=%v\n"+ + " Filter=%v\n"+ + " MatchExprs:\n%s orderBy=%v\n"+ + " RankFeature=%v", + req.IndexNames, req.KbIDs, req.Offset, req.Limit, req.SelectFields, req.Filter, matchExprsStr, req.OrderBy, req.RankFeature)) + } + + if len(req.IndexNames) == 0 { + return nil, fmt.Errorf("index names cannot be empty") + } + + // Get retrieval parameters with defaults + pageSize := req.Limit + if pageSize <= 0 { + pageSize = 30 + } + + offset := req.Offset + if offset < 0 { + offset = 0 + } + + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return nil, fmt.Errorf("failed to get database: %w", err) + } + + isMetadataTable := false + isSkillIndex := false + for _, idx := range req.IndexNames { + if strings.HasPrefix(idx, "ragflow_doc_meta_") { + isMetadataTable = true + break + } + if strings.HasPrefix(idx, "skill_") { + isSkillIndex = true + break + } + } + + var outputColumns []string + if isMetadataTable { + outputColumns = []string{"id", "kb_id", "meta_fields"} + } else if isSkillIndex { + outputColumns = []string{ + "skill_id", "space_id", "folder_id", "name", "tags", "description", "content", + "version", "status", "create_time", "update_time", + } + outputColumns = convertSelectFields(outputColumns, true) + } else { + outputColumns = []string{ + "id", "doc_id", "kb_id", "content_ltks", "content_with_weight", + "title_tks", "docnm_kwd", "img_id", "available_int", "important_kwd", + "position_int", "page_num_int", "top_int", "chunk_order_int", + "create_timestamp_flt", "knowledge_graph_kwd", "question_kwd", "question_tks", + "doc_type_kwd", "mom_id", "tag_kwd", "pagerank_fea", "tag_feas", + } + outputColumns = convertSelectFields(outputColumns) + } + + hasTextMatch := false + hasVectorMatch := false + var matchText *types.MatchTextExpr + var matchDense *types.MatchDenseExpr + if req.MatchExprs != nil && len(req.MatchExprs) > 0 { + for _, expr := range req.MatchExprs { + if expr == nil { + continue + } + switch e := expr.(type) { + case string: + if e != "" { + hasTextMatch = true + matchText = &types.MatchTextExpr{ + MatchingText: e, + TopN: pageSize, + } + } + case *types.MatchTextExpr: + if e.MatchingText != "" { + hasTextMatch = true + matchText = e + } + case *types.MatchDenseExpr: + if len(e.EmbeddingData) > 0 { + hasVectorMatch = true + matchDense = e + } + } + } + } + + if hasTextMatch || hasVectorMatch { + if hasTextMatch { + outputColumns = append(outputColumns, "score()") + } + // similarity() is only allowed by Infinity when there is ONLY MATCH VECTOR. + // When both text and vector matches exist (hybrid search with Fusion), + // only score() is valid — Fusion produces a unified SCORE column. + if hasVectorMatch && !hasTextMatch { + outputColumns = append(outputColumns, "similarity()") + } + // Skill index does not have pagerank_fea and tag_feas columns + if !isSkillIndex { + if !slices.Contains(outputColumns, common.PAGERANK_FLD) { + outputColumns = append(outputColumns, common.PAGERANK_FLD) + } + if !slices.Contains(outputColumns, common.TAG_FLD) { + outputColumns = append(outputColumns, common.TAG_FLD) + } + } + } + + if !slices.Contains(outputColumns, "row_id") && !slices.Contains(outputColumns, "row_id()") { + outputColumns = append(outputColumns, "row_id()") + } + + outputColumns = convertSelectFields(outputColumns, isSkillIndex) + if hasVectorMatch && matchDense != nil && matchDense.VectorColumnName != "" { + outputColumns = append(outputColumns, matchDense.VectorColumnName) + } + + var filterParts []string + if isMetadataTable && len(req.KbIDs) > 0 && req.KbIDs[0] != "" { + kbIDs := req.KbIDs + if len(kbIDs) == 1 { + filterParts = append(filterParts, fmt.Sprintf("kb_id = '%s'", kbIDs[0])) + } else { + kbIDStr := strings.Join(kbIDs, "', '") + filterParts = append(filterParts, fmt.Sprintf("kb_id IN ('%s')", kbIDStr)) + } + } + + if !isMetadataTable && (hasTextMatch || hasVectorMatch) { + if req.Filter != nil { + if availInt, ok := req.Filter["available_int"]; ok { + filterParts = append(filterParts, fmt.Sprintf("available_int=%v", availInt)) + } else if status, ok := req.Filter["status"]; ok { + filterParts = append(filterParts, fmt.Sprintf("status='%s'", status)) + } else { + if isSkillIndex { + filterParts = append(filterParts, "status='1'") + } else { + filterParts = append(filterParts, "available_int=1") + } + } + } else { + if isSkillIndex { + filterParts = append(filterParts, "status='1'") + } else { + filterParts = append(filterParts, "available_int=1") + } + } + } + + // Build filter string from req.Filter + if req.Filter != nil { + filterCopy := req.Filter + if !isMetadataTable { + filterCopy = make(map[string]interface{}) + for k, v := range req.Filter { + if k != "kb_id" { + filterCopy[k] = v + } + } + } + + condStr := equivalentConditionToStr(filterCopy) + if condStr != "" { + filterParts = append(filterParts, condStr) + } + } + filterStr := strings.Join(filterParts, " AND ") + + orderBy := req.OrderBy + var rankFeature map[string]float64 + if req.RankFeature != nil { + rankFeature = req.RankFeature + } + + var fusionExpr *types.FusionExpr + if len(req.MatchExprs) > 2 { + if fe, ok := req.MatchExprs[2].(*types.FusionExpr); ok { + fusionExpr = fe + } + } + + var allResults []map[string]interface{} + totalHits := int64(0) + + for _, indexName := range req.IndexNames { + var tableNames []string + if strings.HasPrefix(indexName, "ragflow_doc_meta_") { + tableNames = []string{indexName} + } else { + kbIDs := req.KbIDs + if len(kbIDs) == 0 { + kbIDs = []string{""} + } + for _, kbID := range kbIDs { + if kbID == "" { + tableNames = append(tableNames, indexName) + } else { + tableNames = append(tableNames, fmt.Sprintf("%s_%s", indexName, kbID)) + } + } + } + + minMatch := 0.3 + + var questionText string + var vectorData []float64 + textTopN := pageSize + var originalQuery string + if matchText != nil { + questionText = matchText.MatchingText + textTopN = int(matchText.TopN) + if matchText.ExtraOptions != nil { + if oq, ok := matchText.ExtraOptions["original_query"].(string); ok { + originalQuery = oq + } + } + } + if matchDense != nil { + vectorData = matchDense.EmbeddingData + } + + for _, tableName := range tableNames { + tbl, err := db.GetTable(tableName) + if err != nil { + continue + } + table := tbl.Output(outputColumns) + + var textFields []string + if matchText != nil && len(matchText.Fields) > 0 { + textFields = matchText.Fields + } else if isSkillIndex { + textFields = []string{ + "name^10", + "tags^5", + "description^3", + "content^1", + } + } else { + textFields = []string{ + "title_tks^10", + "title_sm_tks^5", + "important_kwd^30", + "important_tks^20", + "question_tks^20", + "content_ltks^2", + "content_sm_ltks", + } + } + + // Convert field names for Infinity + var convertedFields []string + for _, f := range textFields { + cf := convertMatchingField(f) + convertedFields = append(convertedFields, cf) + } + fields := strings.Join(convertedFields, ",") + + hasTextMatch := questionText != "" + hasVectorMatch := len(vectorData) > 0 + // Add text match if question is provided + if hasTextMatch { + extraOptions := map[string]string{ + "minimum_should_match": fmt.Sprintf("%d%%", int(minMatch*100)), + } + + if filterStr != "" { + extraOptions["filter"] = filterStr + } + + if rankFeature != nil { + var rankFeaturesList []string + for featureName, weight := range rankFeature { + rankFeaturesList = append(rankFeaturesList, fmt.Sprintf("%s^%s^%.0f", common.TAG_FLD, featureName, weight)) + } + if len(rankFeaturesList) > 0 { + extraOptions["rank_features"] = strings.Join(rankFeaturesList, ",") + } + } + + if originalQuery != "" { + extraOptions["original_query"] = originalQuery + } + + table = table.MatchText(fields, questionText, textTopN, extraOptions) + + common.Debug(fmt.Sprintf( + "MatchTextExpr:\n"+ + " fields=%s\n"+ + " matching_text=%s\n"+ + " topn=%d\n"+ + " extra_options=%v", + fields, questionText, textTopN, extraOptions, + )) + } + + // Add vector match if provided + if hasVectorMatch { + vecFieldName := fmt.Sprintf("q_%d_vec", len(vectorData)) + dataType := "float" + distanceType := "cosine" + + if matchDense != nil { + if matchDense.VectorColumnName != "" { + vecFieldName = matchDense.VectorColumnName + } + if matchDense.EmbeddingDataType != "" { + dataType = matchDense.EmbeddingDataType + } + if matchDense.DistanceType != "" { + distanceType = matchDense.DistanceType + } + } + + vectorTopN := pageSize + if matchDense != nil && matchDense.TopN > 0 { + vectorTopN = int(matchDense.TopN) + } + + denseFilterStr := filterStr + if denseFilterStr == "" { + if isSkillIndex { + denseFilterStr = "status='1'" + } else { + denseFilterStr = "available_int=1" + } + } + + if hasTextMatch && fusionExpr == nil { + fieldsStr := strings.Join(convertedFields, ",") + filterFulltext := fmt.Sprintf("filter_fulltext('%s', '%s')", fieldsStr, questionText) + denseFilterStr = fmt.Sprintf("(%s) AND %s", denseFilterStr, filterFulltext) + } + extraOptions := map[string]string{ + "threshold": utility.FloatToString(0.0), + "filter": denseFilterStr, + } + + common.Debug("MatchDense for hybrid search", + zap.String("fieldName", vecFieldName), + zap.String("distanceType", distanceType), + zap.Int("topN", vectorTopN), + zap.Bool("hasFusion", fusionExpr != nil)) + + table = table.MatchDense(vecFieldName, vectorData, dataType, distanceType, vectorTopN, extraOptions) + } + + // Add fusion (for text + vector combination) + if hasTextMatch && hasVectorMatch && fusionExpr != nil { + fusionMethod := fusionExpr.Method + fusionTopK := fusionExpr.TopN + if fusionTopK == 0 { + fusionTopK = pageSize + } + fusionParams := map[string]interface{}{ + "normalize": "atan", + } + if fusionExpr.FusionParams != nil { + for k, v := range fusionExpr.FusionParams { + fusionParams[k] = v + } + } + + common.Debug("Applying Fusion for hybrid search", + zap.String("method", fusionMethod), + zap.Int("topN", fusionTopK), + zap.Any("params", fusionParams)) + + table = table.Fusion(fusionMethod, fusionTopK, fusionParams) + } + + // Add order_by if provided + if orderBy != nil && len(orderBy.Fields) > 0 { + var sortFields [][2]interface{} + for _, orderField := range orderBy.Fields { + sortType := infinity.SortTypeAsc + if orderField.Type == types.SortDesc { + sortType = infinity.SortTypeDesc + } + sortFields = append(sortFields, [2]interface{}{orderField.Field, sortType}) + } + table = table.Sort(sortFields) + } + + // Add filter when there's no text/vector match (like metadata queries) + if !hasTextMatch && !hasVectorMatch && filterStr != "" { + common.Debug(fmt.Sprintf("Adding filter for no-match query: %s", filterStr)) + table = table.Filter(filterStr) + } + + // Set limit and offset + table = table.Limit(pageSize) + if offset > 0 { + table = table.Offset(offset) + } + + // Request total_hits_count from Infinity + table = table.Option(map[string]interface{}{"total_hits_count": true}) + + // Execute query + df, err := table.ToDataFrame() + if err != nil { + common.Warn("Infinity query failed", + zap.String("tableName", tableName), + zap.Bool("hasTextMatch", hasTextMatch), + zap.Bool("hasVectorMatch", hasVectorMatch), + zap.Bool("hasFusion", fusionExpr != nil), + zap.Error(err)) + continue + } + + // Convert DataFrame to chunks format (column-oriented to row-oriented) + searchChunks := make([]map[string]interface{}, 0) + for colName, colData := range df.ColumnData { + for i, val := range colData { + for len(searchChunks) <= i { + searchChunks = append(searchChunks, make(map[string]interface{})) + } + searchChunks[i][colName] = val + } + } + + // Apply field name mapping and row_id handling + // Skill index uses different schema + // so we skip the document-specific field mappings + if !isSkillIndex { + GetFields(searchChunks, nil) + } else { + // For skill index, only handle ROW_ID -> row_id() mapping + for _, chunk := range searchChunks { + if val, ok := chunk["ROW_ID"]; ok { + chunk["row_id()"] = val + delete(chunk, "ROW_ID") + } + } + } + + // Parse total_hits_count from ExtraInfo + var tableTotal int64 + if df.ExtraInfo != "" { + var extraResult map[string]interface{} + if err := json.Unmarshal([]byte(df.ExtraInfo), &extraResult); err == nil { + if count, ok := extraResult["total_hits_count"].(float64); ok { + tableTotal = int64(count) + } + } + } + + searchResult := &types.SearchResult{ + Chunks: searchChunks, + Total: tableTotal, + } + + allResults = append(allResults, searchResult.Chunks...) + totalHits += searchResult.Total + } + } + + if hasTextMatch || hasVectorMatch { + scoreColumn := "" + if hasTextMatch && hasVectorMatch { + scoreColumn = "SCORE" + } else if hasTextMatch { + scoreColumn = "SCORE" + } else if hasVectorMatch { + scoreColumn = "SIMILARITY" + } + pagerankField := common.PAGERANK_FLD + if isSkillIndex { + pagerankField = "" // Skill index has no pagerank field + } + + allResults = calculateScores(allResults, scoreColumn, pagerankField) + allResults = sortByScore(allResults, len(allResults)) + } + + if len(allResults) > pageSize { + allResults = allResults[:pageSize] + } + + common.Debug("Search in Infinity completed", zap.Int("returnedRows", len(allResults)), zap.Int64("totalHits", totalHits)) + + return &types.SearchResult{ + Chunks: allResults, + Total: totalHits, + }, nil +} + +// GetChunk gets a chunk by ID +func (e *infinityEngine) GetChunk(ctx context.Context, tableName, chunkID string, datasetIDs []string) (interface{}, error) { + if e.client == nil || e.client.conn == nil { + return nil, fmt.Errorf("Infinity client not initialized") + } + + // Build list of table names to search + var tableNames []string + if strings.HasPrefix(tableName, "ragflow_doc_meta_") { + tableNames = []string{tableName} + } else { + // Search in tables like _ for each datasetID + if len(datasetIDs) > 0 { + for _, datasetID := range datasetIDs { + tableNames = append(tableNames, fmt.Sprintf("%s_%s", tableName, datasetID)) + } + } + // Also try the base tableName + tableNames = append(tableNames, tableName) + } + + // Try each table and collect results from all tables + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return nil, fmt.Errorf("failed to get database: %w", err) + } + + // Collect chunks from all tables (same as Python's concat_dataframes) + allChunks := make(map[string]map[string]interface{}) + + for _, tblName := range tableNames { + table, err := db.GetTable(tblName) + if err != nil { + continue + } + + // Query with filter for the specific chunk ID + filter := fmt.Sprintf("id = '%s'", chunkID) + result, err := table.Output([]string{"*"}).Filter(filter).ToResult() + if err != nil { + continue + } + + qr, ok := result.(*infinity.QueryResult) + if !ok { + continue + } + + if len(qr.Data) == 0 { + continue + } + + // Convert to chunk format + chunks := make([]map[string]interface{}, 0) + for colName, colData := range qr.Data { + for i, val := range colData { + for len(chunks) <= i { + chunks = append(chunks, make(map[string]interface{})) + } + chunks[i][colName] = val + } + } + + // Merge chunks into allChunks (by id), keeping first non-empty value + for _, chunk := range chunks { + if idVal, ok := chunk["id"].(string); ok { + if existing, exists := allChunks[idVal]; exists { + // Merge: keep first non-empty value for each field + for k, v := range chunk { + if _, has := existing[k]; !has || (utility.IsEmpty(existing[k]) && !utility.IsEmpty(v)) { + existing[k] = v + } + } + } else { + allChunks[idVal] = chunk + } + } + } + } + + // Get the chunk by chunkID + chunk, found := allChunks[chunkID] + if !found { + return nil, nil + } + + common.Debug("infinity get chunk", zap.String("chunkID", chunkID), zap.Any("tables", tableNames)) + + // Apply field mappings (same as in GetFields) + // docnm -> docnm_kwd, title_tks, title_sm_tks + if val, ok := chunk["docnm"].(string); ok { + chunk["docnm_kwd"] = val + chunk["title_tks"] = val + chunk["title_sm_tks"] = val + } + + // content -> content_with_weight, content_ltks, content_sm_ltks + if val, ok := chunk["content"].(string); ok { + chunk["content_with_weight"] = val + chunk["content_ltks"] = val + chunk["content_sm_ltks"] = val + } + + // important_keywords -> important_kwd (split by comma), important_tks + if val, ok := chunk["important_keywords"].(string); ok { + if val == "" { + chunk["important_kwd"] = []interface{}{} + } else { + parts := strings.Split(val, ",") + chunk["important_kwd"] = parts + } + chunk["important_tks"] = val + } else { + chunk["important_kwd"] = []interface{}{} + chunk["important_tks"] = []interface{}{} + } + + // questions -> question_kwd (split by newline), question_tks + if val, ok := chunk["questions"].(string); ok { + if val == "" { + chunk["question_kwd"] = []interface{}{} + } else { + parts := strings.Split(val, "\n") + chunk["question_kwd"] = parts + } + chunk["question_tks"] = val + } else { + chunk["question_kwd"] = []interface{}{} + chunk["question_tks"] = []interface{}{} + } + + if posVal, ok := chunk["position_int"].(string); ok { + chunk["position_int"] = utility.ConvertHexToPositionIntArray(posVal) + } else { + chunk["position_int"] = []interface{}{} + } + + return chunk, nil +} + +// GetFields applies field mappings to chunks and returns a dict keyed by chunk ID. +// Equivalent to Python's get_fields() in infinity_conn.py. +// When fields is nil/empty, returns all fields from chunks. +func GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} { + result := make(map[string]map[string]interface{}) + if len(chunks) == 0 { + return result + } + + // If fields is provided, create a set for lookup + fieldSet := make(map[string]bool) + for _, f := range fields { + fieldSet[f] = true + } + + for _, chunk := range chunks { + // Apply field mappings + // docnm -> docnm_kwd, title_tks, title_sm_tks + if val, ok := chunk["docnm"].(string); ok { + chunk["docnm_kwd"] = val + chunk["title_tks"] = val + chunk["title_sm_tks"] = val + } + + // important_keywords -> important_kwd (split by comma), important_tks + if val, ok := chunk["important_keywords"].(string); ok { + if val == "" { + chunk["important_kwd"] = []interface{}{} + } else { + parts := strings.Split(val, ",") + chunk["important_kwd"] = parts + } + chunk["important_tks"] = val + } else { + chunk["important_kwd"] = []interface{}{} + chunk["important_tks"] = []interface{}{} + } + + // questions -> question_kwd (split by newline), question_tks + if val, ok := chunk["questions"].(string); ok { + if val == "" { + chunk["question_kwd"] = []interface{}{} + } else { + parts := strings.Split(val, "\n") + chunk["question_kwd"] = parts + } + chunk["question_tks"] = val + } else { + chunk["question_kwd"] = []interface{}{} + chunk["question_tks"] = []interface{}{} + } + + // content -> content_with_weight, content_ltks, content_sm_ltks + if val, ok := chunk["content"].(string); ok { + chunk["content_with_weight"] = val + chunk["content_ltks"] = val + chunk["content_sm_ltks"] = val + } + + // authors -> authors_tks, authors_sm_tks + if val, ok := chunk["authors"].(string); ok { + chunk["authors_tks"] = val + chunk["authors_sm_tks"] = val + } + + // position_int: convert from hex string to array format (grouped by 5) + if val, ok := chunk["position_int"].(string); ok { + chunk["position_int"] = utility.ConvertHexToPositionIntArray(val) + } + + // Convert page_num_int and top_int from hex string to array + for _, colName := range []string{"page_num_int", "top_int"} { + if val, ok := chunk[colName].(string); ok && val != "" { + chunk[colName] = utility.ConvertHexToIntArray(val) + } + } + + // Post-process: convert nil/empty values to empty slices for array-like fields + // and split _kwd fields by "###" (except knowledge_graph_kwd, docnm_kwd, important_kwd, question_kwd) + kwdNoSplit := map[string]bool{ + "knowledge_graph_kwd": true, "docnm_kwd": true, + "important_kwd": true, "question_kwd": true, + } + arrayFields := []string{ + "doc_type_kwd", "important_kwd", "important_tks", "question_tks", + "question_kwd", "authors_tks", "authors_sm_tks", "title_tks", + "title_sm_tks", "content_ltks", "content_sm_ltks", "tag_kwd", + } + for _, colName := range arrayFields { + val, ok := chunk[colName] + if !ok || val == nil || val == "" { + chunk[colName] = []interface{}{} + } else if !kwdNoSplit[colName] { + // Split by "###" for _kwd fields + if strVal, ok := val.(string); ok && strings.Contains(strVal, "###") { + parts := strings.Split(strVal, "###") + var filtered []interface{} + for _, p := range parts { + if p != "" { + filtered = append(filtered, p) + } + } + chunk[colName] = filtered + } + } + } + + // Handle row_id mapping - Infinity returns "ROW_ID" but we use "row_id()" + if val, ok := chunk["ROW_ID"]; ok { + chunk["row_id()"] = val + delete(chunk, "ROW_ID") + } + + // Build result map keyed by id + if id, ok := chunk["id"].(string); ok { + fieldMap := make(map[string]interface{}) + for field, value := range chunk { + if len(fieldSet) == 0 || fieldSet[field] { + fieldMap[field] = value + } + } + result[id] = fieldMap + } + } + + return result +} + +// GetFields is a method wrapper for infinityEngine to satisfy DocEngine interface +func (e *infinityEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} { + return GetFields(chunks, fields) +} + +// GetAggregation aggregates chunk values by field name. +// Input: [{"docnm_kwd": "docA"}, {"docnm_kwd": "docA"}, {"docnm_kwd": "docB"}] +// +// GetAggregation(chunks, "docnm_kwd") returns: +// +// [{"key": "docA", "count": 2}, {"key": "docB", "count": 1}] +// +// For tag_kwd field, splits values by "###" separator. +// For other fields, uses comma separation. +func (e *infinityEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} { + if len(chunks) == 0 { + return []map[string]interface{}{} + } + + // Check if field exists in first chunk + hasField := false + for _, chunk := range chunks { + if _, ok := chunk[fieldName]; ok { + hasField = true + break + } + } + if !hasField { + return []map[string]interface{}{} + } + + // Count occurrences + tagCounts := make(map[string]int) + for _, chunk := range chunks { + value, ok := chunk[fieldName] + if !ok || value == nil { + continue + } + + // Handle string value + if valueStr, ok := value.(string); ok { + if valueStr == "" { + continue + } + + var tags []string + // Split by "###" for tag_kwd field + if fieldName == "tag_kwd" && strings.Contains(valueStr, "###") { + for _, tag := range strings.Split(valueStr, "###") { + tag = strings.TrimSpace(tag) + if tag != "" { + tags = append(tags, tag) + } + } + } else { + // Fallback to comma separation + for _, tag := range strings.Split(valueStr, ",") { + tag = strings.TrimSpace(tag) + if tag != "" { + tags = append(tags, tag) + } + } + } + + for _, tag := range tags { + tagCounts[tag]++ + } + continue + } + + // Handle list value + if valueList, ok := value.([]interface{}); ok { + for _, item := range valueList { + if itemStr, ok := item.(string); ok { + tag := strings.TrimSpace(itemStr) + if tag != "" { + tagCounts[tag]++ + } + } + } + } + } + + if len(tagCounts) == 0 { + return []map[string]interface{}{} + } + + // Convert to slice and sort by count descending + type tagCountPair struct { + tag string + count int + } + pairs := make([]tagCountPair, 0, len(tagCounts)) + for tag, count := range tagCounts { + pairs = append(pairs, tagCountPair{tag, count}) + } + sort.Slice(pairs, func(i, j int) bool { + return pairs[i].count > pairs[j].count + }) + + // Convert to []map[string]interface{} directly + result := make([]map[string]interface{}, len(pairs)) + for i, p := range pairs { + result[i] = map[string]interface{}{"key": p.tag, "count": p.count} + } + + return result +} + +// GetDocIDs extracts document IDs from search results. +// Extracts "id" field from each chunk and returns as a list. +func (e *infinityEngine) GetDocIDs(chunks []map[string]interface{}) []string { + if len(chunks) == 0 { + return nil + } + ids := make([]string, 0, len(chunks)) + for _, chunk := range chunks { + if id, ok := chunk["id"].(string); ok { + ids = append(ids, id) + } + } + return ids +} + +// GetHighlight generates highlighted text snippets for search results. +// Matches keywords in text and wraps them with tags. +func (e *infinityEngine) GetHighlight(chunks []map[string]interface{}, keywords []string, fieldName string) map[string]string { + result := make(map[string]string) + if len(chunks) == 0 || len(keywords) == 0 { + return result + } + + // Check if field exists + hasField := false + for _, chunk := range chunks { + if _, ok := chunk[fieldName]; ok { + hasField = true + break + } + } + if !hasField { + // Try alternative field names + if fieldName == "content_with_weight" { + if _, ok := chunks[0]["content"]; ok { + fieldName = "content" + hasField = true + } + } + } + if !hasField { + return result + } + + emTag := regexp.MustCompile(`[^<>]+`) + + for _, chunk := range chunks { + id := "" + if idVal, ok := chunk["id"].(string); ok { + id = idVal + } + + txt, ok := chunk[fieldName].(string) + if !ok || txt == "" { + continue + } + + // Check if already highlighted + if emTag.MatchString(txt) { + result[id] = txt + continue + } + + // Replace newlines with spaces + txt = regexp.MustCompile(`[\r\n]`).ReplaceAllString(txt, " ") + + // Split by sentence delimiters + delimiters := regexp.MustCompile(`[.?!;\n]`) + segments := delimiters.Split(txt, -1) + + var highlightedSegments []string + for _, segment := range segments { + // Check if segment is English or contains keywords + englishCount := 0 + totalCount := 0 + for _, r := range segment { + if unicode.IsLetter(r) { + totalCount++ + if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { + englishCount++ + } + } + } + isEnglish := totalCount > 0 && float64(englishCount)/float64(totalCount) > 0.5 + segmentToCheck := segment + if isEnglish { + // For English: match whole words with boundaries + for _, kw := range keywords { + re := regexp.MustCompile(`(^|[ .?/'\"\(\)!,:;-])` + regexp.QuoteMeta(kw) + `([ .?/'\"\(\)!,:;-]|$)`) + segmentToCheck = re.ReplaceAllString(segmentToCheck, "$1"+kw+"$2") + } + } else { + // For non-English: simple substring match + for _, kw := range keywords { + segmentToCheck = strings.ReplaceAll(segmentToCheck, kw, ""+kw+"") + } + } + if strings.Contains(segmentToCheck, "") { + highlightedSegments = append(highlightedSegments, segmentToCheck) + } + } + + if len(highlightedSegments) > 0 { + result[id] = strings.Join(highlightedSegments, "...") + } + } + + return result +} + +// convertSelectFields converts field names to Infinity format +// isSkillIndex indicates if this is a skill index (uses skill_id instead of id) +func convertSelectFields(output []string, isSkillIndex ...bool) []string { + fieldMapping := map[string]string{ + "docnm_kwd": "docnm", + "title_tks": "docnm", + "title_sm_tks": "docnm", + "important_kwd": "important_keywords", + "important_tks": "important_keywords", + "question_kwd": "questions", + "question_tks": "questions", + "content_with_weight": "content", + "content_ltks": "content", + "content_sm_ltks": "content", + "authors_tks": "authors", + "authors_sm_tks": "authors", + } + + skillIndex := false + if len(isSkillIndex) > 0 { + skillIndex = isSkillIndex[0] + } + + needEmptyCount := false + for i, field := range output { + if field == "important_kwd" { + needEmptyCount = true + } + if newField, ok := fieldMapping[field]; ok { + output[i] = newField + } + } + + // Remove duplicates + seen := make(map[string]bool) + result := []string{} + for _, f := range output { + if f != "" && !seen[f] { + seen[f] = true + result = append(result, f) + } + } + + // Add id and empty count if needed + // For skill index, use skill_id instead of id + hasID := false + idField := "id" + if skillIndex { + idField = "skill_id" + } + for _, f := range result { + if f == idField { + hasID = true + break + } + } + if !hasID { + result = append([]string{idField}, result...) + } + + if needEmptyCount { + result = append(result, "important_kwd_empty_count") + } + + return result +} + +// convertMatchingField converts field names for matching +// For regular document indices: maps _tks/_kwd fields to column@index_name format +// For skill indices: maps raw field names to column@index_name format +// Infinity requires column@index_name when a column has multiple full-text indexes +func convertMatchingField(fieldWeightStr string) string { + // Split on ^ to get field name + parts := strings.Split(fieldWeightStr, "^") + field := parts[0] + + // Field name conversion + fieldMapping := map[string]string{ + "docnm_kwd": "docnm@ft_docnm_rag_coarse", + "title_tks": "docnm@ft_docnm_rag_coarse", + "title_sm_tks": "docnm@ft_docnm_rag_fine", + "important_kwd": "important_keywords@ft_important_keywords_rag_coarse", + "important_tks": "important_keywords@ft_important_keywords_rag_fine", + "question_kwd": "questions@ft_questions_rag_coarse", + "question_tks": "questions@ft_questions_rag_fine", + "content_with_weight": "content@ft_content_rag_coarse", + "content_ltks": "content@ft_content_rag_coarse", + "content_sm_ltks": "content@ft_content_rag_fine", + "authors_tks": "authors@ft_authors_rag_coarse", + "authors_sm_tks": "authors@ft_authors_rag_fine", + "tag_kwd": "tag_kwd@ft_tag_kwd_whitespace__", + // Skill index fields + "name": "name@ft_name_rag_coarse", + "tags": "tags@ft_tags_rag_coarse", + "description": "description@ft_description_rag_coarse", + "content": "content@ft_content_rag_coarse", + } + + if newField, ok := fieldMapping[field]; ok { + parts[0] = newField + } + + return strings.Join(parts, "^") +} + +// escapeFilterValue escapes single quotes for filter values +func escapeFilterValue(s string) string { + return strings.ReplaceAll(s, "'", "''") +} + +// equivalentConditionToStr converts a condition map to an Infinity filter string +func equivalentConditionToStr(condition map[string]interface{}) string { + if len(condition) == 0 { + return "" + } + + var cond []string + + for k, v := range condition { + if k == "_id" || utility.IsEmpty(v) { + continue + } + + // Handle must_not specially + if k == "must_not" { + if m, ok := v.(map[string]interface{}); ok { + for kk, vv := range m { + if kk == "exists" { + // For must_not exists, use !='' since we don't have table schema + cond = append(cond, fmt.Sprintf("NOT (%v!='')", vv)) + } + } + } + continue + } + + // Handle exists specially (without table schema, use string comparison) + if k == "exists" { + cond = append(cond, fmt.Sprintf("%v!=''", v)) + continue + } + + // Handle keyword fields (using full-text filter) + if fieldKeyword(k) { + // For keyword fields, values are always treated as strings for filter_fulltext + switch val := v.(type) { + case []string: + var inCond []string + for _, item := range val { + inCond = append(inCond, fmt.Sprintf("filter_fulltext('%s', '%s')", + convertMatchingField(k), escapeFilterValue(item))) + } + if len(inCond) > 0 { + cond = append(cond, "("+strings.Join(inCond, " or ")+")") + } + case []interface{}: + var inCond []string + for _, item := range val { + if s, ok := item.(string); ok { + inCond = append(inCond, fmt.Sprintf("filter_fulltext('%s', '%s')", + convertMatchingField(k), escapeFilterValue(s))) + } else { + inCond = append(inCond, fmt.Sprintf("filter_fulltext('%s', '%s')", + convertMatchingField(k), escapeFilterValue(fmt.Sprintf("%v", item)))) + } + } + if len(inCond) > 0 { + cond = append(cond, "("+strings.Join(inCond, " or ")+")") + } + case string: + cond = append(cond, fmt.Sprintf("filter_fulltext('%s', '%s')", + convertMatchingField(k), escapeFilterValue(val))) + default: + cond = append(cond, fmt.Sprintf("filter_fulltext('%s', '%s')", + convertMatchingField(k), escapeFilterValue(fmt.Sprintf("%v", v)))) + } + continue + } + + // Handle list values (mixed types - strings get quotes, numbers don't) + if list, ok := v.([]interface{}); ok && len(list) > 0 { + var strItems, numItems []string + for _, item := range list { + if s, ok := item.(string); ok { + strItems = append(strItems, fmt.Sprintf("'%s'", escapeFilterValue(s))) + } else if n, ok := item.(int); ok { + numItems = append(numItems, strconv.Itoa(n)) + } else if n, ok := item.(int64); ok { + numItems = append(numItems, strconv.FormatInt(n, 10)) + } else if f, ok := item.(float64); ok { + numItems = append(numItems, strconv.FormatFloat(f, 'f', -1, 64)) + } else if s, ok := item.(fmt.Stringer); ok { + strItems = append(strItems, fmt.Sprintf("'%s'", escapeFilterValue(s.String()))) + } else { + strItems = append(strItems, fmt.Sprintf("'%s'", escapeFilterValue(fmt.Sprintf("%v", item)))) + } + } + if len(strItems) > 0 { + if len(strItems) == 1 { + cond = append(cond, fmt.Sprintf("%s=%s", k, strItems[0])) + } else { + cond = append(cond, fmt.Sprintf("%s IN (%s)", k, strings.Join(strItems, ", "))) + } + } + if len(numItems) > 0 { + if len(numItems) == 1 { + cond = append(cond, fmt.Sprintf("%s=%s", k, numItems[0])) + } else { + cond = append(cond, fmt.Sprintf("%s IN (%s)", k, strings.Join(numItems, ", "))) + } + } + continue + } + + if list, ok := v.([]string); ok && len(list) > 0 { + if len(list) == 1 { + cond = append(cond, fmt.Sprintf("%s='%s'", k, escapeFilterValue(list[0]))) + } else { + var items []string + for _, item := range list { + items = append(items, fmt.Sprintf("'%s'", escapeFilterValue(item))) + } + cond = append(cond, fmt.Sprintf("%s IN (%s)", k, strings.Join(items, ", "))) + } + continue + } + + if list, ok := v.([]int); ok && len(list) > 0 { + if len(list) == 1 { + cond = append(cond, fmt.Sprintf("%s=%d", k, list[0])) + } else { + var strs []string + for _, n := range list { + strs = append(strs, strconv.Itoa(n)) + } + cond = append(cond, fmt.Sprintf("%s IN (%s)", k, strings.Join(strs, ", "))) + } + continue + } + + // Handle numeric values (no quotes) + if utility.IsNumericValue(v) { + cond = append(cond, fmt.Sprintf("%s=%v", k, v)) + continue + } + + // Handle string values (with quotes and escaping) + if str, ok := v.(string); ok { + cond = append(cond, fmt.Sprintf("%s='%s'", k, escapeFilterValue(str))) + continue + } + + // Fallback: treat as string + cond = append(cond, fmt.Sprintf("%s='%s'", k, escapeFilterValue(fmt.Sprintf("%v", v)))) + } + + if len(cond) == 0 { + return "" + } + return strings.Join(cond, " AND ") +} + +// calculateScores calculates _score = score_column + pagerank +func calculateScores(chunks []map[string]interface{}, scoreColumn, pagerankField string) []map[string]interface{} { + for i := range chunks { + score := 0.0 + if scoreVal, ok := chunks[i][scoreColumn]; ok { + if f, ok := utility.ToFloat64(scoreVal); ok { + score += f + } + } + if pagerankField != "" { + if prVal, ok := chunks[i][pagerankField]; ok { + if f, ok := utility.ToFloat64(prVal); ok { + score += f + } + } + } + chunks[i]["_score"] = score + } + return chunks +} + +// sortByScore sorts by _score descending and limits +func sortByScore(chunks []map[string]interface{}, limit int) []map[string]interface{} { + if len(chunks) == 0 { + return chunks + } + + // Sort by _score descending + sort.Slice(chunks, func(i, j int) bool { + scoreI := getChunkScore(chunks[i]) + scoreJ := getChunkScore(chunks[j]) + return scoreI > scoreJ + }) + + // Limit + if len(chunks) > limit && limit > 0 { + chunks = chunks[:limit] + } + + return chunks +} + +// getChunkScore extracts the score from a chunk +func getChunkScore(chunk map[string]interface{}) float64 { + if v, ok := chunk["_score"].(float64); ok { + return v + } + if v, ok := chunk["SCORE"].(float64); ok { + return v + } + if v, ok := chunk["SIMILARITY"].(float64); ok { + return v + } + return 0.0 +} + +// transformChunkFields converts chunk field names to Infinity format. +// Converts internal field names (like docnm_kwd) to Infinity column names (docnm). +// Also handles: +// - kb_id: extracts first element if it's a list +// - position_int, page_num_int, top_int: converts arrays to hex strings +// - tag_kwd: joins with ### separator +// - question_kwd: joins with newline separator +// - chunk_data: dict -> JSON string +// - Missing embeddings filled with zeros if embeddingCols provided +func transformChunkFields(chunk map[string]interface{}, embeddingCols [][2]interface{}) map[string]interface{} { + d := make(map[string]interface{}) + + for k, v := range chunk { + switch k { + case "docnm_kwd": + d["docnm"] = v + case "title_kwd": + if _, exists := chunk["docnm_kwd"]; !exists { + d["docnm"] = utility.ConvertToString(v) + } + case "title_sm_tks": + if _, exists := chunk["docnm_kwd"]; !exists { + d["docnm"] = utility.ConvertToString(v) + } + case "important_kwd": + if list, ok := v.([]interface{}); ok { + emptyCount := 0 + tokens := make([]string, 0) + for _, item := range list { + if str, ok := item.(string); ok { + if str == "" { + emptyCount++ + } else { + tokens = append(tokens, str) + } + } + } + d["important_keywords"] = strings.Join(tokens, ",") + d["important_kwd_empty_count"] = emptyCount + } else { + d["important_keywords"] = utility.ConvertToString(v) + } + case "important_tks": + if _, exists := chunk["important_kwd"]; !exists { + d["important_keywords"] = v + } + case "content_with_weight": + d["content"] = v + case "content_ltks": + if _, exists := chunk["content_with_weight"]; !exists { + d["content"] = v + } + case "content_sm_ltks": + if _, exists := chunk["content_with_weight"]; !exists { + d["content"] = v + } + case "authors_tks": + d["authors"] = v + case "authors_sm_tks": + if _, exists := chunk["authors_tks"]; !exists { + d["authors"] = v + } + case "question_kwd": + d["questions"] = strings.Join(utility.ConvertToStringSlice(v), "\n") + case "tag_kwd": + d["tag_kwd"] = strings.Join(utility.ConvertToStringSlice(v), "###") + case "question_tks": + if _, exists := chunk["question_kwd"]; !exists { + d["questions"] = utility.ConvertToString(v) + } + case "kb_id": + if list, ok := v.([]interface{}); ok && len(list) > 0 { + d["kb_id"] = list[0] + } else { + d["kb_id"] = v + } + case "position_int": + if list, ok := v.([]interface{}); ok { + d["position_int"] = utility.ConvertPositionIntArrayToHex(list) + } else { + d["position_int"] = v + } + case "page_num_int", "top_int": + if list, ok := v.([]interface{}); ok { + d[k] = utility.ConvertIntArrayToHex(list) + } else { + d[k] = v + } + case "chunk_data": + d["chunk_data"] = utility.ConvertMapToJSONString(v) + default: + // Check for *_feas fields + if strings.HasSuffix(k, "_feas") { + jsonBytes, _ := json.Marshal(v) + d[k] = string(jsonBytes) + } else if fieldKeyword(k) { + // keyword fields with list values -> ### joined + if list, ok := v.([]interface{}); ok { + d[k] = strings.Join(utility.ConvertToStringSlice(list), "###") + } else { + d[k] = v + } + } else { + d[k] = v + } + } + } + + // Remove intermediate token fields + for _, key := range []string{"docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", + "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", + "question_kwd", "question_tks"} { + delete(d, key) + } + + // Fill missing embedding columns with zeros if embedding info provided + for _, ec := range embeddingCols { + name, ok1 := ec[0].(string) + size, ok2 := ec[1].(int) + if !ok1 || !ok2 { + continue + } + if _, exists := d[name]; !exists { + zeros := make([]float64, size) + for i := range zeros { + zeros[i] = 0 + } + d[name] = zeros + } + } + + return d +} + +// DropChunkStore drops a chunk table from Infinity +func (e *infinityEngine) DropChunkStore(ctx context.Context, baseName, datasetID string) error { + return e.dropTable(ctx, buildChunkTableName(baseName, datasetID)) +} + +// ChunkStoreExists checks if a chunk table exists in Infinity +func (e *infinityEngine) ChunkStoreExists(ctx context.Context, baseName, datasetID string) (bool, error) { + return e.tableExists(ctx, buildChunkTableName(baseName, datasetID)) +} \ No newline at end of file diff --git a/internal/engine/infinity/common.go b/internal/engine/infinity/common.go index 47bd09a0b7f..ec4b4e4bbcd 100644 --- a/internal/engine/infinity/common.go +++ b/internal/engine/infinity/common.go @@ -25,101 +25,58 @@ import ( "strings" infinity "github.com/infiniflow/infinity-go-sdk" + + "go.uber.org/zap" ) -// Delete deletes rows from either a dataset table or metadata table. -// If indexName starts with "ragflow_doc_meta_", it's a metadata table. -// Otherwise, it's a dataset table: {indexName}_{datasetID} -func (e *infinityEngine) Delete(ctx context.Context, condition map[string]interface{}, indexName string, datasetID string) (int64, error) { - var tableName string - if strings.HasPrefix(indexName, "ragflow_doc_meta_") { - tableName = indexName - } else { - tableName = fmt.Sprintf("%s_%s", indexName, datasetID) +// dropTable drops a table from Infinity +func (e *infinityEngine) dropTable(ctx context.Context, tableName string) error { + if tableName == "" { + return fmt.Errorf("table name cannot be empty") } - db, err := e.client.conn.GetDatabase(e.client.dbName) + // Check if table exists + exists, err := e.tableExists(ctx, tableName) if err != nil { - return 0, fmt.Errorf("failed to get database: %w", err) + return fmt.Errorf("failed to check table existence: %w", err) } - - table, err := db.GetTable(tableName) - if err != nil { - common.Warn(fmt.Sprintf("Table %s does not exist, skipping delete", tableName)) - return 0, nil + if !exists { + return fmt.Errorf("table '%s' does not exist", tableName) } - // Get table columns for building filter - clmns := make(map[string]struct { - Type string - Default interface{} - }) - colsResp, err := table.ShowColumns() + db, err := e.client.conn.GetDatabase(e.client.dbName) if err != nil { - return 0, fmt.Errorf("failed to get columns: %w", err) - } - result, ok := colsResp.(*infinity.QueryResult) - if ok { - if nameArr, ok := result.Data["name"]; ok { - if typeArr, ok := result.Data["type"]; ok { - if defArr, ok := result.Data["default"]; ok { - for i := 0; i < len(nameArr); i++ { - colName, _ := nameArr[i].(string) - colType, _ := typeArr[i].(string) - var colDefault interface{} - if i < len(defArr) { - colDefault = defArr[i] - } - clmns[colName] = struct { - Type string - Default interface{} - }{colType, colDefault} - } - } - } - } + return fmt.Errorf("failed to get database: %w", err) } - // Build filter from condition - filter := buildFilterFromCondition(condition, clmns) - - delResp, err := table.Delete(filter) + _, err = db.DropTable(tableName, infinity.ConflictTypeError) if err != nil { - return 0, fmt.Errorf("failed to delete: %w", err) + return fmt.Errorf("failed to drop table: %w", err) } - return delResp.DeletedRows, nil + common.Info("Infinity dropped table", zap.String("tableName", tableName)) + return nil } -// DropTable deletes a table/index -func (e *infinityEngine) DropTable(ctx context.Context, indexName string) error { - db, err := e.client.conn.GetDatabase(e.client.dbName) - if err != nil { - return fmt.Errorf("Failed to get database: %w", err) +// tableExists checks if a table exists in Infinity +func (e *infinityEngine) tableExists(ctx context.Context, tableName string) (bool, error) { + if tableName == "" { + return false, fmt.Errorf("table name cannot be empty") } - _, err = db.DropTable(indexName, infinity.ConflictTypeIgnore) - if err != nil { - return fmt.Errorf("Failed to drop table: %w", err) - } - return nil -} - -// TableExists checks if table/index exists -func (e *infinityEngine) TableExists(ctx context.Context, indexName string) (bool, error) { db, err := e.client.conn.GetDatabase(e.client.dbName) if err != nil { - return false, fmt.Errorf("Failed to get database: %w", err) + return false, fmt.Errorf("failed to get database: %w", err) } - _, err = db.GetTable(indexName) + // Try to get the table - if it exists, no error + _, err = db.GetTable(tableName) if err != nil { - // Check if error is "table not found" - errLower := strings.ToLower(err.Error()) - if strings.Contains(errLower, "not found") || strings.Contains(errLower, "notexist") || strings.Contains(errLower, "doesn't exist") { + errMsg := strings.ToLower(err.Error()) + if strings.Contains(errMsg, "not found") || strings.Contains(errMsg, "doesn't exist") { return false, nil } - return false, err + return false, fmt.Errorf("failed to check table existence: %w", err) } return true, nil } @@ -335,3 +292,18 @@ func (e *infinityEngine) columnExists(table *infinity.Table, columnName string) } return false, nil } + +// buildChunkTableName returns the chunk table name for a dataset +// Skill Table: table name is just baseName (e.g., "skill_abc123_def456") +// Regular chunk Table: table name is {baseName}_{datasetID} +func buildChunkTableName(baseName, datasetID string) string { + if datasetID == "skill" { + return baseName + } + return fmt.Sprintf("%s_%s", baseName, datasetID) +} + +// buildMetadataTableName returns the metadata table name for a tenant +func buildMetadataTableName(tenantID string) string { + return fmt.Sprintf("ragflow_doc_meta_%s", tenantID) +} diff --git a/internal/engine/infinity/dataset.go b/internal/engine/infinity/dataset.go deleted file mode 100644 index 8fb80ab5724..00000000000 --- a/internal/engine/infinity/dataset.go +++ /dev/null @@ -1,655 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package infinity - -import ( - "context" - "encoding/json" - "fmt" - "os" - "path/filepath" - "ragflow/internal/common" - "regexp" - "strconv" - "strings" - - "ragflow/internal/utility" - - infinity "github.com/infiniflow/infinity-go-sdk" - - "go.uber.org/zap" -) - -// CreateDataset creates a table in Infinity -// indexName is the table name prefix (e.g., "ragflow_") -// The full table name is built as "{indexName}_{datasetID}" -// For skill index (datasetID="skill"), tableName is just indexName and uses skill_infinity_mapping.json -func (e *infinityEngine) CreateDataset(ctx context.Context, indexName, datasetID string, vectorSize int, parserID string) error { - vecSize := vectorSize - - // Determine table name and mapping file based on index type - var tableName string - var mappingFile string - - if datasetID == "skill" { - // Skill index: table name is just indexName (e.g., "skill_abc123_def456") - tableName = indexName - mappingFile = "skill_infinity_mapping.json" - common.Info("Creating skill index table", zap.String("tableName", tableName), zap.String("mappingFile", mappingFile)) - } else { - // Regular document index: table name is {indexName}_{datasetID} - tableName = fmt.Sprintf("%s_%s", indexName, datasetID) - mappingFile = e.mappingFileName - common.Info("Creating regular index table", zap.String("tableName", tableName), zap.String("mappingFile", mappingFile)) - } - - // Use configured schema - fpMapping := filepath.Join(utility.GetProjectRoot(), "conf", mappingFile) - - schemaData, err := os.ReadFile(fpMapping) - if err != nil { - return fmt.Errorf("Failed to read mapping file: %w", err) - } - - var schema orderedFields - if err := json.Unmarshal(schemaData, &schema); err != nil { - return fmt.Errorf("Failed to parse mapping file: %w", err) - } - - // Get database - db, err := e.client.conn.GetDatabase(e.client.dbName) - if err != nil { - return fmt.Errorf("Failed to get database: %w", err) - } - - // Determine vector column name - vectorColName := fmt.Sprintf("q_%d_vec", vecSize) - - // Check if table already exists - exists, err := e.TableExists(ctx, tableName) - if err != nil { - return fmt.Errorf("Failed to check if table exists: %w", err) - } - - var table *infinity.Table - if exists { - // Table exists, open it and check if vector column needs to be added - common.Info("Table already exists, checking for vector column", zap.String("tableName", tableName)) - table, err = db.GetTable(tableName) - if err != nil { - return fmt.Errorf("Failed to open existing table %s: %w", tableName, err) - } - - // Check if vector column exists (for embedding model changes) - colExists, err := e.columnExists(table, vectorColName) - if err != nil { - common.Warn("Failed to check column existence", zap.String("column", vectorColName), zap.Error(err)) - } - - // Add new vector column if it doesn't exist (handles embedding model change) - if !colExists { - common.Info("Adding new vector column for embedding model change", zap.String("column", vectorColName), zap.Int("size", vecSize)) - addColSchema := infinity.TableSchema{ - &infinity.ColumnDefinition{ - Name: vectorColName, - DataType: fmt.Sprintf("vector,%d,float", vecSize), - }, - } - if _, err := table.AddColumns(addColSchema); err != nil { - common.Error("Failed to add vector column "+vectorColName, err) - return fmt.Errorf("Failed to add vector column %s: %w", vectorColName, err) - } - common.Info("Successfully added vector column", zap.String("column", vectorColName)) - } - } else { - // Table doesn't exist, create it with vector column in the initial schema - common.Info(fmt.Sprintf("Creating table with vector column: %s with dimension %d", vectorColName, vecSize)) - - // Build column definitions (preserving JSON order) - var columns infinity.TableSchema - for _, fieldName := range schema.Keys { - fieldInfo := schema.Fields[fieldName] - col := infinity.ColumnDefinition{ - Name: fieldName, - DataType: fieldInfo.Type, - Default: fieldInfo.Default, - // Comment: fieldInfo.Comment, - } - columns = append(columns, &col) - } - - // Add vector column - columns = append(columns, &infinity.ColumnDefinition{ - Name: vectorColName, - DataType: fmt.Sprintf("vector,%d,float", vecSize), - }) - - // Add chunk_data column for table parser - if parserID == "table" { - columns = append(columns, &infinity.ColumnDefinition{ - Name: "chunk_data", - DataType: "json", - Default: "{}", - }) - } - - // Create table - table, err = db.CreateTable(tableName, columns, infinity.ConflictTypeIgnore) - if err != nil { - return fmt.Errorf("Failed to create table: %w", err) - } - common.Debug("Infinity created table", zap.String("tableName", tableName)) - } - - // Create HNSW index on vector column with unique name based on vector size - // Use unique index name to avoid conflict when embedding model changes - vectorIndexName := fmt.Sprintf("q_%d_vec_idx", vecSize) - _, err = table.CreateIndex( - vectorIndexName, - infinity.NewIndexInfo(vectorColName, infinity.IndexTypeHnsw, map[string]string{ - "M": "16", - "ef_construction": "50", - "metric": "cosine", - "encode": "lvq", - }), - infinity.ConflictTypeIgnore, - "", - ) - if err != nil { - return fmt.Errorf("Failed to create HNSW index %s: %w", vectorIndexName, err) - } - common.Info("Created vector index", zap.String("indexName", vectorIndexName), zap.String("column", vectorColName)) - - // Create full-text indexes for varchar fields with analyzers - for _, fieldName := range schema.Keys { - fieldInfo := schema.Fields[fieldName] - if fieldInfo.Type != "varchar" || fieldInfo.Analyzer == nil { - continue - } - - analyzers := []string{} - switch a := fieldInfo.Analyzer.(type) { - case string: - analyzers = []string{a} - case []interface{}: - for _, v := range a { - if s, ok := v.(string); ok { - analyzers = append(analyzers, s) - } - } - } - - for _, analyzer := range analyzers { - indexNameFt := fmt.Sprintf("ft_%s_%s", - regexp.MustCompile(`[^a-zA-Z0-9]`).ReplaceAllString(fieldName, "_"), - regexp.MustCompile(`[^a-zA-Z0-9]`).ReplaceAllString(analyzer, "_"), - ) - _, err = table.CreateIndex( - indexNameFt, - infinity.NewIndexInfo(fieldName, infinity.IndexTypeFullText, map[string]string{"ANALYZER": analyzer}), - infinity.ConflictTypeIgnore, - "", - ) - if err != nil { - return fmt.Errorf("Failed to create fulltext index %s: %w", indexNameFt, err) - } - } - } - - // Create secondary indexes for fields with index_type - for _, fieldName := range schema.Keys { - fieldInfo := schema.Fields[fieldName] - if fieldInfo.IndexType == nil { - continue - } - - indexTypeStr := "" - params := map[string]string{} - - switch it := fieldInfo.IndexType.(type) { - case string: - indexTypeStr = it - case map[string]interface{}: - if t, ok := it["type"].(string); ok { - indexTypeStr = t - } - if card, ok := it["cardinality"].(string); ok { - params["cardinality"] = card - } - } - - if indexTypeStr == "secondary" { - indexNameSec := fmt.Sprintf("sec_%s", fieldName) - _, err = table.CreateIndex( - indexNameSec, - infinity.NewIndexInfo(fieldName, infinity.IndexTypeSecondary, params), - infinity.ConflictTypeIgnore, - "", - ) - if err != nil { - return fmt.Errorf("Failed to create secondary index %s: %w", indexNameSec, err) - } - } - } - - _ = table // suppress unused variable warning - return nil -} - -// InsertDataset inserts chunks into a dataset table -// Table name format: {tableNamePrefix}_{knowledgebaseID} -// Auto-create the table if it doesn't exist -// Delete existing rows with matching IDs before insert -func (e *infinityEngine) InsertDataset(ctx context.Context, chunks []map[string]interface{}, tableNamePrefix string, knowledgebaseID string) ([]string, error) { - tableName := fmt.Sprintf("%s_%s", tableNamePrefix, knowledgebaseID) - common.Info("InfinityConnection.InsertDataset called", zap.String("tableName", tableName), zap.Int("chunkCount", len(chunks))) - - db, err := e.client.conn.GetDatabase(e.client.dbName) - if err != nil { - return nil, fmt.Errorf("Failed to get database: %w", err) - } - - table, err := db.GetTable(tableName) - if err != nil { - // Table doesn't exist, try to create it - errMsg := strings.ToLower(err.Error()) - if !strings.Contains(errMsg, "not found") && !strings.Contains(errMsg, "doesn't exist") { - return nil, fmt.Errorf("Failed to get table %s: %w", tableName, err) - } - - // Infer vector size from chunks - vectorSize := 0 - vectorPattern := regexp.MustCompile(`q_(\d+)_vec`) - for _, chunk := range chunks { - for key := range chunk { - matches := vectorPattern.FindStringSubmatch(key) - if len(matches) >= 2 { - vectorSize, _ = strconv.Atoi(matches[1]) - break - } - } - if vectorSize > 0 { - break - } - } - if vectorSize == 0 { - return nil, fmt.Errorf("cannot infer vector size from chunks") - } - - // Determine parser_id from chunk structure - parserID := "" - if chunkData, ok := chunks[0]["chunk_data"].(map[string]interface{}); ok && chunkData != nil { - parserID = "table" - } - - // Create table - if err := e.CreateDataset(ctx, tableNamePrefix, knowledgebaseID, vectorSize, parserID); err != nil { - return nil, fmt.Errorf("Failed to create table: %w", err) - } - - table, err = db.GetTable(tableName) - if err != nil { - return nil, fmt.Errorf("Failed to get table after creation: %w", err) - } - } - - // Get embedding columns and their sizes - var embeddingCols [][2]interface{} - colsResp, err := table.ShowColumns() - if err != nil { - return nil, fmt.Errorf("Failed to get columns: %w", err) - } - result, ok := colsResp.(*infinity.QueryResult) - if !ok { - return nil, fmt.Errorf("unexpected response type: %T", colsResp) - } - - // ShowColumns returns a result set where Data contains arrays of column values - re := regexp.MustCompile(`Embedding\([a-z]+,(\d+)\)`) - if nameArr, ok := result.Data["name"]; ok { - if typeArr, ok := result.Data["type"]; ok { - for i := 0; i < len(nameArr); i++ { - colName, _ := nameArr[i].(string) - colType, _ := typeArr[i].(string) - matches := re.FindStringSubmatch(colType) - if len(matches) >= 2 { - size, _ := strconv.Atoi(matches[1]) - embeddingCols = append(embeddingCols, [2]interface{}{colName, size}) - } - } - } - } - - // Transform chunks using helper function - insertChunks := make([]map[string]interface{}, len(chunks)) - for i, chunk := range chunks { - insertChunks[i] = TransformChunkFields(chunk, embeddingCols) - } - - // Delete existing rows with matching IDs - if len(insertChunks) > 0 { - idList := make([]string, len(insertChunks)) - for i, chunk := range insertChunks { - idList[i] = fmt.Sprintf("'%v'", chunk["id"]) - } - filter := fmt.Sprintf("id IN (%s)", strings.Join(idList, ", ")) - common.Debug(fmt.Sprintf("Deleting existing rows with filter: %s", filter)) - delResp, delErr := table.Delete(filter) - if delErr != nil { - common.Warn(fmt.Sprintf("Failed to delete existing rows: %v", delErr)) - } else { - common.Info(fmt.Sprintf("Deleted %d existing rows", delResp.DeletedRows)) - } - } - - // Insert chunks to dataset - _, err = table.Insert(insertChunks) - if err != nil { - return nil, fmt.Errorf("Failed to insert chunks to dataset: %w", err) - } - - common.Info("InfinityConnection.InsertDataset result", zap.String("tableName", tableName), zap.Int("count", len(insertChunks))) - return []string{}, nil -} - -// UpdateDataset updates chunks in a dataset table -// Table name format: {tableNamePrefix}_{knowledgebaseID} -func (e *infinityEngine) UpdateDataset(ctx context.Context, condition map[string]interface{}, newValue map[string]interface{}, tableNamePrefix string, knowledgebaseID string) error { - tableName := fmt.Sprintf("%s_%s", tableNamePrefix, knowledgebaseID) - common.Info("InfinityConnection.UpdateDataset called", zap.String("tableName", tableName), zap.Any("condition", condition)) - - db, err := e.client.conn.GetDatabase(e.client.dbName) - if err != nil { - return fmt.Errorf("Failed to get database: %w", err) - } - - table, err := db.GetTable(tableName) - if err != nil { - return fmt.Errorf("Failed to get table %s: %w", tableName, err) - } - - // Get table columns - clmns := make(map[string]struct { - Type string - Default interface{} - }) - colsResp, err := table.ShowColumns() - if err != nil { - return fmt.Errorf("Failed to get columns: %w", err) - } - result, ok := colsResp.(*infinity.QueryResult) - if ok { - if nameArr, ok := result.Data["name"]; ok { - if typeArr, ok := result.Data["type"]; ok { - if defArr, ok := result.Data["default"]; ok { - for i := 0; i < len(nameArr); i++ { - colName, _ := nameArr[i].(string) - colType, _ := typeArr[i].(string) - var colDefault interface{} - if i < len(defArr) { - colDefault = defArr[i] - } - clmns[colName] = struct { - Type string - Default interface{} - }{colType, colDefault} - } - } - } - } - } - - // Build filter string from condition - filter := buildFilterFromCondition(condition, clmns) - - // Process remove operation first - removeValue := make(map[string]interface{}) - if removeData, ok := newValue["remove"].(map[string]interface{}); ok { - removeValue = removeData - } - delete(newValue, "remove") - - // Transform new_value fields using helper function (no embeddings needed for update) - transformed := TransformChunkFields(newValue, nil) - for k, v := range transformed { - newValue[k] = v - } - - // Remove original fields that were transformed (they're now in transformed with new names/types) - // Also remove intermediate token fields that shouldn't be stored in Infinity - // This must match Python's delete list in infinity_conn.py - for _, key := range []string{"docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", - "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", - "question_kwd", "question_tks"} { - delete(newValue, key) - } - - // Handle remove operations if any - if len(removeValue) > 0 { - colToRemove := make([]string, 0, len(removeValue)) - for k := range removeValue { - colToRemove = append(colToRemove, k) - } - colToRemove = append(colToRemove, "id") - - // Query rows to be updated - queryResult, err := table.Output(colToRemove).Filter(filter).ToResult() - if err != nil { - common.Warn(fmt.Sprintf("Failed to query rows for remove operation: %v", err)) - } else { - qr, ok := queryResult.(*infinity.QueryResult) - if ok && len(qr.Data) > 0 { - // Get the id column and columns to remove - idCol := qr.Data["id"] - removeOpt := make(map[string]map[string][]string) // column -> value -> [ids] - - for colName, colData := range qr.Data { - if colName == "id" { - continue - } - removeVal := removeValue[colName] - for i, id := range idCol { - if i < len(colData) { - existingVal := colData[i] - if removeStr, ok := removeVal.(string); ok { - // Split existing value by ### and remove the target value - if existingStr, ok := existingVal.(string); ok { - parts := strings.Split(existingStr, "###") - var newParts []string - for _, p := range parts { - if p != removeStr { - newParts = append(newParts, p) - } - } - if len(newParts) != len(parts) { - idStr := fmt.Sprintf("%v", id) - if removeOpt[colName] == nil { - removeOpt[colName] = make(map[string][]string) - } - removeOpt[colName][strings.Join(newParts, "###")] = append(removeOpt[colName][strings.Join(newParts, "###")], idStr) - } - } - } - } - } - } - - // Execute remove updates - for colName, valueToIDs := range removeOpt { - for newVal, ids := range valueToIDs { - idFilter := filter + " AND id IN (" + strings.Join(ids, ", ") + ")" - common.Info(fmt.Sprintf("INFINITY remove update: table=%s, idFilter=%s, column=%s, newValue=%v", tableName, idFilter, colName, newVal)) - _, err := table.Update(idFilter, map[string]interface{}{colName: newVal}) - if err != nil { - common.Warn(fmt.Sprintf("Failed to remove value from column %s: %v", colName, err)) - } - } - } - } - } - } - - // Execute the main update - common.Info(fmt.Sprintf("INFINITY update: table=%s, filter=%s, newValue=%v", tableName, filter, newValue)) - _, err = table.Update(filter, newValue) - if err != nil { - return fmt.Errorf("Failed to update chunks: %w", err) - } - - common.Info("InfinityConnection.UpdateDataset completes", zap.String("tableName", tableName)) - return nil -} - -// TransformChunkFields transforms chunk field name for insert/update -// It handles field name conversions and value transformations: -// - docnm_kwd -> docnm -// - title_kwd/title_sm_tks -> docnm (if docnm_kwd not set) -// - important_kwd -> important_keywords (+ important_kwd_empty_count) -// - content_with_weight/content_ltks/content_sm_ltks -> content -// - authors_tks/authors_sm_tks -> authors -// - question_kwd -> questions (joined with \n), question_tks -> questions (if question_kwd not set) -// - kb_id: list -> str (first element) -// - position_int: list -> hex_joined string -// - page_num_int, top_int: list -> hex string -// - *_feas fields -> JSON string -// - keyword fields with list values -> ### joined string -// - chunk_data: dict -> JSON string -// - Missing embeddings filled with zeros if embeddingCols provided -func TransformChunkFields(chunk map[string]interface{}, embeddingCols [][2]interface{}) map[string]interface{} { - d := make(map[string]interface{}) - - for k, v := range chunk { - switch k { - case "docnm_kwd": - d["docnm"] = v - case "title_kwd": - if _, exists := chunk["docnm_kwd"]; !exists { - d["docnm"] = utility.ConvertToString(v) - } - case "title_sm_tks": - if _, exists := chunk["docnm_kwd"]; !exists { - d["docnm"] = utility.ConvertToString(v) - } - case "important_kwd": - if list, ok := v.([]interface{}); ok { - emptyCount := 0 - tokens := make([]string, 0) - for _, item := range list { - if str, ok := item.(string); ok { - if str == "" { - emptyCount++ - } else { - tokens = append(tokens, str) - } - } - } - d["important_keywords"] = strings.Join(tokens, ",") - d["important_kwd_empty_count"] = emptyCount - } else { - d["important_keywords"] = utility.ConvertToString(v) - } - case "important_tks": - if _, exists := chunk["important_kwd"]; !exists { - d["important_keywords"] = v - } - case "content_with_weight": - d["content"] = v - case "content_ltks": - if _, exists := chunk["content_with_weight"]; !exists { - d["content"] = v - } - case "content_sm_ltks": - if _, exists := chunk["content_with_weight"]; !exists { - d["content"] = v - } - case "authors_tks": - d["authors"] = v - case "authors_sm_tks": - if _, exists := chunk["authors_tks"]; !exists { - d["authors"] = v - } - case "question_kwd": - d["questions"] = strings.Join(utility.ConvertToStringSlice(v), "\n") - case "tag_kwd": - d["tag_kwd"] = strings.Join(utility.ConvertToStringSlice(v), "###") - case "question_tks": - if _, exists := chunk["question_kwd"]; !exists { - d["questions"] = utility.ConvertToString(v) - } - case "kb_id": - if list, ok := v.([]interface{}); ok && len(list) > 0 { - d["kb_id"] = list[0] - } else { - d["kb_id"] = v - } - case "position_int": - if list, ok := v.([]interface{}); ok { - d["position_int"] = utility.ConvertPositionIntArrayToHex(list) - } else { - d["position_int"] = v - } - case "page_num_int", "top_int": - if list, ok := v.([]interface{}); ok { - d[k] = utility.ConvertIntArrayToHex(list) - } else { - d[k] = v - } - case "chunk_data": - d["chunk_data"] = utility.ConvertMapToJSONString(v) - default: - // Check for *_feas fields - if strings.HasSuffix(k, "_feas") { - jsonBytes, _ := json.Marshal(v) - d[k] = string(jsonBytes) - } else if fieldKeyword(k) { - // keyword fields with list values -> ### joined - if list, ok := v.([]interface{}); ok { - d[k] = strings.Join(utility.ConvertToStringSlice(list), "###") - } else { - d[k] = v - } - } else { - d[k] = v - } - } - } - - // Remove intermediate token fields - for _, key := range []string{"docnm_kwd", "title_tks", "title_sm_tks", "important_kwd", "important_tks", - "content_with_weight", "content_ltks", "content_sm_ltks", "authors_tks", "authors_sm_tks", - "question_kwd", "question_tks"} { - delete(d, key) - } - - // Fill missing embedding columns with zeros if embedding info provided - for _, ec := range embeddingCols { - name, ok1 := ec[0].(string) - size, ok2 := ec[1].(int) - if !ok1 || !ok2 { - continue - } - if _, exists := d[name]; !exists { - zeros := make([]float64, size) - for i := range zeros { - zeros[i] = 0 - } - d[name] = zeros - } - } - - return d -} diff --git a/internal/engine/infinity/get.go b/internal/engine/infinity/get.go deleted file mode 100644 index 8adbb4adedb..00000000000 --- a/internal/engine/infinity/get.go +++ /dev/null @@ -1,303 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package infinity - -import ( - "context" - "fmt" - "ragflow/internal/common" - "strings" - - "ragflow/internal/utility" - - infinity "github.com/infiniflow/infinity-go-sdk" - - "go.uber.org/zap" -) - -// GetChunk gets a chunk by ID -func (e *infinityEngine) GetChunk(ctx context.Context, tableName, chunkID string, kbIDs []string) (interface{}, error) { - if e.client == nil || e.client.conn == nil { - return nil, fmt.Errorf("Infinity client not initialized") - } - - // Build list of table names to search - var tableNames []string - if strings.HasPrefix(tableName, "ragflow_doc_meta_") { - tableNames = []string{tableName} - } else { - // Search in tables like _ for each kbID - if len(kbIDs) > 0 { - for _, kbID := range kbIDs { - tableNames = append(tableNames, fmt.Sprintf("%s_%s", tableName, kbID)) - } - } - // Also try the base tableName - tableNames = append(tableNames, tableName) - } - - // Try each table and collect results from all tables - db, err := e.client.conn.GetDatabase(e.client.dbName) - if err != nil { - return nil, fmt.Errorf("failed to get database: %w", err) - } - - // Collect chunks from all tables (same as Python's concat_dataframes) - allChunks := make(map[string]map[string]interface{}) - - for _, tblName := range tableNames { - table, err := db.GetTable(tblName) - if err != nil { - continue - } - - // Query with filter for the specific chunk ID - filter := fmt.Sprintf("id = '%s'", chunkID) - result, err := table.Output([]string{"*"}).Filter(filter).ToResult() - if err != nil { - continue - } - - qr, ok := result.(*infinity.QueryResult) - if !ok { - continue - } - - if len(qr.Data) == 0 { - continue - } - - // Convert to chunk format - chunks := make([]map[string]interface{}, 0) - for colName, colData := range qr.Data { - for i, val := range colData { - for len(chunks) <= i { - chunks = append(chunks, make(map[string]interface{})) - } - chunks[i][colName] = val - } - } - - // Merge chunks into allChunks (by id), keeping first non-empty value - for _, chunk := range chunks { - if idVal, ok := chunk["id"].(string); ok { - if existing, exists := allChunks[idVal]; exists { - // Merge: keep first non-empty value for each field - for k, v := range chunk { - if _, has := existing[k]; !has || utility.IsEmpty(v) { - existing[k] = v - } - } - } else { - allChunks[idVal] = chunk - } - } - } - } - - // Get the chunk by chunkID - chunk, found := allChunks[chunkID] - if !found { - return nil, nil - } - - common.Debug("infinity get chunk", zap.String("chunkID", chunkID), zap.Any("tables", tableNames)) - - // Apply field mappings (same as in GetFields) - // docnm -> docnm_kwd, title_tks, title_sm_tks - if val, ok := chunk["docnm"].(string); ok { - chunk["docnm_kwd"] = val - chunk["title_tks"] = val - chunk["title_sm_tks"] = val - } - - // content -> content_with_weight, content_ltks, content_sm_ltks - if val, ok := chunk["content"].(string); ok { - chunk["content_with_weight"] = val - chunk["content_ltks"] = val - chunk["content_sm_ltks"] = val - } - - // important_keywords -> important_kwd (split by comma), important_tks - if val, ok := chunk["important_keywords"].(string); ok { - if val == "" { - chunk["important_kwd"] = []interface{}{} - } else { - parts := strings.Split(val, ",") - chunk["important_kwd"] = parts - } - chunk["important_tks"] = val - } else { - chunk["important_kwd"] = []interface{}{} - chunk["important_tks"] = []interface{}{} - } - - // questions -> question_kwd (split by newline), question_tks - if val, ok := chunk["questions"].(string); ok { - if val == "" { - chunk["question_kwd"] = []interface{}{} - } else { - parts := strings.Split(val, "\n") - chunk["question_kwd"] = parts - } - chunk["question_tks"] = val - } else { - chunk["question_kwd"] = []interface{}{} - chunk["question_tks"] = []interface{}{} - } - - if posVal, ok := chunk["position_int"].(string); ok { - chunk["position_int"] = utility.ConvertHexToPositionIntArray(posVal) - } else { - chunk["position_int"] = []interface{}{} - } - - return chunk, nil -} - -// GetFields applies field mappings to chunks and returns a dict keyed by chunk ID. -// Equivalent to Python's get_fields() in infinity_conn.py. -// When fields is nil/empty, returns all fields from chunks. -func GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} { - result := make(map[string]map[string]interface{}) - if len(chunks) == 0 { - return result - } - - // If fields is provided, create a set for lookup - fieldSet := make(map[string]bool) - for _, f := range fields { - fieldSet[f] = true - } - - for _, chunk := range chunks { - // Apply field mappings - // docnm -> docnm_kwd, title_tks, title_sm_tks - if val, ok := chunk["docnm"].(string); ok { - chunk["docnm_kwd"] = val - chunk["title_tks"] = val - chunk["title_sm_tks"] = val - } - - // important_keywords -> important_kwd (split by comma), important_tks - if val, ok := chunk["important_keywords"].(string); ok { - if val == "" { - chunk["important_kwd"] = []interface{}{} - } else { - parts := strings.Split(val, ",") - chunk["important_kwd"] = parts - } - chunk["important_tks"] = val - } else { - chunk["important_kwd"] = []interface{}{} - chunk["important_tks"] = []interface{}{} - } - - // questions -> question_kwd (split by newline), question_tks - if val, ok := chunk["questions"].(string); ok { - if val == "" { - chunk["question_kwd"] = []interface{}{} - } else { - parts := strings.Split(val, "\n") - chunk["question_kwd"] = parts - } - chunk["question_tks"] = val - } else { - chunk["question_kwd"] = []interface{}{} - chunk["question_tks"] = []interface{}{} - } - - // content -> content_with_weight, content_ltks, content_sm_ltks - if val, ok := chunk["content"].(string); ok { - chunk["content_with_weight"] = val - chunk["content_ltks"] = val - chunk["content_sm_ltks"] = val - } - - // authors -> authors_tks, authors_sm_tks - if val, ok := chunk["authors"].(string); ok { - chunk["authors_tks"] = val - chunk["authors_sm_tks"] = val - } - - // position_int: convert from hex string to array format (grouped by 5) - if val, ok := chunk["position_int"].(string); ok { - chunk["position_int"] = utility.ConvertHexToPositionIntArray(val) - } - - // Convert page_num_int and top_int from hex string to array - for _, colName := range []string{"page_num_int", "top_int"} { - if val, ok := chunk[colName].(string); ok && val != "" { - chunk[colName] = utility.ConvertHexToIntArray(val) - } - } - - // Post-process: convert nil/empty values to empty slices for array-like fields - // and split _kwd fields by "###" (except knowledge_graph_kwd, docnm_kwd, important_kwd, question_kwd) - kwdNoSplit := map[string]bool{ - "knowledge_graph_kwd": true, "docnm_kwd": true, - "important_kwd": true, "question_kwd": true, - } - arrayFields := []string{ - "doc_type_kwd", "important_kwd", "important_tks", "question_tks", - "question_kwd", "authors_tks", "authors_sm_tks", "title_tks", - "title_sm_tks", "content_ltks", "content_sm_ltks", "tag_kwd", - } - for _, colName := range arrayFields { - val, ok := chunk[colName] - if !ok || val == nil || val == "" { - chunk[colName] = []interface{}{} - } else if !kwdNoSplit[colName] { - // Split by "###" for _kwd fields - if strVal, ok := val.(string); ok && strings.Contains(strVal, "###") { - parts := strings.Split(strVal, "###") - var filtered []interface{} - for _, p := range parts { - if p != "" { - filtered = append(filtered, p) - } - } - chunk[colName] = filtered - } - } - } - - // Handle row_id mapping - Infinity returns "ROW_ID" but we use "row_id()" - if val, ok := chunk["ROW_ID"]; ok { - chunk["row_id()"] = val - delete(chunk, "ROW_ID") - } - - // Build result map keyed by id - if id, ok := chunk["id"].(string); ok { - fieldMap := make(map[string]interface{}) - for field, value := range chunk { - if len(fieldSet) == 0 || fieldSet[field] { - fieldMap[field] = value - } - } - result[id] = fieldMap - } - } - - return result -} - -// GetFields is a method wrapper for infinityEngine to satisfy DocEngine interface -func (e *infinityEngine) GetFields(chunks []map[string]interface{}, fields []string) map[string]map[string]interface{} { - return GetFields(chunks, fields) -} diff --git a/internal/engine/infinity/metadata.go b/internal/engine/infinity/metadata.go index 31ef64bccbb..c62a64ba0a9 100644 --- a/internal/engine/infinity/metadata.go +++ b/internal/engine/infinity/metadata.go @@ -22,18 +22,20 @@ import ( "fmt" "os" "path/filepath" - "ragflow/internal/common" "strings" - "ragflow/internal/utility" - infinity "github.com/infiniflow/infinity-go-sdk" + "ragflow/internal/common" + "ragflow/internal/utility" "go.uber.org/zap" ) -// CreateMetadata creates the document metadata table/index -func (e *infinityEngine) CreateMetadata(ctx context.Context, indexName string) error { +// CreateMetadataStore creates a metadata table in Infinity +// tenantID is the tenant identifier used to build the table name +func (e *infinityEngine) CreateMetadataStore(ctx context.Context, tenantID string) error { + tableName := buildMetadataTableName(tenantID) + // Get database db, err := e.client.conn.GetDatabase(e.client.dbName) if err != nil { @@ -41,12 +43,12 @@ func (e *infinityEngine) CreateMetadata(ctx context.Context, indexName string) e } // Check if table already exists - exists, err := e.TableExists(ctx, indexName) + exists, err := e.tableExists(ctx, tableName) if err != nil { return fmt.Errorf("Failed to check if table exists: %w", err) } if exists { - return fmt.Errorf("metadata table '%s' already exists", indexName) + return fmt.Errorf("metadata table '%s' already exists", tableName) } // Use configured doc_meta mapping file @@ -69,27 +71,26 @@ func (e *infinityEngine) CreateMetadata(ctx context.Context, indexName string) e Name: fieldName, DataType: fieldInfo.Type, Default: fieldInfo.Default, - // Comment: fieldInfo.Comment, } columns = append(columns, &col) } // Create table - _, err = db.CreateTable(indexName, columns, infinity.ConflictTypeIgnore) + _, err = db.CreateTable(tableName, columns, infinity.ConflictTypeIgnore) if err != nil { return fmt.Errorf("Failed to create doc meta table: %w", err) } - common.Debug("Infinity created doc meta table", zap.String("tableName", indexName)) + common.Debug("Infinity created doc meta table", zap.String("tableName", tableName)) // Get table for creating indexes - table, err := db.GetTable(indexName) + table, err := db.GetTable(tableName) if err != nil { return fmt.Errorf("Failed to get table: %w", err) } // Create secondary index on id _, err = table.CreateIndex( - fmt.Sprintf("idx_%s_id", indexName), + fmt.Sprintf("idx_%s_id", tableName), infinity.NewIndexInfo("id", infinity.IndexTypeSecondary, nil), infinity.ConflictTypeIgnore, "", @@ -100,7 +101,7 @@ func (e *infinityEngine) CreateMetadata(ctx context.Context, indexName string) e // Create secondary index on kb_id _, err = table.CreateIndex( - fmt.Sprintf("idx_%s_kb_id", indexName), + fmt.Sprintf("idx_%s_kb_id", tableName), infinity.NewIndexInfo("kb_id", infinity.IndexTypeSecondary, nil), infinity.ConflictTypeIgnore, "", @@ -113,11 +114,10 @@ func (e *infinityEngine) CreateMetadata(ctx context.Context, indexName string) e } // InsertMetadata inserts document metadata into tenant's metadata table -// Table name format: ragflow_doc_meta_{tenant_id} // Auto-create the table if it doesn't exist // Replace existing metadata with same id and kb_id func (e *infinityEngine) InsertMetadata(ctx context.Context, metadata []map[string]interface{}, tenantID string) ([]string, error) { - tableName := fmt.Sprintf("ragflow_doc_meta_%s", tenantID) + tableName := buildMetadataTableName(tenantID) common.Info("InfinityConnection.InsertMetadata called", zap.String("tableName", tableName), zap.Int("metaCount", len(metadata))) db, err := e.client.conn.GetDatabase(e.client.dbName) @@ -134,7 +134,7 @@ func (e *infinityEngine) InsertMetadata(ctx context.Context, metadata []map[stri } // Create metadata table - if createErr := e.CreateMetadata(ctx, tableName); createErr != nil { + if createErr := e.CreateMetadataStore(ctx, tenantID); createErr != nil { return nil, fmt.Errorf("Failed to create metadata table: %w", createErr) } @@ -188,12 +188,11 @@ func (e *infinityEngine) InsertMetadata(ctx context.Context, metadata []map[stri } // UpdateMetadata updates or inserts document metadata in tenant's metadata table. -// If a row with the given docID and kbID exists, it merges the new metadata with existing. +// If a row with the given docID and datasetID exists, it merges the new metadata with existing. // If no row exists, it inserts a new row. -// Table name format: ragflow_doc_meta_{tenant_id} -func (e *infinityEngine) UpdateMetadata(ctx context.Context, docID string, kbID string, metaFields map[string]interface{}, tenantID string) error { - tableName := fmt.Sprintf("ragflow_doc_meta_%s", tenantID) - common.Info("InfinityConnection.UpdateMetadata called", zap.String("tableName", tableName), zap.String("docID", docID), zap.String("kbID", kbID)) +func (e *infinityEngine) UpdateMetadata(ctx context.Context, docID string, datasetID string, metaFields map[string]interface{}, tenantID string) error { + tableName := buildMetadataTableName(tenantID) + common.Info("InfinityConnection.UpdateMetadata called", zap.String("tableName", tableName), zap.String("docID", docID), zap.String("datasetID", datasetID)) db, err := e.client.conn.GetDatabase(e.client.dbName) if err != nil { @@ -205,10 +204,10 @@ func (e *infinityEngine) UpdateMetadata(ctx context.Context, docID string, kbID return fmt.Errorf("failed to get metadata table %s: %w", tableName, err) } - // Build filter to find existing row by docID and kbID + // Build filter to find existing row by docID and datasetID escapedDocID := strings.ReplaceAll(docID, "'", "''") - escapedKbID := strings.ReplaceAll(kbID, "'", "''") - filter := fmt.Sprintf("id = '%s' AND kb_id = '%s'", escapedDocID, escapedKbID) + escapedDatasetID := strings.ReplaceAll(datasetID, "'", "''") + filter := fmt.Sprintf("id = '%s' AND kb_id = '%s'", escapedDocID, escapedDatasetID) // Query existing metadata using the chainable API queryTable := table.Output([]string{"id", "kb_id", "meta_fields"}).Filter(filter).Limit(1).Offset(0) @@ -271,7 +270,7 @@ func (e *infinityEngine) UpdateMetadata(ctx context.Context, docID string, kbID // Row doesn't exist: insert new row insertFields := map[string]interface{}{ "id": docID, - "kb_id": kbID, + "kb_id": datasetID, "meta_fields": utility.ConvertMapToJSONString(metaFields), } common.Info(fmt.Sprintf("UpdateMetadata: inserting new row, table=%s, newValue=%v", tableName, insertFields)) @@ -284,3 +283,72 @@ func (e *infinityEngine) UpdateMetadata(ctx context.Context, docID string, kbID common.Info("InfinityConnection.UpdateMetadata completes", zap.String("tableName", tableName), zap.String("docID", docID)) return nil } + +// DeleteMetadata deletes metadata from tenant's metadata table by condition +func (e *infinityEngine) DeleteMetadata(ctx context.Context, condition map[string]interface{}, tenantID string) (int64, error) { + tableName := buildMetadataTableName(tenantID) + + db, err := e.client.conn.GetDatabase(e.client.dbName) + if err != nil { + return 0, fmt.Errorf("failed to get database: %w", err) + } + + table, err := db.GetTable(tableName) + if err != nil { + common.Warn(fmt.Sprintf("Metadata table %s does not exist, skipping delete", tableName)) + return 0, nil + } + + // Get table columns for building filter + clmns := make(map[string]struct { + Type string + Default interface{} + }) + colsResp, err := table.ShowColumns() + if err != nil { + return 0, fmt.Errorf("failed to get columns: %w", err) + } + result, ok := colsResp.(*infinity.QueryResult) + if ok { + if nameArr, ok := result.Data["name"]; ok { + if typeArr, ok := result.Data["type"]; ok { + if defArr, ok := result.Data["default"]; ok { + for i := 0; i < len(nameArr); i++ { + colName, _ := nameArr[i].(string) + colType, _ := typeArr[i].(string) + var colDefault interface{} + if i < len(defArr) { + colDefault = defArr[i] + } + clmns[colName] = struct { + Type string + Default interface{} + }{colType, colDefault} + } + } + } + } + } + + // Build filter from condition + filter := buildFilterFromCondition(condition, clmns) + + delResp, err := table.Delete(filter) + if err != nil { + return 0, fmt.Errorf("failed to delete metadata: %w", err) + } + + return delResp.DeletedRows, nil +} + +// DropMetadataStore drops a metadata table from Infinity +func (e *infinityEngine) DropMetadataStore(ctx context.Context, tenantID string) error { + tableName := buildMetadataTableName(tenantID) + return e.dropTable(ctx, tableName) +} + +// MetadataStoreExists checks if a metadata table exists in Infinity +func (e *infinityEngine) MetadataStoreExists(ctx context.Context, tenantID string) (bool, error) { + tableName := buildMetadataTableName(tenantID) + return e.tableExists(ctx, tableName) +} \ No newline at end of file diff --git a/internal/engine/infinity/search.go b/internal/engine/infinity/search.go deleted file mode 100644 index 3656854b31f..00000000000 --- a/internal/engine/infinity/search.go +++ /dev/null @@ -1,1100 +0,0 @@ -// -// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -// - -package infinity - -import ( - "context" - "encoding/json" - "fmt" - "ragflow/internal/common" - "ragflow/internal/engine/types" - "ragflow/internal/utility" - "regexp" - "slices" - "sort" - "strconv" - "strings" - "unicode" - - infinity "github.com/infiniflow/infinity-go-sdk" - "go.uber.org/zap" -) - -// Search searches the Infinity engine for matching chunks. -// It supports three matching types: MatchTextExpr (full-text), MatchDenseExpr (vector), and FusionExpr (combined). -// If no match expressions are provided, Search relies solely on filter (e.g., doc_id, available_int) to find results. -func (e *infinityEngine) Search(ctx context.Context, req *types.SearchRequest) (*types.SearchResult, error) { - common.Debug("Search in Infinity started", zap.Any("indexNames", req.IndexNames)) - if common.IsDebugEnabled() { - // Format match expressions for logging - var matchExprsStr string - for i, expr := range req.MatchExprs { - switch e := expr.(type) { - case *types.MatchTextExpr: - matchExprsStr += fmt.Sprintf(" [%d] MatchTextExpr: fields=%v, matchingText=%s, topN=%d, extraOptions=%v\n", i, e.Fields, e.MatchingText, e.TopN, e.ExtraOptions) - case *types.MatchDenseExpr: - matchExprsStr += fmt.Sprintf(" [%d] MatchDenseExpr: vectorColumn=%s, vectorSize=%d, topN=%d, extraOptions=%v\n", i, e.VectorColumnName, len(e.EmbeddingData), e.TopN, e.ExtraOptions) - case *types.FusionExpr: - matchExprsStr += fmt.Sprintf(" [%d] FusionExpr: method=%s, topN=%d, fusionParams=%v\n", i, e.Method, e.TopN, e.FusionParams) - default: - matchExprsStr += fmt.Sprintf(" [%d] unknown type\n", i) - } - } - common.Debug(fmt.Sprintf("Search request:\n"+ - " indexNames=%v\n"+ - " KbIDs=%v\n"+ - " offset=%d, limit=%d\n"+ - " SelectFields=%v\n"+ - " Filter=%v\n"+ - " MatchExprs:\n%s orderBy=%v\n"+ - " RankFeature=%v", - req.IndexNames, req.KbIDs, req.Offset, req.Limit, req.SelectFields, req.Filter, matchExprsStr, req.OrderBy, req.RankFeature)) - } - - if len(req.IndexNames) == 0 { - return nil, fmt.Errorf("index names cannot be empty") - } - - // Get retrieval parameters with defaults - pageSize := req.Limit - if pageSize <= 0 { - pageSize = 30 - } - - offset := req.Offset - if offset < 0 { - offset = 0 - } - - db, err := e.client.conn.GetDatabase(e.client.dbName) - if err != nil { - return nil, fmt.Errorf("failed to get database: %w", err) - } - - isMetadataTable := false - isSkillIndex := false - for _, idx := range req.IndexNames { - if strings.HasPrefix(idx, "ragflow_doc_meta_") { - isMetadataTable = true - break - } - if strings.HasPrefix(idx, "skill_") { - isSkillIndex = true - break - } - } - - var outputColumns []string - if isMetadataTable { - outputColumns = []string{"id", "kb_id", "meta_fields"} - } else if isSkillIndex { - outputColumns = []string{ - "skill_id", "space_id", "folder_id", "name", "tags", "description", "content", - "version", "status", "create_time", "update_time", - } - outputColumns = convertSelectFields(outputColumns, true) - } else { - outputColumns = []string{ - "id", "doc_id", "kb_id", "content_ltks", "content_with_weight", - "title_tks", "docnm_kwd", "img_id", "available_int", "important_kwd", - "position_int", "page_num_int", "top_int", "chunk_order_int", - "create_timestamp_flt", "knowledge_graph_kwd", "question_kwd", "question_tks", - "doc_type_kwd", "mom_id", "tag_kwd", "pagerank_fea", "tag_feas", - } - outputColumns = convertSelectFields(outputColumns) - } - - hasTextMatch := false - hasVectorMatch := false - var matchText *types.MatchTextExpr - var matchDense *types.MatchDenseExpr - if req.MatchExprs != nil && len(req.MatchExprs) > 0 { - for _, expr := range req.MatchExprs { - if expr == nil { - continue - } - switch e := expr.(type) { - case string: - if e != "" { - hasTextMatch = true - matchText = &types.MatchTextExpr{ - MatchingText: e, - TopN: pageSize, - } - } - case *types.MatchTextExpr: - if e.MatchingText != "" { - hasTextMatch = true - matchText = e - } - case *types.MatchDenseExpr: - if len(e.EmbeddingData) > 0 { - hasVectorMatch = true - matchDense = e - } - } - } - } - - if hasTextMatch || hasVectorMatch { - if hasTextMatch { - outputColumns = append(outputColumns, "score()") - } - // similarity() is only allowed by Infinity when there is ONLY MATCH VECTOR. - // When both text and vector matches exist (hybrid search with Fusion), - // only score() is valid — Fusion produces a unified SCORE column. - if hasVectorMatch && !hasTextMatch { - outputColumns = append(outputColumns, "similarity()") - } - // Skill index does not have pagerank_fea and tag_feas columns - if !isSkillIndex { - if !slices.Contains(outputColumns, common.PAGERANK_FLD) { - outputColumns = append(outputColumns, common.PAGERANK_FLD) - } - if !slices.Contains(outputColumns, common.TAG_FLD) { - outputColumns = append(outputColumns, common.TAG_FLD) - } - } - } - - if !slices.Contains(outputColumns, "row_id") && !slices.Contains(outputColumns, "row_id()") { - outputColumns = append(outputColumns, "row_id()") - } - - outputColumns = convertSelectFields(outputColumns, isSkillIndex) - if hasVectorMatch && matchDense != nil && matchDense.VectorColumnName != "" { - outputColumns = append(outputColumns, matchDense.VectorColumnName) - } - - var filterParts []string - if isMetadataTable && len(req.KbIDs) > 0 && req.KbIDs[0] != "" { - kbIDs := req.KbIDs - if len(kbIDs) == 1 { - filterParts = append(filterParts, fmt.Sprintf("kb_id = '%s'", kbIDs[0])) - } else { - kbIDStr := strings.Join(kbIDs, "', '") - filterParts = append(filterParts, fmt.Sprintf("kb_id IN ('%s')", kbIDStr)) - } - } - - if !isMetadataTable && (hasTextMatch || hasVectorMatch) { - if req.Filter != nil { - if availInt, ok := req.Filter["available_int"]; ok { - filterParts = append(filterParts, fmt.Sprintf("available_int=%v", availInt)) - } else if status, ok := req.Filter["status"]; ok { - filterParts = append(filterParts, fmt.Sprintf("status='%s'", status)) - } else { - if isSkillIndex { - filterParts = append(filterParts, "status='1'") - } else { - filterParts = append(filterParts, "available_int=1") - } - } - } else { - if isSkillIndex { - filterParts = append(filterParts, "status='1'") - } else { - filterParts = append(filterParts, "available_int=1") - } - } - } - - // Build filter string from req.Filter - if req.Filter != nil { - filterCopy := req.Filter - if !isMetadataTable { - filterCopy = make(map[string]interface{}) - for k, v := range req.Filter { - if k != "kb_id" { - filterCopy[k] = v - } - } - } - - condStr := equivalentConditionToStr(filterCopy) - if condStr != "" { - filterParts = append(filterParts, condStr) - } - } - filterStr := strings.Join(filterParts, " AND ") - - orderBy := req.OrderBy - var rankFeature map[string]float64 - if req.RankFeature != nil { - rankFeature = req.RankFeature - } - - var fusionExpr *types.FusionExpr - if len(req.MatchExprs) > 2 { - if fe, ok := req.MatchExprs[2].(*types.FusionExpr); ok { - fusionExpr = fe - } - } - - var allResults []map[string]interface{} - totalHits := int64(0) - - for _, indexName := range req.IndexNames { - var tableNames []string - if strings.HasPrefix(indexName, "ragflow_doc_meta_") { - tableNames = []string{indexName} - } else { - kbIDs := req.KbIDs - if len(kbIDs) == 0 { - kbIDs = []string{""} - } - for _, kbID := range kbIDs { - if kbID == "" { - tableNames = append(tableNames, indexName) - } else { - tableNames = append(tableNames, fmt.Sprintf("%s_%s", indexName, kbID)) - } - } - } - - minMatch := 0.3 - - var questionText string - var vectorData []float64 - textTopN := pageSize - var originalQuery string - if matchText != nil { - questionText = matchText.MatchingText - textTopN = int(matchText.TopN) - if matchText.ExtraOptions != nil { - if oq, ok := matchText.ExtraOptions["original_query"].(string); ok { - originalQuery = oq - } - } - } - if matchDense != nil { - vectorData = matchDense.EmbeddingData - } - - for _, tableName := range tableNames { - tbl, err := db.GetTable(tableName) - if err != nil { - continue - } - table := tbl.Output(outputColumns) - - var textFields []string - if matchText != nil && len(matchText.Fields) > 0 { - textFields = matchText.Fields - } else if isSkillIndex { - textFields = []string{ - "name^10", - "tags^5", - "description^3", - "content^1", - } - } else { - textFields = []string{ - "title_tks^10", - "title_sm_tks^5", - "important_kwd^30", - "important_tks^20", - "question_tks^20", - "content_ltks^2", - "content_sm_ltks", - } - } - - // Convert field names for Infinity - var convertedFields []string - for _, f := range textFields { - cf := convertMatchingField(f) - convertedFields = append(convertedFields, cf) - } - fields := strings.Join(convertedFields, ",") - - hasTextMatch := questionText != "" - hasVectorMatch := len(vectorData) > 0 - // Add text match if question is provided - if hasTextMatch { - extraOptions := map[string]string{ - "minimum_should_match": fmt.Sprintf("%d%%", int(minMatch*100)), - } - - if filterStr != "" { - extraOptions["filter"] = filterStr - } - - if rankFeature != nil { - var rankFeaturesList []string - for featureName, weight := range rankFeature { - rankFeaturesList = append(rankFeaturesList, fmt.Sprintf("%s^%s^%.0f", common.TAG_FLD, featureName, weight)) - } - if len(rankFeaturesList) > 0 { - extraOptions["rank_features"] = strings.Join(rankFeaturesList, ",") - } - } - - if originalQuery != "" { - extraOptions["original_query"] = originalQuery - } - - table = table.MatchText(fields, questionText, textTopN, extraOptions) - - common.Debug(fmt.Sprintf( - "MatchTextExpr:\n"+ - " fields=%s\n"+ - " matching_text=%s\n"+ - " topn=%d\n"+ - " extra_options=%v", - fields, questionText, textTopN, extraOptions, - )) - } - - // Add vector match if provided - if hasVectorMatch { - vectorSize := len(vectorData) - fieldName := fmt.Sprintf("q_%d_vec", vectorSize) - dataType := "float" - distanceType := "cosine" - - if matchDense != nil { - if matchDense.VectorColumnName != "" { - fieldName = matchDense.VectorColumnName - } - if matchDense.EmbeddingDataType != "" { - dataType = matchDense.EmbeddingDataType - } - if matchDense.DistanceType != "" { - distanceType = matchDense.DistanceType - } - } - - vectorTopN := pageSize - if matchDense != nil && matchDense.TopN > 0 { - vectorTopN = int(matchDense.TopN) - } - - denseFilterStr := filterStr - if denseFilterStr == "" { - if isSkillIndex { - denseFilterStr = "status='1'" - } else { - denseFilterStr = "available_int=1" - } - } - - if hasTextMatch && fusionExpr == nil { - fieldsStr := strings.Join(convertedFields, ",") - filterFulltext := fmt.Sprintf("filter_fulltext('%s', '%s')", fieldsStr, questionText) - denseFilterStr = fmt.Sprintf("(%s) AND %s", denseFilterStr, filterFulltext) - } - extraOptions := map[string]string{ - "threshold": utility.FloatToString(0.0), - "filter": denseFilterStr, - } - - common.Debug("MatchDense for hybrid search", - zap.String("fieldName", fieldName), - zap.String("distanceType", distanceType), - zap.Int("topN", vectorTopN), - zap.Bool("hasFusion", fusionExpr != nil)) - - table = table.MatchDense(fieldName, vectorData, dataType, distanceType, vectorTopN, extraOptions) - } - - // Add fusion (for text + vector combination) - if hasTextMatch && hasVectorMatch && fusionExpr != nil { - fusionMethod := fusionExpr.Method - fusionTopK := fusionExpr.TopN - if fusionTopK == 0 { - fusionTopK = pageSize - } - fusionParams := map[string]interface{}{ - "normalize": "atan", - } - if fusionExpr.FusionParams != nil { - for k, v := range fusionExpr.FusionParams { - fusionParams[k] = v - } - } - - common.Debug("Applying Fusion for hybrid search", - zap.String("method", fusionMethod), - zap.Int("topN", fusionTopK), - zap.Any("params", fusionParams)) - - table = table.Fusion(fusionMethod, fusionTopK, fusionParams) - } - - // Add order_by if provided - if orderBy != nil && len(orderBy.Fields) > 0 { - var sortFields [][2]interface{} - for _, orderField := range orderBy.Fields { - sortType := infinity.SortTypeAsc - if orderField.Type == types.SortDesc { - sortType = infinity.SortTypeDesc - } - sortFields = append(sortFields, [2]interface{}{orderField.Field, sortType}) - } - table = table.Sort(sortFields) - } - - // Add filter when there's no text/vector match (like metadata queries) - if !hasTextMatch && !hasVectorMatch && filterStr != "" { - common.Debug(fmt.Sprintf("Adding filter for no-match query: %s", filterStr)) - table = table.Filter(filterStr) - } - - // Set limit and offset - table = table.Limit(pageSize) - if offset > 0 { - table = table.Offset(offset) - } - - // Request total_hits_count from Infinity - table = table.Option(map[string]interface{}{"total_hits_count": true}) - - // Execute query - df, err := table.ToDataFrame() - if err != nil { - common.Warn("Infinity query failed", - zap.String("tableName", tableName), - zap.Bool("hasTextMatch", hasTextMatch), - zap.Bool("hasVectorMatch", hasVectorMatch), - zap.Bool("hasFusion", fusionExpr != nil), - zap.Error(err)) - continue - } - - // Convert DataFrame to chunks format (column-oriented to row-oriented) - chunks := make([]map[string]interface{}, 0) - for colName, colData := range df.ColumnData { - for i, val := range colData { - for len(chunks) <= i { - chunks = append(chunks, make(map[string]interface{})) - } - chunks[i][colName] = val - } - } - - // Apply field name mapping and row_id handling - // Skill index uses different schema - // so we skip the document-specific field mappings - if !isSkillIndex { - GetFields(chunks, nil) - } else { - // For skill index, only handle ROW_ID -> row_id() mapping - for _, chunk := range chunks { - if val, ok := chunk["ROW_ID"]; ok { - chunk["row_id()"] = val - delete(chunk, "ROW_ID") - } - } - } - - // Parse total_hits_count from ExtraInfo - var tableTotal int64 - if df.ExtraInfo != "" { - var extraResult map[string]interface{} - if err := json.Unmarshal([]byte(df.ExtraInfo), &extraResult); err == nil { - if count, ok := extraResult["total_hits_count"].(float64); ok { - tableTotal = int64(count) - } - } - } - - searchResult := &types.SearchResult{ - Chunks: chunks, - Total: tableTotal, - } - - allResults = append(allResults, searchResult.Chunks...) - totalHits += searchResult.Total - } - } - - if hasTextMatch || hasVectorMatch { - scoreColumn := "" - if hasTextMatch && hasVectorMatch { - scoreColumn = "SCORE" - } else if hasTextMatch { - scoreColumn = "SCORE" - } else if hasVectorMatch { - scoreColumn = "SIMILARITY" - } - pagerankField := common.PAGERANK_FLD - if isSkillIndex { - pagerankField = "" // Skill index has no pagerank field - } - - allResults = calculateScores(allResults, scoreColumn, pagerankField) - allResults = sortByScore(allResults, len(allResults)) - } - - if len(allResults) > pageSize { - allResults = allResults[:pageSize] - } - - common.Debug("Search in Infinity completed", zap.Int("returnedRows", len(allResults)), zap.Int64("totalHits", totalHits)) - - return &types.SearchResult{ - Chunks: allResults, - Total: totalHits, - }, nil -} - -// convertSelectFields converts field names to Infinity format -// isSkillIndex indicates if this is a skill index (uses skill_id instead of id) -func convertSelectFields(output []string, isSkillIndex ...bool) []string { - fieldMapping := map[string]string{ - "docnm_kwd": "docnm", - "title_tks": "docnm", - "title_sm_tks": "docnm", - "important_kwd": "important_keywords", - "important_tks": "important_keywords", - "question_kwd": "questions", - "question_tks": "questions", - "content_with_weight": "content", - "content_ltks": "content", - "content_sm_ltks": "content", - "authors_tks": "authors", - "authors_sm_tks": "authors", - } - - skillIndex := false - if len(isSkillIndex) > 0 { - skillIndex = isSkillIndex[0] - } - - needEmptyCount := false - for i, field := range output { - if field == "important_kwd" { - needEmptyCount = true - } - if newField, ok := fieldMapping[field]; ok { - output[i] = newField - } - } - - // Remove duplicates - seen := make(map[string]bool) - result := []string{} - for _, f := range output { - if f != "" && !seen[f] { - seen[f] = true - result = append(result, f) - } - } - - // Add id and empty count if needed - // For skill index, use skill_id instead of id - hasID := false - idField := "id" - if skillIndex { - idField = "skill_id" - } - for _, f := range result { - if f == idField { - hasID = true - break - } - } - if !hasID { - result = append([]string{idField}, result...) - } - - if needEmptyCount { - result = append(result, "important_kwd_empty_count") - } - - return result -} - -// convertMatchingField converts field names for matching -// For regular document indices: maps _tks/_kwd fields to column@index_name format -// For skill indices: maps raw field names to column@index_name format -// Infinity requires column@index_name when a column has multiple full-text indexes -func convertMatchingField(fieldWeightStr string) string { - // Split on ^ to get field name - parts := strings.Split(fieldWeightStr, "^") - field := parts[0] - - // Field name conversion - fieldMapping := map[string]string{ - "docnm_kwd": "docnm@ft_docnm_rag_coarse", - "title_tks": "docnm@ft_docnm_rag_coarse", - "title_sm_tks": "docnm@ft_docnm_rag_fine", - "important_kwd": "important_keywords@ft_important_keywords_rag_coarse", - "important_tks": "important_keywords@ft_important_keywords_rag_fine", - "question_kwd": "questions@ft_questions_rag_coarse", - "question_tks": "questions@ft_questions_rag_fine", - "content_with_weight": "content@ft_content_rag_coarse", - "content_ltks": "content@ft_content_rag_coarse", - "content_sm_ltks": "content@ft_content_rag_fine", - "authors_tks": "authors@ft_authors_rag_coarse", - "authors_sm_tks": "authors@ft_authors_rag_fine", - "tag_kwd": "tag_kwd@ft_tag_kwd_whitespace__", - // Skill index fields - "name": "name@ft_name_rag_coarse", - "tags": "tags@ft_tags_rag_coarse", - "description": "description@ft_description_rag_coarse", - "content": "content@ft_content_rag_coarse", - } - - if newField, ok := fieldMapping[field]; ok { - parts[0] = newField - } - - return strings.Join(parts, "^") -} - -// escapeFilterValue escapes single quotes for filter values -func escapeFilterValue(s string) string { - return strings.ReplaceAll(s, "'", "''") -} - -// equivalentConditionToStr converts a condition map to an Infinity filter string -func equivalentConditionToStr(condition map[string]interface{}) string { - if len(condition) == 0 { - return "" - } - - var cond []string - - for k, v := range condition { - if k == "_id" || utility.IsEmpty(v) { - continue - } - - // Handle must_not specially - if k == "must_not" { - if m, ok := v.(map[string]interface{}); ok { - for kk, vv := range m { - if kk == "exists" { - // For must_not exists, use !='' since we don't have table schema - cond = append(cond, fmt.Sprintf("NOT (%v!='')", vv)) - } - } - } - continue - } - - // Handle exists specially (without table schema, use string comparison) - if k == "exists" { - cond = append(cond, fmt.Sprintf("%v!=''", v)) - continue - } - - // Handle keyword fields (using full-text filter) - if fieldKeyword(k) { - // For keyword fields, values are always treated as strings for filter_fulltext - switch val := v.(type) { - case []string: - var inCond []string - for _, item := range val { - inCond = append(inCond, fmt.Sprintf("filter_fulltext('%s', '%s')", - convertMatchingField(k), escapeFilterValue(item))) - } - if len(inCond) > 0 { - cond = append(cond, "("+strings.Join(inCond, " or ")+")") - } - case []interface{}: - var inCond []string - for _, item := range val { - if s, ok := item.(string); ok { - inCond = append(inCond, fmt.Sprintf("filter_fulltext('%s', '%s')", - convertMatchingField(k), escapeFilterValue(s))) - } else { - inCond = append(inCond, fmt.Sprintf("filter_fulltext('%s', '%s')", - convertMatchingField(k), escapeFilterValue(fmt.Sprintf("%v", item)))) - } - } - if len(inCond) > 0 { - cond = append(cond, "("+strings.Join(inCond, " or ")+")") - } - case string: - cond = append(cond, fmt.Sprintf("filter_fulltext('%s', '%s')", - convertMatchingField(k), escapeFilterValue(val))) - default: - cond = append(cond, fmt.Sprintf("filter_fulltext('%s', '%s')", - convertMatchingField(k), escapeFilterValue(fmt.Sprintf("%v", v)))) - } - continue - } - - // Handle list values (mixed types - strings get quotes, numbers don't) - if list, ok := v.([]interface{}); ok && len(list) > 0 { - var strItems, numItems []string - for _, item := range list { - if s, ok := item.(string); ok { - strItems = append(strItems, fmt.Sprintf("'%s'", escapeFilterValue(s))) - } else if n, ok := item.(int); ok { - numItems = append(numItems, strconv.Itoa(n)) - } else if n, ok := item.(int64); ok { - numItems = append(numItems, strconv.FormatInt(n, 10)) - } else if f, ok := item.(float64); ok { - numItems = append(numItems, strconv.FormatFloat(f, 'f', -1, 64)) - } else if s, ok := item.(fmt.Stringer); ok { - strItems = append(strItems, fmt.Sprintf("'%s'", escapeFilterValue(s.String()))) - } else { - strItems = append(strItems, fmt.Sprintf("'%s'", escapeFilterValue(fmt.Sprintf("%v", item)))) - } - } - if len(strItems) > 0 { - if len(strItems) == 1 { - cond = append(cond, fmt.Sprintf("%s=%s", k, strItems[0])) - } else { - cond = append(cond, fmt.Sprintf("%s IN (%s)", k, strings.Join(strItems, ", "))) - } - } - if len(numItems) > 0 { - if len(numItems) == 1 { - cond = append(cond, fmt.Sprintf("%s=%s", k, numItems[0])) - } else { - cond = append(cond, fmt.Sprintf("%s IN (%s)", k, strings.Join(numItems, ", "))) - } - } - continue - } - - if list, ok := v.([]string); ok && len(list) > 0 { - if len(list) == 1 { - cond = append(cond, fmt.Sprintf("%s='%s'", k, escapeFilterValue(list[0]))) - } else { - var items []string - for _, item := range list { - items = append(items, fmt.Sprintf("'%s'", escapeFilterValue(item))) - } - cond = append(cond, fmt.Sprintf("%s IN (%s)", k, strings.Join(items, ", "))) - } - continue - } - - if list, ok := v.([]int); ok && len(list) > 0 { - if len(list) == 1 { - cond = append(cond, fmt.Sprintf("%s=%d", k, list[0])) - } else { - var strs []string - for _, n := range list { - strs = append(strs, strconv.Itoa(n)) - } - cond = append(cond, fmt.Sprintf("%s IN (%s)", k, strings.Join(strs, ", "))) - } - continue - } - - // Handle numeric values (no quotes) - if utility.IsNumericValue(v) { - cond = append(cond, fmt.Sprintf("%s=%v", k, v)) - continue - } - - // Handle string values (with quotes and escaping) - if str, ok := v.(string); ok { - cond = append(cond, fmt.Sprintf("%s='%s'", k, escapeFilterValue(str))) - continue - } - - // Fallback: treat as string - cond = append(cond, fmt.Sprintf("%s='%s'", k, escapeFilterValue(fmt.Sprintf("%v", v)))) - } - - if len(cond) == 0 { - return "" - } - return strings.Join(cond, " AND ") -} - -// calculateScores calculates _score = score_column + pagerank -func calculateScores(chunks []map[string]interface{}, scoreColumn, pagerankField string) []map[string]interface{} { - for i := range chunks { - score := 0.0 - if scoreVal, ok := chunks[i][scoreColumn]; ok { - if f, ok := utility.ToFloat64(scoreVal); ok { - score += f - } - } - if pagerankField != "" { - if prVal, ok := chunks[i][pagerankField]; ok { - if f, ok := utility.ToFloat64(prVal); ok { - score += f - } - } - } - chunks[i]["_score"] = score - } - return chunks -} - -// sortByScore sorts by _score descending and limits -func sortByScore(chunks []map[string]interface{}, limit int) []map[string]interface{} { - if len(chunks) == 0 { - return chunks - } - - // Sort by _score descending - sort.Slice(chunks, func(i, j int) bool { - scoreI := getChunkScore(chunks[i]) - scoreJ := getChunkScore(chunks[j]) - return scoreI > scoreJ - }) - - // Limit - if len(chunks) > limit && limit > 0 { - chunks = chunks[:limit] - } - - return chunks -} - -// getChunkScore extracts the score from a chunk -func getChunkScore(chunk map[string]interface{}) float64 { - if v, ok := chunk["_score"].(float64); ok { - return v - } - if v, ok := chunk["SCORE"].(float64); ok { - return v - } - if v, ok := chunk["SIMILARITY"].(float64); ok { - return v - } - return 0.0 -} - -// GetAggregation aggregates field values from search results. -// -// Example: -// input chunks: -// -// [{"docnm_kwd": "docA"}, {"docnm_kwd": "docA"}, {"docnm_kwd": "docB"}] -// -// GetAggregation(chunks, "docnm_kwd") returns: -// -// [{"key": "docA", "count": 2}, {"key": "docB", "count": 1}] -// -// For tag_kwd field, splits values by "###" separator. -// For other fields, uses comma separation. -func (e *infinityEngine) GetAggregation(chunks []map[string]interface{}, fieldName string) []map[string]interface{} { - if len(chunks) == 0 { - return []map[string]interface{}{} - } - - // Check if field exists in first chunk - hasField := false - for _, chunk := range chunks { - if _, ok := chunk[fieldName]; ok { - hasField = true - break - } - } - if !hasField { - return []map[string]interface{}{} - } - - // Count occurrences - tagCounts := make(map[string]int) - for _, chunk := range chunks { - value, ok := chunk[fieldName] - if !ok || value == nil { - continue - } - - // Handle string value - if valueStr, ok := value.(string); ok { - if valueStr == "" { - continue - } - - var tags []string - // Split by "###" for tag_kwd field - if fieldName == "tag_kwd" && strings.Contains(valueStr, "###") { - for _, tag := range strings.Split(valueStr, "###") { - tag = strings.TrimSpace(tag) - if tag != "" { - tags = append(tags, tag) - } - } - } else { - // Fallback to comma separation - for _, tag := range strings.Split(valueStr, ",") { - tag = strings.TrimSpace(tag) - if tag != "" { - tags = append(tags, tag) - } - } - } - - for _, tag := range tags { - tagCounts[tag]++ - } - continue - } - - // Handle list value - if valueList, ok := value.([]interface{}); ok { - for _, item := range valueList { - if itemStr, ok := item.(string); ok { - tag := strings.TrimSpace(itemStr) - if tag != "" { - tagCounts[tag]++ - } - } - } - } - } - - if len(tagCounts) == 0 { - return []map[string]interface{}{} - } - - // Convert to slice and sort by count descending - type tagCountPair struct { - tag string - count int - } - pairs := make([]tagCountPair, 0, len(tagCounts)) - for tag, count := range tagCounts { - pairs = append(pairs, tagCountPair{tag, count}) - } - sort.Slice(pairs, func(i, j int) bool { - return pairs[i].count > pairs[j].count - }) - - // Convert to []map[string]interface{} directly - result := make([]map[string]interface{}, len(pairs)) - for i, p := range pairs { - result[i] = map[string]interface{}{"key": p.tag, "count": p.count} - } - - return result -} - -// GetDocIDs extracts document IDs from search results. -// Extracts "id" field from each chunk and returns as a list. -func (e *infinityEngine) GetDocIDs(chunks []map[string]interface{}) []string { - if len(chunks) == 0 { - return nil - } - ids := make([]string, 0, len(chunks)) - for _, chunk := range chunks { - if id, ok := chunk["id"].(string); ok { - ids = append(ids, id) - } - } - return ids -} - -// GetHighlight generates highlighted text snippets for search results. -// Matches keywords in text and wraps them with tags. -func (e *infinityEngine) GetHighlight(chunks []map[string]interface{}, keywords []string, fieldName string) map[string]string { - result := make(map[string]string) - if len(chunks) == 0 || len(keywords) == 0 { - return result - } - - // Check if field exists - hasField := false - for _, chunk := range chunks { - if _, ok := chunk[fieldName]; ok { - hasField = true - break - } - } - if !hasField { - // Try alternative field names - if fieldName == "content_with_weight" { - if _, ok := chunks[0]["content"]; ok { - fieldName = "content" - hasField = true - } - } - } - if !hasField { - return result - } - - emTag := regexp.MustCompile(`[^<>]+`) - - for _, chunk := range chunks { - id := "" - if idVal, ok := chunk["id"].(string); ok { - id = idVal - } - - txt, ok := chunk[fieldName].(string) - if !ok || txt == "" { - continue - } - - // Check if already highlighted - if emTag.MatchString(txt) { - result[id] = txt - continue - } - - // Replace newlines with spaces - txt = regexp.MustCompile(`[\r\n]`).ReplaceAllString(txt, " ") - - // Split by sentence delimiters - delimiters := regexp.MustCompile(`[.?!;\n]`) - segments := delimiters.Split(txt, -1) - - var highlightedSegments []string - for _, segment := range segments { - // Check if segment is English or contains keywords - englishCount := 0 - totalCount := 0 - for _, r := range segment { - if unicode.IsLetter(r) { - totalCount++ - if (r >= 'a' && r <= 'z') || (r >= 'A' && r <= 'Z') { - englishCount++ - } - } - } - isEnglish := totalCount > 0 && float64(englishCount)/float64(totalCount) > 0.5 - segmentToCheck := segment - if isEnglish { - // For English: match whole words with boundaries - for _, kw := range keywords { - re := regexp.MustCompile(`(^|[ .?/'\"\(\)!,:;-])` + regexp.QuoteMeta(kw) + `([ .?/'\"\(\)!,:;-]|$)`) - segmentToCheck = re.ReplaceAllString(segmentToCheck, "$1"+kw+"$2") - } - } else { - // For non-English: simple keyword replacement (sorted by length desc for longer matches first) - sortedKeywords := make([]string, len(keywords)) - copy(sortedKeywords, keywords) - sort.Slice(sortedKeywords, func(i, j int) bool { - return len(sortedKeywords[i]) > len(sortedKeywords[j]) - }) - for _, kw := range sortedKeywords { - re := regexp.MustCompile(regexp.QuoteMeta(kw)) - segmentToCheck = re.ReplaceAllString(segmentToCheck, ""+kw+"") - } - } - - // Check if any keywords were highlighted - if emTag.MatchString(segmentToCheck) { - highlightedSegments = append(highlightedSegments, segmentToCheck) - } - } - - if len(highlightedSegments) > 0 { - result[id] = "..." + strings.Join(highlightedSegments, "...") + "..." - } else { - result[id] = txt - } - } - - return result -} diff --git a/internal/entity/base.go b/internal/entity/base.go index 748fea87132..774be466f3d 100644 --- a/internal/entity/base.go +++ b/internal/entity/base.go @@ -20,6 +20,8 @@ import ( "database/sql/driver" "encoding/json" "time" + + "gorm.io/gorm" ) // BaseModel base model @@ -31,6 +33,92 @@ type BaseModel struct { UpdateDate *time.Time `gorm:"column:update_date;index" json:"update_date,omitempty"` } +func autoModelTime() (int64, time.Time) { + now := time.Now().Local() + return now.UnixMilli(), now.Truncate(time.Second) +} + +func statementHasTimeField(tx *gorm.DB, fieldNames ...string) bool { + if tx == nil || tx.Statement == nil { + return false + } + + switch dest := tx.Statement.Dest.(type) { + case map[string]interface{}: + for _, fieldName := range fieldNames { + if _, ok := dest[fieldName]; ok { + return true + } + } + case []map[string]interface{}: + for _, item := range dest { + for _, fieldName := range fieldNames { + if _, ok := item[fieldName]; ok { + return true + } + } + } + } + + return false +} + +// BeforeCreate injects timestamps for models embedding BaseModel. +func (m *BaseModel) BeforeCreate(tx *gorm.DB) error { + timestamp, dateTime := autoModelTime() + + if m.CreateTime == nil { + m.CreateTime = ×tamp + } + if m.CreateDate == nil { + m.CreateDate = &dateTime + } + if m.UpdateTime == nil { + m.UpdateTime = ×tamp + } + if m.UpdateDate == nil { + m.UpdateDate = &dateTime + } + + if tx != nil && tx.Statement != nil { + if !statementHasTimeField(tx, "create_time", "CreateTime") && m.CreateTime != nil { + tx.Statement.SetColumn("CreateTime", *m.CreateTime) + } + if !statementHasTimeField(tx, "create_date", "CreateDate") && m.CreateDate != nil { + tx.Statement.SetColumn("CreateDate", *m.CreateDate) + } + if !statementHasTimeField(tx, "update_time", "UpdateTime") && m.UpdateTime != nil { + tx.Statement.SetColumn("UpdateTime", *m.UpdateTime) + } + if !statementHasTimeField(tx, "update_date", "UpdateDate") && m.UpdateDate != nil { + tx.Statement.SetColumn("UpdateDate", *m.UpdateDate) + } + } + return nil +} + +// BeforeUpdate injects update timestamps for models embedding BaseModel. +func (m *BaseModel) BeforeUpdate(tx *gorm.DB) error { + timestamp, dateTime := autoModelTime() + + if !statementHasTimeField(tx, "update_time", "UpdateTime") { + m.UpdateTime = ×tamp + } + if !statementHasTimeField(tx, "update_date", "UpdateDate") { + m.UpdateDate = &dateTime + } + + if tx != nil && tx.Statement != nil { + if !statementHasTimeField(tx, "update_time", "UpdateTime") && m.UpdateTime != nil { + tx.Statement.SetColumn("UpdateTime", *m.UpdateTime) + } + if !statementHasTimeField(tx, "update_date", "UpdateDate") && m.UpdateDate != nil { + tx.Statement.SetColumn("UpdateDate", *m.UpdateDate) + } + } + return nil +} + // JSONMap is a map type that can store JSON data type JSONMap map[string]interface{} diff --git a/internal/entity/document.go b/internal/entity/document.go index 36012196663..9b25ff46634 100644 --- a/internal/entity/document.go +++ b/internal/entity/document.go @@ -46,6 +46,39 @@ type Document struct { BaseModel } +// DocumentListItem represents a document list row with joined fields. +type DocumentListItem struct { + ID string `gorm:"column:id" json:"id"` + Thumbnail *string `gorm:"column:thumbnail" json:"thumbnail,omitempty"` + KbID string `gorm:"column:kb_id" json:"kb_id"` + ParserID string `gorm:"column:parser_id" json:"parser_id"` + PipelineID *string `gorm:"column:pipeline_id" json:"pipeline_id,omitempty"` + PipelineName *string `gorm:"column:pipeline_name" json:"pipeline_name,omitempty"` + ParserConfig string `gorm:"column:parser_config" json:"parser_config"` + SourceType string `gorm:"column:source_type" json:"source_type"` + Type string `gorm:"column:type" json:"type"` + CreatedBy string `gorm:"column:created_by" json:"created_by"` + Nickname *string `gorm:"column:nickname" json:"nickname,omitempty"` + Name *string `gorm:"column:name" json:"name,omitempty"` + Location *string `gorm:"column:location" json:"location,omitempty"` + Size int64 `gorm:"column:size" json:"size"` + TokenNum int64 `gorm:"column:token_num" json:"token_num"` + ChunkNum int64 `gorm:"column:chunk_num" json:"chunk_num"` + Progress float64 `gorm:"column:progress" json:"progress"` + ProgressMsg *string `gorm:"column:progress_msg" json:"progress_msg,omitempty"` + ProcessBeginAt *time.Time `gorm:"column:process_begin_at" json:"process_begin_at,omitempty"` + ProcessDuration float64 `gorm:"column:process_duration" json:"process_duration"` + ContentHash *string `gorm:"column:content_hash" json:"content_hash,omitempty"` + MetaFields *string `gorm:"column:meta_fields" json:"meta_fields,omitempty"` + Suffix string `gorm:"column:suffix" json:"suffix"` + Run *string `gorm:"column:run" json:"run,omitempty"` + Status *string `gorm:"column:status" json:"status,omitempty"` + CreateTime *int64 `gorm:"column:create_time" json:"create_time,omitempty"` + CreateDate *time.Time `gorm:"column:create_date" json:"create_date,omitempty"` + UpdateTime *int64 `gorm:"column:update_time" json:"update_time,omitempty"` + UpdateDate *time.Time `gorm:"column:update_date" json:"update_date,omitempty"` +} + // TableName specify table name func (Document) TableName() string { return "document" diff --git a/internal/entity/model.go b/internal/entity/model.go index 08a2958a5f4..a0b69f7020f 100644 --- a/internal/entity/model.go +++ b/internal/entity/model.go @@ -230,8 +230,10 @@ func NewProviderManager(dirPath string) (*ProviderManager, error) { // if the prefix of mode.Name is matched with keys of modelSupportThinking if provider.Class == "" { pos := strings.Index(model.Name, "-") - modelClass := model.Name[0:pos] - model.Class = &modelClass + if pos >= 0 { + modelClass := model.Name[0:pos] + model.Class = &modelClass + } } else { model.Class = &provider.Name } diff --git a/internal/entity/models/302ai.go b/internal/entity/models/302ai.go new file mode 100644 index 00000000000..a093681a685 --- /dev/null +++ b/internal/entity/models/302ai.go @@ -0,0 +1,937 @@ +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "ragflow/internal/common" + "strconv" + "strings" + "time" +) + +type AI302Model struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewAI302Model(baseURL map[string]string, urlSuffix URLSuffix) *AI302Model { + return &AI302Model{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (a *AI302Model) NewInstance(baseURL map[string]string) ModelDriver { + return &AI302Model{ + BaseURL: baseURL, + URLSuffix: a.URLSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + Proxy: http.ProxyFromEnvironment, + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (a *AI302Model) Name() string { + return "302ai" +} + +func (a *AI302Model) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + "temperature": 1, + } + + if chatModelConfig != nil { + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + + if chatModelConfig.Effort != nil { + reqBody["reasoning"] = map[string]interface{}{ + "effort": *chatModelConfig.Effort, + } + } + + if chatModelConfig.Thinking != nil { + if *chatModelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Accept", "application/json") + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + var reasonContent string + if chatModelConfig != nil && chatModelConfig.Thinking != nil && *chatModelConfig.Thinking { + reasonContent, ok = messageMap["reasoning_content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + // if first char of reasonContent is \n remove the '\n' + if reasonContent != "" && reasonContent[0] == '\n' { + reasonContent = reasonContent[1:] + } + } + + chatResponse := &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + } + + return chatResponse, nil +} + +func (a *AI302Model) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(a.BaseURL[region], "/"), a.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body with streaming enabled + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + "temperature": 1, + } + + if modelConfig != nil { + if modelConfig.Stream != nil { + reqBody["stream"] = *modelConfig.Stream + } + + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + + if modelConfig.DoSample != nil { + reqBody["do_sample"] = *modelConfig.DoSample + } + + if modelConfig.TopP != nil { + reqBody["top_p"] = *modelConfig.TopP + } + + if modelConfig.Stop != nil { + reqBody["stop"] = *modelConfig.Stop + } + + if modelConfig.Effort != nil { + reqBody["reasoning"] = map[string]interface{}{ + "effort": *modelConfig.Effort, + } + } + + if modelConfig.Thinking != nil { + if *modelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: read line by line + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + common.Info(line) + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + reasoningContent, ok := delta["reasoning_content"].(string) + if ok && reasoningContent != "" { + if err := sender(nil, &reasoningContent); err != nil { + return err + } + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + break + } + } + + // Send [DONE] marker for OpenAI compatibility + endOfStream := "[DONE]" + if err = sender(&endOfStream, nil); err != nil { + return err + } + + return scanner.Err() +} + +func (a *AI302Model) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Jina embedding API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var parsedResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + } + + if err = json.Unmarshal(body, &parsedResponse); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(parsedResponse.Data) == 0 { + return nil, fmt.Errorf("Jina embedding response contains no data: %s", string(body)) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsedResponse.Data { + embeddings = append(embeddings, EmbeddingData{ + Embedding: dataElem.Embedding, + Index: dataElem.Index, + }) + } + + return embeddings, nil +} + +func (a *AI302Model) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.Rerank) + + var topN = rerankConfig.TopN + if rerankConfig.TopN != 0 { + topN = rerankConfig.TopN + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "query": query, + "documents": documents, + "top_n": topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("302.ai Rerank API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var rerankResp struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` + } + + if err = json.Unmarshal(body, &rerankResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + var rerankResponse RerankResponse + for _, result := range rerankResp.Results { + rerankResult := RerankResult{ + Index: result.Index, + RelevanceScore: result.RelevanceScore, + } + rerankResponse.Data = append(rerankResponse.Data, rerankResult) + } + + return &rerankResponse, nil +} + +func (a *AI302Model) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + if file == nil || *file == "" { + return nil, fmt.Errorf("file is missing") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.ASR) + + // multipart body + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + // open audio file + audioFile, err := os.Open(*file) + if err != nil { + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + // create multipart file field + part, err := writer.CreateFormFile( + "file", + filepath.Base(*file), + ) + if err != nil { + return nil, fmt.Errorf("failed to create multipart file: %w", err) + } + + // copy file content + if _, err = io.Copy(part, audioFile); err != nil { + return nil, fmt.Errorf("failed to copy audio data: %w", err) + } + + // model field + if err := writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model field: %w", err) + } + + // extra params + if asrConfig != nil && asrConfig.Params != nil { + for key, value := range asrConfig.Params { + + var val string + + switch v := value.(type) { + case string: + val = v + case bool: + val = strconv.FormatBool(v) + case int: + val = strconv.Itoa(v) + case int64: + val = strconv.FormatInt(v, 10) + case float32: + val = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + val = strconv.FormatFloat(v, 'f', -1, 64) + default: + val = fmt.Sprintf("%v", v) + } + + if err = writer.WriteField(key, val); err != nil { + return nil, fmt.Errorf("failed to write field %s: %w", key, err) + } + } + } + + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + // build request + req, err := http.NewRequest("POST", url, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Accept", "application/json") + + // send request + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("302.ai ASR error: %s - %s", resp.Status, string(respBody)) + } + + // Response + var result struct { + Text string `json:"text"` + } + + if err = json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w, body=%s", err, string(respBody)) + } + + return &ASRResponse{Text: result.Text}, nil +} + +func (a *AI302Model) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s no such method", a.Name()) +} + +func (a *AI302Model) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + // TODO https://302ai-en.apifox.cn/225254060e0 + return nil, fmt.Errorf("%s no such method", a.Name()) +} + +func (a *AI302Model) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s no such method", a.Name()) +} + +func (a *AI302Model) OCRFile(modelName *string, content []byte, urls *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + if (urls == nil || *urls == "") && (content == nil || len(content) == 0) { + return nil, fmt.Errorf("file url or content is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.OCR) + + var docURL string + if urls != nil && *urls != "" { + docURL = *urls + } else { + mimeType := http.DetectContentType(content) + base64Str := base64.StdEncoding.EncodeToString(content) + docURL = fmt.Sprintf("data:%s;base64,%s", mimeType, base64Str) + } + + reqData := map[string]interface{}{ + "model": *modelName, + "document": map[string]interface{}{ + "type": "document_url", + "document_url": docURL, + }, + } + + jsonData, err := json.Marshal(reqData) + if err != nil { + return nil, fmt.Errorf("failed to marshal json payload: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Mistral OCR API error: %s, body: %s", resp.Status, string(body)) + } + + var mistralResp struct { + Pages []struct { + Index int `json:"index"` + Markdown string `json:"markdown"` + } `json:"pages"` + } + + if err = json.Unmarshal(body, &mistralResp); err != nil { + return nil, fmt.Errorf("failed to parse response json: %w", err) + } + + var fullMarkdown strings.Builder + for _, page := range mistralResp.Pages { + fullMarkdown.WriteString(page.Markdown) + fullMarkdown.WriteString("\n\n") + } + + resultText := strings.TrimSpace(fullMarkdown.String()) + + return &OCRFileResponse{ + Text: &resultText, + }, nil +} + +func (a *AI302Model) ParseFile(modelName *string, content []byte, documentURL *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + if documentURL == nil || *documentURL == "" { + return nil, fmt.Errorf("302.ai API requires a valid public document URL; direct file upload is not supported") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + apiURL := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.DocumentParse) + + reqBody := map[string]interface{}{ + "url": *documentURL, + } + + if modelName != nil && *modelName != "" { + reqBody["model_version"] = *modelName + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("302.ai API failed with status %d: %s", resp.StatusCode, string(body)) + } + + var taskResp mineruTaskSubmitResponse + if err := json.Unmarshal(body, &taskResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if taskResp.Code != 0 { + return nil, fmt.Errorf("MinerU task creation failed (code %d): %s", taskResp.Code, taskResp.Msg) + } + + return &ParseFileResponse{ + TaskID: taskResp.Data.TaskID, + }, nil +} + +func (a *AI302Model) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", a.BaseURL[region], a.URLSuffix.Models) + + reqBody := map[string]string{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // convert result["data"] to []map[string]interface{} + models := make([]string, 0) + for _, model := range result["data"].([]interface{}) { + modelMap := model.(map[string]interface{}) + modelName := modelMap["id"].(string) + models = append(models, modelName) + } + + return models, nil +} + +func (a *AI302Model) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s no such method", a.Name()) +} + +func (a *AI302Model) CheckConnection(apiConfig *APIConfig) error { + _, err := a.ListModels(apiConfig) + return err +} + +func (a *AI302Model) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s no such method", a.Name()) +} + +func (a *AI302Model) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + // URL: https://mineru.net/api/v4/extract/task/{task_id} + apiURL := fmt.Sprintf("%s/%s/%s", a.BaseURL[region], a.URLSuffix.DocumentParse, taskID) + + req, err := http.NewRequest("GET", apiURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("MinerU query API failed with status %d: %s", resp.StatusCode, string(body)) + } + + var queryResp mineruTaskQueryResponse + if err := json.Unmarshal(body, &queryResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if queryResp.Code != 0 { + return nil, fmt.Errorf("MinerU task query failed: %s", queryResp.Msg) + } + + // failed state + if queryResp.Data.State == "failed" { + return nil, fmt.Errorf("MinerU task failed: %s", queryResp.Data.ErrMsg) + } + + content := "" + if queryResp.Data.State == "done" { + content = queryResp.Data.FullZipURL + } else if queryResp.Data.State == "running" { + content = fmt.Sprintf("Task is running... Progress: %d / %d pages", + queryResp.Data.ExtractProgress.ExtractedPages, + queryResp.Data.ExtractProgress.TotalPages) + } else { + // queue or formating + content = fmt.Sprintf("Task state: %s", queryResp.Data.State) + } + + return &TaskResponse{ + Segments: []TaskSegment{ + { + Index: 0, + Content: content, + }, + }, + }, nil +} diff --git a/internal/entity/models/aliyun.go b/internal/entity/models/aliyun.go index a1ddd6dddb7..79a01b02f70 100644 --- a/internal/entity/models/aliyun.go +++ b/internal/entity/models/aliyun.go @@ -71,7 +71,12 @@ func (z *AliyunModel) ChatWithMessages(modelName string, messages []Message, api region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat) + baseURL, ok := z.BaseURL[region] + if !ok || baseURL == "" { + return nil, fmt.Errorf("aliyun: no base URL configured for region %q", region) + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Chat) // Convert messages to the format expected by API apiMessages := make([]map[string]interface{}, len(messages)) @@ -207,7 +212,12 @@ func (z *AliyunModel) ChatStreamlyWithSender(modelName string, messages []Messag region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat) + baseURL, ok := z.BaseURL[region] + if !ok || baseURL == "" { + return fmt.Errorf("aliyun: no base URL configured for region %q", region) + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Chat) // Convert messages to API format apiMessages := make([]map[string]interface{}, len(messages)) @@ -352,16 +362,28 @@ func (z *AliyunModel) ChatStreamlyWithSender(modelName string, messages []Messag } type aliyunEmbeddingResponse struct { - Data []struct { - Index int `json:"index"` - Embedding []interface{} `json:"embedding"` - } `json:"data"` + Data []EmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` + Usage aliyunUsage `json:"usage"` + ID string `json:"id"` +} + +type aliyunEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + Object string `json:"object"` +} + +type aliyunUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` } -// Encode encodes a list of texts into embeddings -func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Embed embeds a list of texts into embeddings +func (z *AliyunModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { if len(texts) == 0 { - return [][]float64{}, nil + return []EmbeddingData{}, nil } if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { @@ -430,29 +452,12 @@ func (z *AliyunModel) Encode(modelName *string, texts []string, apiConfig *APICo return nil, fmt.Errorf("failed to parse response: %w", err) } - embeddings := make([][]float64, len(texts)) - for _, item := range parsed.Data { - if item.Index < 0 || item.Index >= len(texts) { - return nil, fmt.Errorf("unexpected embedding index %d for %d inputs", item.Index, len(texts)) - } - vec := make([]float64, len(item.Embedding)) - for j, v := range item.Embedding { - switch val := v.(type) { - case float64: - vec[j] = val - case float32: - vec[j] = float64(val) - default: - return nil, fmt.Errorf("unexpected embedding value type at item %d index %d", item.Index, j) - } - } - embeddings[item.Index] = vec - } - - for i, vec := range embeddings { - if vec == nil { - return nil, fmt.Errorf("missing embedding for input at index %d", i) - } + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) } return embeddings, nil @@ -550,6 +555,34 @@ func (z *AliyunModel) Rerank(modelName *string, query string, documents []string return &rerankResponse, nil } +// TranscribeAudio transcribe audio +func (z *AliyunModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AliyunModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (z *AliyunModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AliyunModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (z *AliyunModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +// ParseFile parse file +func (z *AliyunModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + type AliyunModelItem struct { ModelName string `json:"model_name"` BaseCapacity int `json:"base_capacity"` @@ -573,7 +606,12 @@ func (z *AliyunModel) ListModels(apiConfig *APIConfig) ([]string, error) { region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Models) + baseURL, ok := z.BaseURL[region] + if !ok || baseURL == "" { + return nil, fmt.Errorf("aliyun: no base URL configured for region %q", region) + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Models) // Build request body reqBody := map[string]interface{}{} @@ -632,3 +670,11 @@ func (z *AliyunModel) CheckConnection(apiConfig *APIConfig) error { } return nil } + +func (z *AliyunModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *AliyunModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/anthropic.go b/internal/entity/models/anthropic.go new file mode 100644 index 00000000000..8777712f7d9 --- /dev/null +++ b/internal/entity/models/anthropic.go @@ -0,0 +1,491 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +const anthropicVersion = "2023-06-01" + +// AnthropicModel implements ModelDriver for Claude models through the +// Anthropic Messages API. +type AnthropicModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewAnthropicModel(baseURL map[string]string, urlSuffix URLSuffix) *AnthropicModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.ResponseHeaderTimeout = 60 * time.Second + + return &AnthropicModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (a *AnthropicModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewAnthropicModel(baseURL, a.URLSuffix) +} + +func (a *AnthropicModel) Name() string { + return "anthropic" +} + +func (a *AnthropicModel) baseURLForRegion(region string) (string, error) { + base, ok := a.BaseURL[region] + if !ok || strings.TrimSpace(base) == "" { + return "", fmt.Errorf("anthropic: no base URL configured for region %q", region) + } + return strings.TrimRight(base, "/"), nil +} + +func (a *AnthropicModel) region(apiConfig *APIConfig) string { + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + return *apiConfig.Region + } + return "default" +} + +func (a *AnthropicModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + apiKey, err := anthropicAPIKey(apiConfig) + if err != nil { + return nil, err + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + apiMessages, systemPrompt, err := anthropicMessages(messages) + if err != nil { + return nil, err + } + + baseURL, err := a.baseURLForRegion(a.region(apiConfig)) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, strings.TrimLeft(a.URLSuffix.Chat, "/")) + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "max_tokens": 1024, + } + if systemPrompt != "" { + reqBody["system"] = systemPrompt + } + applyAnthropicChatConfig(reqBody, chatModelConfig) + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + setAnthropicHeaders(req, apiKey) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Anthropic messages API error: %s, body: %s", resp.Status, string(body)) + } + + answer, reasoning, err := parseAnthropicChatResponse(body) + if err != nil { + return nil, err + } + return &ChatResponse{ + Answer: &answer, + ReasonContent: &reasoning, + }, nil +} + +func anthropicAPIKey(apiConfig *APIConfig) (string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" { + return "", fmt.Errorf("api key is required") + } + return strings.TrimSpace(*apiConfig.ApiKey), nil +} + +func applyAnthropicChatConfig(reqBody map[string]interface{}, chatModelConfig *ChatConfig) { + if chatModelConfig == nil { + return + } + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop_sequences"] = *chatModelConfig.Stop + } +} + +func setAnthropicHeaders(req *http.Request, apiKey string) { + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", anthropicVersion) +} + +func anthropicMessages(messages []Message) ([]map[string]interface{}, string, error) { + apiMessages := make([]map[string]interface{}, 0, len(messages)) + systemPrompts := make([]string, 0) + for _, msg := range messages { + role := strings.ToLower(strings.TrimSpace(msg.Role)) + content, err := anthropicContent(msg.Content) + if err != nil { + return nil, "", err + } + switch role { + case "system": + if text, ok := anthropicSystemText(content); ok && text != "" { + systemPrompts = append(systemPrompts, text) + } + case "user", "assistant": + apiMessages = append(apiMessages, map[string]interface{}{ + "role": role, + "content": content, + }) + default: + return nil, "", fmt.Errorf("anthropic: unsupported message role %q", msg.Role) + } + } + if len(apiMessages) == 0 { + return nil, "", fmt.Errorf("messages is empty") + } + return apiMessages, strings.Join(systemPrompts, "\n\n"), nil +} + +func anthropicSystemText(content interface{}) (string, bool) { + switch value := content.(type) { + case string: + return value, true + case []map[string]interface{}: + parts := make([]string, 0, len(value)) + for _, block := range value { + if block["type"] == "text" { + if text, ok := block["text"].(string); ok { + parts = append(parts, text) + } + } + } + return strings.Join(parts, "\n"), true + default: + return "", false + } +} + +func anthropicContent(content interface{}) (interface{}, error) { + switch value := content.(type) { + case string: + return value, nil + case []interface{}: + return anthropicContentBlocks(value) + case []map[string]interface{}: + blocks := make([]interface{}, 0, len(value)) + for _, block := range value { + blocks = append(blocks, block) + } + return anthropicContentBlocks(blocks) + default: + return nil, fmt.Errorf("anthropic: unsupported message content type %T", content) + } +} + +func anthropicContentBlocks(blocks []interface{}) ([]map[string]interface{}, error) { + apiBlocks := make([]map[string]interface{}, 0, len(blocks)) + for _, item := range blocks { + block, ok := item.(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("anthropic: invalid content block %T", item) + } + converted, err := anthropicContentBlock(block) + if err != nil { + return nil, err + } + apiBlocks = append(apiBlocks, converted) + } + return apiBlocks, nil +} + +func anthropicContentBlock(block map[string]interface{}) (map[string]interface{}, error) { + blockType, _ := block["type"].(string) + switch blockType { + case "text": + text, ok := block["text"].(string) + if !ok { + return nil, fmt.Errorf("anthropic: text block missing or invalid text field %T", block["text"]) + } + return map[string]interface{}{"type": "text", "text": text}, nil + case "image": + return validateAnthropicImageBlock(block) + case "image_url": + return anthropicImageURLBlock(block) + default: + return nil, fmt.Errorf("anthropic: unsupported content block type %q", blockType) + } +} + +func validateAnthropicImageBlock(block map[string]interface{}) (map[string]interface{}, error) { + source, ok := block["source"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("anthropic: image block missing source object") + } + sourceType, ok := source["type"].(string) + if !ok || sourceType == "" { + return nil, fmt.Errorf("anthropic: image source missing type") + } + switch sourceType { + case "url": + if url, ok := source["url"].(string); !ok || url == "" { + return nil, fmt.Errorf("anthropic: image url source missing url") + } + case "base64": + mediaType, ok := source["media_type"].(string) + if !ok || mediaType == "" { + return nil, fmt.Errorf("anthropic: image base64 source missing media_type") + } + data, ok := source["data"].(string) + if !ok || data == "" { + return nil, fmt.Errorf("anthropic: image base64 source missing data") + } + if _, err := base64.StdEncoding.DecodeString(data); err != nil { + return nil, fmt.Errorf("anthropic: invalid base64 image data: %w", err) + } + default: + return nil, fmt.Errorf("anthropic: unsupported image source type %q", sourceType) + } + return block, nil +} + +func anthropicImageURLBlock(block map[string]interface{}) (map[string]interface{}, error) { + imageURL, ok := block["image_url"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("anthropic: image_url block missing image_url object") + } + url, _ := imageURL["url"].(string) + if url == "" { + return nil, fmt.Errorf("anthropic: image_url block missing url") + } + source := map[string]interface{}{ + "type": "url", + "url": url, + } + if strings.HasPrefix(url, "data:") { + mediaType, data, err := parseDataImageURL(url) + if err != nil { + return nil, err + } + source = map[string]interface{}{ + "type": "base64", + "media_type": mediaType, + "data": data, + } + } + return map[string]interface{}{ + "type": "image", + "source": source, + }, nil +} + +func parseDataImageURL(url string) (string, string, error) { + const marker = ";base64," + if !strings.HasPrefix(url, "data:") || !strings.Contains(url, marker) { + return "", "", fmt.Errorf("anthropic: unsupported data image url") + } + trimmed := strings.TrimPrefix(url, "data:") + parts := strings.SplitN(trimmed, marker, 2) + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "", fmt.Errorf("anthropic: invalid data image url") + } + if _, err := base64.StdEncoding.DecodeString(parts[1]); err != nil { + return "", "", fmt.Errorf("anthropic: invalid base64 image data: %w", err) + } + return parts[0], parts[1], nil +} + +func parseAnthropicChatResponse(body []byte) (string, string, error) { + var result struct { + Content []struct { + Type string `json:"type"` + Text string `json:"text"` + Thinking string `json:"thinking"` + } `json:"content"` + } + if err := json.Unmarshal(body, &result); err != nil { + return "", "", fmt.Errorf("failed to parse response: %w", err) + } + if len(result.Content) == 0 { + return "", "", fmt.Errorf("no content in Anthropic response") + } + + var answer strings.Builder + var reasoning strings.Builder + for _, block := range result.Content { + switch block.Type { + case "text": + answer.WriteString(block.Text) + case "thinking": + reasoning.WriteString(block.Thinking) + } + } + if answer.Len() == 0 { + return "", "", fmt.Errorf("no text content in Anthropic response") + } + return answer.String(), reasoning.String(), nil +} + +func (a *AnthropicModel) ListModels(apiConfig *APIConfig) ([]string, error) { + apiKey, err := anthropicAPIKey(apiConfig) + if err != nil { + return nil, err + } + + baseURL, err := a.baseURLForRegion(a.region(apiConfig)) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, strings.TrimLeft(a.URLSuffix.Models, "/")) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + setAnthropicHeaders(req, apiKey) + + resp, err := a.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Anthropic models API error: %s, body: %s", resp.Status, string(body)) + } + + var result struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` + } + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + models := make([]string, 0, len(result.Data)) + for _, item := range result.Data { + if item.ID != "" { + models = append(models, item.ID) + } + } + return models, nil +} + +func (a *AnthropicModel) CheckConnection(apiConfig *APIConfig) error { + _, err := a.ListModels(apiConfig) + return err +} + +func (a *AnthropicModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", a.Name()) +} + +func (a *AnthropicModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", a.Name()) +} diff --git a/internal/entity/models/anthropic_test.go b/internal/entity/models/anthropic_test.go new file mode 100644 index 00000000000..e2ad0768d23 --- /dev/null +++ b/internal/entity/models/anthropic_test.go @@ -0,0 +1,383 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newAnthropicServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("x-api-key"); got != "test-key" { + t.Errorf("expected x-api-key=test-key, got %q", got) + return + } + if got := r.Header.Get("anthropic-version"); got != anthropicVersion { + t.Errorf("expected anthropic-version=%s, got %q", anthropicVersion, got) + return + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("expected Content-Type to start with application/json, got %q", got) + return + } + if r.Method == http.MethodPost { + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v\nraw=%s", err, string(raw)) + return + } + handler(t, body, w) + return + } + handler(t, nil, w) + })) +} + +func newAnthropicForTest(baseURL string) *AnthropicModel { + return NewAnthropicModel( + map[string]string{"default": baseURL}, + URLSuffix{Chat: "v1/messages", Models: "v1/models"}, + ) +} + +func TestAnthropicName(t *testing.T) { + if got := newAnthropicForTest("http://unused").Name(); got != "anthropic" { + t.Errorf("Name()=%q, want anthropic", got) + } +} + +func TestAnthropicChatHappyPath(t *testing.T) { + srv := newAnthropicServer(t, "/v1/messages", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "claude-sonnet-4-5-20250929" { + t.Errorf("model=%v", body["model"]) + } + if body["max_tokens"] != float64(1024) { + t.Errorf("max_tokens=%v want 1024", body["max_tokens"]) + } + msgs, ok := body["messages"].([]interface{}) + if !ok || len(msgs) != 1 { + t.Errorf("messages=%v, want one message", body["messages"]) + return + } + msg, ok := msgs[0].(map[string]interface{}) + if !ok || msg["role"] != "user" || msg["content"] != "ping" { + t.Errorf("message=%v, want user ping", msgs[0]) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "content": []map[string]interface{}{ + {"type": "thinking", "thinking": "reasoning"}, + {"type": "text", "text": "pong"}, + }, + }) + }) + defer srv.Close() + + apiKey := "test-key" + resp, err := newAnthropicForTest(srv.URL).ChatWithMessages( + "claude-sonnet-4-5-20250929", + []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + nil, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "pong" { + t.Errorf("answer=%v, want pong", resp.Answer) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "reasoning" { + t.Errorf("reason=%v, want reasoning", resp.ReasonContent) + } +} + +func TestAnthropicChatMapsSystemConfigAndImages(t *testing.T) { + srv := newAnthropicServer(t, "/v1/messages", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["system"] != "be concise" { + t.Errorf("system=%v, want be concise", body["system"]) + } + if body["max_tokens"] != float64(64) { + t.Errorf("max_tokens=%v want 64", body["max_tokens"]) + } + if body["temperature"] != 0.25 { + t.Errorf("temperature=%v want 0.25", body["temperature"]) + } + if body["top_p"] != 0.8 { + t.Errorf("top_p=%v want 0.8", body["top_p"]) + } + stop, ok := body["stop_sequences"].([]interface{}) + if !ok || len(stop) != 1 || stop[0] != "END" { + t.Errorf("stop_sequences=%v want [END]", body["stop_sequences"]) + } + msgs, ok := body["messages"].([]interface{}) + if !ok || len(msgs) == 0 { + t.Errorf("messages=%v, want non-empty array", body["messages"]) + return + } + first, ok := msgs[0].(map[string]interface{}) + if !ok { + t.Errorf("first message=%v, want object", msgs[0]) + return + } + content, ok := first["content"].([]interface{}) + if !ok || len(content) < 2 { + t.Errorf("content=%v, want at least 2 blocks", first["content"]) + return + } + image, ok := content[1].(map[string]interface{}) + if !ok { + t.Errorf("image block=%v, want object", content[1]) + return + } + source, ok := image["source"].(map[string]interface{}) + if !ok { + t.Errorf("image source=%v, want object", image["source"]) + return + } + if image["type"] != "image" || source["type"] != "url" || source["url"] != "https://example.com/cat.png" { + t.Errorf("image block=%v", image) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "content": []map[string]interface{}{{"type": "text", "text": "ok"}}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + maxTokens := 64 + temperature := 0.25 + topP := 0.8 + stop := []string{"END"} + _, err := newAnthropicForTest(srv.URL).ChatWithMessages( + "claude-opus-4-5-20251101", + []Message{ + {Role: "system", Content: "be concise"}, + {Role: "user", Content: []interface{}{ + map[string]interface{}{"type": "text", "text": "what is this?"}, + map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{"url": "https://example.com/cat.png"}}, + }}, + }, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &maxTokens, Temperature: &temperature, TopP: &topP, Stop: &stop}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } +} + +func TestAnthropicChatMapsDataImageURL(t *testing.T) { + srv := newAnthropicServer(t, "/v1/messages", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + msgs, ok := body["messages"].([]interface{}) + if !ok || len(msgs) == 0 { + t.Errorf("messages=%v, want non-empty array", body["messages"]) + return + } + first, ok := msgs[0].(map[string]interface{}) + if !ok { + t.Errorf("first message=%v, want object", msgs[0]) + return + } + content, ok := first["content"].([]interface{}) + if !ok || len(content) == 0 { + t.Errorf("content=%v, want non-empty array", first["content"]) + return + } + image, ok := content[0].(map[string]interface{}) + if !ok { + t.Errorf("image block=%v, want object", content[0]) + return + } + source, ok := image["source"].(map[string]interface{}) + if !ok { + t.Errorf("source=%v, want object", image["source"]) + return + } + if source["type"] != "base64" || source["media_type"] != "image/png" || source["data"] != "aGVsbG8=" { + t.Errorf("source=%v", source) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "content": []map[string]interface{}{{"type": "text", "text": "ok"}}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + _, err := newAnthropicForTest(srv.URL).ChatWithMessages( + "claude-sonnet-4-5-20250929", + []Message{{Role: "user", Content: []interface{}{ + map[string]interface{}{"type": "image_url", "image_url": map[string]interface{}{"url": "data:image/png;base64,aGVsbG8="}}, + }}}, + &APIConfig{ApiKey: &apiKey}, + nil, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } +} + +func TestAnthropicChatValidationErrors(t *testing.T) { + m := newAnthropicForTest("http://unused") + apiKey := "test-key" + if _, err := m.ChatWithMessages("claude", []Message{{Role: "user", Content: "x"}}, nil, nil); err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("nil api config: got %v", err) + } + if _, err := m.ChatWithMessages("claude", nil, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "messages is empty") { + t.Errorf("empty messages: got %v", err) + } + if _, err := m.ChatWithMessages("claude", []Message{{Role: "tool", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "unsupported message role") { + t.Errorf("bad role: got %v", err) + } + if _, err := m.ChatWithMessages("claude", []Message{{Role: "user", Content: []interface{}{map[string]interface{}{"type": "video_url"}}}}, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "unsupported content block type") { + t.Errorf("bad block: got %v", err) + } + if _, err := m.ChatWithMessages("claude", []Message{{Role: "user", Content: []interface{}{map[string]interface{}{"type": "text", "text": 42}}}}, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "invalid text field") { + t.Errorf("bad text block: got %v", err) + } + if _, err := m.ChatWithMessages("claude", []Message{{Role: "user", Content: []interface{}{map[string]interface{}{"type": "image"}}}}, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "image block missing source") { + t.Errorf("bad image block: got %v", err) + } +} + +func TestAnthropicChatRejectsHTTPError(t *testing.T) { + srv := newAnthropicServer(t, "/v1/messages", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":{"message":"bad key"}}`)) + }) + defer srv.Close() + + apiKey := "test-key" + _, err := newAnthropicForTest(srv.URL).ChatWithMessages("claude", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "401") || !strings.Contains(err.Error(), "bad key") { + t.Errorf("expected provider error, got %v", err) + } +} + +func TestAnthropicChatRejectsMalformedResponse(t *testing.T) { + srv := newAnthropicServer(t, "/v1/messages", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"content": []map[string]interface{}{{"type": "tool_use"}}}) + }) + defer srv.Close() + + apiKey := "test-key" + _, err := newAnthropicForTest(srv.URL).ChatWithMessages("claude", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "no text content") { + t.Errorf("expected no-text error, got %v", err) + } +} + +func TestAnthropicListModelsAndCheckConnection(t *testing.T) { + var calls int + srv := newAnthropicServer(t, "/v1/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + calls++ + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"id": "claude-sonnet-4-5-20250929"}, + {"id": "claude-haiku-4-5-20251001"}, + }, + }) + }) + defer srv.Close() + + apiKey := "test-key" + m := newAnthropicForTest(srv.URL) + models, err := m.ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if strings.Join(models, ",") != "claude-sonnet-4-5-20250929,claude-haiku-4-5-20251001" { + t.Errorf("models=%v", models) + } + if err := m.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil { + t.Errorf("CheckConnection: %v", err) + } + if calls != 2 { + t.Errorf("calls=%d, want 2", calls) + } +} + +func TestAnthropicListModelsRejectsProviderError(t *testing.T) { + srv := newAnthropicServer(t, "/v1/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusForbidden) + _, _ = w.Write([]byte(`{"error":"forbidden"}`)) + }) + defer srv.Close() + + apiKey := "test-key" + _, err := newAnthropicForTest(srv.URL).ListModels(&APIConfig{ApiKey: &apiKey}) + if err == nil || !strings.Contains(err.Error(), "403") { + t.Errorf("expected 403 error, got %v", err) + } +} + +func TestAnthropicFactoryRegistration(t *testing.T) { + driver, err := NewModelFactory().CreateModelDriver("Anthropic", map[string]string{"default": "http://unused"}, URLSuffix{}) + if err != nil { + t.Fatalf("CreateModelDriver: %v", err) + } + if _, ok := driver.(*AnthropicModel); !ok { + t.Fatalf("driver type=%T, want *AnthropicModel", driver) + } +} + +func TestAnthropicUnsupportedMethods(t *testing.T) { + m := newAnthropicForTest("http://unused") + apiKey := "test-key" + modelName := "claude" + checks := []struct { + name string + err error + }{ + {"stream", m.ChatStreamlyWithSender(modelName, []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil, func(*string, *string) error { return nil })}, + } + for _, check := range checks { + if check.err == nil || !strings.Contains(check.err.Error(), "no such method") { + t.Errorf("%s: want no such method, got %v", check.name, check.err) + } + } + if _, err := m.Embed(&modelName, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Embed: got %v", err) + } + if _, err := m.Rerank(&modelName, "q", []string{"d"}, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank: got %v", err) + } + if _, err := m.Balance(&APIConfig{ApiKey: &apiKey}); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Balance: got %v", err) + } + if _, err := m.TranscribeAudio(&modelName, &modelName, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("TranscribeAudio: got %v", err) + } + if err := m.TranscribeAudioWithSender(&modelName, &modelName, &APIConfig{ApiKey: &apiKey}, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("TranscribeAudioWithSender: got %v", err) + } + if _, err := m.AudioSpeech(&modelName, &modelName, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("AudioSpeech: got %v", err) + } + if err := m.AudioSpeechWithSender(&modelName, &modelName, &APIConfig{ApiKey: &apiKey}, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("AudioSpeechWithSender: got %v", err) + } + if _, err := m.OCRFile(&modelName, nil, &modelName, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("OCRFile: got %v", err) + } + if _, err := m.ParseFile(&modelName, nil, &modelName, &APIConfig{ApiKey: &apiKey}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("ParseFile: got %v", err) + } + if _, err := m.ListTasks(&APIConfig{ApiKey: &apiKey}); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("ListTasks: got %v", err) + } + if _, err := m.ShowTask("task-id", &APIConfig{ApiKey: &apiKey}); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("ShowTask: got %v", err) + } +} diff --git a/internal/entity/models/baichuan.go b/internal/entity/models/baichuan.go new file mode 100644 index 00000000000..d139b8a672a --- /dev/null +++ b/internal/entity/models/baichuan.go @@ -0,0 +1,429 @@ +package models + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "ragflow/internal/common" + "strings" + "time" +) + +// sk-6e16f0a6bfaa7fc58e30a50962665d1d +type BaichuanModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewBaichuanModel(baseURL map[string]string, urlSuffix URLSuffix) *BaichuanModel { + return &BaichuanModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + }, + } +} + +func (b *BaichuanModel) NewInstance(baseURL map[string]string) ModelDriver { + return &BaichuanModel{ + BaseURL: baseURL, + URLSuffix: b.URLSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + }, + } +} + +func (b *BaichuanModel) Name() string { + return "baichuan" +} + +func (b *BaichuanModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is nil or empty") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", b.BaseURL[region], b.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + "temperature": 1, + } + + if chatModelConfig != nil { + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Add("Content-Type", "application/json") + req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := b.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to send request: %d %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("no choices in response") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("no message in response") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("no message in response") + } + + // baichuan not support think + emptyReason := "" + chatResponse := &ChatResponse{ + Answer: &content, + ReasonContent: &emptyReason, + } + + return chatResponse, nil +} + +func (b *BaichuanModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", b.BaseURL[region], b.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + "temperature": 1, + } + + if modelConfig != nil { + if modelConfig.Stream != nil { + reqBody["stream"] = *modelConfig.Stream + } + + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + + if modelConfig.TopP != nil { + reqBody["top_p"] = *modelConfig.TopP + } + + if modelConfig.Stop != nil { + reqBody["stop"] = *modelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := b.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("invalid status code: %d, body: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: read line by line + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + common.Info(line) + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + break + } + } + + // Send [DONE] marker for OpenAI compatibility + endOfStream := "[DONE]" + if err = sender(&endOfStream, nil); err != nil { + return err + } + + return scanner.Err() +} + +func (b *BaichuanModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", b.BaseURL[region], b.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) + + resp, err := b.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Baichuan embedding API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var parsedResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + } + + if err = json.Unmarshal(body, &parsedResponse); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(parsedResponse.Data) == 0 { + return nil, fmt.Errorf("Baichuan embedding response contains no data: %s", string(body)) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsedResponse.Data { + embeddings = append(embeddings, EmbeddingData{ + Embedding: dataElem.Embedding, + Index: dataElem.Index, + }) + } + + return embeddings, nil +} + +func (b *BaichuanModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method") +} + +// TranscribeAudio transcribe audio +func (z *BaichuanModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *BaichuanModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (z *BaichuanModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *BaichuanModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (z *BaichuanModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +// ParseFile parse file +func (z *BaichuanModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (b *BaichuanModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("no such method") +} + +func (b *BaichuanModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +func (b *BaichuanModel) CheckConnection(apiConfig *APIConfig) error { + return fmt.Errorf("no such method") +} + +func (z *BaichuanModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *BaichuanModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/baidu.go b/internal/entity/models/baidu.go new file mode 100644 index 00000000000..b2bcf84393c --- /dev/null +++ b/internal/entity/models/baidu.go @@ -0,0 +1,795 @@ +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "ragflow/internal/common" + "strings" + "time" +) + +type BaiduModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func (b *BaiduModel) NewInstance(baseURL map[string]string) ModelDriver { + return &BaiduModel{ + BaseURL: baseURL, + URLSuffix: b.URLSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + }, + } +} + +func NewBaiduModel(baseURL map[string]string, urlSuffix URLSuffix) *BaiduModel { + return &BaiduModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + Transport: &http.Transport{ + MaxConnsPerHost: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + }, + } +} + +func (b *BaiduModel) Name() string { + return "baidu" +} + +func (b *BaiduModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is nil or empty") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", b.BaseURL[region], b.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + "temperature": 1, + } + + if chatModelConfig != nil { + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + + if chatModelConfig.Thinking != nil { + lowerModelName := strings.ToLower(modelName) + + // `enable_think` for qwen and erine + if strings.HasPrefix(lowerModelName, "qwen") || strings.HasPrefix(lowerModelName, "ernie") { + reqBody["enable_thinking"] = *chatModelConfig.Thinking + } else { + if *chatModelConfig.Thinking { + thinkingFlag := "enabled" + + if strings.Contains(lowerModelName, "deepseek-v4") { + effort := "high" + if chatModelConfig.Effort != nil { + effort = *chatModelConfig.Effort + } + switch effort { + case "none", "low", "medium": + thinkingFlag = "disabled" + case "high", "default": + thinkingFlag = "enabled" + reqBody["reasoning_effort"] = "high" + case "max": + thinkingFlag = "enabled" + reqBody["reasoning_effort"] = "max" + default: + thinkingFlag = "enabled" + reqBody["reasoning_effort"] = effort + } + } + + reqBody["thinking"] = map[string]interface{}{ + "type": thinkingFlag, + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := b.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + var reasonContent string + if chatModelConfig != nil && chatModelConfig.Thinking != nil && *chatModelConfig.Thinking { + reasonContent, ok = messageMap["reasoning_content"].(string) + if !ok { + return nil, fmt.Errorf("invalid reasoning content format") + } + // if first char of reasonContent is \n remove the '\n' + if reasonContent != "" && reasonContent[0] == '\n' { + reasonContent = reasonContent[1:] + } + } + + chatResponse := &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + } + + return chatResponse, nil +} + +func (b *BaiduModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(b.BaseURL[region], "/"), b.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body with streaming enabled + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + "temperature": 1, + } + + if modelConfig != nil { + if modelConfig.Stream != nil { + reqBody["stream"] = *modelConfig.Stream + } + + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + + if modelConfig.DoSample != nil { + reqBody["do_sample"] = *modelConfig.DoSample + } + + if modelConfig.TopP != nil { + reqBody["top_p"] = *modelConfig.TopP + } + + if modelConfig.Stop != nil { + reqBody["stop"] = *modelConfig.Stop + } + + if modelConfig.Thinking != nil { + lowerModelName := strings.ToLower(modelName) + + // `enable_think` for qwen and erine + if strings.HasPrefix(lowerModelName, "qwen") || strings.HasPrefix(lowerModelName, "ernie") { + reqBody["enable_thinking"] = *modelConfig.Thinking + } else { + if *modelConfig.Thinking { + thinkingFlag := "enabled" + + if strings.Contains(lowerModelName, "deepseek-v4") { + effort := "high" + if modelConfig.Effort != nil { + effort = *modelConfig.Effort + } + switch effort { + case "none", "low", "medium": + thinkingFlag = "disabled" + case "high", "default": + thinkingFlag = "enabled" + reqBody["reasoning_effort"] = "high" + case "max": + thinkingFlag = "enabled" + reqBody["reasoning_effort"] = "max" + default: + thinkingFlag = "enabled" + reqBody["reasoning_effort"] = effort + } + } + + reqBody["thinking"] = map[string]interface{}{ + "type": thinkingFlag, + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := b.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: read line by line + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + common.Info(line) + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + reasoningContent, ok := delta["reasoning_content"].(string) + if ok && reasoningContent != "" { + if err = sender(nil, &reasoningContent); err != nil { + return err + } + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err = sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + break + } + } + + // Send [DONE] marker for OpenAI compatibility + endOfStream := "[DONE]" + if err = sender(&endOfStream, nil); err != nil { + return err + } + + return scanner.Err() +} + +type baiduEmbeddingResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Data []baiduEmbeddingData `json:"data"` + Model string `json:"model"` + Usage baiduUsage `json:"usage"` +} + +type baiduEmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index *int `json:"index"` +} + +type baiduUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + +func (b *BaiduModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", b.BaseURL[region], b.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := b.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Baidu embedding API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var parsed baiduEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if len(parsed.Data) != len(texts) { + return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(parsed.Data)) + } + + embeddings := make([]EmbeddingData, len(texts)) + seen := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index == nil { + return nil, fmt.Errorf("missing index field in embedding item") + } + idx := *item.Index + if idx < 0 || idx >= len(texts) { + return nil, fmt.Errorf("embedding index %d out of range", idx) + } + if seen[idx] { + return nil, fmt.Errorf("duplicate embedding index %d", idx) + } + if len(item.Embedding) == 0 { + return nil, fmt.Errorf("empty embedding at index %d", idx) + } + seen[idx] = true + embeddings[idx] = EmbeddingData{ + Embedding: item.Embedding, + Index: idx, + } + } + + for i, ok := range seen { + if !ok { + return nil, fmt.Errorf("missing embedding index %d", i) + } + } + + return embeddings, nil +} + +func (b *BaiduModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(b.BaseURL[region], "/"), b.URLSuffix.Rerank) + + var topN = rerankConfig.TopN + if rerankConfig.TopN == 0 { + topN = len(documents) + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "query": query, + "documents": documents, + "top_n": topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := b.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Baidu rerank API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var rerankResp struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` + } + + if err = json.Unmarshal(body, &rerankResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + var rerankResponse RerankResponse + for _, result := range rerankResp.Results { + rerankResult := RerankResult{ + Index: result.Index, + RelevanceScore: result.RelevanceScore, + } + rerankResponse.Data = append(rerankResponse.Data, rerankResult) + } + + return &rerankResponse, nil +} + +// TranscribeAudio transcribe audio +func (b *BaiduModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", b.Name()) +} + +func (z *BaiduModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (b *BaiduModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", b.Name()) +} + +func (z *BaiduModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +type qianfanOCRResponse struct { + Id string `json:"id"` + Result struct { + LayoutParsingResults []struct { + Markdown struct { + Text string `json:"text"` + } `json:"markdown"` + } `json:"layoutParsingResults"` + } `json:"result"` +} + +func (b *BaiduModel) OCRFile(modelName *string, content []byte, fileURL *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + if (fileURL == nil || *fileURL == "") && (content == nil || len(content) == 0) { + return nil, fmt.Errorf("image url or content is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", b.BaseURL[region], b.URLSuffix.OCR) + + reqData := map[string]interface{}{ + "model": *modelName, + } + + if fileURL != nil && *fileURL != "" { + reqData["file"] = *fileURL + if strings.HasSuffix(strings.ToLower(*fileURL), ".pdf") { + reqData["fileType"] = 0 // PDF + } else { + reqData["fileType"] = 1 // img + } + } else if content != nil && len(content) > 0 { + reqData["file"] = base64.StdEncoding.EncodeToString(content) + + mimeType := http.DetectContentType(content) + if strings.Contains(mimeType, "pdf") { + reqData["fileType"] = 0 // PDF + } else { + reqData["fileType"] = 1 // img + } + } + + jsonData, err := json.Marshal(reqData) + if err != nil { + return nil, fmt.Errorf("failed to marshal json payload: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := b.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API error: %s, body: %s", resp.Status, string(body)) + } + + var apiResponse qianfanOCRResponse + if err = json.Unmarshal(body, &apiResponse); err != nil { + return nil, fmt.Errorf("failed to parse response json: %w", err) + } + + var extractedText string + if len(apiResponse.Result.LayoutParsingResults) > 0 { + extractedText = apiResponse.Result.LayoutParsingResults[0].Markdown.Text + } else { + return nil, fmt.Errorf("no parsing results returned from API") + } + + var ocrResponse = OCRFileResponse{ + Text: &extractedText, + } + + return &ocrResponse, nil +} + +func (b *BaiduModel) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", b.BaseURL[region], b.URLSuffix.Models) + + reqBody := map[string]string{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := b.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // convert result["data"] to []map[string]interface{} + models := make([]string, 0) + for _, model := range result["data"].([]interface{}) { + modelMap := model.(map[string]interface{}) + modelName := modelMap["id"].(string) + models = append(models, modelName) + } + + return models, nil +} + +func (b *BaiduModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf(b.Name() + "no such method") +} + +func (b *BaiduModel) CheckConnection(apiConfig *APIConfig) error { + _, err := b.ListModels(apiConfig) + return err +} + +func (z *BaiduModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("no such method", z.Name()) +} + +func (z *BaiduModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("no such method", z.Name()) +} + +func (z *BaiduModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("no such method", z.Name()) +} diff --git a/internal/entity/models/cohere.go b/internal/entity/models/cohere.go new file mode 100644 index 00000000000..1d16ece24dd --- /dev/null +++ b/internal/entity/models/cohere.go @@ -0,0 +1,687 @@ +package models + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +type CoHereModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func (c *CoHereModel) NewInstance(baseURL map[string]string) ModelDriver { + return &CoHereModel{ + BaseURL: baseURL, + URLSuffix: c.URLSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + +func NewCoHereModel(baseURL map[string]string, urlSuffix URLSuffix) *CoHereModel { + return &CoHereModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + +func (c *CoHereModel) Name() string { + return "cohere" +} + +func (c *CoHereModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is nil or empty") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", c.BaseURL[region], c.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + "temperature": 0.3, + } + + if chatModelConfig != nil { + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Thinking != nil { + if *chatModelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("content-Type", "application/json") + req.Header.Set("accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Cohere chat API error: %d %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + messageMap, ok := result["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("no message found in Cohere response: %s", string(body)) + } + + contentArray, ok := messageMap["content"].([]interface{}) + if !ok { + return nil, fmt.Errorf("content is not an array in Cohere response") + } + + var fullContent string + var reasonContent string + for _, cBlock := range contentArray { + cmap, ok := cBlock.(map[string]interface{}) + if !ok { + continue + } + if blockType, ok := cmap["type"].(string); ok && blockType == "thinking" { + if thinkingText, ok := cmap["thinking"].(string); ok { + reasonContent += thinkingText + } + } else if text, ok := cmap["text"].(string); ok { + fullContent += text + } + } + + chatResponse := &ChatResponse{ + Answer: &fullContent, + ReasonContent: &reasonContent, + } + + return chatResponse, nil +} + +func (c *CoHereModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", c.BaseURL[region], c.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + "temperature": 1, + } + + if modelConfig != nil { + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + if modelConfig.TopP != nil { + reqBody["p"] = *modelConfig.TopP + } + } + + if modelConfig != nil { + if modelConfig.Stream != nil { + reqBody["stream"] = *modelConfig.Stream + } + + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + + if modelConfig.TopP != nil { + reqBody["top_p"] = *modelConfig.TopP + } + + if modelConfig.Thinking != nil { + if *modelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("content-type", "application/json") + req.Header.Set("accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) + + resp, err := c.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("Cohere stream API error %d: %s", resp.StatusCode, string(body)) + } + + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + data := strings.TrimSpace(line) + + if strings.HasPrefix(data, "data:") { + data = strings.TrimSpace(data[5:]) + } + + if data == "" || data == "[DONE]" { + continue + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + eventType, ok := event["type"].(string) + if !ok { + continue + } + + if eventType == "message-end" { + break + } + + if eventType == "content-delta" { + delta, ok := event["delta"].(map[string]interface{}) + if !ok { + continue + } + msg, ok := delta["message"].(map[string]interface{}) + if !ok { + continue + } + content, ok := msg["content"].(map[string]interface{}) + if !ok { + continue + } + + if thinking, ok := content["thinking"].(string); ok && thinking != "" { + if err := sender(nil, &thinking); err != nil { + return err + } + } + + if text, ok := content["text"].(string); ok && text != "" { + if err := sender(&text, nil); err != nil { + return err + } + } + } + } + + endOfStream := "[DONE]" + if err = sender(&endOfStream, nil); err != nil { + return err + } + + return scanner.Err() +} + +func (c *CoHereModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := strings.TrimSuffix(c.BaseURL[region], "/") + suffix := strings.TrimPrefix(c.URLSuffix.Embedding, "/") + url := fmt.Sprintf("%s/%s", baseURL, suffix) + + reqBody := map[string]interface{}{ + "model": *modelName, + "texts": texts, + "input_type": "search_document", + "embedding_types": []string{"float"}, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Cohere embedding API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var result struct { + Embeddings struct { + Float [][]float64 `json:"float"` + } `json:"embeddings"` + } + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(result.Embeddings.Float) == 0 { + return nil, fmt.Errorf("Cohere embedding response contains no float data: %s", string(body)) + } + + var embeddings []EmbeddingData + for i, floatArr := range result.Embeddings.Float { + embeddings = append(embeddings, EmbeddingData{ + Embedding: floatArr, + Index: i, + }) + } + + return embeddings, nil +} + +func (c *CoHereModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := strings.TrimSuffix(c.BaseURL[region], "/") + suffix := strings.TrimPrefix(c.URLSuffix.Rerank, "/") + url := fmt.Sprintf("%s/%s", baseURL, suffix) + + var topN = rerankConfig.TopN + if rerankConfig.TopN == 0 { + topN = len(documents) + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "query": query, + "documents": documents, + "top_n": topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Cohere rerank API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var rerankResp struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` + } + + if err := json.Unmarshal(body, &rerankResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + var rerankResponse RerankResponse + for _, result := range rerankResp.Results { + rerankResult := RerankResult{ + Index: result.Index, + RelevanceScore: result.RelevanceScore, + } + rerankResponse.Data = append(rerankResponse.Data, rerankResult) + } + + return &rerankResponse, nil +} + +// TranscribeAudio transcribe audio +func (c *CoHereModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + if file == nil || *file == "" { + return nil, fmt.Errorf("file is missing") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", c.BaseURL[region], c.URLSuffix.ASR) + + // multipart body + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + // open audio file + audioFile, err := os.Open(*file) + if err != nil { + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + // create multipart file field + + if err = writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model name: %w", err) + } + // extra params + if asrConfig != nil && asrConfig.Params != nil { + for key, value := range asrConfig.Params { + + var val string + + switch v := value.(type) { + case string: + val = v + case bool: + val = strconv.FormatBool(v) + case int: + val = strconv.Itoa(v) + case int64: + val = strconv.FormatInt(v, 10) + case float32: + val = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + val = strconv.FormatFloat(v, 'f', -1, 64) + default: + val = fmt.Sprintf("%v", v) + } + + if err = writer.WriteField(key, val); err != nil { + return nil, fmt.Errorf("failed to write field %s: %w", key, err) + } + } + } + + // all form fields (model, language) must appear before the file part in the multipart body + part, err := writer.CreateFormFile("file", filepath.Base(*file)) + if err != nil { + return nil, fmt.Errorf("failed to create form file: %w", err) + } + + if _, err := io.Copy(part, audioFile); err != nil { + return nil, fmt.Errorf("failed to copy audio file: %w", err) + } + + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close writer: %w", err) + } + + // build request + req, err := http.NewRequest("POST", url, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Cohere ASR API error: status %d, body: %s", resp.StatusCode, string(respBody)) + } + + var result struct { + Text string `json:"text"` + } + + if err = json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &ASRResponse{Text: result.Text}, nil +} + +func (z *CoHereModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (c *CoHereModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", c.Name()) +} + +func (z *CoHereModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (c *CoHereModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", c.Name()) +} + +// ParseFile parse file +func (z *CoHereModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (c *CoHereModel) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", c.BaseURL[region], c.URLSuffix.Models) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("accept", "application/json") + if apiConfig != nil && apiConfig.ApiKey != nil { + req.Header.Set("Authorization", fmt.Sprintf("bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Cohere API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models := make([]string, 0) + if modelsRaw, ok := result["models"].([]interface{}); ok { + for _, model := range modelsRaw { + if modelMap, ok := model.(map[string]interface{}); ok { + if modelName, ok := modelMap["name"].(string); ok { + models = append(models, modelName) + } + } + } + } else { + return nil, fmt.Errorf("failed to find 'models' array in response") + } + + return models, nil +} + +func (c *CoHereModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf(c.Name() + " no such method") +} + +func (c *CoHereModel) CheckConnection(apiConfig *APIConfig) error { + _, err := c.ListModels(apiConfig) + return err +} + +func (z *CoHereModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *CoHereModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/cometapi.go b/internal/entity/models/cometapi.go new file mode 100644 index 00000000000..118a0748af0 --- /dev/null +++ b/internal/entity/models/cometapi.go @@ -0,0 +1,818 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "net/url" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +// CometAPIModel implements ModelDriver for CometAPI AI. +// +// CometAPI exposes OpenAI-compatible chat and embeddings under +// https://api.cometapi.com/v1, a public model catalog under +// https://api.cometapi.com/api/models, and account quota data through the +// separate query service at https://query.cometapi.com/user/quota. +type CometAPIModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewCometAPIModel creates a new CometAPI model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +// +// The Client itself has no Timeout. http.Client.Timeout would also +// cap the time spent reading the response body, which would cut off +// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming +// callers wrap each request with context.WithTimeout instead. +func NewCometAPIModel(baseURL map[string]string, urlSuffix URLSuffix) *CometAPIModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &CometAPIModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (m *CometAPIModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewCometAPIModel(baseURL, m.URLSuffix) +} + +func (m *CometAPIModel) Name() string { + return "cometapi" +} + +func validateCometAPIAPIKey(apiConfig *APIConfig) (string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return "", fmt.Errorf("api key is required") + } + return *apiConfig.ApiKey, nil +} + +func validateCometAPIModelName(modelName string) error { + if strings.TrimSpace(modelName) == "" { + return fmt.Errorf("model name is required") + } + return nil +} + +func cometapiRegion(apiConfig *APIConfig) string { + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + return *apiConfig.Region + } + return "default" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. This makes a misconfigured region fail +// fast with a clear message, instead of silently producing a relative +// URL that the HTTP transport then rejects. +func (m *CometAPIModel) baseURLForRegion(region string) (string, error) { + base, ok := m.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("cometapi: no base URL configured for region %q", region) + } + return strings.TrimRight(base, "/"), nil +} + +func (m *CometAPIModel) endpointURL(region, suffix string) (string, error) { + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", baseURL, strings.TrimLeft(suffix, "/")), nil +} + +func (m *CometAPIModel) balanceURL(apiKey string) string { + rawURL := strings.TrimSpace(m.URLSuffix.Balance) + if !strings.HasPrefix(rawURL, "http://") && !strings.HasPrefix(rawURL, "https://") { + rawURL = fmt.Sprintf("https://query.cometapi.com/%s", strings.TrimLeft(rawURL, "/")) + } + parsed, err := url.Parse(rawURL) + if err != nil { + return rawURL + } + query := parsed.Query() + query.Set("key", apiKey) + parsed.RawQuery = query.Encode() + return parsed.String() +} + +type cometapiChatRequest struct { + Model string `json:"model"` + Messages []cometapiAPIMessage `json:"messages"` + Stream bool `json:"stream"` + MaxTokens *int `json:"max_tokens,omitempty"` + Temperature *float64 `json:"temperature,omitempty"` + TopP *float64 `json:"top_p,omitempty"` + Stop *[]string `json:"stop,omitempty"` +} + +type cometapiAPIMessage struct { + Role string `json:"role"` + Content interface{} `json:"content"` +} + +func buildCometAPIChatRequest(modelName string, messages []Message, stream bool, chatModelConfig *ChatConfig) cometapiChatRequest { + apiMessages := make([]cometapiAPIMessage, len(messages)) + for i, msg := range messages { + apiMessages[i] = cometapiAPIMessage{ + Role: msg.Role, + Content: msg.Content, + } + } + + reqBody := cometapiChatRequest{ + Model: modelName, + Messages: apiMessages, + Stream: stream, + } + if chatModelConfig != nil { + reqBody.MaxTokens = chatModelConfig.MaxTokens + reqBody.Temperature = chatModelConfig.Temperature + reqBody.TopP = chatModelConfig.TopP + reqBody.Stop = chatModelConfig.Stop + } + return reqBody +} + +func newCometAPIJSONRequest(ctx context.Context, method string, endpoint string, payload interface{}, apiKey string) (*http.Request, error) { + jsonData, err := json.Marshal(payload) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, method, endpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + if apiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + } + return req, nil +} + +type cometapiHTTPResponse struct { + StatusCode int + Status string + Body []byte +} + +func (m *CometAPIModel) doCometAPIRequest(req *http.Request) (*cometapiHTTPResponse, error) { + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + return &cometapiHTTPResponse{ + StatusCode: resp.StatusCode, + Status: resp.Status, + Body: body, + }, nil +} + +type cometapiChatResponsePayload struct { + Choices []cometapiChatChoice `json:"choices"` +} + +type cometapiChatChoice struct { + Message cometapiChatMessage `json:"message"` + Delta cometapiChatDelta `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +type cometapiChatMessage struct { + Content *string `json:"content"` + ReasoningContent string `json:"reasoning_content"` +} + +type cometapiChatDelta struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` +} + +func parseCometAPIChatResponse(body []byte) (*ChatResponse, error) { + var parsed cometapiChatResponsePayload + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if len(parsed.Choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + if parsed.Choices[0].Message.Content == nil { + return nil, fmt.Errorf("invalid content format") + } + + content := *parsed.Choices[0].Message.Content + reasonContent := strings.TrimLeft(parsed.Choices[0].Message.ReasoningContent, "\n") + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +func parseCometAPIStreamEvent(data string) (content string, reasonContent string, terminal bool, ok bool) { + var event cometapiChatResponsePayload + if err := json.Unmarshal([]byte(data), &event); err != nil { + return "", "", false, false + } + if len(event.Choices) == 0 { + return "", "", false, false + } + choice := event.Choices[0] + return choice.Delta.Content, choice.Delta.ReasoningContent, choice.FinishReason != "", true +} + +type cometapiModelCatalogResponse struct { + Data []cometapiModelCatalogItem `json:"data"` +} + +type cometapiModelCatalogItem struct { + ID string `json:"id"` +} + +func parseCometAPIModelCatalog(body []byte) ([]string, error) { + var parsed cometapiModelCatalogResponse + if err := json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models := make([]string, 0, len(parsed.Data)) + for _, model := range parsed.Data { + if model.ID != "" { + models = append(models, model.ID) + } + } + return models, nil +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (m *CometAPIModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + apiKey, err := validateCometAPIAPIKey(apiConfig) + if err != nil { + return nil, err + } + if err := validateCometAPIModelName(modelName); err != nil { + return nil, err + } + + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + url, err := m.endpointURL(cometapiRegion(apiConfig), m.URLSuffix.Chat) + if err != nil { + return nil, err + } + + // Note: do NOT propagate chatModelConfig.Stream into the request body + // here. ChatWithMessages parses a single JSON response, so stream must + // always be off for this code path. + reqBody := buildCometAPIChatRequest(modelName, messages, false, chatModelConfig) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := newCometAPIJSONRequest(ctx, "POST", url, reqBody, apiKey) + if err != nil { + return nil, err + } + resp, err := m.doCometAPIRequest(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(resp.Body)) + } + return parseCometAPIChatResponse(resp.Body) +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. The CometAPI SSE stream uses the same shape as OpenAI: +// "data:" lines carrying JSON events, with a final "[DONE]" line. +func (m *CometAPIModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + if err := validateCometAPIModelName(modelName); err != nil { + return err + } + + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + apiKey, err := validateCometAPIAPIKey(apiConfig) + if err != nil { + return err + } + + url, err := m.endpointURL(cometapiRegion(apiConfig), m.URLSuffix.Chat) + if err != nil { + return err + } + + if chatModelConfig != nil { + // Refuse to run if the caller explicitly asked for stream=false. + // The body of this method only knows how to read SSE, so a + // non-SSE JSON response would be parsed as if it were a stream + // and produce no chunks. Better to fail clearly. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + } + reqBody := buildCometAPIChatRequest(modelName, messages, true, chatModelConfig) + + // Use an explicit background context. SSE streams are long-lived + // so we do not attach a hard deadline here; the transport's + // ResponseHeaderTimeout caps the connection-establishment phase. + req, err := newCometAPIJSONRequest(context.Background(), "POST", url, reqBody, apiKey) + if err != nil { + return err + } + resp, err := m.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: bump the scanner buffer from the 64KB default to 1MB + // so we never silently truncate a long data: line. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + + if data == "[DONE]" { + sawTerminal = true + break + } + + content, reasoningContent, terminal, ok := parseCometAPIStreamEvent(data) + if !ok { + continue + } + + if reasoningContent != "" { + if err := sender(nil, &reasoningContent); err != nil { + return err + } + } + + if content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + if terminal { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("cometapi: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +type cometapiEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type cometapiEmbeddingResponse struct { + Data []cometapiEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` +} + +type cometapiEmbeddingRequest struct { + Model string `json:"model"` + Input []string `json:"input"` + Dimensions int `json:"dimensions,omitempty"` +} + +// Embed turns a list of texts into embedding vectors using the +// CometAPI /v1/embeddings endpoint. The output has one vector per input, +// in the same order the inputs were given. +func (m *CometAPIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + apiKey, err := validateCometAPIAPIKey(apiConfig) + if err != nil { + return nil, err + } + + if modelName == nil || strings.TrimSpace(*modelName) == "" { + return nil, fmt.Errorf("model name is required") + } + + url, err := m.endpointURL(cometapiRegion(apiConfig), m.URLSuffix.Embedding) + if err != nil { + return nil, err + } + + reqBody := cometapiEmbeddingRequest{ + Model: *modelName, + Input: texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody.Dimensions = embeddingConfig.Dimension + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := newCometAPIJSONRequest(ctx, "POST", url, reqBody, apiKey) + if err != nil { + return nil, err + } + + resp, err := m.doCometAPIRequest(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("CometAPI embeddings API error: %s, body: %s", resp.Status, string(resp.Body)) + } + + var parsed cometapiEmbeddingResponse + if err = json.Unmarshal(resp.Body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Reorder the returned vectors by their reported index so the output + // always lines up with the input texts, even if the upstream API ever + // returns items out of order. A nil slot at the end indicates the + // upstream did not return an embedding for that input. + embeddings := make([]EmbeddingData, len(texts)) + filled := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("cometapi: response index %d out of range for %d inputs", item.Index, len(texts)) + } + if filled[item.Index] { + // A malformed response that repeats the same index would + // silently overwrite the earlier vector. Fail loudly so + // the caller never uses ambiguous output. + return nil, fmt.Errorf("cometapi: duplicate embedding index %d in response", item.Index) + } + embeddings[item.Index] = EmbeddingData{ + Embedding: item.Embedding, + Index: item.Index, + } + filled[item.Index] = true + } + for i, ok := range filled { + if !ok { + return nil, fmt.Errorf("cometapi: missing embedding for input index %d", i) + } + } + + return embeddings, nil +} + +// ListModels returns the public CometAPI model catalog. +func (m *CometAPIModel) ListModels(apiConfig *APIConfig) ([]string, error) { + url, err := m.endpointURL(cometapiRegion(apiConfig), m.URLSuffix.Models) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := m.doCometAPIRequest(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(resp.Body)) + } + return parseCometAPIModelCatalog(resp.Body) +} + +// Balance queries CometAPI's quota service. Unlike model requests, this +// endpoint authenticates with the key query parameter on query.cometapi.com. +func (m *CometAPIModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if strings.TrimSpace(m.URLSuffix.Balance) == "" { + return nil, fmt.Errorf("balance URL is required") + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", m.balanceURL(*apiConfig.ApiKey), nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + resp, err := m.doCometAPIRequest(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("CometAPI quota API error: %s, body: %s", resp.Status, string(resp.Body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(resp.Body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return result, nil +} + +// CheckConnection runs a quota query to verify the API key. +func (m *CometAPIModel) CheckConnection(apiConfig *APIConfig) error { + _, err := m.Balance(apiConfig) + if err != nil { + return err + } + return nil +} + +// Rerank calculates similarity scores between query and documents. CometAPI +// does not expose a public rerank API, so this returns "no such method". +func (m *CometAPIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method") +} + +// TranscribeAudio transcribe audio +func (m *CometAPIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + if file == nil || *file == "" { + return nil, fmt.Errorf("file is missing") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", m.BaseURL[region], m.URLSuffix.ASR) + + // multipart body + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + // open audio file + audioFile, err := os.Open(*file) + if err != nil { + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + // create multipart file field + part, err := writer.CreateFormFile("file", filepath.Base(*file)) + if err != nil { + return nil, fmt.Errorf("failed to create multipart file: %w", err) + } + + // copy file content + if _, err = io.Copy(part, audioFile); err != nil { + return nil, fmt.Errorf("failed to copy audio data: %w", err) + } + + // model field + if err := writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model field: %w", err) + } + + // extra params + if asrConfig != nil && asrConfig.Params != nil { + for key, value := range asrConfig.Params { + + var val string + + switch v := value.(type) { + case string: + val = v + case bool: + val = strconv.FormatBool(v) + case int: + val = strconv.Itoa(v) + case int64: + val = strconv.FormatInt(v, 10) + case float32: + val = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + val = strconv.FormatFloat(v, 'f', -1, 64) + default: + val = fmt.Sprintf("%v", v) + } + + if err = writer.WriteField(key, val); err != nil { + return nil, fmt.Errorf("failed to write field %s: %w", key, err) + } + } + } + + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + // build request + req, err := http.NewRequest("POST", url, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Accept", "application/json") + + // send request + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("SiliconFlow ASR error: %s - %s", resp.Status, string(respBody)) + } + + // SiliconFlow response + var result struct { + Text string `json:"text"` + } + + if err = json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w, body=%s", err, string(respBody)) + } + + return &ASRResponse{Text: result.Text}, nil +} + +func (m *CometAPIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", m.Name()) +} + +// AudioSpeech synthesizes speech audio from text. +func (m *CometAPIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("audio content is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", m.BaseURL[region], m.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": *audioContent, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["response_format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s - %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil +} + +func (m *CometAPIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", m.Name()) +} + +// OCRFile OCR file +func (m *CometAPIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *CometAPIModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *CometAPIModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +func (m *CometAPIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} diff --git a/internal/entity/models/cometapi_test.go b/internal/entity/models/cometapi_test.go new file mode 100644 index 00000000000..34cfe9c6ce5 --- /dev/null +++ b/internal/entity/models/cometapi_test.go @@ -0,0 +1,744 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" +) + +// newCometAPIServer stands up an httptest server that asserts the +// request shape and lets the caller decide what to return. +func newCometAPIServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if r.Method != http.MethodGet && r.Header.Get("Authorization") != "Bearer test-key" { + got := r.Header.Get("Authorization") + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if r.Method == http.MethodPost { + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("expected Content-Type=application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("invalid JSON body: %v\n%s", err, string(raw)) + return + } + handler(t, body, w) + return + } + // GET path: no body + handler(t, nil, w) + })) +} + +func newCometAPIForTest(baseURL string) *CometAPIModel { + return NewCometAPIModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "v1/chat/completions", + Models: "api/models", + Embedding: "v1/embeddings", + Balance: "user/quota", + }, + ) +} + +func TestCometAPIName(t *testing.T) { + m := newCometAPIForTest("http://unused") + if got := m.Name(); got != "cometapi" { + t.Errorf("Name()=%q, want %q", got, "cometapi") + } +} + +func TestCometAPIFactoryRoute(t *testing.T) { + driver, err := NewModelFactory().CreateModelDriver("cometapi", map[string]string{"default": "http://unused"}, URLSuffix{}) + if err != nil { + t.Fatalf("CreateModelDriver: %v", err) + } + if _, ok := driver.(*CometAPIModel); !ok { + t.Fatalf("driver type=%T, want *CometAPIModel", driver) + } +} + +func TestCometAPIChatHappyPath(t *testing.T) { + srv := newCometAPIServer(t, "/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "gpt-5" { + t.Errorf("expected model=gpt-5, got %v", body["model"]) + } + if body["stream"] != false { + t.Errorf("expected stream=false, got %v", body["stream"]) + } + msgs, ok := body["messages"].([]interface{}) + if !ok || len(msgs) != 1 { + t.Errorf("expected 1 message, got %v", body["messages"]) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{ + {"message": map[string]interface{}{"content": "pong"}}, + }, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + resp, err := m.ChatWithMessages("gpt-5", []Message{ + {Role: "user", Content: "ping"}, + }, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "pong" { + t.Errorf("answer=%v, want pong", resp.Answer) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "" { + t.Errorf("expected empty reason content, got %v", resp.ReasonContent) + } +} + +func TestCometAPIChatPropagatesConfig(t *testing.T) { + srv := newCometAPIServer(t, "/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["max_tokens"] != float64(64) { + t.Errorf("max_tokens=%v want 64", body["max_tokens"]) + } + if body["temperature"] != 0.3 { + t.Errorf("temperature=%v want 0.3", body["temperature"]) + } + if body["top_p"] != 0.9 { + t.Errorf("top_p=%v want 0.9", body["top_p"]) + } + stop, ok := body["stop"].([]interface{}) + if !ok || len(stop) != 1 || stop[0] != "END" { + t.Errorf("stop=%v want [END]", body["stop"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{"message": map[string]interface{}{"content": "ok"}}}, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + mt := 64 + temp := 0.3 + topP := 0.9 + stop := []string{"END"} + _, err := m.ChatWithMessages("gpt-5", []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &mt, Temperature: &temp, TopP: &topP, Stop: &stop}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } +} + +func TestCometAPIChatReturnsReasoningContent(t *testing.T) { + srv := newCometAPIServer(t, "/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{ + {"message": map[string]interface{}{"content": "answer", "reasoning_content": "\nreason"}}, + }, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + resp, err := m.ChatWithMessages("gpt-5", []Message{{Role: "user", Content: "ping"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "reason" { + t.Errorf("reason=%v want reason", resp.ReasonContent) + } +} + +func TestCometAPIChatRequiresAPIKey(t *testing.T) { + m := newCometAPIForTest("http://unused") + _, err := m.ChatWithMessages("gpt-5", []Message{{Role: "user", Content: "x"}}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } + emptyKey := "" + _, err = m.ChatWithMessages("gpt-5", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &emptyKey}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("empty key: expected api-key error, got %v", err) + } +} + +func TestCometAPIChatRequiresModelName(t *testing.T) { + m := newCometAPIForTest("http://unused") + apiKey := "test-key" + _, err := m.ChatWithMessages("", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } + err = m.ChatStreamlyWithSender(" ", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil, func(*string, *string) error { return nil }) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("stream: expected model-name error, got %v", err) + } +} + +func TestCometAPIChatRequiresMessages(t *testing.T) { + m := newCometAPIForTest("http://unused") + apiKey := "test-key" + _, err := m.ChatWithMessages("gpt-5", nil, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "messages is empty") { + t.Errorf("expected messages-empty error, got %v", err) + } +} + +func TestCometAPIChatRejectsHTTPError(t *testing.T) { + srv := newCometAPIServer(t, "/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + _, err := m.ChatWithMessages("gpt-5", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "401") { + t.Errorf("expected 401 propagated, got %v", err) + } +} + +func TestCometAPIChatFallsBackToDefaultOnEmptyRegion(t *testing.T) { + // Empty *Region pointer must fall back to the "default" entry, not + // be treated as an explicit "" region (which would miss the lookup). + srv := newCometAPIServer(t, "/v1/chat/completions", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{"message": map[string]interface{}{"content": "ok"}}}, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + emptyRegion := "" + _, err := m.ChatWithMessages("gpt-5", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey, Region: &emptyRegion}, nil) + if err != nil { + t.Errorf("empty Region: expected fallback to default, got %v", err) + } +} + +func TestCometAPIListModelsFallsBackToDefaultOnEmptyRegion(t *testing.T) { + srv := newCometAPIServer(t, "/api/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{{"id": "x"}}}) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + emptyRegion := "" + if _, err := m.ListModels(&APIConfig{ApiKey: &apiKey, Region: &emptyRegion}); err != nil { + t.Errorf("empty Region: expected fallback to default, got %v", err) + } +} + +func TestCometAPIStreamRequiresSender(t *testing.T) { + m := newCometAPIForTest("http://unused") + apiKey := "test-key" + err := m.ChatStreamlyWithSender("gpt-5", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, nil) + if err == nil || !strings.Contains(err.Error(), "sender is required") { + t.Errorf("expected sender-required error, got %v", err) + } +} + +func TestCometAPIChatRejectsUnknownRegion(t *testing.T) { + m := newCometAPIForTest("http://unused") + apiKey := "test-key" + region := "eu" + _, err := m.ChatWithMessages("gpt-5", []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey, Region: ®ion}, nil) + if err == nil || !strings.Contains(err.Error(), "no base URL configured for region") { + t.Errorf("expected region error, got %v", err) + } +} + +func TestCometAPIBaseURLNormalizesSlashes(t *testing.T) { + tests := []struct { + name string + path string + run func(*CometAPIModel, *APIConfig) error + }{ + { + name: "Chat", + path: "/v1/chat/completions", + run: func(m *CometAPIModel, apiConfig *APIConfig) error { + _, err := m.ChatWithMessages("gpt-5", []Message{{Role: "user", Content: "x"}}, apiConfig, nil) + return err + }, + }, + { + name: "Stream", + path: "/v1/chat/completions", + run: func(m *CometAPIModel, apiConfig *APIConfig) error { + return m.ChatStreamlyWithSender("gpt-5", []Message{{Role: "user", Content: "x"}}, apiConfig, nil, func(*string, *string) error { return nil }) + }, + }, + { + name: "Embed", + path: "/v1/embeddings", + run: func(m *CometAPIModel, apiConfig *APIConfig) error { + model := "text-embedding-3-small" + _, err := m.Embed(&model, []string{"x"}, apiConfig, nil) + return err + }, + }, + { + name: "ListModels", + path: "/api/models", + run: func(m *CometAPIModel, apiConfig *APIConfig) error { + _, err := m.ListModels(apiConfig) + return err + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + srv := newCometAPIServer(t, tt.path, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + switch tt.name { + case "Chat": + _ = json.NewEncoder(w).Encode(map[string]interface{}{"choices": []map[string]interface{}{{"message": map[string]interface{}{"content": "ok"}}}}) + case "Stream": + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, `data: {"choices":[{"delta":{"content":"ok"},"finish_reason":"stop"}]}`+"\n") + case "Embed": + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{{"embedding": []float64{1}, "index": 0}}}) + case "ListModels": + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{{"id": "gpt-5"}}}) + } + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL + "/") + m.URLSuffix.Chat = "/v1/chat/completions" + m.URLSuffix.Models = "/api/models" + m.URLSuffix.Embedding = "/v1/embeddings" + apiKey := "test-key" + if err := tt.run(m, &APIConfig{ApiKey: &apiKey}); err != nil { + t.Fatalf("%s: %v", tt.name, err) + } + }) + } +} + +func TestCometAPIStreamHappyPath(t *testing.T) { + srv := newCometAPIServer(t, "/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["stream"] != true { + t.Errorf("expected stream=true, got %v", body["stream"]) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + // Two content chunks then finish_reason terminator, then [DONE]. + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"content":"Hello "}}]}`+"\n"+ + `data: {"choices":[{"delta":{"content":"world"}}]}`+"\n"+ + `data: {"choices":[{"delta":{},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + var chunks []string + var sawDone int32 + err := m.ChatStreamlyWithSender("gpt-5", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(content *string, _ *string) error { + if content == nil { + return nil + } + if *content == "[DONE]" { + atomic.StoreInt32(&sawDone, 1) + return nil + } + chunks = append(chunks, *content) + return nil + }, + ) + if err != nil { + t.Fatalf("stream: %v", err) + } + if strings.Join(chunks, "") != "Hello world" { + t.Errorf("chunks=%v want [\"Hello \" \"world\"]", chunks) + } + if atomic.LoadInt32(&sawDone) != 1 { + t.Error("expected sender to receive [DONE] sentinel") + } +} + +func TestCometAPIStreamRejectsExplicitFalse(t *testing.T) { + m := newCometAPIForTest("http://unused") + apiKey := "test-key" + stream := false + err := m.ChatStreamlyWithSender("gpt-5", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Stream: &stream}, + func(*string, *string) error { return nil }, + ) + if err == nil || !strings.Contains(err.Error(), "stream must be true") { + t.Errorf("expected stream-true guard, got %v", err) + } +} + +func TestCometAPIStreamFailsWithoutTerminal(t *testing.T) { + // Body closes before [DONE] or a finish_reason -> driver must complain + // instead of pretending the stream finished cleanly. + srv := newCometAPIServer(t, "/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `data: {"choices":[{"delta":{"content":"half"}}]}`+"\n") + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + err := m.ChatStreamlyWithSender("gpt-5", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(*string, *string) error { return nil }, + ) + if err == nil || !strings.Contains(err.Error(), "stream ended before") { + t.Errorf("expected stream-truncation error, got %v", err) + } +} + +func TestCometAPIListModelsHappyPath(t *testing.T) { + srv := newCometAPIServer(t, "/api/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"id": "gpt-5"}, + {"id": "gpt-4o-mini"}, + {"id": "text-embedding-3-small"}, + }, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + ids, err := m.ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if len(ids) != 3 || ids[0] != "gpt-5" || ids[2] != "text-embedding-3-small" { + t.Errorf("ids=%v, want [gpt-5 gpt-4o-mini text-embedding-3-small]", ids) + } +} + +func TestCometAPIListModelsAllowsNilAPIConfig(t *testing.T) { + srv := newCometAPIServer(t, "/api/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{{"id": "gpt-5"}}}) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + ids, err := m.ListModels(nil) + if err != nil { + t.Fatalf("ListModels(nil): %v", err) + } + if len(ids) != 1 || ids[0] != "gpt-5" { + t.Errorf("ids=%v want [gpt-5]", ids) + } +} + +func TestCometAPICheckConnectionDelegatesToBalance(t *testing.T) { + // 200 -> CheckConnection succeeds; 401 -> CheckConnection propagates. + okSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/user/quota" { + t.Errorf("path=%s want /user/quota", r.URL.Path) + } + if got := r.URL.Query().Get("key"); got != "test-key" { + t.Errorf("key query=%q want test-key", got) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"total_quota": 10.0}) + })) + defer okSrv.Close() + failSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + })) + defer failSrv.Close() + + apiKey := "test-key" + mOK := newCometAPIForTest(okSrv.URL) + mOK.URLSuffix.Balance = okSrv.URL + "/user/quota" + if err := mOK.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil { + t.Errorf("CheckConnection(ok): %v", err) + } + mFail := newCometAPIForTest(failSrv.URL) + mFail.URLSuffix.Balance = failSrv.URL + "/user/quota" + if err := mFail.CheckConnection(&APIConfig{ApiKey: &apiKey}); err == nil { + t.Error("CheckConnection(fail): expected error, got nil") + } +} + +func TestCometAPIBalanceHappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/user/quota" { + t.Errorf("path=%s want /user/quota", r.URL.Path) + } + if got := r.URL.Query().Get("key"); got != "test-key" { + t.Errorf("key query=%q want test-key", got) + } + if got := r.Header.Get("Authorization"); got != "" { + t.Errorf("Authorization=%q want empty", got) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "username": "tester", + "total_quota": 20.5, + "total_used_quota": 1.25, + "request_count": 7, + }) + })) + defer srv.Close() + + m := newCometAPIForTest("http://unused") + m.URLSuffix.Balance = srv.URL + "/user/quota" + apiKey := "test-key" + balance, err := m.Balance(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("Balance: %v", err) + } + if balance["username"] != "tester" || balance["total_quota"] != 20.5 { + t.Errorf("balance=%v", balance) + } +} + +func TestCometAPIBalanceRequiresAPIKey(t *testing.T) { + m := newCometAPIForTest("http://unused") + _, err := m.Balance(&APIConfig{}) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("Balance: expected api-key error, got %v", err) + } +} + +func TestCometAPIBalanceRequiresConfiguredURL(t *testing.T) { + m := newCometAPIForTest("http://unused") + m.URLSuffix.Balance = "" + apiKey := "test-key" + _, err := m.Balance(&APIConfig{ApiKey: &apiKey}) + if err == nil || !strings.Contains(err.Error(), "balance URL is required") { + t.Errorf("Balance: expected balance URL error, got %v", err) + } +} + +func TestCometAPIRerankReturnsNoSuchMethod(t *testing.T) { + m := newCometAPIForTest("http://unused") + q := "gpt-5" + _, err := m.Rerank(&q, "what is rag?", []string{"a", "b"}, &APIConfig{}, &RerankConfig{TopN: 2}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank: expected 'no such method', got %v", err) + } +} + +func TestCometAPIEmbedHappyPath(t *testing.T) { + srv := newCometAPIServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "text-embedding-3-small" { + t.Errorf("model=%v want text-embedding-3-small", body["model"]) + } + if body["dimensions"] != float64(256) { + t.Errorf("dimensions=%v want 256", body["dimensions"]) + } + inputs, ok := body["input"].([]interface{}) + if !ok || len(inputs) != 3 { + t.Errorf("input=%v want 3-element array", body["input"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{0.1, 0.2}, "index": 0}, + {"embedding": []float64{0.3, 0.4}, "index": 1}, + {"embedding": []float64{0.5, 0.6}, "index": 2}, + }, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + model := "text-embedding-3-small" + vecs, err := m.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, &EmbeddingConfig{Dimension: 256}) + if err != nil { + t.Fatalf("Embed: %v", err) + } + if len(vecs) != 3 { + t.Fatalf("len(vecs)=%d want 3", len(vecs)) + } + if vecs[1].Embedding[0] != 0.3 || vecs[1].Index != 1 { + t.Errorf("vecs[1]=%+v want {Embedding:[0.3 0.4] Index:1}", vecs[1]) + } +} + +func TestCometAPIEmbedReordersByIndex(t *testing.T) { + // Upstream returns the three vectors in shuffled order. The driver + // must reorder them so the slot at position i corresponds to input i. + srv := newCometAPIServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{2}, "index": 2}, + {"embedding": []float64{0}, "index": 0}, + {"embedding": []float64{1}, "index": 1}, + }, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + model := "text-embedding-3-small" + vecs, err := m.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + for i, v := range vecs { + if v.Index != i || v.Embedding[0] != float64(i) { + t.Errorf("slot %d = %+v, want Embedding=[%d] Index=%d", i, v, i, i) + } + } +} + +func TestCometAPIEmbedEmptyInputShortCircuits(t *testing.T) { + // Empty input must NOT make an HTTP call; the test fails the request + // rather than the assertion if it does. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("Embed([]) made an unexpected HTTP call") + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + model := "text-embedding-3-small" + vecs, err := m.Embed(&model, []string{}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed([]): %v", err) + } + if len(vecs) != 0 { + t.Errorf("len(vecs)=%d want 0", len(vecs)) + } +} + +func TestCometAPIEmbedRequiresAPIKey(t *testing.T) { + m := newCometAPIForTest("http://unused") + model := "text-embedding-3-small" + _, err := m.Embed(&model, []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestCometAPIEmbedRequiresModelName(t *testing.T) { + m := newCometAPIForTest("http://unused") + apiKey := "test-key" + _, err := m.Embed(nil, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } + empty := "" + _, err = m.Embed(&empty, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("empty model: expected model-name error, got %v", err) + } +} + +func TestCometAPIEmbedRejectsDuplicateIndex(t *testing.T) { + // A malformed upstream that repeats data[*].index would silently + // overwrite the earlier vector; the driver must fail loudly instead. + srv := newCometAPIServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + {"embedding": []float64{2}, "index": 0}, + }, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + model := "text-embedding-3-small" + _, err := m.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") { + t.Errorf("expected duplicate-index error, got %v", err) + } +} + +func TestCometAPIEmbedRejectsOutOfRangeIndex(t *testing.T) { + srv := newCometAPIServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 7}, // out of range for 2-input request + }, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + model := "text-embedding-3-small" + _, err := m.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestCometAPIEmbedRejectsMissingSlot(t *testing.T) { + // Upstream returns only one of the two requested embeddings. + srv := newCometAPIServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + }, + }) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + model := "text-embedding-3-small" + _, err := m.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "missing embedding for input index 1") { + t.Errorf("expected missing-embedding error for slot 1, got %v", err) + } +} + +func TestCometAPIEmbedRejectsHTTPError(t *testing.T) { + srv := newCometAPIServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + }) + defer srv.Close() + + m := newCometAPIForTest(srv.URL) + apiKey := "test-key" + model := "text-embedding-3-small" + _, err := m.Embed(&model, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "CometAPI embeddings API error") { + t.Errorf("expected CometAPI embeddings API error, got %v", err) + } +} diff --git a/internal/entity/models/deepinfra.go b/internal/entity/models/deepinfra.go new file mode 100644 index 00000000000..dd70981fe6f --- /dev/null +++ b/internal/entity/models/deepinfra.go @@ -0,0 +1,840 @@ +package models + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "ragflow/internal/common" + "strconv" + "strings" + "time" +) + +type DeepInfraModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewDeepInfraModel(baseURL map[string]string, urlSuffix URLSuffix) *DeepInfraModel { + return &DeepInfraModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (d *DeepInfraModel) NewInstance(baseURL map[string]string) ModelDriver { + return &DeepInfraModel{ + BaseURL: baseURL, + URLSuffix: d.URLSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (d *DeepInfraModel) Name() string { + return "deepinfra" +} + +func (d *DeepInfraModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", d.BaseURL[region], d.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + if chatModelConfig != nil { + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + + if chatModelConfig.Effort != nil { + reqBody["reasoning_effort"] = *chatModelConfig.Effort + } + + if chatModelConfig.Thinking != nil && *chatModelConfig.Thinking { + reasoningMap := map[string]interface{}{ + "enabled": true, + } + if chatModelConfig.Effort != nil { + reasoningMap["effort"] = *chatModelConfig.Effort + } + reqBody["reasoning"] = reasoningMap + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Parse result + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response body: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + var reasonContent string + if rc, ok := messageMap["reasoning_content"].(string); ok { + reasonContent = rc + } + + chatResponse := &ChatResponse{ + Answer: &content, + } + if reasonContent != "" { + chatResponse.ReasonContent = &reasonContent + } + + return chatResponse, nil +} + +func (d *DeepInfraModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", d.BaseURL[region], d.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body with streaming enabled + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + "temperature": 1, + } + + if modelConfig != nil { + if modelConfig.Stream != nil { + reqBody["stream"] = *modelConfig.Stream + } + + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + + if modelConfig.DoSample != nil { + reqBody["do_sample"] = *modelConfig.DoSample + } + + if modelConfig.TopP != nil { + reqBody["top_p"] = *modelConfig.TopP + } + + if modelConfig.Stop != nil { + reqBody["stop"] = *modelConfig.Stop + } + + if modelConfig.Effort != nil { + reqBody["reasoning_effort"] = *modelConfig.Effort + } + + if modelConfig.Thinking != nil && *modelConfig.Thinking { + reasoningMap := map[string]interface{}{ + "enabled": true, + } + if modelConfig.Effort != nil { + reasoningMap["effort"] = *modelConfig.Effort + } + reqBody["reasoning"] = reasoningMap + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := d.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: read line by line + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + common.Info(line) + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + reasoningContent, ok := delta["reasoning_content"].(string) + if ok && reasoningContent != "" { + if err := sender(nil, &reasoningContent); err != nil { + return err + } + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + break + } + } + + // Send [DONE] marker for OpenAI compatibility + endOfStream := "[DONE]" + if err = sender(&endOfStream, nil); err != nil { + return err + } + + return scanner.Err() +} + +func (d *DeepInfraModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, fmt.Errorf("texts is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", d.BaseURL[region], d.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + if embeddingConfig != nil && embeddingConfig.Dimension >= 32 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("DeepInfra embedding API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var parsed struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + } + + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + // 组装 RAGFlow 需要的返回格式 + var embeddings []EmbeddingData + for _, data := range parsed.Data { + embeddings = append(embeddings, EmbeddingData{ + Embedding: data.Embedding, + Index: data.Index, + }) + } + + return embeddings, nil + + return embeddings, nil +} + +func (d *DeepInfraModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s no such method", d.Name()) +} + +func (d *DeepInfraModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("DeepInfra API key is missing") + } + + if file == nil || *file == "" { + return nil, fmt.Errorf("file is missing") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is missing") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", d.BaseURL[region], d.URLSuffix.ASR) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + if err := writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model field: %w", err) + } + + // Open File + audioFile, err := os.Open(*file) + if err != nil { + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + part, err := writer.CreateFormFile("file", filepath.Base(*file)) + if err != nil { + return nil, fmt.Errorf("failed to create multipart file: %w", err) + } + + if _, err = io.Copy(part, audioFile); err != nil { + return nil, fmt.Errorf("failed to copy audio data: %w", err) + } + + // get config + if asrConfig != nil && asrConfig.Params != nil { + for key, value := range asrConfig.Params { + var val string + + switch v := value.(type) { + case string: + val = v + case bool: + val = strconv.FormatBool(v) + case int: + val = strconv.Itoa(v) + case int64: + val = strconv.FormatInt(v, 10) + case float64: + val = strconv.FormatFloat(v, 'f', -1, 64) + case float32: + val = strconv.FormatFloat(float64(v), 'f', -1, 32) + default: + val = fmt.Sprintf("%v", v) + } + + if err := writer.WriteField(key, val); err != nil { + return nil, fmt.Errorf("failed to write field %s: %w", key, err) + } + } + } + + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + req, err := http.NewRequest("POST", url, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Accept", "application/json") + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("DeepInfra ASR error: %s - %s", resp.Status, string(respBody)) + } + + // Parse result + var result struct { + Text string `json:"text"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &ASRResponse{ + Text: result.Text, + }, nil +} + +func (d *DeepInfraModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s no such method", d.Name()) +} + +func (d *DeepInfraModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("DeepInfra API key is missing") + } + + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("text content is missing") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + reqBody := map[string]interface{}{ + "text": *audioContent, + } + voiceID := "" + + if ttsConfig != nil && ttsConfig.Params != nil { + if v, ok := ttsConfig.Params["voice_id"].(string); ok && v != "" { + voiceID = v + } else if v, ok := ttsConfig.Params["voice"].(string); ok && v != "" { + voiceID = v + } + + for key, value := range ttsConfig.Params { + if key != "voice_id" && key != "voice" { + reqBody[key] = value + } + } + } + + if voiceID == "" { + return nil, fmt.Errorf("voice_id is missing (must be provided in params or model name)") + } + + // URL: https://api.deepinfra.com/v1/text-to-speech/{voice_id} + url := fmt.Sprintf("%s/%s/%s", d.BaseURL[region], d.URLSuffix.TTS, voiceID) + + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["output_format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("DeepInfra TTS error: status %d - %s", resp.StatusCode, string(body)) + } + + return &TTSResponse{Audio: body}, nil +} + +func (d *DeepInfraModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("DeepInfra API key is missing") + } + + if audioContent == nil || *audioContent == "" { + return fmt.Errorf("text content is missing") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + voiceID := "" + + reqBody := map[string]interface{}{ + "text": *audioContent, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + if v, ok := ttsConfig.Params["voice_id"].(string); ok && v != "" { + voiceID = v + } else if v, ok := ttsConfig.Params["voice"].(string); ok && v != "" { + voiceID = v + } + + for key, value := range ttsConfig.Params { + if key != "voice_id" && key != "voice" { + reqBody[key] = value + } + } + } + + if voiceID == "" { + return fmt.Errorf("voice_id is missing") + } + + // URL: https://api.deepinfra.com/v1/text-to-speech/{voice_id}/stream + url := fmt.Sprintf("%s/%s/%s/stream", d.BaseURL[region], d.URLSuffix.TTS, voiceID) + + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["output_format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := d.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("DeepInfra TTS Stream error: status %d - %s", resp.StatusCode, string(body)) + } + + buffer := make([]byte, 4096) + for { + n, err := resp.Body.Read(buffer) + + if n > 0 { + chunkStr := string(buffer[:n]) + if sendErr := sender(&chunkStr, nil); sendErr != nil { + return sendErr + } + } + + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("error reading stream: %w", err) + } + } + + endOfStream := "[DONE]" + if err = sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +func (d *DeepInfraModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s no such method", d.Name()) +} + +func (d *DeepInfraModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s no such method", d.Name()) +} + +func (d *DeepInfraModel) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", d.BaseURL[region], d.URLSuffix.Models) + + reqBody := map[string]interface{}{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to read response: %s", string(body)) + } + + // Parse response + var result []struct { + ModelName string `json:"model_name"` + } + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + models := make([]string, 0) + for _, model := range result { + if model.ModelName != "" { + models = append(models, model.ModelName) + } + } + + return models, nil +} + +func (d *DeepInfraModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", d.BaseURL[region], d.URLSuffix.Balance) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := d.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to read response: %s", string(body)) + } + + var result struct { + Balance interface{} `json:"stripe_balance"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return map[string]interface{}{ + "balance": result.Balance, + "currence": "USD", + }, nil +} + +func (d *DeepInfraModel) CheckConnection(apiConfig *APIConfig) error { + _, err := d.ListModels(apiConfig) + return err +} + +func (d *DeepInfraModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s no such method", d.Name()) +} + +func (d *DeepInfraModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s no such method", d.Name()) +} diff --git a/internal/entity/models/deepseek.go b/internal/entity/models/deepseek.go index dc06ebbfbd7..78f5607c424 100644 --- a/internal/entity/models/deepseek.go +++ b/internal/entity/models/deepseek.go @@ -415,8 +415,8 @@ func (z *DeepSeekModel) ChatStreamlyWithSender(modelName string, messages []Mess return scanner.Err() } -// Encode encodes a list of texts into embeddings -func (z *DeepSeekModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Embed embeds a list of texts into embeddings +func (z *DeepSeekModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } @@ -584,3 +584,39 @@ func (z *DeepSeekModel) CheckConnection(apiConfig *APIConfig) error { func (z *DeepSeekModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (d *DeepSeekModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +func (z *DeepSeekModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (d *DeepSeekModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +func (z *DeepSeekModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (d *DeepSeekModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +// ParseFile parse file +func (z *DeepSeekModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *DeepSeekModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *DeepSeekModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/dummy.go b/internal/entity/models/dummy.go index ffc0f9f4b78..718dd4bca1b 100644 --- a/internal/entity/models/dummy.go +++ b/internal/entity/models/dummy.go @@ -34,42 +34,78 @@ func NewDummyModel(baseURL map[string]string, urlSuffix URLSuffix) *DummyModel { } } -func (z *DummyModel) NewInstance(baseURL map[string]string) ModelDriver { +func (d *DummyModel) NewInstance(baseURL map[string]string) ModelDriver { return nil } -func (z *DummyModel) Name() string { +func (d *DummyModel) Name() string { return "dummy" } // ChatWithMessages sends multiple messages with roles and returns response -func (z *DummyModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { +func (d *DummyModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { return nil, fmt.Errorf("not implemented") } // ChatStreamlyWithSender sends messages and streams response via sender function (best performance, no channel) -func (z *DummyModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { +func (d *DummyModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { return fmt.Errorf("not implemented") } -// Encode encodes a list of texts into embeddings -func (z *DummyModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Embed embeds a list of texts into embeddings +func (d *DummyModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { return nil, fmt.Errorf("not implemented") } -func (z *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) { +func (d *DummyModel) ListModels(apiConfig *APIConfig) ([]string, error) { return nil, fmt.Errorf("not implemented") } -func (z *DummyModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { +func (d *DummyModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { return nil, fmt.Errorf("no such method") } -func (z *DummyModel) CheckConnection(apiConfig *APIConfig) error { +func (d *DummyModel) CheckConnection(apiConfig *APIConfig) error { return fmt.Errorf("no such method") } // Rerank calculates similarity scores between query and documents -func (z *DummyModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +func (d *DummyModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", d.Name()) +} + +// TranscribeAudio transcribe audio +func (d *DummyModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +func (z *DummyModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (d *DummyModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +func (z *DummyModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (d *DummyModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", d.Name()) +} + +// ParseFile parse file +func (z *DummyModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *DummyModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *DummyModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) } diff --git a/internal/entity/models/factory.go b/internal/entity/models/factory.go index b38e4ff9d45..819878ed62e 100644 --- a/internal/entity/models/factory.go +++ b/internal/entity/models/factory.go @@ -33,6 +33,8 @@ func NewModelFactory() *ModelFactory { func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string]string, urlSuffix URLSuffix) (ModelDriver, error) { providerLower := strings.ToLower(providerName) switch providerLower { + case "anthropic": + return NewAnthropicModel(baseURL, urlSuffix), nil case "zhipu-ai": return NewZhipuAIModel(baseURL, urlSuffix), nil case "deepseek": @@ -57,12 +59,62 @@ func (f *ModelFactory) CreateModelDriver(providerName string, baseURL map[string return NewXAIModel(baseURL, urlSuffix), nil case "lmstudio": return NewLmStudioModel(baseURL, urlSuffix), nil + case "ollama": + return NewOllamaModel(baseURL, urlSuffix), nil + case "openai": + return NewOpenAIModel(baseURL, urlSuffix), nil case "nvidia": return NewNvidiaModel(baseURL, urlSuffix), nil case "openrouter": return NewOpenRouterModel(baseURL, urlSuffix), nil case "huggingface": return NewHuggingFaceModel(baseURL, urlSuffix), nil + case "baidu": + return NewBaiduModel(baseURL, urlSuffix), nil + case "cohere": + return NewCoHereModel(baseURL, urlSuffix), nil + case "cometapi": + return NewCometAPIModel(baseURL, urlSuffix), nil + case "fishaudio": + return NewFishAudioModel(baseURL, urlSuffix), nil + case "mistral": + return NewMistralModel(baseURL, urlSuffix), nil + case "upstage": + return NewUpstageModel(baseURL, urlSuffix), nil + case "stepfun": + return NewStepFunModel(baseURL, urlSuffix), nil + case "baichuan": + return NewBaichuanModel(baseURL, urlSuffix), nil + case "jina": + return NewJinaModel(baseURL, urlSuffix), nil + case "localai": + return NewLocalAIModel(baseURL, urlSuffix), nil + case "xinference": + return NewXinferenceModel(baseURL, urlSuffix), nil + case "longcat": + return NewLongCatModel(baseURL, urlSuffix), nil + case "novita": + return NewNovitaModel(baseURL, urlSuffix), nil + case "replicate": + return NewReplicateModel(baseURL, urlSuffix), nil + case "togetherai": + return NewTogetherAIModel(baseURL, urlSuffix), nil + case "voyage": + return NewVoyageModel(baseURL, urlSuffix), nil + case "paddleocr": + return NewPaddleOCRModel(baseURL, urlSuffix), nil + case "xunfei": + return NewXunFeiModel(baseURL, urlSuffix), nil + case "deepinfra": + return NewDeepInfraModel(baseURL, urlSuffix), nil + case "mineru": + return NewMinerUModel(baseURL, urlSuffix), nil + case "jiekouai": + return NewJieKouAIModel(baseURL, urlSuffix), nil + case "302.ai": + return NewAI302Model(baseURL, urlSuffix), nil + case "mineru_local": + return NewMinerLocalUModel(baseURL, urlSuffix), nil default: return NewDummyModel(baseURL, urlSuffix), nil } diff --git a/internal/entity/models/fishaudio.go b/internal/entity/models/fishaudio.go new file mode 100644 index 00000000000..70f67076211 --- /dev/null +++ b/internal/entity/models/fishaudio.go @@ -0,0 +1,455 @@ +package models + +import ( + "bufio" + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "os" + "path/filepath" + "strconv" + "strings" + "time" +) + +// 208cc2d0e4594ca896a600c43c9497aa + +type FishAudioModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewFishAudioModel(baseURL map[string]string, urlSuffix URLSuffix) *FishAudioModel { + return &FishAudioModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + +func (f *FishAudioModel) NewInstance(baseURL map[string]string) ModelDriver { + return &FishAudioModel{ + BaseURL: baseURL, + URLSuffix: f.URLSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + }, + } +} + +func (f *FishAudioModel) Name() string { + return "fishaudio" +} + +func (f *FishAudioModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + return nil, fmt.Errorf(f.Name() + " no such method") +} + +func (f *FishAudioModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf(f.Name() + " no such method") +} + +func (f *FishAudioModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("no such method") +} + +func (f *FishAudioModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method") +} + +// TranscribeAudio transcribe audio +func (f *FishAudioModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("FishAudio API key is missing") + } + + if file == nil || *file == "" { + return nil, fmt.Errorf("file is missing") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", f.BaseURL[region], f.URLSuffix.ASR) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + // audio file + audioFile, err := os.Open(*file) + if err != nil { + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + part, err := writer.CreateFormFile("audio", filepath.Base(*file)) + if err != nil { + return nil, fmt.Errorf("failed to create multipart file: %w", err) + } + + if _, err = io.Copy(part, audioFile); err != nil { + return nil, fmt.Errorf("failed to copy audio data: %w", err) + } + + // extra params + if asrConfig != nil && asrConfig.Params != nil { + for key, value := range asrConfig.Params { + + var val string + + switch v := value.(type) { + case string: + val = v + case bool: + val = strconv.FormatBool(v) + case int: + val = strconv.Itoa(v) + case float64: + val = strconv.FormatFloat(v, 'f', -1, 64) + default: + val = fmt.Sprintf("%v", v) + } + + if err := writer.WriteField(key, val); err != nil { + return nil, fmt.Errorf("failed to write field %s: %w", key, err) + } + } + } + + if err := writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + // request + req, err := http.NewRequest("POST", url, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + + resp, err := f.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf( + "FishAudio ASR error: %s - %s", + resp.Status, + string(respBody), + ) + } + + // result + var result struct { + Text string `json:"text"` + } + + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w", err) + } + + return &ASRResponse{ + Text: result.Text, + }, nil +} + +func (f *FishAudioModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", f.Name()) +} + +// AudioSpeech convert text to audio +func (f *FishAudioModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("FishAudio API key is missing") + } + + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("text content is missing") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", f.BaseURL[region], f.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "text": *audioContent, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("model", *modelName) + + resp, err := f.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s - %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil +} + +func (f *FishAudioModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("FishAudio API key is missing") + } + + if audioContent == nil || *audioContent == "" { + return fmt.Errorf("text content is missing") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s/%s", f.BaseURL[region], f.URLSuffix.TTS, "stream/with-timestamp") + + reqBody := map[string]interface{}{ + "text": *audioContent, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // Build Request + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("model", *modelName) + + resp, err := f.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + buf := make([]byte, 1024) + n, _ := resp.Body.Read(buf) + return fmt.Errorf("FishAudio stream API error: %d - %s", resp.StatusCode, string(buf[:n])) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 8*1024*1024) + + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data: ") { + continue + } + + dataStr := strings.TrimSpace(line[6:]) + if dataStr == "" { + continue + } + + var event struct { + AudioBase64 string `json:"audio_base64"` + } + + if err := json.Unmarshal([]byte(dataStr), &event); err != nil { + continue + } + + if event.AudioBase64 != "" { + audioBytes, err := base64.StdEncoding.DecodeString(event.AudioBase64) + if err == nil && len(audioBytes) > 0 { + chunk := string(audioBytes) + if errSend := sender(&chunk, nil); errSend != nil { + return errSend + } + } + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading FishAudio stream: %w", err) + } + + return nil +} + +// OCRFile OCR file +func (f *FishAudioModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", f.Name()) +} + +// ParseFile parse file +func (z *FishAudioModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (f *FishAudioModel) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", f.BaseURL[region], f.URLSuffix.Models) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } else { + return nil, fmt.Errorf("Fish Audio API key is missing") + } + + resp, err := f.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Fish Audio API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result struct { + Items []struct { + ID string `json:"_id"` + Title string `json:"title"` + } `json:"items"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models := make([]string, 0, len(result.Items)) + for _, item := range result.Items { + models = append(models, item.Title) + } + + return models, nil +} + +func (f *FishAudioModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := f.BaseURL[region] + if baseURL == "" { + baseURL = f.BaseURL["default"] + } + + url := fmt.Sprintf("%s/wallet/self/api-credit", strings.TrimSuffix(baseURL, "/")) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := f.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Fish Audio balance API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + return result, nil +} + +func (f *FishAudioModel) CheckConnection(apiConfig *APIConfig) error { + _, err := f.ListModels(apiConfig) + return err +} + +func (z *FishAudioModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *FishAudioModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/gitee.go b/internal/entity/models/gitee.go index 34d04251029..05e4baa33a7 100644 --- a/internal/entity/models/gitee.go +++ b/internal/entity/models/gitee.go @@ -23,6 +23,7 @@ import ( "encoding/json" "fmt" "io" + "mime/multipart" "net/http" "ragflow/internal/common" "strings" @@ -53,16 +54,16 @@ func NewGiteeModel(baseURL map[string]string, urlSuffix URLSuffix) *GiteeModel { } } -func (z *GiteeModel) NewInstance(baseURL map[string]string) ModelDriver { +func (g *GiteeModel) NewInstance(baseURL map[string]string) ModelDriver { return nil } -func (z *GiteeModel) Name() string { +func (g *GiteeModel) Name() string { return "gitee" } // ChatWithMessages sends multiple messages with roles and returns response -func (z *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { +func (g *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { return nil, fmt.Errorf("api key is nil or empty") } @@ -75,7 +76,7 @@ func (z *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiC if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat) + url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Chat) // Convert messages to the format expected by API apiMessages := make([]map[string]interface{}, len(messages)) @@ -144,7 +145,7 @@ func (z *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiC req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -213,7 +214,7 @@ func (z *GiteeModel) ChatWithMessages(modelName string, messages []Message, apiC } // ChatStreamlyWithSender sends messages and streams response via sender function (best performance, no channel) -func (z *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { +func (g *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { if len(messages) == 0 { return fmt.Errorf("messages is empty") } @@ -223,7 +224,7 @@ func (z *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message region = *apiConfig.Region } - url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region]) + url := fmt.Sprintf("%s/chat/completions", g.BaseURL[region]) // Convert messages to API format apiMessages := make([]map[string]interface{}, len(messages)) @@ -291,7 +292,7 @@ func (z *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -398,9 +399,108 @@ func (z *GiteeModel) ChatStreamlyWithSender(modelName string, messages []Message return scanner.Err() } -// Encode encodes a list of texts into embeddings -func (z *GiteeModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("%s, no such method", z.Name()) +type giteeEmbeddingResponse struct { + Object string `json:"object"` + Data []giteeEmbeddingData `json:"data"` + Model string `json:"model"` + Usage giteeUsage `json:"usage"` +} + +type giteeEmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type giteeUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Embed embeds a list of texts into embeddings +func (g *GiteeModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := g.BaseURL["default"] + if region != "default" { + if regional, ok := g.BaseURL[region]; ok && regional != "" { + baseURL = regional + } + } + if baseURL == "" { + return nil, fmt.Errorf("gitee: no base URL configured for default region") + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), g.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := g.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Gitee embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed giteeEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) + } + + return embeddings, nil } type giteeRerankRequest struct { @@ -412,7 +512,7 @@ type giteeRerankRequest struct { } // Rerank calculates similarity scores between query and documents -func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { +func (g *GiteeModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { if len(documents) == 0 { return &RerankResponse{}, nil } @@ -430,9 +530,9 @@ func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, region = *apiConfig.Region } - baseURL := z.BaseURL["default"] + baseURL := g.BaseURL["default"] if region != "default" { - if regional, ok := z.BaseURL[region]; ok && regional != "" { + if regional, ok := g.BaseURL[region]; ok && regional != "" { baseURL = regional } } @@ -440,7 +540,7 @@ func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, return nil, fmt.Errorf("gitee: no base URL configured for default region") } - url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Rerank) + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), g.URLSuffix.Rerank) var topN = rerankConfig.TopN if rerankConfig.TopN == 0 { @@ -471,7 +571,7 @@ func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -483,7 +583,7 @@ func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("Gitee rerank API error: %s, body: %s", resp.Status, string(body)) + return nil, fmt.Errorf("gitee rerank API error: %s, body: %s", resp.Status, string(body)) } var rerankResponse RerankResponse @@ -494,13 +594,291 @@ func (z *GiteeModel) Rerank(modelName *string, query string, documents []string, return &rerankResponse, nil } -func (z *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { +// TranscribeAudio transcribe audio +func (g *GiteeModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +func (z *GiteeModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (g *GiteeModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +func (z *GiteeModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +type giteeOCRResponse struct { + Text string `json:"text_result"` + Prompt string `json:"prompt"` +} + +// OCRFile OCR file +func (g *GiteeModel) OCRFile(modelName *string, content []byte, imageURL *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + if imageURL == nil && content == nil { + return nil, fmt.Errorf("url or content is required") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := g.BaseURL["default"] + if region != "default" { + if regional, ok := g.BaseURL[region]; ok && regional != "" { + baseURL = regional + } + } + if baseURL == "" { + return nil, fmt.Errorf("gitee: no base URL configured for default region") + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), g.URLSuffix.OCR) + + payload := &bytes.Buffer{} + writer := multipart.NewWriter(payload) + + if err := writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model field: %w", err) + } + + if imageURL != nil { + if err := writer.WriteField("image", *imageURL); err != nil { + return nil, fmt.Errorf("failed to write image URL: %w", err) + } + } else if content != nil && len(content) > 0 { + part, err := writer.CreateFormFile("image", "image") + if err != nil { + return nil, fmt.Errorf("failed to create image form file: %w", err) + } + if _, err = part.Write(content); err != nil { + return nil, fmt.Errorf("failed to write image content: %w", err) + } + } else { + return nil, fmt.Errorf("image or image URL is required") + } + + writer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, payload) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := g.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("gitee OCR API error: %s, body: %s", resp.Status, string(body)) + } + + var giteeResponse giteeOCRResponse + if err = json.Unmarshal(body, &giteeResponse); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + var ocrResponse = OCRFileResponse{ + Text: &giteeResponse.Text, + } + + return &ocrResponse, nil +} + +type giteeParseFileResponse struct { + TaskID string `json:"task_id"` + Status string `json:"status"` + CreatedAt int64 `json:"created_at"` + URLs giteeURLs `json:"urls"` +} + +type giteeURLs struct { + Get string `json:"get"` + Cancel string `json:"cancel"` +} + +// ParseFile parse file +func (g *GiteeModel) ParseFile(modelName *string, content []byte, documentURL *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + if documentURL == nil && content == nil { + return nil, fmt.Errorf("url or content is required") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := g.BaseURL["default"] + if region != "default" { + if regional, ok := g.BaseURL[region]; ok && regional != "" { + baseURL = regional + } + } + if baseURL == "" { + return nil, fmt.Errorf("gitee: no base URL configured for default region") + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), g.URLSuffix.DocumentParse) + + payload := &bytes.Buffer{} + writer := multipart.NewWriter(payload) + + if err := writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model field: %w", err) + } + + if documentURL != nil { + if err := writer.WriteField("file", *documentURL); err != nil { + return nil, fmt.Errorf("failed to write file URL: %w", err) + } + } else if content != nil && len(content) > 0 { + part, err := writer.CreateFormFile("file", "file") + if err != nil { + return nil, fmt.Errorf("failed to create file form file: %w", err) + } + if _, err = part.Write(content); err != nil { + return nil, fmt.Errorf("failed to write file content: %w", err) + } + } else { + return nil, fmt.Errorf("file or file URL is required") + } + + writer.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, payload) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := g.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("gitee OCR API error: %s, body: %s", resp.Status, string(body)) + } + + var giteeParseFileResp giteeParseFileResponse + if err = json.Unmarshal(body, &giteeParseFileResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + _, err = g.getParseFile(&baseURL, apiConfig.ApiKey, &giteeParseFileResp.TaskID, 5*time.Second, 10) + if err != nil { + return nil, err + } + + var parseFileResponse = ParseFileResponse{} + + return &parseFileResponse, nil +} + +type giteeGetParseFileResponse struct { +} + +func (g *GiteeModel) getParseFile(baseURL *string, apiKey, taskID *string, timeOut time.Duration, count int) (*giteeGetParseFileResponse, error) { + url := fmt.Sprintf("%s/task/%s/status", strings.TrimSuffix(*baseURL, "/"), *taskID) + + reqBody := map[string]interface{}{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey)) + + var resp *http.Response + for i := 0; i < count; i++ { + resp, err = g.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + var body []byte + body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + time.Sleep(timeOut) + } + + // if resp show the file is ok, download it. otherwise, provide timeout info + return nil, nil +} + +func (g *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" - if apiConfig.Region != nil { + if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Models) + url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Models) // Build request body reqBody := map[string]interface{}{} @@ -518,7 +896,7 @@ func (z *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -551,13 +929,13 @@ func (z *GiteeModel) ListModels(apiConfig *APIConfig) ([]string, error) { return models, nil } -func (z *GiteeModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { +func (g *GiteeModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { var region = "default" - if apiConfig.Region != nil { + if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Balance) + url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Balance) // Build request body reqBody := map[string]interface{}{} @@ -575,7 +953,7 @@ func (z *GiteeModel) Balance(apiConfig *APIConfig) (map[string]interface{}, erro req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("failed to send request: %w", err) } @@ -606,13 +984,13 @@ func (z *GiteeModel) Balance(apiConfig *APIConfig) (map[string]interface{}, erro return response, nil } -func (z *GiteeModel) CheckConnection(apiConfig *APIConfig) error { +func (g *GiteeModel) CheckConnection(apiConfig *APIConfig) error { var region = "default" if apiConfig.Region != nil { region = *apiConfig.Region } - url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Status) + url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Status) // Build request body reqBody := map[string]interface{}{} @@ -630,7 +1008,7 @@ func (z *GiteeModel) CheckConnection(apiConfig *APIConfig) error { req.Header.Set("Content-Type", "application/json") req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - resp, err := z.httpClient.Do(req) + resp, err := g.httpClient.Do(req) if err != nil { return fmt.Errorf("failed to send request: %w", err) } @@ -647,3 +1025,144 @@ func (z *GiteeModel) CheckConnection(apiConfig *APIConfig) error { return nil } + +type giteeTaskListResponse struct { + Total int `json:"total"` + Items []giteeTaskItem `json:"items"` +} + +type giteeTaskItem struct { + TaskID string `json:"task_id"` + //Output giteeTaskOutput `json:"output"` + Status string `json:"status"` + CreatedAt int64 `json:"created_at"` + StartedAt int64 `json:"started_at,omitempty"` + CompletedAt int64 `json:"completed_at,omitempty"` + Price float64 `json:"price"` + Currency string `json:"currency"` + URLs giteeTaskURLs `json:"urls"` +} + +type giteeTaskOutput struct { + Segments []giteeSegment `json:"segments"` +} + +type giteeSegment struct { + Index int `json:"index"` + Content string `json:"content"` +} +type giteeTaskURLs struct { + Get string `json:"get"` + Cancel string `json:"cancel"` +} + +func (g *GiteeModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", g.BaseURL[region], g.URLSuffix.Tasks) + + // Build request body + reqBody := map[string]interface{}{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := g.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + var body []byte + body, err = io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + var giteeTaskList giteeTaskListResponse + if err = json.Unmarshal(body, &giteeTaskList); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + taskListResp := []ListTaskStatus{} + for _, item := range giteeTaskList.Items { + taskListResp = append(taskListResp, ListTaskStatus{ + TaskID: item.TaskID, + Status: item.Status, + }) + } + return taskListResp, nil +} + +func (g *GiteeModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s/%s/get", g.BaseURL[region], g.URLSuffix.Task, taskID) + + // Build request body + reqBody := map[string]interface{}{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := g.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var taskOutput giteeTaskOutput + if err = json.Unmarshal(body, &taskOutput); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + taskResp := &TaskResponse{} + + for _, segment := range taskOutput.Segments { + taskResp.Segments = append(taskResp.Segments, TaskSegment{ + Index: segment.Index, + Content: segment.Content, + }) + } + + return taskResp, nil +} diff --git a/internal/entity/models/google.go b/internal/entity/models/google.go index b5679ac8da9..2702a04384e 100644 --- a/internal/entity/models/google.go +++ b/internal/entity/models/google.go @@ -20,11 +20,58 @@ import ( "context" "fmt" "ragflow/internal/common" + "strings" "google.golang.org/genai" ) -// GoogleModel implements ModelDriver for Dummy AI +type googleModelPage struct { + items []string + nextPageToken string +} + +func collectGoogleModelNames(ctx context.Context, listPage func(context.Context, string) (googleModelPage, error)) ([]string, error) { + var modelNames []string + pageToken := "" + + for { + page, err := listPage(ctx, pageToken) + if err != nil { + return nil, err + } + + modelNames = append(modelNames, page.items...) + if page.nextPageToken == "" { + return modelNames, nil + } + pageToken = page.nextPageToken + } +} + +var googleListModels = func(ctx context.Context, apiKey string) ([]string, error) { + client, err := genai.NewClient(ctx, &genai.ClientConfig{ + APIKey: apiKey, + Backend: genai.BackendGeminiAPI, + }) + if err != nil { + return nil, err + } + + return collectGoogleModelNames(ctx, func(ctx context.Context, pageToken string) (googleModelPage, error) { + models, err := client.Models.List(ctx, &genai.ListModelsConfig{PageToken: pageToken}) + if err != nil { + return googleModelPage{}, err + } + + var modelNames []string + for _, m := range models.Items { + modelNames = append(modelNames, m.Name) + } + return googleModelPage{items: modelNames, nextPageToken: models.NextPageToken}, nil + }) +} + +// GoogleModel implements ModelDriver for Google AI type GoogleModel struct { BaseURL map[string]string URLSuffix URLSuffix @@ -38,15 +85,15 @@ func NewGoogleModel(baseURL map[string]string, urlSuffix URLSuffix) *GoogleModel } } -func (z *GoogleModel) NewInstance(baseURL map[string]string) ModelDriver { +func (g *GoogleModel) NewInstance(baseURL map[string]string) ModelDriver { return nil } -func (z *GoogleModel) Name() string { +func (g *GoogleModel) Name() string { return "google" } -func (z *GoogleModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { +func (g *GoogleModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { return nil, fmt.Errorf("api key is nil or empty") } @@ -120,7 +167,7 @@ func (z *GoogleModel) ChatWithMessages(modelName string, messages []Message, api } // ChatStreamlyWithSender sends messages and streams response via sender function (best performance, no channel) -func (z *GoogleModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { +func (g *GoogleModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { if len(messages) == 0 { return fmt.Errorf("messages is empty") } @@ -212,43 +259,119 @@ func (z *GoogleModel) ChatStreamlyWithSender(modelName string, messages []Messag return err } -// Encode encodes a list of texts into embeddings -func (z *GoogleModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("not implemented") -} +// Embed generates embeddings for a batch of texts using the Gemini embeddings API. +// The SDK routes to batchEmbedContents internally, so all texts are sent in one request. +func (g *GoogleModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + if len(texts) == 0 { + return nil, fmt.Errorf("texts is empty") + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() -func (z *GoogleModel) ListModels(apiConfig *APIConfig) ([]string, error) { - ctx := context.Background() client, err := genai.NewClient(ctx, &genai.ClientConfig{ APIKey: *apiConfig.ApiKey, Backend: genai.BackendGeminiAPI, }) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create client: %w", err) } - // Retrieve the list of models. - models, err := client.Models.List(ctx, &genai.ListModelsConfig{}) + contents := make([]*genai.Content, len(texts)) + for i, text := range texts { + contents[i] = genai.NewContentFromText(text, genai.RoleUser) + } + + var cfg *genai.EmbedContentConfig + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + dim := int32(embeddingConfig.Dimension) + cfg = &genai.EmbedContentConfig{OutputDimensionality: &dim} + } + + resp, err := client.Models.EmbedContent(ctx, *modelName, contents, cfg) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to embed content: %w", err) } - var modelNames []string - for _, m := range models.Items { - modelNames = append(modelNames, m.Name) + if len(resp.Embeddings) != len(texts) { + return nil, fmt.Errorf("expected %d embeddings, got %d", len(texts), len(resp.Embeddings)) + } + + result := make([]EmbeddingData, len(resp.Embeddings)) + for i, emb := range resp.Embeddings { + vec := make([]float64, len(emb.Values)) + for j, v := range emb.Values { + vec[j] = float64(v) + } + result[i] = EmbeddingData{ + Embedding: vec, + Index: i, + } + } + + return result, nil +} + +func (g *GoogleModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || strings.TrimSpace(*apiConfig.ApiKey) == "" { + return nil, fmt.Errorf("api key is required") } - return modelNames, nil + + return googleListModels(context.Background(), *apiConfig.ApiKey) } -func (z *GoogleModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { +func (g *GoogleModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { return nil, fmt.Errorf("no such method") } -func (z *GoogleModel) CheckConnection(apiConfig *APIConfig) error { - return fmt.Errorf("no such method") +func (g *GoogleModel) CheckConnection(apiConfig *APIConfig) error { + _, err := g.ListModels(apiConfig) + return err } // Rerank calculates similarity scores between query and documents -func (z *GoogleModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) +func (g *GoogleModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, Rerank not implemented", g.Name()) +} + +// TranscribeAudio transcribe audio +func (g *GoogleModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +func (z *GoogleModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (g *GoogleModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +func (z *GoogleModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (g *GoogleModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", g.Name()) +} + +// ParseFile parse file +func (z *GoogleModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *GoogleModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *GoogleModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) } diff --git a/internal/entity/models/google_test.go b/internal/entity/models/google_test.go new file mode 100644 index 00000000000..5b09c7a1686 --- /dev/null +++ b/internal/entity/models/google_test.go @@ -0,0 +1,249 @@ +package models + +import ( + "context" + "errors" + "reflect" + "strings" + "sync" + "testing" +) + +var googleListModelsMu sync.Mutex + +func withGoogleListModelsStub(t *testing.T, fn func(context.Context, string) ([]string, error)) { + t.Helper() + + googleListModelsMu.Lock() + original := googleListModels + googleListModels = fn + t.Cleanup(func() { + googleListModels = original + googleListModelsMu.Unlock() + }) +} + +func TestGoogleModelListModelsRequiresAPIKey(t *testing.T) { + model := &GoogleModel{} + cases := []struct { + name string + apiConfig *APIConfig + }{ + { + name: "nil config", + apiConfig: nil, + }, + { + name: "nil api key", + apiConfig: &APIConfig{}, + }, + { + name: "empty api key", + apiConfig: &APIConfig{ + ApiKey: stringPtr(""), + }, + }, + { + name: "blank api key", + apiConfig: &APIConfig{ + ApiKey: stringPtr(" \t\n "), + }, + }, + } + + calls := 0 + withGoogleListModelsStub(t, func(context.Context, string) ([]string, error) { + calls++ + return nil, nil + }) + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + models, err := model.ListModels(tc.apiConfig) + if err == nil { + t.Fatal("expected an API key error") + } + if !strings.Contains(err.Error(), "api key is required") { + t.Fatalf("expected API key error, got %v", err) + } + if models != nil { + t.Fatalf("expected no models, got %v", models) + } + }) + } + + if calls != 0 { + t.Fatalf("expected no ListModels calls without an API key, got %d", calls) + } +} + +func TestGoogleModelListModelsReturnsModelNames(t *testing.T) { + model := &GoogleModel{} + apiKey := "test-api-key" + expected := []string{"models/gemini-2.5-flash", "models/gemini-2.5-pro"} + + withGoogleListModelsStub(t, func(_ context.Context, gotAPIKey string) ([]string, error) { + if gotAPIKey != apiKey { + t.Fatalf("expected API key %q, got %q", apiKey, gotAPIKey) + } + return expected, nil + }) + + models, err := model.ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if !reflect.DeepEqual(models, expected) { + t.Fatalf("expected models %v, got %v", expected, models) + } +} + +func TestGoogleModelCheckConnectionUsesListModels(t *testing.T) { + model := &GoogleModel{} + apiKey := "test-api-key" + calls := 0 + + withGoogleListModelsStub(t, func(_ context.Context, gotAPIKey string) ([]string, error) { + calls++ + if gotAPIKey != apiKey { + t.Fatalf("expected API key %q, got %q", apiKey, gotAPIKey) + } + return []string{"models/gemini-2.5-flash"}, nil + }) + + if err := model.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil { + t.Fatalf("expected no error, got %v", err) + } + if calls != 1 { + t.Fatalf("expected one ListModels call, got %d", calls) + } +} + +func TestGoogleModelCheckConnectionRequiresAPIKey(t *testing.T) { + model := &GoogleModel{} + calls := 0 + + withGoogleListModelsStub(t, func(context.Context, string) ([]string, error) { + calls++ + return nil, nil + }) + + cases := []struct { + name string + apiConfig *APIConfig + }{ + { + name: "nil config", + apiConfig: nil, + }, + { + name: "nil api key", + apiConfig: &APIConfig{}, + }, + { + name: "empty api key", + apiConfig: &APIConfig{ + ApiKey: stringPtr(""), + }, + }, + { + name: "blank api key", + apiConfig: &APIConfig{ + ApiKey: stringPtr(" \t\n "), + }, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + err := model.CheckConnection(tc.apiConfig) + if err == nil { + t.Fatal("expected an API key error") + } + if !strings.Contains(err.Error(), "api key is required") { + t.Fatalf("expected API key error, got %v", err) + } + }) + } + if calls != 0 { + t.Fatalf("expected no ListModels calls without an API key, got %d", calls) + } +} + +func TestGoogleModelCheckConnectionReturnsListModelsError(t *testing.T) { + model := &GoogleModel{} + apiKey := "test-api-key" + listErr := errors.New("list models failed") + + withGoogleListModelsStub(t, func(context.Context, string) ([]string, error) { + return nil, listErr + }) + + err := model.CheckConnection(&APIConfig{ApiKey: &apiKey}) + if !errors.Is(err, listErr) { + t.Fatalf("expected ListModels error %v, got %v", listErr, err) + } +} + +func TestCollectGoogleModelNamesPaginates(t *testing.T) { + pages := []googleModelPage{ + {items: []string{"models/gemini-2.5-flash"}, nextPageToken: "page-2"}, + {items: []string{"models/gemini-2.5-pro"}, nextPageToken: ""}, + } + var pageTokens []string + + models, err := collectGoogleModelNames(context.Background(), func(_ context.Context, pageToken string) (googleModelPage, error) { + pageTokens = append(pageTokens, pageToken) + if len(pageTokens) > len(pages) { + t.Fatalf("unexpected extra page request with token %q", pageToken) + } + return pages[len(pageTokens)-1], nil + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + + expectedModels := []string{"models/gemini-2.5-flash", "models/gemini-2.5-pro"} + if !reflect.DeepEqual(models, expectedModels) { + t.Fatalf("expected models %v, got %v", expectedModels, models) + } + expectedPageTokens := []string{"", "page-2"} + if !reflect.DeepEqual(pageTokens, expectedPageTokens) { + t.Fatalf("expected page tokens %v, got %v", expectedPageTokens, pageTokens) + } +} + +func TestCollectGoogleModelNamesPreservesEmptyResult(t *testing.T) { + models, err := collectGoogleModelNames(context.Background(), func(context.Context, string) (googleModelPage, error) { + return googleModelPage{}, nil + }) + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + if models != nil { + t.Fatalf("expected nil models, got %v", models) + } +} + +func TestCollectGoogleModelNamesReturnsPageError(t *testing.T) { + pageErr := errors.New("next page failed") + calls := 0 + + models, err := collectGoogleModelNames(context.Background(), func(context.Context, string) (googleModelPage, error) { + calls++ + if calls == 1 { + return googleModelPage{items: []string{"models/gemini-2.5-flash"}, nextPageToken: "page-2"}, nil + } + return googleModelPage{}, pageErr + }) + if !errors.Is(err, pageErr) { + t.Fatalf("expected page error %v, got %v", pageErr, err) + } + if models != nil { + t.Fatalf("expected no models on error, got %v", models) + } +} + +func stringPtr(value string) *string { + return &value +} diff --git a/internal/entity/models/huggingface.go b/internal/entity/models/huggingface.go index d1160d1c46c..87e41242bfb 100644 --- a/internal/entity/models/huggingface.go +++ b/internal/entity/models/huggingface.go @@ -26,12 +26,6 @@ func NewHuggingFaceModel(baseURL map[string]string, urlSuffix URLSuffix) *Huggin URLSuffix: urlSuffix, httpClient: &http.Client{ Timeout: 120 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 10, - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - DisableCompression: false, - }, }, } } @@ -41,12 +35,6 @@ func (h *HuggingFaceModel) NewInstance(baseURL map[string]string) ModelDriver { URLSuffix: h.URLSuffix, httpClient: &http.Client{ Timeout: 120 * time.Second, - Transport: &http.Transport{ - MaxIdleConns: 10, - MaxIdleConnsPerHost: 100, - IdleConnTimeout: 90 * time.Second, - DisableCompression: false, - }, }, } } @@ -204,7 +192,7 @@ func (h *HuggingFaceModel) ChatStreamlyWithSender(modelName string, messages []M region = *apiConfig.Region } - url := fmt.Sprintf("%s/chat/completions", h.BaseURL[region]) + url := fmt.Sprintf("%s/%s", h.BaseURL[region], h.URLSuffix.Chat) // Convert messages to API format apiMessages := make([]map[string]interface{}, len(messages)) @@ -351,15 +339,14 @@ func (h *HuggingFaceModel) ChatStreamlyWithSender(modelName string, messages []M return scanner.Err() } -type hfEmbeddingRequest struct { - Inputs []string `json:"inputs"` -} - -type hfEmbeddingResponse [][]float64 - -func (h *HuggingFaceModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +func (h *HuggingFaceModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { if len(texts) == 0 { - return [][]float64{}, nil + return []EmbeddingData{}, nil + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region } if modelName == nil || *modelName == "" { @@ -379,7 +366,7 @@ func (h *HuggingFaceModel) Encode(modelName *string, texts []string, apiConfig * return nil, err } - url := fmt.Sprintf("https://router.huggingface.co/hf-inference/models/%s", *modelName) + url := fmt.Sprintf("%s/%s/%s", h.BaseURL[region], h.URLSuffix.Embedding, *modelName) req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { @@ -404,18 +391,54 @@ func (h *HuggingFaceModel) Encode(modelName *string, texts []string, apiConfig * return nil, fmt.Errorf("HF embeddings API error: %s", string(body)) } - var result [][]float64 - if err = json.Unmarshal(body, &result); err != nil { - return nil, err + var parsed openaiEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) } - return result, nil + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) + } + + return embeddings, nil } func (h *HuggingFaceModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("no such method") } +// TranscribeAudio transcribe audio +func (h *HuggingFaceModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", h.Name()) +} + +func (z *HuggingFaceModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (h *HuggingFaceModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", h.Name()) +} + +func (z *HuggingFaceModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (h *HuggingFaceModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", h.Name()) +} + +// ParseFile parse file +func (z *HuggingFaceModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + func (h *HuggingFaceModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig.Region != nil && *apiConfig.Region != "" { @@ -479,3 +502,11 @@ func (h *HuggingFaceModel) CheckConnection(apiConfig *APIConfig) error { _, err := h.ListModels(apiConfig) return err } + +func (z *HuggingFaceModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *HuggingFaceModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/jiekouai.go b/internal/entity/models/jiekouai.go new file mode 100644 index 00000000000..459b8d959cf --- /dev/null +++ b/internal/entity/models/jiekouai.go @@ -0,0 +1,586 @@ +package models + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "ragflow/internal/common" + "strings" + "time" +) + +type JieKouAIModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewJieKouAIModel(baseURL map[string]string, urlSuffix URLSuffix) *JieKouAIModel { + return &JieKouAIModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (j *JieKouAIModel) NewInstance(baseURL map[string]string) ModelDriver { + return &JieKouAIModel{ + BaseURL: baseURL, + URLSuffix: j.URLSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (j *JieKouAIModel) Name() string { + return "jiekouai" +} + +func (j *JieKouAIModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", j.BaseURL[region], j.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + "temperature": 1, + } + + if chatModelConfig != nil { + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + + // For zai-org/glm-4.5: https://docs.jiekou.ai/docs/models/reference-llm-create-chat-completion + if chatModelConfig.Thinking != nil { + if *chatModelConfig.Thinking { + reqBody["enable_thinking"] = *chatModelConfig.Thinking + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := j.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + var reasonContent string + if chatModelConfig != nil && chatModelConfig.Thinking != nil && *chatModelConfig.Thinking { + reasonContent, ok = messageMap["reasoning_content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + // if first char of reasonContent is \n remove the '\n' + if reasonContent != "" && reasonContent[0] == '\n' { + reasonContent = reasonContent[1:] + } + } + + chatResponse := &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + } + + return chatResponse, nil +} + +func (j *JieKouAIModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(j.BaseURL[region], "/"), j.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + // Build request body with streaming enabled + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + "temperature": 1, + } + + if modelConfig != nil { + if modelConfig.Stream != nil { + reqBody["stream"] = *modelConfig.Stream + } + + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + + if modelConfig.DoSample != nil { + reqBody["do_sample"] = *modelConfig.DoSample + } + + if modelConfig.TopP != nil { + reqBody["top_p"] = *modelConfig.TopP + } + + if modelConfig.Stop != nil { + reqBody["stop"] = *modelConfig.Stop + } + + if modelConfig.Thinking != nil { + if *modelConfig.Thinking { + reqBody["enable_thinking"] = *modelConfig.Thinking + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := j.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: read line by line + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + common.Info(line) + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + reasoningContent, ok := delta["reasoning_content"].(string) + if ok && reasoningContent != "" { + if err := sender(nil, &reasoningContent); err != nil { + return err + } + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + break + } + } + + // Send [DONE] marker for OpenAI compatibility + endOfStream := "[DONE]" + if err = sender(&endOfStream, nil); err != nil { + return err + } + + return scanner.Err() +} + +func (j *JieKouAIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, fmt.Errorf("texts is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(j.BaseURL[region], "/"), j.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := j.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var parsedResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + } + + if err = json.Unmarshal(body, &parsedResponse); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if len(parsedResponse.Data) == 0 { + return nil, fmt.Errorf("failed to parse response") + } + + var embeddings []EmbeddingData + for _, embedding := range parsedResponse.Data { + embeddings = append(embeddings, EmbeddingData{ + Embedding: embedding.Embedding, + Index: embedding.Index, + }) + } + + return embeddings, nil +} + +func (j *JieKouAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(j.BaseURL[region], "/"), j.URLSuffix.Rerank) + + var topN = rerankConfig.TopN + if rerankConfig.TopN != 0 { + topN = rerankConfig.TopN + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "query": query, + "documents": documents, + "top_n": topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := j.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var rerankResp struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` + } + + if err = json.Unmarshal(body, &rerankResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + var rerankResponse RerankResponse + for _, result := range rerankResp.Results { + rerankResult := RerankResult{ + Index: result.Index, + RelevanceScore: result.RelevanceScore, + } + rerankResponse.Data = append(rerankResponse.Data, rerankResult) + } + + return &rerankResponse, nil +} + +func (j *JieKouAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", j.Name()) +} + +func (j *JieKouAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", j.Name()) +} + +func (j *JieKouAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", j.Name()) +} + +func (j *JieKouAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", j.Name()) +} + +func (j *JieKouAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", j.Name()) +} + +func (j *JieKouAIModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", j.Name()) +} + +func (j *JieKouAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", j.BaseURL[region], j.URLSuffix.Models) + + reqBody := map[string]string{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := j.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // convert result["data"] to []map[string]interface{} + models := make([]string, 0) + for _, model := range result["data"].([]interface{}) { + modelMap := model.(map[string]interface{}) + modelName := modelMap["id"].(string) + models = append(models, modelName) + } + + return models, nil +} + +func (j *JieKouAIModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", j.Name()) +} + +func (j *JieKouAIModel) CheckConnection(apiConfig *APIConfig) error { + _, err := j.ListModels(apiConfig) + return err +} + +func (j *JieKouAIModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s no such method", j.Name()) +} + +func (j *JieKouAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s no such method", j.Name()) +} diff --git a/internal/entity/models/jina.go b/internal/entity/models/jina.go new file mode 100644 index 00000000000..fca88664a1b --- /dev/null +++ b/internal/entity/models/jina.go @@ -0,0 +1,406 @@ +package models + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type JinaModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewJinaModel(baseURL map[string]string, urlSuffix URLSuffix) *JinaModel { + return &JinaModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 90, + }, + } +} + +func (j *JinaModel) NewInstance(baseURL map[string]string) ModelDriver { + return &JinaModel{ + BaseURL: baseURL, + URLSuffix: j.URLSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 90, + }, + } +} + +func (j *JinaModel) Name() string { + return "jina" +} + +func (j *JinaModel) baseURLForRegion(region string) (string, error) { + base, ok := j.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("jina: no base URL configured for region %q", region) + } + return base, nil +} + +func (j *JinaModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == "" { + return nil, fmt.Errorf("model name is required") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := j.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, j.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := j.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Jina chat API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + reasonContent := "" + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +func (j *JinaModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + //TODO implement me: https://api.jina.ai/docs#/Search%20Foundation%20Models/chat_completions_v1_chat_completions_post + return fmt.Errorf("jina does not implement ChatStreamlyWithSender(not available for now)") +} + +func (j *JinaModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", j.BaseURL[region], j.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := j.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Jina embedding API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var parsedResponse struct { + Data []struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + } `json:"data"` + } + + if err = json.Unmarshal(body, &parsedResponse); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + if len(parsedResponse.Data) == 0 { + return nil, fmt.Errorf("Jina embedding response contains no data: %s", string(body)) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsedResponse.Data { + embeddings = append(embeddings, EmbeddingData{ + Embedding: dataElem.Embedding, + Index: dataElem.Index, + }) + } + + return embeddings, nil +} + +func (j *JinaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", j.BaseURL[region], j.URLSuffix.Rerank) + + var topN = rerankConfig.TopN + if rerankConfig.TopN != 0 { + topN = rerankConfig.TopN + } + + reqBody := map[string]interface{}{ + "model": *modelName, + "query": query, + "documents": documents, + "top_n": topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := j.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Jina Rerank API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var rerankResp struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` + } + + if err = json.Unmarshal(body, &rerankResp); err != nil { + return nil, fmt.Errorf("failed to decode response: %w", err) + } + + var rerankResponse RerankResponse + for _, result := range rerankResp.Results { + rerankResult := RerankResult{ + Index: result.Index, + RelevanceScore: result.RelevanceScore, + } + rerankResponse.Data = append(rerankResponse.Data, rerankResult) + } + + return &rerankResponse, nil +} + +func (j *JinaModel) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", j.BaseURL[region], j.URLSuffix.Models) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := j.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // convert result["data"] to []map[string]interface{} + models := make([]string, 0) + for _, model := range result["data"].([]interface{}) { + modelMap := model.(map[string]interface{}) + modelName := modelMap["name"].(string) + models = append(models, modelName) + } + + return models, nil +} + +func (j *JinaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +func (j *JinaModel) CheckConnection(apiConfig *APIConfig) error { + _, err := j.ListModels(apiConfig) + return err +} + +// TranscribeAudio transcribe audio +func (z *JinaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *JinaModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (z *JinaModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *JinaModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (z *JinaModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +// ParseFile parse file +func (z *JinaModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *JinaModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *JinaModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/jina_test.go b/internal/entity/models/jina_test.go new file mode 100644 index 00000000000..2ae8a7be86d --- /dev/null +++ b/internal/entity/models/jina_test.go @@ -0,0 +1,240 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newJinaServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("expected Content-Type=application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("invalid JSON body: %v\n%s", err, string(raw)) + return + } + handler(t, body, w) + })) +} + +func newJinaForTest(baseURL string) *JinaModel { + return NewJinaModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "chat/completions", + Models: "models", + Embedding: "embeddings", + Rerank: "rerank", + }, + ) +} + +func TestJinaChatHappyPath(t *testing.T) { + srv := newJinaServer(t, "/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "jina-vlm" { + t.Errorf("expected model=jina-vlm, got %v", body["model"]) + } + if body["stream"] != false { + t.Errorf("expected stream=false, got %v", body["stream"]) + } + msgs, ok := body["messages"].([]interface{}) + if !ok || len(msgs) != 1 { + t.Errorf("expected 1 message, got %v", body["messages"]) + return + } + msg, ok := msgs[0].(map[string]interface{}) + if !ok || msg["role"] != "user" || msg["content"] != "ping" { + t.Errorf("unexpected message payload: %v", msgs[0]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{ + {"message": map[string]interface{}{"content": "pong"}}, + }, + }) + }) + defer srv.Close() + + j := newJinaForTest(srv.URL) + apiKey := "test-key" + resp, err := j.ChatWithMessages("jina-vlm", []Message{{Role: "user", Content: "ping"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "pong" { + t.Errorf("answer=%v, want pong", resp.Answer) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "" { + t.Errorf("expected empty reason content, got %v", resp.ReasonContent) + } +} + +func TestJinaChatPropagatesConfig(t *testing.T) { + srv := newJinaServer(t, "/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["max_tokens"] != float64(128) { + t.Errorf("max_tokens=%v want 128", body["max_tokens"]) + } + if body["temperature"] != 0.2 { + t.Errorf("temperature=%v want 0.2", body["temperature"]) + } + if body["top_p"] != 0.8 { + t.Errorf("top_p=%v want 0.8", body["top_p"]) + } + stop, ok := body["stop"].([]interface{}) + if !ok || len(stop) != 1 || stop[0] != "END" { + t.Errorf("stop=%v want [END]", body["stop"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{"message": map[string]interface{}{"content": "ok"}}}, + }) + }) + defer srv.Close() + + j := newJinaForTest(srv.URL) + apiKey := "test-key" + maxTokens := 128 + temperature := 0.2 + topP := 0.8 + stop := []string{"END"} + _, err := j.ChatWithMessages("jina-vlm", []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &maxTokens, Temperature: &temperature, TopP: &topP, Stop: &stop}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } +} + +func TestJinaChatValidation(t *testing.T) { + j := newJinaForTest("http://unused") + apiKey := "test-key" + emptyKey := "" + + tests := []struct { + name string + modelName string + messages []Message + apiConfig *APIConfig + want string + }{ + { + name: "missing api config", + modelName: "jina-vlm", + messages: []Message{{Role: "user", Content: "x"}}, + want: "api key is required", + }, + { + name: "missing api key", + modelName: "jina-vlm", + messages: []Message{{Role: "user", Content: "x"}}, + apiConfig: &APIConfig{}, + want: "api key is required", + }, + { + name: "empty api key", + modelName: "jina-vlm", + messages: []Message{{Role: "user", Content: "x"}}, + apiConfig: &APIConfig{ApiKey: &emptyKey}, + want: "api key is required", + }, + { + name: "missing model", + messages: []Message{{Role: "user", Content: "x"}}, + apiConfig: &APIConfig{ApiKey: &apiKey}, + want: "model name is required", + }, + { + name: "missing messages", + modelName: "jina-vlm", + apiConfig: &APIConfig{ApiKey: &apiKey}, + want: "messages is empty", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := j.ChatWithMessages(tt.modelName, tt.messages, tt.apiConfig, nil) + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("expected %q error, got %v", tt.want, err) + } + }) + } +} + +func TestJinaChatRejectsHTTPError(t *testing.T) { + srv := newJinaServer(t, "/chat/completions", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"detail":"invalid api key"}`)) + }) + defer srv.Close() + + j := newJinaForTest(srv.URL) + apiKey := "test-key" + _, err := j.ChatWithMessages("jina-vlm", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "status 401") { + t.Errorf("expected 401 propagated, got %v", err) + } +} + +func TestJinaChatRejectsMalformedResponse(t *testing.T) { + srv := newJinaServer(t, "/chat/completions", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"choices": []map[string]interface{}{}}) + }) + defer srv.Close() + + j := newJinaForTest(srv.URL) + apiKey := "test-key" + _, err := j.ChatWithMessages("jina-vlm", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "no choices in response") { + t.Errorf("expected malformed-response error, got %v", err) + } +} + +func TestJinaChatRejectsUnknownRegion(t *testing.T) { + j := newJinaForTest("http://unused") + apiKey := "test-key" + region := "eu" + _, err := j.ChatWithMessages("jina-vlm", []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey, Region: ®ion}, nil) + if err == nil || !strings.Contains(err.Error(), "no base URL configured for region") { + t.Errorf("expected region error, got %v", err) + } +} + +func TestJinaChatFallsBackToDefaultOnEmptyRegion(t *testing.T) { + srv := newJinaServer(t, "/chat/completions", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{"message": map[string]interface{}{"content": "ok"}}}, + }) + }) + defer srv.Close() + + j := newJinaForTest(srv.URL) + apiKey := "test-key" + emptyRegion := "" + _, err := j.ChatWithMessages("jina-vlm", []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey, Region: &emptyRegion}, nil) + if err != nil { + t.Errorf("empty Region: expected fallback to default, got %v", err) + } +} diff --git a/internal/entity/models/lmstudio.go b/internal/entity/models/lmstudio.go index 89a40e4685b..d5d60efb37f 100644 --- a/internal/entity/models/lmstudio.go +++ b/internal/entity/models/lmstudio.go @@ -3,6 +3,7 @@ package models import ( "bufio" "bytes" + "context" "encoding/json" "fmt" "io" @@ -361,14 +362,119 @@ func (l *LmStudioModel) ChatStreamlyWithSender(modelName string, messages []Mess return scanner.Err() } -func (l *LmStudioModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("no such method") +func (l *LmStudioModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := l.BaseURL[region] + if baseURL == "" { + baseURL = l.BaseURL["default"] + } + if baseURL == "" { + return nil, fmt.Errorf("missing base URL: please configure the local access address for LM Studio (e.g., http://127.0.0.1:1234/v1)") + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), l.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } + + resp, err := l.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("LM Studio embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed openaiEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) + } + + return embeddings, nil } func (l *LmStudioModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("no such method") } +// TranscribeAudio transcribe audio +func (z *LmStudioModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *LmStudioModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (z *LmStudioModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *LmStudioModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (l *LmStudioModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +// ParseFile parse file +func (z *LmStudioModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + // ListModels list supported models func (l *LmStudioModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" @@ -447,3 +553,11 @@ func (l *LmStudioModel) CheckConnection(apiConfig *APIConfig) error { _, err := l.ListModels(apiConfig) return err } + +func (z *LmStudioModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *LmStudioModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/localai.go b/internal/entity/models/localai.go new file mode 100644 index 00000000000..b218709ba19 --- /dev/null +++ b/internal/entity/models/localai.go @@ -0,0 +1,837 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// localAIStreamIdleTimeout bounds how long ChatStreamlyWithSender will +// wait between SSE chunks before assuming the upstream has stalled and +// aborting the request. A local LLM normally emits at least one token +// every few seconds; 60s is generous enough to never break a working +// stream but tight enough to bound a worst-case mid-body hang. +// +// var (not const) so tests can lower it without waiting a real minute. +var localAIStreamIdleTimeout = 60 * time.Second + +// LocalAIModel implements ModelDriver for LocalAI, a self-hosted +// OpenAI-compatible inference server (https://localai.io). +// +// Unlike cloud providers, LocalAI runs on a tenant-supplied base URL +// (for example http://127.0.0.1:8080/v1). The driver therefore reads +// the base URL from the per-instance map at call time and does not +// assume a "default" entry. The API key is optional: LocalAI accepts +// an empty key by default, and the driver only sets the Authorization +// header when a non-empty key was supplied. +type LocalAIModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewLocalAIModel creates a new LocalAI model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +// +// The Client itself has no Timeout. http.Client.Timeout would also +// cap the time spent reading the response body, which would cut off +// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming +// callers wrap each request with context.WithTimeout instead. +func NewLocalAIModel(baseURL map[string]string, urlSuffix URLSuffix) *LocalAIModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &LocalAIModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (l *LocalAIModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewLocalAIModel(baseURL, l.URLSuffix) +} + +func (l *LocalAIModel) Name() string { + return "localai" +} + +// resolveBaseURL returns the tenant-supplied base URL for the given +// region, falling back to the "default" entry, and fails with a clear +// message when nothing is configured. LocalAI is self-hosted so the +// driver cannot fall back to a public endpoint. +func (l *LocalAIModel) resolveBaseURL(region string) (string, error) { + if base, ok := l.BaseURL[region]; ok && base != "" { + return strings.TrimSuffix(base, "/"), nil + } + if base, ok := l.BaseURL["default"]; ok && base != "" { + return strings.TrimSuffix(base, "/"), nil + } + return "", fmt.Errorf("localai: missing base URL, configure the local access address (e.g., http://127.0.0.1:8080/v1)") +} + +// setAuth sets the Authorization header only when a non-empty API key +// is supplied. LocalAI accepts an empty key by default, so sending +// "Bearer " (with an empty value) would be wrong in both directions: +// some local proxies reject it, and it leaks the fact that the +// driver was misconfigured. +func setLocalAIAuth(req *http.Request, apiConfig *APIConfig) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) +} + +// localAIReasoningFields lists the JSON field names that different +// upstream models put their chain-of-thought into. LocalAI is a proxy +// that can route to any of these, so the driver tries each in turn: +// +// - reasoning_content: OpenAI o-series, kimi-k2.6, DeepSeek-R1, +// magistral when proxied through an OpenAI-shim +// - reasoning: Upstage solar-pro3 (and its proxies) +// - thinking: Qwen3 (Ollama-style) and some local llama-r1 +// variants exposed through LocalAI's OpenAI shim +// +// The first non-empty match wins. Order matters: reasoning_content is +// the OpenAI-conformant name and the most widely used, so it's tried +// first. +var localAIReasoningFields = []string{"reasoning_content", "reasoning", "thinking"} + +// extractLocalAIReasoning pulls the chain-of-thought out of a message +// or delta object regardless of which field name the upstream model +// chose. Returns "" when no reasoning field is present or non-string. +func extractLocalAIReasoning(m map[string]interface{}) string { + for _, k := range localAIReasoningFields { + if v, ok := m[k].(string); ok && v != "" { + return v + } + } + return "" +} + +// addLocalAIReasoningRequestParams propagates the caller's request-side +// reasoning intent into the body. Different upstream models behind +// LocalAI accept different parameters: +// +// - reasoning_effort: OpenAI-compatible reasoning APIs (kimi, magistral, +// solar-pro2/pro3, gpt-o-series, R1 proxies) +// - enable_thinking: Qwen3 explicit thinking toggle +// +// Both are emitted when the caller opts in, so the request works +// against whichever family the LocalAI instance routes to. A non- +// supporting upstream simply ignores the extra field. +func addLocalAIReasoningRequestParams(reqBody map[string]interface{}, cfg *ChatConfig) { + if cfg == nil { + return + } + if cfg.Effort != nil && *cfg.Effort != "" { + reqBody["reasoning_effort"] = *cfg.Effort + } + if cfg.Thinking != nil { + reqBody["enable_thinking"] = *cfg.Thinking + } +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (l *LocalAIModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := l.resolveBaseURL(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + // Note: do NOT propagate chatModelConfig.Stream into the request body + // here. ChatWithMessages parses a single JSON response, so stream must + // always be off for this code path. + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + // LocalAI is a proxy; emit both reasoning_effort and + // enable_thinking so the request works regardless of which + // model family the LocalAI instance routes to. See + // addLocalAIReasoningRequestParams. + addLocalAIReasoningRequestParams(reqBody, chatModelConfig) + } + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + setLocalAIAuth(req, apiConfig) + + resp, err := l.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + // Pull the chain-of-thought from whichever field the upstream model + // used. See localAIReasoningFields for the priority order. + reasonContent := extractLocalAIReasoning(messageMap) + + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. The LocalAI SSE stream uses the same shape as OpenAI: +// "data:" lines carrying JSON events, with a final "[DONE]" line. +func (l *LocalAIModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := l.resolveBaseURL(region) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + // Refuse to run if the caller explicitly asked for stream=false. + // The body of this method only knows how to read SSE, so a + // non-SSE JSON response would be parsed as if it were a stream + // and produce no chunks. Better to fail clearly. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + // LocalAI is a proxy; emit both reasoning_effort and + // enable_thinking so the streaming request works regardless of + // which model family the LocalAI instance routes to. + addLocalAIReasoningRequestParams(reqBody, chatModelConfig) + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // SSE streams are long-lived, so we cannot attach a hard deadline: + // a legitimate response may take many minutes to finish on a busy + // local model. Instead, wrap the request with WithCancel and run + // an idle watchdog below that calls cancel() if no new data has + // arrived for streamIdleTimeout. That bounds the worst-case stall + // to a known finite window without breaking working long streams. + // + // Threading a real caller-supplied ctx through the ModelDriver + // interface remains a wider follow-up; this is the contained fix. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + setLocalAIAuth(req, apiConfig) + + resp, err := l.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // Idle watchdog: every successful Scan resets lastActive. If + // streamIdleTimeout passes without a reset, the watchdog calls + // cancel(), which closes the underlying connection. The blocking + // scanner.Scan() then returns false with the context error in + // scanner.Err(), and we surface it to the caller instead of + // hanging the goroutine forever. + lastActive := time.Now() + var lastActiveMu sync.Mutex + done := make(chan struct{}) + defer close(done) + go func() { + ticker := time.NewTicker(localAIStreamIdleTimeout / 4) + defer ticker.Stop() + for { + select { + case <-done: + return + case now := <-ticker.C: + lastActiveMu.Lock() + idle := now.Sub(lastActive) + lastActiveMu.Unlock() + if idle >= localAIStreamIdleTimeout { + cancel() + return + } + } + } + }() + + // SSE parsing: bump the scanner buffer from the 64KB default to 1MB + // so we never silently truncate a long data: line. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + lastActiveMu.Lock() + lastActive = time.Now() + lastActiveMu.Unlock() + + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + + if data == "[DONE]" { + sawTerminal = true + break + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + // Reasoning chunk first, content second. When an SSE event + // carries both, callers that pipe them to a UI render the + // chain-of-thought before the answer for that token, matching + // the wire ordering Upstage solar-pro3 and kimi-k2.6 emit. + // extractLocalAIReasoning tries reasoning_content, reasoning, + // and thinking in that order so this works against whichever + // model family LocalAI routes to. + if reasoning := extractLocalAIReasoning(delta); reasoning != "" { + if err := sender(nil, &reasoning); err != nil { + return err + } + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + // If the watchdog fired, the context is done; surface that as + // a clearer "idle" error instead of leaking the raw + // "context canceled" string. + if ctx.Err() != nil { + return fmt.Errorf("localai: stream idle for more than %s, aborted", localAIStreamIdleTimeout) + } + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("localai: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +type localAIEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type localAIEmbeddingResponse struct { + Data []localAIEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` +} + +// Embed turns a list of texts into embedding vectors using the LocalAI +// /v1/embeddings endpoint. The output has one vector per input, in the +// same order the inputs were given. +func (l *LocalAIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := l.resolveBaseURL(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + setLocalAIAuth(req, apiConfig) + + resp, err := l.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("LocalAI embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed localAIEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Reorder by the reported index so the output always lines up with + // the input texts, even if the upstream API ever returns items out + // of order. A nil slot at the end indicates the upstream did not + // return an embedding for that input. + embeddings := make([]EmbeddingData, len(texts)) + filled := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("localai: response index %d out of range for %d inputs", item.Index, len(texts)) + } + if filled[item.Index] { + // A malformed response that repeats the same index would + // silently overwrite the earlier vector. Fail loudly so + // the caller never uses ambiguous output. + return nil, fmt.Errorf("localai: duplicate embedding index %d in response", item.Index) + } + embeddings[item.Index] = EmbeddingData{ + Embedding: item.Embedding, + Index: item.Index, + } + filled[item.Index] = true + } + for i, ok := range filled { + if !ok { + return nil, fmt.Errorf("localai: missing embedding for input index %d", i) + } + } + + return embeddings, nil +} + +type localAIRerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopN int `json:"top_n"` +} + +type localAIRerankResponse struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` +} + +// Rerank calculates similarity scores between a query and a list of documents +// using LocalAI's /v1/rerank endpoint. The response shape is Cohere-style: +// {results: [{index, relevance_score}]}. The output is copied into the shared +// RerankResponse{Data: []RerankResult{Index, RelevanceScore}} shape that the +// rest of the codebase already consumes. +func (l *LocalAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := l.resolveBaseURL(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Rerank) + + topN := len(documents) + if rerankConfig != nil && rerankConfig.TopN > 0 { + topN = rerankConfig.TopN + } + + reqBody := localAIRerankRequest{ + Model: *modelName, + Query: query, + Documents: documents, + TopN: topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + setLocalAIAuth(req, apiConfig) + + resp, err := l.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("LocalAI rerank API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed localAIRerankResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + rerankResponse := &RerankResponse{} + for _, r := range parsed.Results { + if r.Index < 0 || r.Index >= len(documents) { + return nil, fmt.Errorf("localai: rerank result index %d out of range for %d documents", r.Index, len(documents)) + } + rerankResponse.Data = append(rerankResponse.Data, RerankResult{ + Index: r.Index, + RelevanceScore: r.RelevanceScore, + }) + } + + return rerankResponse, nil +} + +// ListModels returns the list of model ids the running LocalAI instance has +// loaded. There is no fixed model list at the SaaS level because LocalAI is +// self-hosted; the answer depends on what the tenant has configured. +func (l *LocalAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := l.resolveBaseURL(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + setLocalAIAuth(req, apiConfig) + + resp, err := l.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid models list format") + } + + models := make([]string, 0) + for _, model := range data { + modelMap, ok := model.(map[string]interface{}) + if !ok { + continue + } + modelName, ok := modelMap["id"].(string) + if !ok { + continue + } + models = append(models, modelName) + } + + return models, nil +} + +// Balance is not exposed by LocalAI (it is self-hosted and free), so this +// returns "no such method". +func (l *LocalAIModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +// CheckConnection runs a lightweight ListModels call to verify the LocalAI +// base URL is reachable. +func (l *LocalAIModel) CheckConnection(apiConfig *APIConfig) error { + _, err := l.ListModels(apiConfig) + if err != nil { + return err + } + return nil +} + +// TranscribeAudio (ASR): LocalAI can route audio to a Whisper backend +// when one is loaded, but the wire shape and driver-side plumbing for +// streaming audio uploads is separate from this PR's scope. Stub here +// to satisfy the ModelDriver interface; follow-up PR welcome. +func (l *LocalAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +func (l *LocalAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", l.Name()) +} + +// AudioSpeech convert text to audio +func (l *LocalAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +func (l *LocalAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", l.Name()) +} + +// OCRFile: LocalAI has no OCR pipeline in its OpenAI-compatible surface; +// document parsing belongs to a different interface entirely. +func (l *LocalAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +// ParseFile parse file +func (z *LocalAIModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *LocalAIModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *LocalAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/localai_test.go b/internal/entity/models/localai_test.go new file mode 100644 index 00000000000..6f8739b0221 --- /dev/null +++ b/internal/entity/models/localai_test.go @@ -0,0 +1,626 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync" + "testing" + "time" +) + +func newLocalAIForTest(baseURL string) *LocalAIModel { + return NewLocalAIModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "chat/completions", + Models: "models", + Embedding: "embeddings", + Rerank: "rerank", + }, + ) +} + +// withLocalAIIdleTimeout swaps the package-level idle timeout for the +// duration of the test. Tests that exercise the stall watchdog use a +// sub-second value so they finish quickly. +func withLocalAIIdleTimeout(t *testing.T, d time.Duration) { + t.Helper() + original := localAIStreamIdleTimeout + localAIStreamIdleTimeout = d + t.Cleanup(func() { + localAIStreamIdleTimeout = original + }) +} + +func TestLocalAIName(t *testing.T) { + l := newLocalAIForTest("http://unused") + if got := l.Name(); got != "localai" { + t.Errorf("Name()=%q, want %q", got, "localai") + } +} + +func TestLocalAIStreamCancelsOnIdle(t *testing.T) { + // The server emits one valid chunk and then stalls. Without the + // watchdog, scanner.Scan() would hang forever. With the watchdog + // at 200ms, it must return a clear "stream idle" error in well + // under a second. + withLocalAIIdleTimeout(t, 200*time.Millisecond) + + // hold blocks the handler until the test closes it. Register + // close(hold) FIRST so it runs LAST (defers are LIFO) — wait, + // that's the opposite. We want close(hold) to run BEFORE + // srv.Close() so the handler can return. Use t.Cleanup, which + // runs in reverse-registration order: register srv.Close first + // so it runs last, then close(hold) so it runs first. + hold := make(chan struct{}) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + _, _ = io.WriteString(w, `data: {"choices":[{"delta":{"content":"hi"}}]}`+"\n") + f.Flush() + } + // Hang until either the client disconnects (watchdog cancels + // the request, which causes r.Context() to fire) or the test + // teardown signals via `hold`. + select { + case <-hold: + case <-r.Context().Done(): + } + })) + t.Cleanup(srv.Close) + t.Cleanup(func() { close(hold) }) + + l := newLocalAIForTest(srv.URL) + var got []string + var mu sync.Mutex + start := time.Now() + err := l.ChatStreamlyWithSender("gpt-4", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil, + func(content *string, _ *string) error { + if content == nil || *content == "" { + return nil + } + mu.Lock() + got = append(got, *content) + mu.Unlock() + return nil + }, + ) + elapsed := time.Since(start) + + if err == nil { + t.Fatal("expected an idle-timeout error, got nil") + } + if !strings.Contains(err.Error(), "idle for more than") { + t.Errorf("expected idle-timeout error, got %v", err) + } + if elapsed > 5*time.Second { + t.Errorf("watchdog did not fire promptly; elapsed=%v", elapsed) + } + mu.Lock() + defer mu.Unlock() + if len(got) == 0 || got[0] != "hi" { + t.Errorf("expected first chunk before stall, got %v", got) + } +} + +func TestLocalAIStreamCompletesWithoutTriggeringWatchdog(t *testing.T) { + // Sanity check: a fast, complete stream should not trip the + // watchdog even with a moderately tight idle window. + withLocalAIIdleTimeout(t, 500*time.Millisecond) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + f, _ := w.(http.Flusher) + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"content":"a"}}]}`+"\n"+ + `data: {"choices":[{"delta":{"content":"b"}}]}`+"\n"+ + `data: {"choices":[{"delta":{},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + if f != nil { + f.Flush() + } + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + var chunks []string + err := l.ChatStreamlyWithSender("gpt-4", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil, + func(content *string, _ *string) error { + if content != nil && *content != "" && *content != "[DONE]" { + chunks = append(chunks, *content) + } + return nil + }, + ) + if err != nil { + t.Fatalf("stream: %v", err) + } + if strings.Join(chunks, "") != "ab" { + t.Errorf("chunks=%v want [a b]", chunks) + } +} + +func TestLocalAIStreamRequiresSender(t *testing.T) { + l := newLocalAIForTest("http://unused") + err := l.ChatStreamlyWithSender("gpt-4", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil, nil) + if err == nil || !strings.Contains(err.Error(), "sender is required") { + t.Errorf("expected sender-required error, got %v", err) + } +} + +func TestLocalAIChatMissingBaseURLFailsClearly(t *testing.T) { + // LocalAI has no public default; resolveBaseURL must fail with a + // helpful message when neither the requested region nor "default" + // is configured. + l := NewLocalAIModel(map[string]string{}, URLSuffix{Chat: "chat/completions"}) + _, err := l.ChatWithMessages("gpt-4", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "missing base URL") { + t.Errorf("expected missing-base-URL error, got %v", err) + } +} + +func TestLocalAIChatOmitsAuthHeaderWhenKeyEmpty(t *testing.T) { + // Optional-auth contract: LocalAI accepts an empty key, so the + // driver must NOT send a "Bearer " header in that case. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "" { + t.Errorf("expected no Authorization header, got %q", got) + } + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + resp, err := l.ChatWithMessages("gpt-4", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if *resp.Answer != "ok" { + t.Errorf("answer=%q want ok", *resp.Answer) + } +} + +func TestLocalAIChatSendsAuthHeaderWhenKeyProvided(t *testing.T) { + // And conversely: when a tenant has put LocalAI behind an auth + // proxy with a token, the driver does send the Bearer header. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer secret" { + t.Errorf("expected Authorization=Bearer secret, got %q", got) + } + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + key := "secret" + _, err := l.ChatWithMessages("gpt-4", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &key}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } +} + +func TestLocalAIBalanceReturnsNoSuchMethod(t *testing.T) { + l := newLocalAIForTest("http://unused") + _, err := l.Balance(&APIConfig{}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Balance: expected 'no such method', got %v", err) + } +} + +func TestLocalAIEmbedHappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/embeddings" { + t.Errorf("path=%s", r.URL.Path) + } + _, _ = io.WriteString(w, `{"data":[ + {"embedding":[0.1,0.2],"index":0}, + {"embedding":[0.3,0.4],"index":1}, + {"embedding":[0.5,0.6],"index":2}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + model := "text-embedding-ada-002" + vecs, err := l.Embed(&model, []string{"a", "b", "c"}, &APIConfig{}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + if len(vecs) != 3 { + t.Fatalf("len=%d want 3", len(vecs)) + } + if vecs[1].Embedding[0] != 0.3 || vecs[1].Index != 1 { + t.Errorf("vecs[1]=%+v", vecs[1]) + } +} + +func TestLocalAIEmbedRejectsDuplicateIndex(t *testing.T) { + // CodeRabbit caught that a response repeating data[*].index would + // silently overwrite the earlier vector. Verify the driver fails + // loudly instead. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"data":[ + {"embedding":[1],"index":0}, + {"embedding":[2],"index":0}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + model := "text-embedding-ada-002" + _, err := l.Embed(&model, []string{"a", "b"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") { + t.Errorf("expected duplicate-index error, got %v", err) + } +} + +func TestLocalAIEmbedRejectsOutOfRangeIndex(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"data":[{"embedding":[1],"index":7}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + model := "text-embedding-ada-002" + _, err := l.Embed(&model, []string{"a", "b"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestLocalAIEmbedRejectsMissingSlot(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"data":[{"embedding":[1],"index":0}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + model := "text-embedding-ada-002" + _, err := l.Embed(&model, []string{"a", "b"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "missing embedding for input index 1") { + t.Errorf("expected missing-slot error, got %v", err) + } +} + +func TestLocalAIEmbedEmptyInputShortCircuits(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("Embed([]) made an unexpected HTTP call") + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + model := "text-embedding-ada-002" + vecs, err := l.Embed(&model, []string{}, &APIConfig{}, nil) + if err != nil || len(vecs) != 0 { + t.Errorf("Embed([])=(%v,%v) want ([],nil)", vecs, err) + } +} + +// ---------- reasoning extraction (multi-field) ---------- + +// Table-driven unit coverage of the helper. Mirrors the priority order +// reasoning_content > reasoning > thinking declared in +// localAIReasoningFields. New upstream field names can be added by +// extending that slice without touching call sites. +func TestExtractLocalAIReasoning(t *testing.T) { + cases := []struct { + name string + in map[string]interface{} + want string + }{ + {"empty map", map[string]interface{}{}, ""}, + {"reasoning_content wins", map[string]interface{}{ + "reasoning_content": "rc", + "reasoning": "r", + "thinking": "t", + }, "rc"}, + {"reasoning when no reasoning_content", map[string]interface{}{ + "reasoning": "r", + "thinking": "t", + }, "r"}, + {"thinking when only that is set", map[string]interface{}{ + "thinking": "qwen3-thought", + }, "qwen3-thought"}, + {"empty string treated as absent", map[string]interface{}{ + "reasoning_content": "", + "reasoning": "fallback", + }, "fallback"}, + {"non-string ignored", map[string]interface{}{ + "reasoning_content": 42, + "reasoning": "fallback", + }, "fallback"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := extractLocalAIReasoning(tc.in) + if got != tc.want { + t.Errorf("got=%q want=%q", got, tc.want) + } + }) + } +} + +// Non-streaming chat against an upstream that puts the trace in +// message.reasoning_content (kimi-k2.6, OpenAI o-series, DeepSeek-R1 +// when proxied through OpenAI-shim). The driver must surface it on +// ChatResponse.ReasonContent. +func TestLocalAIChatExtractsReasoningContent(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"choices":[{"message":{ + "role":"assistant", + "content":"The answer is 12.", + "reasoning_content":"15% = 0.15; 0.15 * 80 = 12." + }}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + resp, err := l.ChatWithMessages("kimi-k2.6", + []Message{{Role: "user", Content: "15% of 80?"}}, + &APIConfig{}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if *resp.Answer != "The answer is 12." { + t.Errorf("Answer=%q", *resp.Answer) + } + if *resp.ReasonContent != "15% = 0.15; 0.15 * 80 = 12." { + t.Errorf("ReasonContent=%q", *resp.ReasonContent) + } +} + +// Non-streaming chat that uses message.thinking (Qwen3 via Ollama-shim +// inside LocalAI). The driver must surface it on ReasonContent too. +func TestLocalAIChatExtractsThinking(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"choices":[{"message":{ + "role":"assistant", + "content":"12", + "thinking":"Compute 15/100 * 80" + }}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + resp, err := l.ChatWithMessages("qwen3-32b", + []Message{{Role: "user", Content: "15% of 80?"}}, + &APIConfig{}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if *resp.ReasonContent != "Compute 15/100 * 80" { + t.Errorf("ReasonContent=%q want %q", *resp.ReasonContent, "Compute 15/100 * 80") + } +} + +// Regression net: a response with no reasoning field at all (any +// non-reasoning model) must produce empty ReasonContent without +// crashing or erroring. +func TestLocalAIChatHandlesAbsentReasoning(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"choices":[{"message":{ + "role":"assistant","content":"hello" + }}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + resp, err := l.ChatWithMessages("llama-3-8b-instruct", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if *resp.Answer != "hello" { + t.Errorf("Answer=%q", *resp.Answer) + } + if *resp.ReasonContent != "" { + t.Errorf("ReasonContent=%q want empty", *resp.ReasonContent) + } +} + +// Streaming chat where the upstream interleaves delta.reasoning_content +// chunks and delta.content chunks (kimi-k2.6, o-series shape). +// Reasoning must reach the sender's 2nd arg, content the 1st. +func TestLocalAIStreamExtractsReasoningContentDelta(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, + `data: {"choices":[{"index":0,"delta":{"role":"assistant"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"reasoning_content":"step 1. "}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"reasoning_content":"step 2."}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"Answer."},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + var content, reasoning []string + err := l.ChatStreamlyWithSender("kimi-k2.6", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil, + func(c *string, r *string) error { + if c != nil && r != nil { + t.Errorf("sender called with both args non-nil") + } + if r != nil && *r != "" { + reasoning = append(reasoning, *r) + } + if c != nil && *c != "" && *c != "[DONE]" { + content = append(content, *c) + } + return nil + }, + ) + if err != nil { + t.Fatalf("stream: %v", err) + } + if got := strings.Join(reasoning, ""); got != "step 1. step 2." { + t.Errorf("reasoning joined=%q", got) + } + if got := strings.Join(content, ""); got != "Answer." { + t.Errorf("content joined=%q", got) + } +} + +// Streaming chat where the upstream uses delta.thinking (Qwen3 shape). +// The same handler must work. +func TestLocalAIStreamExtractsThinkingDelta(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, + `data: {"choices":[{"index":0,"delta":{"thinking":"qwen-trace"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"final"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + var got []string + err := l.ChatStreamlyWithSender("qwen3-32b", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil, + func(c *string, r *string) error { + if r != nil && *r != "" { + got = append(got, "R:"+*r) + } + if c != nil && *c != "" && *c != "[DONE]" { + got = append(got, "C:"+*c) + } + return nil + }, + ) + if err != nil { + t.Fatalf("stream: %v", err) + } + want := []string{"R:qwen-trace", "C:final"} + if len(got) != 2 || got[0] != want[0] || got[1] != want[1] { + t.Errorf("seq=%v want %v", got, want) + } +} + +// Request-side: ChatConfig.Effort must flow into request body as +// reasoning_effort. +func TestLocalAIChatPropagatesReasoningEffort(t *testing.T) { + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + if err := json.Unmarshal(raw, &seen); err != nil { + t.Errorf("unmarshal request body: %v\nraw=%s", err, string(raw)) + return + } + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + effort := "high" + _, err := l.ChatWithMessages("kimi-k2.6", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, &ChatConfig{Effort: &effort}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if seen["reasoning_effort"] != "high" { + t.Errorf("reasoning_effort=%v want high", seen["reasoning_effort"]) + } + if _, present := seen["enable_thinking"]; present { + t.Errorf("enable_thinking should be absent when Thinking nil") + } +} + +// Request-side: ChatConfig.Thinking must flow into request body as +// enable_thinking (Qwen3-style). +func TestLocalAIChatPropagatesEnableThinking(t *testing.T) { + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + if err := json.Unmarshal(raw, &seen); err != nil { + t.Errorf("unmarshal request body: %v\nraw=%s", err, string(raw)) + return + } + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + think := true + _, err := l.ChatWithMessages("qwen3-32b", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, &ChatConfig{Thinking: &think}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if seen["enable_thinking"] != true { + t.Errorf("enable_thinking=%v want true", seen["enable_thinking"]) + } +} + +// Stream request also propagates the reasoning params. +func TestLocalAIStreamPropagatesReasoningParams(t *testing.T) { + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + if err := json.Unmarshal(raw, &seen); err != nil { + t.Errorf("unmarshal request body: %v\nraw=%s", err, string(raw)) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, + `data: {"choices":[{"index":0,"delta":{"content":"x"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + })) + defer srv.Close() + + l := newLocalAIForTest(srv.URL) + effort := "medium" + think := true + err := l.ChatStreamlyWithSender("kimi-k2.6", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, &ChatConfig{Effort: &effort, Thinking: &think}, + func(*string, *string) error { return nil }, + ) + if err != nil { + t.Fatalf("stream: %v", err) + } + if seen["reasoning_effort"] != "medium" { + t.Errorf("reasoning_effort=%v want medium", seen["reasoning_effort"]) + } + if seen["enable_thinking"] != true { + t.Errorf("enable_thinking=%v want true", seen["enable_thinking"]) + } +} diff --git a/internal/entity/models/longcat.go b/internal/entity/models/longcat.go new file mode 100644 index 00000000000..e35dea19669 --- /dev/null +++ b/internal/entity/models/longcat.go @@ -0,0 +1,477 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// LongCatModel implements ModelDriver for LongCat (Meituan). +// +// LongCat exposes an OpenAI-compatible chat completions endpoint at +// https://api.longcat.chat/openai/v1/chat/completions. The official +// docs (https://longcat.chat/platform/docs/APIDocs.html) only describe +// the chat-completions surface — no /models, /embeddings, /rerank, +// /audio, or /ocr endpoints are advertised. The wire shape matches the +// OpenAI convention: response/delta carry reasoning_content alongside +// content for thinking models. +// +// Documented request fields are limited to: model, messages, stream, +// max_tokens, temperature, top_p. Sending other OpenAI-style fields +// (stop, reasoning_effort, etc.) is not documented and is therefore +// omitted to avoid relying on undocumented upstream behavior. +type LongCatModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewLongCatModel creates a new LongCat model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +// +// The Client itself has no Timeout. http.Client.Timeout would also +// cap the time spent reading the response body, which would cut off +// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming +// callers wrap each request with context.WithTimeout instead. +func NewLongCatModel(baseURL map[string]string, urlSuffix URLSuffix) *LongCatModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &LongCatModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (l *LongCatModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewLongCatModel(baseURL, l.URLSuffix) +} + +func (l *LongCatModel) Name() string { + return "longcat" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. This makes a misconfigured region fail +// fast with a clear message, instead of silently producing a relative +// URL that the HTTP transport then rejects. +func (l *LongCatModel) baseURLForRegion(region string) (string, error) { + base, ok := l.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("longcat: no base URL configured for region %q", region) + } + return base, nil +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (l *LongCatModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := l.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + // Note: do NOT propagate chatModelConfig.Stream into the request body + // here. ChatWithMessages parses a single JSON response, so stream must + // always be off for this code path. + // + // Only the fields documented at + // https://longcat.chat/platform/docs/APIDocs.html are forwarded. + // Other ChatConfig fields (Stop, Effort, ...) are dropped on the + // floor because the upstream behavior is undefined. + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := l.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + // LongCat-Flash-Thinking returns the chain-of-thought in a + // `reasoning_content` field on the message (OpenAI o-series shape, + // also used by kimi-k2.6 and DeepSeek-R1). Pass it through when + // present so callers can surface reasoning to the UI. Absent or + // non-string means no reasoning was emitted — leave it empty. + reasonContent := "" + if r, ok := messageMap["reasoning_content"].(string); ok { + reasonContent = r + } + + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. The LongCat SSE stream uses the same shape as the +// OpenAI o-series: "data:" lines carrying JSON events with +// delta.content for the visible answer and delta.reasoning_content for +// the chain-of-thought (LongCat-Flash-Thinking only), terminated by +// a [DONE] line. +func (l *LongCatModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := l.baseURLForRegion(region) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, l.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + // Refuse to run if the caller explicitly asked for stream=false. + // The body of this method only knows how to read SSE, so a + // non-SSE JSON response would be parsed as if it were a stream + // and produce no chunks. Better to fail clearly. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + // Only documented fields are forwarded; see ChatWithMessages. + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // SSE streams are long-lived. Rely on the transport's + // ResponseHeaderTimeout to cap the connection-establishment phase + // instead of attaching a hard deadline here. + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := l.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: bump the scanner buffer from the 64KB default to 1MB + // so we never silently truncate a long data: line. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + + if data == "[DONE]" { + sawTerminal = true + break + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + // A malformed frame can mean a truncated SSE event or an + // upstream incident; either way, the caller is better + // served by a hard failure than by silent partial output. + return fmt.Errorf("longcat: invalid SSE event: %w", err) + } + + // LongCat (like other OpenAI-compatible upstreams) can emit a + // terminal `{"error": ...}` frame instead of a normal choices + // chunk when something goes wrong mid-stream. Surface it + // instead of falling through to the choices-missing branch. + if apiErr, ok := event["error"]; ok { + return fmt.Errorf("longcat: upstream stream error: %v", apiErr) + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + // Reasoning chunks first, content second. When an SSE event + // carries both, callers that pipe them to a UI render the + // chain-of-thought before the answer for that token, matching + // the wire ordering LongCat-Flash-Thinking emits. + if r, ok := delta["reasoning_content"].(string); ok && r != "" { + if err := sender(nil, &r); err != nil { + return err + } + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("longcat: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +// ListModels is not exposed by the LongCat platform. The official +// docs at https://longcat.chat/platform/docs/APIDocs.html only +// document /openai/v1/chat/completions and /anthropic/v1/messages; +// no /models endpoint exists. The shipped catalog lives in +// conf/models/longcat.json; this driver method does not invent a +// fake one. +func (l *LongCatModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +// CheckConnection is not exposed by the LongCat platform. With no +// documented /models or /health endpoint, there is no cheap way to +// verify the API key without burning a real chat completion against +// a tenant's quota. Return the documented sentinel rather than +// pretend. +func (l *LongCatModel) CheckConnection(apiConfig *APIConfig) error { + return fmt.Errorf("%s, no such method", l.Name()) +} + +// Embed is not exposed by the LongCat API. The /v1/embeddings endpoint +// does not exist on api.longcat.chat; this returns the documented +// sentinel. +func (l *LongCatModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +// Rerank is not exposed by the LongCat API. +func (l *LongCatModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +// Balance is not exposed by the LongCat API. +func (l *LongCatModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +// TranscribeAudio (ASR) is not exposed by the LongCat API. +func (l *LongCatModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +func (l *LongCatModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", l.Name()) +} + +// AudioSpeech (TTS) is not exposed by the LongCat API. +func (l *LongCatModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +func (l *LongCatModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", l.Name()) +} + +// OCRFile is not exposed by the LongCat API. +func (l *LongCatModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", l.Name()) +} + +// ParseFile parse file +func (z *LongCatModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *LongCatModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *LongCatModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/longcat_test.go b/internal/entity/models/longcat_test.go new file mode 100644 index 00000000000..14870f8f69e --- /dev/null +++ b/internal/entity/models/longcat_test.go @@ -0,0 +1,467 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newLongCatServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if r.Method == http.MethodPost { + // Accept "application/json" with or without a parameter + // suffix like "; charset=utf-8" — both are valid JSON. + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("expected Content-Type to start with application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v\nraw=%s", err, string(raw)) + return + } + handler(t, body, w) + return + } + handler(t, nil, w) + })) +} + +func newLongCatForTest(baseURL string) *LongCatModel { + return NewLongCatModel( + map[string]string{"default": baseURL}, + URLSuffix{Chat: "openai/v1/chat/completions"}, + ) +} + +// newLongCatSSEServer returns an httptest.Server that asserts the +// request contract (POST + path + Authorization + Content-Type prefix) +// before writing the supplied SSE payload. Used by the streaming tests +// so a regression in the wire shape can't slip through unnoticed. +func newLongCatSSEServer(t *testing.T, expectedPath, ssePayload string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + return + } + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("expected Content-Type to start with application/json, got %q", got) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, ssePayload) + })) +} + +func TestLongCatName(t *testing.T) { + if got := newLongCatForTest("http://unused").Name(); got != "longcat" { + t.Errorf("Name()=%q, want %q", got, "longcat") + } +} + +func TestLongCatChatHappyPath(t *testing.T) { + srv := newLongCatServer(t, "/openai/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "LongCat-Flash-Chat" { + t.Errorf("model=%v", body["model"]) + } + if body["stream"] != false { + t.Errorf("stream=%v want false", body["stream"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{"content": "pong"}, + }}, + }) + }) + defer srv.Close() + + m := newLongCatForTest(srv.URL) + apiKey := "test-key" + resp, err := m.ChatWithMessages("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Answer == nil || resp.ReasonContent == nil { + t.Fatalf("Answer/ReasonContent must be non-nil pointers, got Answer=%v ReasonContent=%v", resp.Answer, resp.ReasonContent) + } + if *resp.Answer != "pong" { + t.Errorf("answer=%q want pong", *resp.Answer) + } + if *resp.ReasonContent != "" { + t.Errorf("ReasonContent=%q want empty", *resp.ReasonContent) + } +} + +func TestLongCatChatExtractsReasoningContent(t *testing.T) { + // LongCat-Flash-Thinking returns the chain-of-thought in + // message.reasoning_content (OpenAI o-series shape). Live-probed + // against api.longcat.chat; the fixture mimics the actual response + // shape captured there. + srv := newLongCatServer(t, "/openai/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "LongCat-Flash-Thinking" { + t.Errorf("model=%v", body["model"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "role": "assistant", + "content": "15% of 80 is 12.", + "reasoning_content": "We need to compute 15% of 80. 0.15 * 80 = 12.", + }, + }}, + }) + }) + defer srv.Close() + + m := newLongCatForTest(srv.URL) + apiKey := "test-key" + resp, err := m.ChatWithMessages("LongCat-Flash-Thinking", + []Message{{Role: "user", Content: "15% of 80?"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Answer == nil || resp.ReasonContent == nil { + t.Fatalf("Answer/ReasonContent must be non-nil pointers, got Answer=%v ReasonContent=%v", resp.Answer, resp.ReasonContent) + } + if *resp.Answer != "15% of 80 is 12." { + t.Errorf("Answer=%q", *resp.Answer) + } + if *resp.ReasonContent != "We need to compute 15% of 80. 0.15 * 80 = 12." { + t.Errorf("ReasonContent=%q", *resp.ReasonContent) + } +} + +// TestLongCatChatDropsUndocumentedFields guards against re-introducing +// stop / reasoning_effort / response_format / tools etc. The LongCat +// docs only list model, messages, stream, max_tokens, temperature, +// top_p — anything else is undocumented and must not be sent, since +// the maintainer specifically flagged this on PR #14809. +func TestLongCatChatDropsUndocumentedFields(t *testing.T) { + srv := newLongCatServer(t, "/openai/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + for _, k := range []string{"stop", "reasoning_effort", "response_format", "tools", "tool_choice", "presence_penalty", "frequency_penalty", "n", "logprobs"} { + if _, present := body[k]; present { + t.Errorf("undocumented field %q must not be sent: %v", k, body[k]) + } + } + // Documented fields, on the other hand, MUST be forwarded when set. + for _, k := range []string{"model", "messages", "stream", "max_tokens", "temperature", "top_p"} { + if _, present := body[k]; !present { + t.Errorf("documented field %q missing from request body", k) + } + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{"content": "ok"}, + }}, + }) + }) + defer srv.Close() + + m := newLongCatForTest(srv.URL) + apiKey := "test-key" + mt := 32 + temp := 0.7 + topP := 0.9 + stop := []string{"END"} + effort := "high" + _, err := m.ChatWithMessages("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + // Deliberately pass Stop/Effort to prove they are filtered out. + &ChatConfig{MaxTokens: &mt, Temperature: &temp, TopP: &topP, Stop: &stop, Effort: &effort}) + if err != nil { + t.Fatalf("Chat: %v", err) + } +} + +func TestLongCatChatRequiresAPIKey(t *testing.T) { + m := newLongCatForTest("http://unused") + _, err := m.ChatWithMessages("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestLongCatChatRequiresMessages(t *testing.T) { + m := newLongCatForTest("http://unused") + apiKey := "test-key" + _, err := m.ChatWithMessages("LongCat-Flash-Chat", nil, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "messages is empty") { + t.Errorf("expected messages-empty error, got %v", err) + } +} + +func TestLongCatChatRejectsHTTPError(t *testing.T) { + srv := newLongCatServer(t, "/openai/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + }) + defer srv.Close() + + m := newLongCatForTest(srv.URL) + apiKey := "test-key" + _, err := m.ChatWithMessages("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "401") { + t.Errorf("expected 401 propagated, got %v", err) + } +} + +func TestLongCatStreamHappyPath(t *testing.T) { + srv := newLongCatSSEServer(t, "/openai/v1/chat/completions", + `data: {"choices":[{"index":0,"delta":{"role":"assistant"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"Hello"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":" world"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + defer srv.Close() + + m := newLongCatForTest(srv.URL) + apiKey := "test-key" + var chunks []string + var sawDone bool + err := m.ChatStreamlyWithSender("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(c *string, _ *string) error { + if c == nil { + return nil + } + if *c == "[DONE]" { + sawDone = true + return nil + } + chunks = append(chunks, *c) + return nil + }) + if err != nil { + t.Fatalf("stream: %v", err) + } + if strings.Join(chunks, "") != "Hello world" { + t.Errorf("content=%v", chunks) + } + if !sawDone { + t.Error("expected [DONE] sentinel") + } +} + +func TestLongCatStreamExtractsReasoningContent(t *testing.T) { + // Fixture matches the shape captured live from + // LongCat-Flash-Thinking against api.longcat.chat: deltas + // interleave reasoning_content and content within the stream. + srv := newLongCatSSEServer(t, "/openai/v1/chat/completions", + `data: {"choices":[{"index":0,"delta":{"role":"assistant"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"reasoning_content":"step 1. "}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"reasoning_content":"step 2."}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"final answer"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + defer srv.Close() + + m := newLongCatForTest(srv.URL) + apiKey := "test-key" + var content, reasoning []string + err := m.ChatStreamlyWithSender("LongCat-Flash-Thinking", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(c *string, r *string) error { + if c != nil && r != nil { + t.Errorf("sender called with both args non-nil") + } + if r != nil && *r != "" { + reasoning = append(reasoning, *r) + } + if c != nil && *c != "" && *c != "[DONE]" { + content = append(content, *c) + } + return nil + }) + if err != nil { + t.Fatalf("stream: %v", err) + } + if got := strings.Join(reasoning, ""); got != "step 1. step 2." { + t.Errorf("reasoning=%q", got) + } + if got := strings.Join(content, ""); got != "final answer" { + t.Errorf("content=%q", got) + } +} + +func TestLongCatStreamRejectsExplicitFalse(t *testing.T) { + m := newLongCatForTest("http://unused") + apiKey := "test-key" + stream := false + err := m.ChatStreamlyWithSender("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Stream: &stream}, + func(*string, *string) error { return nil }) + if err == nil || !strings.Contains(err.Error(), "stream must be true") { + t.Errorf("expected stream-true guard, got %v", err) + } +} + +func TestLongCatStreamRequiresSender(t *testing.T) { + m := newLongCatForTest("http://unused") + apiKey := "test-key" + err := m.ChatStreamlyWithSender("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, nil) + if err == nil || !strings.Contains(err.Error(), "sender is required") { + t.Errorf("expected sender-required error, got %v", err) + } +} + +func TestLongCatStreamFailsWithoutTerminal(t *testing.T) { + srv := newLongCatSSEServer(t, "/openai/v1/chat/completions", + `data: {"choices":[{"delta":{"content":"half"}}]}`+"\n", + ) + defer srv.Close() + + m := newLongCatForTest(srv.URL) + apiKey := "test-key" + err := m.ChatStreamlyWithSender("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(*string, *string) error { return nil }) + if err == nil || !strings.Contains(err.Error(), "stream ended before") { + t.Errorf("expected truncation error, got %v", err) + } +} + +// A malformed SSE frame (invalid JSON) used to be silently skipped, +// which masked truncated or corrupted streams. The driver must now +// fail hard with a "longcat: invalid SSE event" wrapper. +func TestLongCatStreamRejectsMalformedFrame(t *testing.T) { + srv := newLongCatSSEServer(t, "/openai/v1/chat/completions", + `data: {"choices":[{"delta":{"content":"ok"}}]}`+"\n"+ + `data: {this is not valid json}`+"\n", + ) + defer srv.Close() + + m := newLongCatForTest(srv.URL) + apiKey := "test-key" + err := m.ChatStreamlyWithSender("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(*string, *string) error { return nil }) + if err == nil || !strings.Contains(err.Error(), "invalid SSE event") { + t.Errorf("expected invalid-SSE error, got %v", err) + } +} + +// An upstream {"error": ...} frame mid-stream used to fall through to +// the "no choices" continue and leave the caller with a generic +// truncation error. The driver must surface the upstream error verbatim. +func TestLongCatStreamSurfacesUpstreamError(t *testing.T) { + srv := newLongCatSSEServer(t, "/openai/v1/chat/completions", + `data: {"choices":[{"delta":{"content":"partial "}}]}`+"\n"+ + `data: {"error":{"message":"rate limit exceeded","type":"rate_limit_error"}}`+"\n", + ) + defer srv.Close() + + m := newLongCatForTest(srv.URL) + apiKey := "test-key" + err := m.ChatStreamlyWithSender("LongCat-Flash-Chat", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(*string, *string) error { return nil }) + if err == nil || !strings.Contains(err.Error(), "upstream stream error") { + t.Errorf("expected upstream-error surfacing, got %v", err) + } + if err != nil && !strings.Contains(err.Error(), "rate limit") { + t.Errorf("expected upstream message included, got %v", err) + } +} + +// LongCat does not document /models or /health endpoints, so per +// maintainer guidance ListModels and CheckConnection both return the +// "no such method" sentinel rather than inventing fake catalogs or +// burning chat completions for connection checks. +func TestLongCatListModelsReturnsNoSuchMethod(t *testing.T) { + apiKey := "test-key" + _, err := newLongCatForTest("http://unused").ListModels(&APIConfig{ApiKey: &apiKey}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("ListModels: want 'no such method', got %v", err) + } +} + +func TestLongCatCheckConnectionReturnsNoSuchMethod(t *testing.T) { + apiKey := "test-key" + err := newLongCatForTest("http://unused").CheckConnection(&APIConfig{ApiKey: &apiKey}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("CheckConnection: want 'no such method', got %v", err) + } +} + +func TestLongCatEmbedReturnsNoSuchMethod(t *testing.T) { + m := newLongCatForTest("http://unused") + model := "x" + _, err := m.Embed(&model, []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Embed: want 'no such method', got %v", err) + } +} + +func TestLongCatRerankReturnsNoSuchMethod(t *testing.T) { + m := newLongCatForTest("http://unused") + model := "x" + _, err := m.Rerank(&model, "q", []string{"a"}, &APIConfig{}, &RerankConfig{TopN: 1}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank: want 'no such method', got %v", err) + } +} + +func TestLongCatBalanceReturnsNoSuchMethod(t *testing.T) { + m := newLongCatForTest("http://unused") + _, err := m.Balance(&APIConfig{}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Balance: want 'no such method', got %v", err) + } +} + +func TestLongCatAudioOCRReturnNoSuchMethod(t *testing.T) { + m := newLongCatForTest("http://unused") + model := "x" + if _, err := m.TranscribeAudio(&model, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("TranscribeAudio: want 'no such method', got %v", err) + } + if _, err := m.AudioSpeech(&model, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("AudioSpeech: want 'no such method', got %v", err) + } + if _, err := m.OCRFile(&model, nil, &model, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("OCRFile: want 'no such method', got %v", err) + } +} diff --git a/internal/entity/models/mineru.go b/internal/entity/models/mineru.go new file mode 100644 index 00000000000..1ff4697db2f --- /dev/null +++ b/internal/entity/models/mineru.go @@ -0,0 +1,265 @@ +package models + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type MinerUModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewMinerUModel(baseURL map[string]string, urlSuffix URLSuffix) *MinerUModel { + return &MinerUModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (m *MinerUModel) NewInstance(baseURL map[string]string) ModelDriver { + return &MinerUModel{ + BaseURL: baseURL, + URLSuffix: m.URLSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (m *MinerUModel) Name() string { + return "mineru" +} + +func (m *MinerUModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerUModel) CheckConnection(apiConfig *APIConfig) error { + return fmt.Errorf("%s no such method", m.Name()) +} + +type mineruTaskSubmitResponse struct { + Code int `json:"code"` + Data struct { + TaskID string `json:"task_id"` + } `json:"data"` + Msg string `json:"msg"` + TraceID string `json:"trace_id"` +} + +func (m *MinerUModel) ParseFile(modelName *string, content []byte, documentURL *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + if documentURL == nil || *documentURL == "" { + return nil, fmt.Errorf("MinerU API requires a valid public document URL; direct file upload is not supported") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + apiURL := fmt.Sprintf("%s/api/%s", m.BaseURL[region], m.URLSuffix.DocumentParse) + + reqBody := map[string]interface{}{ + "url": *documentURL, + } + + if modelName != nil && *modelName != "" { + reqBody["model_version"] = *modelName + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", apiURL, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("MinerU API failed with status %d: %s", resp.StatusCode, string(body)) + } + + var taskResp mineruTaskSubmitResponse + if err := json.Unmarshal(body, &taskResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if taskResp.Code != 0 { + return nil, fmt.Errorf("MinerU task creation failed (code %d): %s", taskResp.Code, taskResp.Msg) + } + + return &ParseFileResponse{ + TaskID: taskResp.Data.TaskID, + }, nil +} + +func (m *MinerUModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +type mineruTaskQueryResponse struct { + Code int `json:"code"` + Data struct { + TaskID string `json:"task_id"` + State string `json:"state"` // including: pending, running, done, failed, converting + FullZipURL string `json:"full_zip_url"` + ErrMsg string `json:"err_msg"` + ExtractProgress struct { + ExtractedPages int `json:"extracted_pages"` + TotalPages int `json:"total_pages"` + } `json:"extract_progress"` + } `json:"data"` + Msg string `json:"msg"` +} + +func (m *MinerUModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + // URL: https://mineru.net/api/v4/extract/task/{task_id} + apiURL := fmt.Sprintf("%s/api/%s/%s", m.BaseURL[region], m.URLSuffix.DocumentParse, taskID) + + req, err := http.NewRequest("GET", apiURL, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("MinerU query API failed with status %d: %s", resp.StatusCode, string(body)) + } + + var queryResp mineruTaskQueryResponse + if err := json.Unmarshal(body, &queryResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if queryResp.Code != 0 { + return nil, fmt.Errorf("MinerU task query failed: %s", queryResp.Msg) + } + + // failed state + if queryResp.Data.State == "failed" { + return nil, fmt.Errorf("MinerU task failed: %s", queryResp.Data.ErrMsg) + } + + content := "" + if queryResp.Data.State == "done" { + content = queryResp.Data.FullZipURL + } else if queryResp.Data.State == "running" { + content = fmt.Sprintf("Task is running... Progress: %d / %d pages", + queryResp.Data.ExtractProgress.ExtractedPages, + queryResp.Data.ExtractProgress.TotalPages) + } else { + // queue or formating + content = fmt.Sprintf("Task state: %s", queryResp.Data.State) + } + + return &TaskResponse{ + Segments: []TaskSegment{ + { + Index: 0, + Content: content, + }, + }, + }, nil +} diff --git a/internal/entity/models/mineru_local.go b/internal/entity/models/mineru_local.go new file mode 100644 index 00000000000..177ce775301 --- /dev/null +++ b/internal/entity/models/mineru_local.go @@ -0,0 +1,267 @@ +package models + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "time" +) + +type MinerULocalModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewMinerLocalUModel(baseURL map[string]string, urlSuffix URLSuffix) *MinerULocalModel { + return &MinerULocalModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (m *MinerULocalModel) NewInstance(baseURL map[string]string) ModelDriver { + return &MinerULocalModel{ + BaseURL: baseURL, + URLSuffix: m.URLSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (m *MinerULocalModel) Name() string { + return "mineru_local" +} + +func (m *MinerULocalModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) CheckConnection(apiConfig *APIConfig) error { + return fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) ParseFile(modelName *string, content []byte, documentURL *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + if len(content) == 0 { + return nil, fmt.Errorf("local MinerU API requires file content byte array, but content is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + apiURL := fmt.Sprintf("%s/%s", m.BaseURL[region], m.URLSuffix.DocumentParse) + + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + // Get file + part, err := writer.CreateFormFile("files", "upload_document.pdf") + if err != nil { + return nil, fmt.Errorf("failed to create multipart file field: %w", err) + } + if _, err = part.Write(content); err != nil { + return nil, fmt.Errorf("failed to write file content: %w", err) + } + + if modelName != nil && *modelName != "" { + _ = writer.WriteField("backend", *modelName) + } else { + _ = writer.WriteField("backend", "pipeline") + } + + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + req, err := http.NewRequest("POST", apiURL, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", writer.FormDataContentType()) + + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != 202 { + return nil, fmt.Errorf("local MinerU API failed with status %d: %s (URL: %s)", resp.StatusCode, string(respBody), apiURL) + } + + var result map[string]interface{} + if err := json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to parse response JSON: %w, body: %s", err, string(respBody)) + } + // Get task ID + var taskID string + if dataMap, ok := result["data"].(map[string]interface{}); ok { + if tid, ok := dataMap["task_id"].(string); ok { + taskID = tid + } + } else if tid, ok := result["task_id"].(string); ok { + taskID = tid + } + + if taskID == "" { + return nil, fmt.Errorf("failed to extract task_id from local MinerU response: %s", string(respBody)) + } + + return &ParseFileResponse{ + TaskID: taskID, + }, nil +} + +func (m *MinerULocalModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s no such method", m.Name()) +} + +func (m *MinerULocalModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + if taskID == "" { + return nil, fmt.Errorf("taskID is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s/%s/result", m.BaseURL[region], m.URLSuffix.Task, taskID) + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create status request: %w", err) + } + + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send status request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read status response: %w", err) + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != 202 { + return nil, fmt.Errorf("MinerU local status API failed with status %d: %s", resp.StatusCode, string(body)) + } + + // parse JSON + var result map[string]interface{} + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse JSON: %w", err) + } + + content := "" + + // results + results, ok := result["results"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("missing results field") + } + + // Get markdown + for _, fileObj := range results { + + fileMap, ok := fileObj.(map[string]interface{}) + if !ok { + continue + } + + md, ok := fileMap["md_content"].(string) + if ok { + content = md + break + } + } + + if content == "" { + return nil, fmt.Errorf("md_content not found") + } + + return &TaskResponse{ + Segments: []TaskSegment{ + { + Index: 1, + Content: content, + }, + }, + }, nil +} diff --git a/internal/entity/models/minimax.go b/internal/entity/models/minimax.go index d40bfef4bd2..61e4110e7cf 100644 --- a/internal/entity/models/minimax.go +++ b/internal/entity/models/minimax.go @@ -19,6 +19,7 @@ package models import ( "bufio" "bytes" + "encoding/hex" "encoding/json" "fmt" "io" @@ -200,7 +201,7 @@ func (z *MinimaxModel) ChatStreamlyWithSender(modelName string, messages []Messa var region = "default" - if apiConfig.Region != nil { + if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } @@ -344,8 +345,8 @@ func (z *MinimaxModel) ChatStreamlyWithSender(modelName string, messages []Messa return scanner.Err() } -// Encode encodes a list of texts into embeddings -func (z *MinimaxModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Embed embeds a list of texts into embeddings +func (z *MinimaxModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { return nil, fmt.Errorf("not implemented") } @@ -447,3 +448,233 @@ func (z *MinimaxModel) CheckConnection(apiConfig *APIConfig) error { func (z *MinimaxModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (z *MinimaxModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MinimaxModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (z *MinimaxModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("MiniMax API key is missing") + } + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("text content is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "model": modelName, + "text": audioContent, + } + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["audio_setting"] = map[string]interface{}{ + "format": ttsConfig.Format, + } + } + reqBody["stream"] = false + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("MiniMax TTS API error: status %d, body: %s", resp.StatusCode, string(body)) + } + + var result struct { + BaseResp struct { + StatusCode int `json:"status_code"` + StatusMsg string `json:"status_msg"` + } `json:"base_resp"` + Data struct { + Audio string `json:"audio"` // HEX + } `json:"data"` + } + + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + if result.BaseResp.StatusCode != 0 { + return nil, fmt.Errorf("MiniMax TTS returned error: %d - %s", result.BaseResp.StatusCode, result.BaseResp.StatusMsg) + } + + // format HEX + audioBytes, err := hex.DecodeString(result.Data.Audio) + if err != nil { + return nil, fmt.Errorf("failed to decode MiniMax hex audio: %w", err) + } + + return &TTSResponse{ + Audio: audioBytes, + }, nil +} + +func (z *MinimaxModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("MiniMax API key is missing") + } + if audioContent == nil || *audioContent == "" { + return fmt.Errorf("text content is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := strings.TrimSuffix(z.BaseURL[region], "/") + if baseURL == "" { + baseURL = strings.TrimSuffix(z.BaseURL["default"], "/") + } + suffix := strings.TrimPrefix(z.URLSuffix.TTS, "/") + if suffix == "" { + suffix = "v1/t2a_v2" + } + url := fmt.Sprintf("%s/%s", baseURL, suffix) + + reqBody := map[string]interface{}{ + "model": modelName, + "text": audioContent, + } + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + reqBody["stream"] = false + + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["audio_setting"] = map[string]interface{}{ + "format": ttsConfig.Format, + } + } + + reqBody["stream"] = true + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", strings.TrimSpace(*apiConfig.ApiKey))) + + resp, err := z.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("MiniMax stream TTS API error: %d, body: %s", resp.StatusCode, string(body)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 2*1024*1024) + + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + dataStr := strings.TrimSpace(line[5:]) + if dataStr == "" { + continue + } + + var event struct { + Data struct { + Audio string `json:"audio"` + Status int `json:"status"` + } `json:"data"` + } + + if err := json.Unmarshal([]byte(dataStr), &event); err != nil { + continue + } + + if event.Data.Audio != "" { + audioBytes, err := hex.DecodeString(event.Data.Audio) + if err == nil && len(audioBytes) > 0 { + chunk := string(audioBytes) + if errSend := sender(&chunk, nil); errSend != nil { + return errSend + } + } + } + + if event.Data.Status == 2 { + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading minimax stream: %w", err) + } + + return nil +} + +// OCRFile OCR file +func (m *MinimaxModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *MinimaxModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MinimaxModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MinimaxModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/mistral.go b/internal/entity/models/mistral.go new file mode 100644 index 00000000000..1b526a87763 --- /dev/null +++ b/internal/entity/models/mistral.go @@ -0,0 +1,682 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// MistralModel implements ModelDriver for Mistral AI. +// +// Mistral exposes an OpenAI-compatible REST API at https://api.mistral.ai/v1 +// (chat completions at /chat/completions, list models at /models). The wire +// shape matches OpenAI closely enough that the chat path here is a direct +// port of the OpenAI driver, with the differences kept small on purpose: +// no reasoning_content pass-through (Mistral does not expose one), and a +// distinct Name() so the factory can route to this driver. +type MistralModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewMistralModel creates a new Mistral model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +// +// The Client itself has no Timeout. http.Client.Timeout would also +// cap the time spent reading the response body, which would cut off +// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming +// callers wrap each request with context.WithTimeout instead. +func NewMistralModel(baseURL map[string]string, urlSuffix URLSuffix) *MistralModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &MistralModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (m *MistralModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewMistralModel(baseURL, m.URLSuffix) +} + +func (m *MistralModel) Name() string { + return "mistral" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. This makes a misconfigured region fail +// fast with a clear message, instead of silently producing a relative +// URL that the HTTP transport then rejects. +func (m *MistralModel) baseURLForRegion(region string) (string, error) { + base, ok := m.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("mistral: no base URL configured for region %q", region) + } + return base, nil +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (m *MistralModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, m.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + // Note: do NOT propagate chatModelConfig.Stream into the request body + // here. ChatWithMessages parses a single JSON response, so stream must + // always be off for this code path. + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + emptyReason := "" + return &ChatResponse{ + Answer: &content, + ReasonContent: &emptyReason, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. The Mistral SSE stream uses the same shape as OpenAI: +// "data:" lines carrying JSON events, with a final "[DONE]" line. +func (m *MistralModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, m.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + // Refuse to run if the caller explicitly asked for stream=false. + // The body of this method only knows how to read SSE, so a + // non-SSE JSON response would be parsed as if it were a stream + // and produce no chunks. Better to fail clearly. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // Use an explicit background context. SSE streams are long-lived + // so we do not attach a hard deadline here; the transport's + // ResponseHeaderTimeout caps the connection-establishment phase. + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: bump the scanner buffer from the 64KB default to 1MB + // so we never silently truncate a long data: line. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + + if data == "[DONE]" { + sawTerminal = true + break + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("mistral: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +type mistralEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type mistralEmbeddingResponse struct { + Data []mistralEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` +} + +// Embed turns a list of texts into embedding vectors using the +// Mistral /v1/embeddings endpoint (mistral-embed). The output has +// one vector per input, in the same order the inputs were given. +func (m *MistralModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, m.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Mistral embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed mistralEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Reorder the returned vectors by their reported index so the output + // always lines up with the input texts, even if the upstream API ever + // returns items out of order. A nil slot at the end indicates the + // upstream did not return an embedding for that input. + embeddings := make([]EmbeddingData, len(texts)) + filled := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("mistral: response index %d out of range for %d inputs", item.Index, len(texts)) + } + if filled[item.Index] { + // A malformed response that repeats the same index would + // silently overwrite the earlier vector. Fail loudly so + // the caller never uses ambiguous output. + return nil, fmt.Errorf("mistral: duplicate embedding index %d in response", item.Index) + } + embeddings[item.Index] = EmbeddingData{ + Embedding: item.Embedding, + Index: item.Index, + } + filled[item.Index] = true + } + for i, ok := range filled { + if !ok { + return nil, fmt.Errorf("mistral: missing embedding for input index %d", i) + } + } + + return embeddings, nil +} + +// ListModels returns the list of model ids visible to the API key. +func (m *MistralModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := m.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, m.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := m.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid models list format") + } + + models := make([]string, 0) + for _, model := range data { + modelMap, ok := model.(map[string]interface{}) + if !ok { + continue + } + modelName, ok := modelMap["id"].(string) + if !ok { + continue + } + models = append(models, modelName) + } + + return models, nil +} + +// Balance is not exposed by the Mistral API, so this returns "no such method". +func (m *MistralModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +// CheckConnection runs a lightweight ListModels call to verify the API key. +func (m *MistralModel) CheckConnection(apiConfig *APIConfig) error { + _, err := m.ListModels(apiConfig) + if err != nil { + return err + } + return nil +} + +// Rerank calculates similarity scores between query and documents. Mistral +// does not expose a public rerank API, so this returns "no such method". +func (m *MistralModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method") +} + +// TranscribeAudio transcribe audio +func (z *MistralModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MistralModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (z *MistralModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MistralModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (z *MistralModel) OCRFile(modelName *string, content []byte, urls *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + if (urls == nil || *urls == "") && (content == nil || len(content) == 0) { + return nil, fmt.Errorf("file url or content is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.OCR) + + var docURL string + if urls != nil && *urls != "" { + docURL = *urls + } else { + mimeType := http.DetectContentType(content) + base64Str := base64.StdEncoding.EncodeToString(content) + docURL = fmt.Sprintf("data:%s;base64,%s", mimeType, base64Str) + } + + reqData := map[string]interface{}{ + "model": *modelName, + "document": map[string]interface{}{ + "type": "document_url", + "document_url": docURL, + }, + } + + jsonData, err := json.Marshal(reqData) + if err != nil { + return nil, fmt.Errorf("failed to marshal json payload: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Mistral OCR API error: %s, body: %s", resp.Status, string(body)) + } + + var mistralResp struct { + Pages []struct { + Index int `json:"index"` + Markdown string `json:"markdown"` + } `json:"pages"` + } + + if err = json.Unmarshal(body, &mistralResp); err != nil { + return nil, fmt.Errorf("failed to parse response json: %w", err) + } + + var fullMarkdown strings.Builder + for _, page := range mistralResp.Pages { + fullMarkdown.WriteString(page.Markdown) + fullMarkdown.WriteString("\n\n") + } + + resultText := strings.TrimSpace(fullMarkdown.String()) + + return &OCRFileResponse{ + Text: &resultText, + }, nil +} + +func (z *MistralModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + //TODO implement me + panic("implement me") +} + +func (z *MistralModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("no such method", z.Name()) +} + +func (z *MistralModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("no such method", z.Name()) +} diff --git a/internal/entity/models/mistral_test.go b/internal/entity/models/mistral_test.go new file mode 100644 index 00000000000..dc7f318e143 --- /dev/null +++ b/internal/entity/models/mistral_test.go @@ -0,0 +1,574 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "sync/atomic" + "testing" +) + +// newMistralServer stands up an httptest server that asserts the +// request shape and lets the caller decide what to return. +func newMistralServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if r.Method == http.MethodPost { + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("expected Content-Type=application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("invalid JSON body: %v\n%s", err, string(raw)) + return + } + handler(t, body, w) + return + } + // GET path: no body + handler(t, nil, w) + })) +} + +func newMistralForTest(baseURL string) *MistralModel { + return NewMistralModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "chat/completions", + Models: "models", + Embedding: "embeddings", + }, + ) +} + +func TestMistralName(t *testing.T) { + m := newMistralForTest("http://unused") + if got := m.Name(); got != "mistral" { + t.Errorf("Name()=%q, want %q", got, "mistral") + } +} + +func TestMistralChatHappyPath(t *testing.T) { + srv := newMistralServer(t, "/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "mistral-large-latest" { + t.Errorf("expected model=mistral-large-latest, got %v", body["model"]) + } + if body["stream"] != false { + t.Errorf("expected stream=false, got %v", body["stream"]) + } + msgs, ok := body["messages"].([]interface{}) + if !ok || len(msgs) != 1 { + t.Errorf("expected 1 message, got %v", body["messages"]) + return + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{ + {"message": map[string]interface{}{"content": "pong"}}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + resp, err := m.ChatWithMessages("mistral-large-latest", []Message{ + {Role: "user", Content: "ping"}, + }, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "pong" { + t.Errorf("answer=%v, want pong", resp.Answer) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "" { + t.Errorf("expected empty reason content, got %v", resp.ReasonContent) + } +} + +func TestMistralChatPropagatesConfig(t *testing.T) { + srv := newMistralServer(t, "/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["max_tokens"] != float64(64) { + t.Errorf("max_tokens=%v want 64", body["max_tokens"]) + } + if body["temperature"] != 0.3 { + t.Errorf("temperature=%v want 0.3", body["temperature"]) + } + if body["top_p"] != 0.9 { + t.Errorf("top_p=%v want 0.9", body["top_p"]) + } + stop, ok := body["stop"].([]interface{}) + if !ok || len(stop) != 1 || stop[0] != "END" { + t.Errorf("stop=%v want [END]", body["stop"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{"message": map[string]interface{}{"content": "ok"}}}, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + mt := 64 + temp := 0.3 + topP := 0.9 + stop := []string{"END"} + _, err := m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &mt, Temperature: &temp, TopP: &topP, Stop: &stop}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } +} + +func TestMistralChatRequiresAPIKey(t *testing.T) { + m := newMistralForTest("http://unused") + _, err := m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "x"}}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } + emptyKey := "" + _, err = m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &emptyKey}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("empty key: expected api-key error, got %v", err) + } +} + +func TestMistralChatRequiresMessages(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + _, err := m.ChatWithMessages("mistral-large-latest", nil, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "messages is empty") { + t.Errorf("expected messages-empty error, got %v", err) + } +} + +func TestMistralChatRejectsHTTPError(t *testing.T) { + srv := newMistralServer(t, "/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + _, err := m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "401") { + t.Errorf("expected 401 propagated, got %v", err) + } +} + +func TestMistralChatFallsBackToDefaultOnEmptyRegion(t *testing.T) { + // Empty *Region pointer must fall back to the "default" entry, not + // be treated as an explicit "" region (which would miss the lookup). + srv := newMistralServer(t, "/chat/completions", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{"message": map[string]interface{}{"content": "ok"}}}, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + emptyRegion := "" + _, err := m.ChatWithMessages("mistral-large-latest", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey, Region: &emptyRegion}, nil) + if err != nil { + t.Errorf("empty Region: expected fallback to default, got %v", err) + } +} + +func TestMistralListModelsFallsBackToDefaultOnEmptyRegion(t *testing.T) { + srv := newMistralServer(t, "/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{{"id": "x"}}}) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + emptyRegion := "" + if _, err := m.ListModels(&APIConfig{ApiKey: &apiKey, Region: &emptyRegion}); err != nil { + t.Errorf("empty Region: expected fallback to default, got %v", err) + } +} + +func TestMistralStreamRequiresSender(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + err := m.ChatStreamlyWithSender("mistral-large-latest", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, nil) + if err == nil || !strings.Contains(err.Error(), "sender is required") { + t.Errorf("expected sender-required error, got %v", err) + } +} + +func TestMistralChatRejectsUnknownRegion(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + region := "eu" + _, err := m.ChatWithMessages("mistral-large-latest", []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey, Region: ®ion}, nil) + if err == nil || !strings.Contains(err.Error(), "no base URL configured for region") { + t.Errorf("expected region error, got %v", err) + } +} + +func TestMistralStreamHappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/chat/completions" { + t.Errorf("path=%s", r.URL.Path) + return + } + raw, _ := io.ReadAll(r.Body) + var body map[string]interface{} + _ = json.Unmarshal(raw, &body) + if body["stream"] != true { + t.Errorf("expected stream=true, got %v", body["stream"]) + } + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + // Two content chunks then finish_reason terminator, then [DONE]. + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"content":"Hello "}}]}`+"\n"+ + `data: {"choices":[{"delta":{"content":"world"}}]}`+"\n"+ + `data: {"choices":[{"delta":{},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + })) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + var chunks []string + var sawDone int32 + err := m.ChatStreamlyWithSender("mistral-large-latest", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(content *string, _ *string) error { + if content == nil { + return nil + } + if *content == "[DONE]" { + atomic.StoreInt32(&sawDone, 1) + return nil + } + chunks = append(chunks, *content) + return nil + }, + ) + if err != nil { + t.Fatalf("stream: %v", err) + } + if strings.Join(chunks, "") != "Hello world" { + t.Errorf("chunks=%v want [\"Hello \" \"world\"]", chunks) + } + if atomic.LoadInt32(&sawDone) != 1 { + t.Error("expected sender to receive [DONE] sentinel") + } +} + +func TestMistralStreamRejectsExplicitFalse(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + stream := false + err := m.ChatStreamlyWithSender("mistral-large-latest", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Stream: &stream}, + func(*string, *string) error { return nil }, + ) + if err == nil || !strings.Contains(err.Error(), "stream must be true") { + t.Errorf("expected stream-true guard, got %v", err) + } +} + +func TestMistralStreamFailsWithoutTerminal(t *testing.T) { + // Body closes before [DONE] or a finish_reason -> driver must complain + // instead of pretending the stream finished cleanly. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, `data: {"choices":[{"delta":{"content":"half"}}]}`+"\n") + })) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + err := m.ChatStreamlyWithSender("mistral-large-latest", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(*string, *string) error { return nil }, + ) + if err == nil || !strings.Contains(err.Error(), "stream ended before") { + t.Errorf("expected stream-truncation error, got %v", err) + } +} + +func TestMistralListModelsHappyPath(t *testing.T) { + srv := newMistralServer(t, "/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"id": "mistral-large-latest"}, + {"id": "mistral-small-latest"}, + {"id": "mistral-embed"}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + ids, err := m.ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if len(ids) != 3 || ids[0] != "mistral-large-latest" || ids[2] != "mistral-embed" { + t.Errorf("ids=%v, want [mistral-large-latest mistral-small-latest mistral-embed]", ids) + } +} + +func TestMistralListModelsRequiresAPIKey(t *testing.T) { + m := newMistralForTest("http://unused") + if _, err := m.ListModels(&APIConfig{}); err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestMistralCheckConnectionDelegatesToListModels(t *testing.T) { + // 200 -> CheckConnection succeeds; 401 -> CheckConnection propagates. + okSrv := newMistralServer(t, "/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{{"id": "x"}}}) + }) + defer okSrv.Close() + failSrv := newMistralServer(t, "/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + }) + defer failSrv.Close() + + apiKey := "test-key" + mOK := newMistralForTest(okSrv.URL) + if err := mOK.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil { + t.Errorf("CheckConnection(ok): %v", err) + } + mFail := newMistralForTest(failSrv.URL) + if err := mFail.CheckConnection(&APIConfig{ApiKey: &apiKey}); err == nil { + t.Error("CheckConnection(fail): expected error, got nil") + } +} + +func TestMistralBalanceReturnsNoSuchMethod(t *testing.T) { + m := newMistralForTest("http://unused") + _, err := m.Balance(&APIConfig{}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Balance: expected 'no such method', got %v", err) + } +} + +func TestMistralRerankReturnsNoSuchMethod(t *testing.T) { + m := newMistralForTest("http://unused") + q := "mistral-large-latest" + _, err := m.Rerank(&q, "what is rag?", []string{"a", "b"}, &APIConfig{}, &RerankConfig{TopN: 2}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank: expected 'no such method', got %v", err) + } +} + +func TestMistralEmbedHappyPath(t *testing.T) { + srv := newMistralServer(t, "/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "mistral-embed" { + t.Errorf("model=%v want mistral-embed", body["model"]) + } + inputs, ok := body["input"].([]interface{}) + if !ok || len(inputs) != 3 { + t.Errorf("input=%v want 3-element array", body["input"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{0.1, 0.2}, "index": 0}, + {"embedding": []float64{0.3, 0.4}, "index": 1}, + {"embedding": []float64{0.5, 0.6}, "index": 2}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + vecs, err := m.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + if len(vecs) != 3 { + t.Fatalf("len(vecs)=%d want 3", len(vecs)) + } + if vecs[1].Embedding[0] != 0.3 || vecs[1].Index != 1 { + t.Errorf("vecs[1]=%+v want {Embedding:[0.3 0.4] Index:1}", vecs[1]) + } +} + +func TestMistralEmbedReordersByIndex(t *testing.T) { + // Upstream returns the three vectors in shuffled order. The driver + // must reorder them so the slot at position i corresponds to input i. + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{2}, "index": 2}, + {"embedding": []float64{0}, "index": 0}, + {"embedding": []float64{1}, "index": 1}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + vecs, err := m.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + for i, v := range vecs { + if v.Index != i || v.Embedding[0] != float64(i) { + t.Errorf("slot %d = %+v, want Embedding=[%d] Index=%d", i, v, i, i) + } + } +} + +func TestMistralEmbedEmptyInputShortCircuits(t *testing.T) { + // Empty input must NOT make an HTTP call; the test fails the request + // rather than the assertion if it does. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("Embed([]) made an unexpected HTTP call") + w.WriteHeader(http.StatusInternalServerError) + })) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + vecs, err := m.Embed(&model, []string{}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed([]): %v", err) + } + if len(vecs) != 0 { + t.Errorf("len(vecs)=%d want 0", len(vecs)) + } +} + +func TestMistralEmbedRequiresAPIKey(t *testing.T) { + m := newMistralForTest("http://unused") + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestMistralEmbedRequiresModelName(t *testing.T) { + m := newMistralForTest("http://unused") + apiKey := "test-key" + _, err := m.Embed(nil, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } + empty := "" + _, err = m.Embed(&empty, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("empty model: expected model-name error, got %v", err) + } +} + +func TestMistralEmbedRejectsDuplicateIndex(t *testing.T) { + // A malformed upstream that repeats data[*].index would silently + // overwrite the earlier vector; the driver must fail loudly instead. + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + {"embedding": []float64{2}, "index": 0}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") { + t.Errorf("expected duplicate-index error, got %v", err) + } +} + +func TestMistralEmbedRejectsOutOfRangeIndex(t *testing.T) { + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 7}, // out of range for 2-input request + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestMistralEmbedRejectsMissingSlot(t *testing.T) { + // Upstream returns only one of the two requested embeddings. + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + }, + }) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "missing embedding for input index 1") { + t.Errorf("expected missing-embedding error for slot 1, got %v", err) + } +} + +func TestMistralEmbedRejectsHTTPError(t *testing.T) { + srv := newMistralServer(t, "/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + }) + defer srv.Close() + + m := newMistralForTest(srv.URL) + apiKey := "test-key" + model := "mistral-embed" + _, err := m.Embed(&model, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "Mistral embeddings API error") { + t.Errorf("expected Mistral embeddings API error, got %v", err) + } +} diff --git a/internal/entity/models/moonshot.go b/internal/entity/models/moonshot.go index 68af2fada8d..fa1ad76ec4e 100644 --- a/internal/entity/models/moonshot.go +++ b/internal/entity/models/moonshot.go @@ -357,8 +357,8 @@ func (k *MoonshotModel) ChatStreamlyWithSender(modelName string, messages []Mess return scanner.Err() } -// Encode encodes a list of texts into embeddings -func (z *MoonshotModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +// Embed embeds a list of texts into embeddings +func (z *MoonshotModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { return nil, fmt.Errorf("not implemented") } @@ -487,3 +487,39 @@ func (z *MoonshotModel) CheckConnection(apiConfig *APIConfig) error { func (z *MoonshotModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (z *MoonshotModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MoonshotModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (z *MoonshotModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MoonshotModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *MoonshotModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *MoonshotModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MoonshotModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *MoonshotModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/novita.go b/internal/entity/models/novita.go new file mode 100644 index 00000000000..33e945f6134 --- /dev/null +++ b/internal/entity/models/novita.go @@ -0,0 +1,778 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// NovitaModel implements ModelDriver for Novita.ai +// (https://novita.ai/docs/api-reference/). +// +// Novita exposes an OpenAI-compatible REST API at +// https://api.novita.ai/v3/openai (chat completions at +// /chat/completions, list models at /models). It serves a large +// catalog of third-party models (DeepSeek, Llama, Qwen3, Kimi, +// Gemma, Mistral, etc.) behind a single OpenAI-shaped surface. +// +// The wire shape matches OpenAI standard with ONE notable +// difference: reasoning models like qwen3-* embed their +// chain-of-thought INLINE inside content as ... +// tags, rather than in a separate reasoning_content field. The +// driver detects those tags and routes the inner text to +// ChatResponse.ReasonContent (non-stream) or the sender's second +// arg (stream), keeping the answer clean of tag clutter. +type NovitaModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewNovitaModel creates a new Novita model instance. +// +// Same transport convention as other Go drivers in this package: +// clone http.DefaultTransport, override the connection-pool fields, +// no client-level Timeout so SSE streams are not capped. +func NewNovitaModel(baseURL map[string]string, urlSuffix URLSuffix) *NovitaModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &NovitaModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (n *NovitaModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewNovitaModel(baseURL, n.URLSuffix) +} + +func (n *NovitaModel) Name() string { + return "novita" +} + +func (n *NovitaModel) baseURLForRegion(region string) (string, error) { + base, ok := n.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("novita: no base URL configured for region %q", region) + } + // Strip a trailing "/" so callers can safely do + // fmt.Sprintf("%s/%s", base, suffix) without producing "//" in + // the path. The shipped config has no trailing slash, but a + // tenant can override the URL per-instance and may add one. + return strings.TrimSuffix(base, "/"), nil +} + +const ( + novitaThinkOpen = "" + novitaThinkClose = "" +) + +// splitNovitaThink walks a complete content string and returns the +// visible portion + the concatenated chain-of-thought from inside +// any ... blocks. Multiple think blocks are +// concatenated; tags themselves are stripped. Used by the +// non-streaming path where the whole content is available at once. +func splitNovitaThink(raw string) (visible, reasoning string) { + var v, r strings.Builder + inside := false + for { + var marker string + if inside { + marker = novitaThinkClose + } else { + marker = novitaThinkOpen + } + idx := strings.Index(raw, marker) + if idx < 0 { + if inside { + r.WriteString(raw) + } else { + v.WriteString(raw) + } + break + } + if inside { + r.WriteString(raw[:idx]) + } else { + v.WriteString(raw[:idx]) + } + raw = raw[idx+len(marker):] + inside = !inside + } + return v.String(), r.String() +} + +// novitaThinkExtractor maintains state across streaming chunks so +// that a ... block spanning multiple SSE events still +// gets split correctly between content and reasoning. The buffer +// preserves up to (len(closingMarker)-1) trailing bytes of each +// chunk in case the next chunk completes a partial tag. +type novitaThinkExtractor struct { + buf strings.Builder + inside bool +} + +// novitaThinkSegment is one routing decision: emit `content` via the +// sender's first arg, or emit `reasoning` via the sender's second arg. +// Exactly one of the two fields is non-empty. +type novitaThinkSegment struct { + content string + reasoning string +} + +// Feed appends an incoming chunk and returns any segments that are +// now safe to emit. Trailing bytes that could be the start of a tag +// are held back in the buffer until the next call. +func (e *novitaThinkExtractor) Feed(chunk string) []novitaThinkSegment { + e.buf.WriteString(chunk) + s := e.buf.String() + var out []novitaThinkSegment + for { + var marker, otherMarker string + if e.inside { + marker = novitaThinkClose + otherMarker = novitaThinkOpen + } else { + marker = novitaThinkOpen + otherMarker = novitaThinkClose + } + idx := strings.Index(s, marker) + if idx < 0 { + // No closing/opening marker yet. Emit everything except a + // possible partial-tag suffix at the very end. Reserve + // (max marker length - 1) trailing bytes so we don't + // emit "". + reserve := len(marker) - 1 + if len(otherMarker)-1 > reserve { + reserve = len(otherMarker) - 1 + } + safe := len(s) - reserve + if safe < 0 { + safe = 0 + } + // Don't reserve if the trailing bytes can't possibly be + // the start of a tag (no '<' suffix). + if safe < len(s) && !strings.Contains(s[safe:], "<") { + safe = len(s) + } + if safe > 0 { + if e.inside { + out = append(out, novitaThinkSegment{reasoning: s[:safe]}) + } else { + out = append(out, novitaThinkSegment{content: s[:safe]}) + } + s = s[safe:] + } + break + } + if idx > 0 { + if e.inside { + out = append(out, novitaThinkSegment{reasoning: s[:idx]}) + } else { + out = append(out, novitaThinkSegment{content: s[:idx]}) + } + } + s = s[idx+len(marker):] + e.inside = !e.inside + } + e.buf.Reset() + e.buf.WriteString(s) + return out +} + +// Flush returns the buffered tail when the stream ends. A stream that +// ends mid-tag would not normally happen with a well-behaved upstream, +// but if it does the partial bytes are emitted according to the +// current mode so nothing is silently lost. +func (e *novitaThinkExtractor) Flush() *novitaThinkSegment { + s := e.buf.String() + e.buf.Reset() + if s == "" { + return nil + } + if e.inside { + return &novitaThinkSegment{reasoning: s} + } + return &novitaThinkSegment{content: s} +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (n *NovitaModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := n.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, n.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + // Map ChatConfig.Thinking -> Novita's `enable_thinking`. + // Per https://novita.ai/docs/api-reference/model-apis-llm-create-chat-completion, + // enable_thinking (boolean | null, default true) "controls the + // switches between thinking and non-thinking modes" for + // zai-org/glm-4.5, deepseek/deepseek-v3.1[-terminus|-exp]. For + // models outside that supported set Novita ignores the field, + // so it's safe to forward whenever the caller opts in. Tenants + // can now disable thinking mode at request time without having + // to use prompt-level hacks like "/no_think". + if chatModelConfig.Thinking != nil { + reqBody["enable_thinking"] = *chatModelConfig.Thinking + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := n.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + rawContent, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + // Novita emits chain-of-thought in two different shapes depending + // on the model and on enable_thinking: + // - qwen3-* and other inline-style models: chain-of-thought is + // embedded inside content as ... tags. + // - deepseek-v3.1 / glm-4.5 (and any model with separate + // reasoning enabled): chain-of-thought arrives in a separate + // `reasoning_content` field, with `content` already cleaned. + // Handle both so the visible Answer is always tag-free and any + // reasoning the upstream supplied is preserved. + visible, reasoning := splitNovitaThink(rawContent) + if r, ok := messageMap["reasoning_content"].(string); ok && r != "" { + if reasoning != "" { + reasoning += "\n" + r + } else { + reasoning = r + } + } + + return &ChatResponse{ + Answer: &visible, + ReasonContent: &reasoning, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via +// the sender. Handles both reasoning shapes Novita can emit: +// - delta.reasoning_content (deepseek-v3.1 / glm-4.5 / any model +// with separate reasoning): forwarded as-is to the second arg. +// - delta.content containing ... (qwen3-* and other +// inline-style models): a stateful extractor splits tag bytes +// across SSE chunk boundaries, then routes content/reasoning to +// the first/second sender arg respectively. +func (n *NovitaModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := n.baseURLForRegion(region) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, n.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + // See ChatWithMessages for why we forward this. + if chatModelConfig.Thinking != nil { + reqBody["enable_thinking"] = *chatModelConfig.Thinking + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := n.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + extractor := &novitaThinkExtractor{} + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(line[5:]) + if data == "[DONE]" { + sawTerminal = true + break + } + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + // deepseek-v3.1 / glm-4.5 (and other models that emit reasoning + // separately) put chain-of-thought in delta.reasoning_content + // rather than inside content as .... Surface it + // before any content from the same chunk so callers piping to + // a UI render reasoning before the visible answer for that + // token, matching the wire ordering Novita emits. + if r, ok := delta["reasoning_content"].(string); ok && r != "" { + rr := r + if err := sender(nil, &rr); err != nil { + return err + } + } + if c, ok := delta["content"].(string); ok && c != "" { + for _, seg := range extractor.Feed(c) { + if seg.reasoning != "" { + r := seg.reasoning + if err := sender(nil, &r); err != nil { + return err + } + } + if seg.content != "" { + cc := seg.content + if err := sender(&cc, nil); err != nil { + return err + } + } + } + } + if finish, ok := firstChoice["finish_reason"].(string); ok && finish != "" { + sawTerminal = true + break + } + } + + // Flush any buffered tail (rare, but covers the case where the + // stream ends right after the last chunk without us seeing the + // closing tag). + if seg := extractor.Flush(); seg != nil { + if seg.reasoning != "" { + r := seg.reasoning + if err := sender(nil, &r); err != nil { + return err + } + } + if seg.content != "" { + cc := seg.content + if err := sender(&cc, nil); err != nil { + return err + } + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("novita: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +// ListModels returns the list of model ids visible to the API key. +func (n *NovitaModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := n.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, n.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := n.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid models list format") + } + + models := make([]string, 0) + for _, model := range data { + modelMap, ok := model.(map[string]interface{}) + if !ok { + continue + } + modelName, ok := modelMap["id"].(string) + if !ok { + continue + } + models = append(models, modelName) + } + return models, nil +} + +// CheckConnection runs a lightweight ListModels call to verify the API key. +func (n *NovitaModel) CheckConnection(apiConfig *APIConfig) error { + _, err := n.ListModels(apiConfig) + return err +} + +type novitaEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type novitaEmbeddingResponse struct { + Data []novitaEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` +} + +// Embed turns a list of texts into embedding vectors using the Novita +// /v3/embeddings endpoint. The output has one vector per input, in the +// same order the inputs were given. +func (n *NovitaModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := n.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, n.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := n.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Novita embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed novitaEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + embeddings := make([]EmbeddingData, len(texts)) + filled := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("novita: response index %d out of range for %d inputs", item.Index, len(texts)) + } + if filled[item.Index] { + return nil, fmt.Errorf("novita: duplicate embedding index %d in response", item.Index) + } + embeddings[item.Index] = EmbeddingData{ + Embedding: item.Embedding, + Index: item.Index, + } + filled[item.Index] = true + } + for i, ok := range filled { + if !ok { + return nil, fmt.Errorf("novita: missing embedding for input index %d", i) + } + } + + return embeddings, nil +} + +// Rerank is not exposed by the Novita API. +func (n *NovitaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, no such method", n.Name()) +} + +// Balance is not exposed by the Novita API. +func (n *NovitaModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", n.Name()) +} + +func (n *NovitaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", n.Name()) +} + +func (n *NovitaModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", n.Name()) +} + +func (n *NovitaModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", n.Name()) +} + +func (n *NovitaModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", n.Name()) +} + +// OCRFile OCR file +func (n *NovitaModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", n.Name()) +} + +// ParseFile parse file +func (z *NovitaModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *NovitaModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *NovitaModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/novita_test.go b/internal/entity/models/novita_test.go new file mode 100644 index 00000000000..29cbdace18c --- /dev/null +++ b/internal/entity/models/novita_test.go @@ -0,0 +1,776 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newNovitaServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + // Content-Type must declare JSON on BOTH POST chat (body is + // JSON) and GET ListModels (Novita platform expects callers to + // negotiate JSON content even though the body is empty — + // maintainer review explicitly flagged the missing header). + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("expected Content-Type to start with application/json, got %q", got) + return + } + if r.Method == http.MethodPost { + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v\nraw=%s", err, string(raw)) + return + } + handler(t, body, w) + return + } + handler(t, nil, w) + })) +} + +func newNovitaForTest(baseURL string) *NovitaModel { + return NewNovitaModel( + map[string]string{"default": baseURL}, + URLSuffix{Chat: "openai/v1/chat/completions", Models: "openai/v1/models"}, + ) +} + +// newNovitaSSEServer asserts the SSE-chat wire contract (POST, path, +// Authorization, Content-Type) the same way newNovitaServer does for +// the JSON-chat path, then writes the supplied SSE payload. Closes +// the gap CodeRabbit flagged where streaming tests used +// httptest.NewServer directly and skipped the request-shape checks. +func newNovitaSSEServer(t *testing.T, expectedPath, ssePayload string) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + return + } + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("expected Content-Type to start with application/json, got %q", got) + return + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, ssePayload) + })) +} + +// ---- think-tag split helpers ---- + +func TestSplitNovitaThinkPureText(t *testing.T) { + v, r := splitNovitaThink("hello world") + if v != "hello world" || r != "" { + t.Errorf("got (%q,%q)", v, r) + } +} + +func TestSplitNovitaThinkSingleBlock(t *testing.T) { + v, r := splitNovitaThink("15% = 0.15. 0.15*80 = 12.The answer is 12.") + if v != "The answer is 12." { + t.Errorf("visible=%q", v) + } + if r != "15% = 0.15. 0.15*80 = 12." { + t.Errorf("reasoning=%q", r) + } +} + +func TestSplitNovitaThinkLeadingText(t *testing.T) { + v, r := splitNovitaThink("intro thoughttail") + if v != "intro tail" { + t.Errorf("visible=%q", v) + } + if r != "thought" { + t.Errorf("reasoning=%q", r) + } +} + +func TestSplitNovitaThinkMultipleBlocks(t *testing.T) { + v, r := splitNovitaThink("Apart1Bpart2") + if v != "part1part2" { + t.Errorf("visible=%q", v) + } + if r != "AB" { + t.Errorf("reasoning=%q", r) + } +} + +func TestSplitNovitaThinkUnclosedTag(t *testing.T) { + // Unclosed -> everything after the open tag is reasoning; + // content stops at the open tag. This matches a real upstream that + // got cut off mid-reasoning by max_tokens. + v, r := splitNovitaThink("visible still thinking when tokens ran out") + if v != "visible " { + t.Errorf("visible=%q", v) + } + if r != "still thinking when tokens ran out" { + t.Errorf("reasoning=%q", r) + } +} + +// ---- streaming extractor ---- + +// Helper to push multiple chunks through and concatenate all output by +// kind. Each chunk goes into Feed; the output is what's safe to emit. +func feedAll(e *novitaThinkExtractor, chunks []string) (content, reasoning string) { + var cb, rb strings.Builder + for _, c := range chunks { + for _, seg := range e.Feed(c) { + cb.WriteString(seg.content) + rb.WriteString(seg.reasoning) + } + } + if seg := e.Flush(); seg != nil { + cb.WriteString(seg.content) + rb.WriteString(seg.reasoning) + } + return cb.String(), rb.String() +} + +func TestNovitaThinkExtractorSingleChunk(t *testing.T) { + e := &novitaThinkExtractor{} + c, r := feedAll(e, []string{"hello thought world"}) + if c != "hello world" { + t.Errorf("content=%q", c) + } + if r != "thought" { + t.Errorf("reasoning=%q", r) + } +} + +func TestNovitaThinkExtractorTagSpansChunks(t *testing.T) { + // "" split across two SSE deltas: "" + e := &novitaThinkExtractor{} + c, r := feedAll(e, []string{"hello thoughttail"}) + if c != "hello tail" { + t.Errorf("content=%q", c) + } + if r != "thought" { + t.Errorf("reasoning=%q", r) + } +} + +func TestNovitaThinkExtractorClosingTagSpansChunks(t *testing.T) { + // "" split across two deltas + e := &novitaThinkExtractor{} + c, r := feedAll(e, []string{"reasoningvisible"}) + if c != "visible" { + t.Errorf("content=%q", c) + } + if r != "reasoning" { + t.Errorf("reasoning=%q", r) + } +} + +func TestNovitaThinkExtractorTokenBoundaries(t *testing.T) { + // Simulate the kind of chunking we saw on the wire for qwen3 — many + // small chunks, sometimes splitting tag bytes. + e := &novitaThinkExtractor{} + c, r := feedAll(e, []string{ + "<", "think>", "Ok", "ay, ", "compute. ", "12", "."}) + if c != "12." { + t.Errorf("content=%q", c) + } + if r != "Okay, compute. " { + t.Errorf("reasoning=%q", r) + } +} + +func TestNovitaThinkExtractorNoTags(t *testing.T) { + e := &novitaThinkExtractor{} + c, r := feedAll(e, []string{"plain ", "content ", "all ", "the way"}) + if c != "plain content all the way" { + t.Errorf("content=%q", c) + } + if r != "" { + t.Errorf("reasoning=%q", r) + } +} + +func TestNovitaThinkExtractorLessThanIsNotTagStart(t *testing.T) { + // "<10" or "... embedded in content. + // Driver must split it into Answer + ReasonContent. + srv := newNovitaServer(t, "/openai/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "role": "assistant", + "content": "15% = 0.15; 0.15 * 80 = 12.The answer is 12.", + }, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + resp, err := newNovitaForTest(srv.URL).ChatWithMessages( + "qwen/qwen3-30b-a3b-fp8", + []Message{{Role: "user", Content: "15% of 80?"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if *resp.Answer != "The answer is 12." { + t.Errorf("Answer=%q", *resp.Answer) + } + if *resp.ReasonContent != "15% = 0.15; 0.15 * 80 = 12." { + t.Errorf("ReasonContent=%q", *resp.ReasonContent) + } +} + +// deepseek-v3.1 / glm-4.5 with enable_thinking=true return reasoning +// in a separate `reasoning_content` field on the message rather than +// inline as .... The driver must surface this field +// to ChatResponse.ReasonContent. Live-confirmed against +// api.novita.ai/openai/v1/chat/completions with deepseek/deepseek-v3.1. +func TestNovitaChatExtractsReasoningContentField(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "role": "assistant", + "content": "4", + "reasoning_content": "2+2 is straightforward arithmetic: the answer is 4.", + }, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + resp, err := newNovitaForTest(srv.URL).ChatWithMessages( + "deepseek/deepseek-v3.1", + []Message{{Role: "user", Content: "2+2?"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Answer == nil || resp.ReasonContent == nil { + t.Fatalf("Answer/ReasonContent must be non-nil pointers") + } + if *resp.Answer != "4" { + t.Errorf("Answer=%q", *resp.Answer) + } + if *resp.ReasonContent != "2+2 is straightforward arithmetic: the answer is 4." { + t.Errorf("ReasonContent=%q", *resp.ReasonContent) + } +} + +// Streaming deepseek-v3.1 with thinking on emits delta.reasoning_content +// (not delta.content with tags). The driver must forward +// those chunks via the sender's second arg. +func TestNovitaStreamExtractsDeltaReasoningContent(t *testing.T) { + srv := newNovitaSSEServer(t, "/openai/v1/chat/completions", + `data: {"choices":[{"index":0,"delta":{"role":"assistant","reasoning_content":"step 1. "}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"reasoning_content":"step 2."}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"final answer"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + defer srv.Close() + + apiKey := "test-key" + var content, reasoning []string + err := newNovitaForTest(srv.URL).ChatStreamlyWithSender( + "deepseek/deepseek-v3.1", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(c *string, r *string) error { + if c != nil && r != nil { + t.Errorf("sender called with both args non-nil") + } + if r != nil && *r != "" { + reasoning = append(reasoning, *r) + } + if c != nil && *c != "" && *c != "[DONE]" { + content = append(content, *c) + } + return nil + }) + if err != nil { + t.Fatalf("stream: %v", err) + } + if got := strings.Join(reasoning, ""); got != "step 1. step 2." { + t.Errorf("reasoning=%q", got) + } + if got := strings.Join(content, ""); got != "final answer" { + t.Errorf("content=%q", got) + } +} + +// TestNovitaChatPropagatesEnableThinking pins the maintainer's +// requested behaviour: when ChatConfig.Thinking is set, the driver +// MUST forward it as Novita's documented `enable_thinking` body field +// so a tenant can switch a deepseek-v3.1 / glm-4.5 / qwen3 deployment +// out of its default thinking mode without prompt-level hacks. +func TestNovitaChatPropagatesEnableThinking(t *testing.T) { + cases := []struct { + name string + value bool + }{ + {"enabled", true}, + {"disabled", false}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + got, present := body["enable_thinking"] + if !present { + t.Errorf("enable_thinking missing from body, want %v", tc.value) + } + if got != tc.value { + t.Errorf("enable_thinking=%v want %v", got, tc.value) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{"content": "ok"}, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + thinking := tc.value + _, err := newNovitaForTest(srv.URL).ChatWithMessages( + "qwen/qwen3-30b-a3b-fp8", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Thinking: &thinking}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + }) + } +} + +// Sending enable_thinking when the caller didn't ask for it would +// silently flip behavior for downstream proxies that distinguish +// "field absent" from "field present with default". Leave it out. +func TestNovitaChatOmitsEnableThinkingWhenUnset(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if _, present := body["enable_thinking"]; present { + t.Errorf("enable_thinking must be absent when Thinking unset, got %v", body["enable_thinking"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{"content": "ok"}, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + _, err := newNovitaForTest(srv.URL).ChatWithMessages("m", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{}) // no Thinking + if err != nil { + t.Fatalf("Chat: %v", err) + } +} + +// TestNovitaStreamPropagatesEnableThinking mirrors the non-stream +// case for ChatStreamlyWithSender so callers get the same toggle +// regardless of streaming mode. +func TestNovitaStreamPropagatesEnableThinking(t *testing.T) { + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"content":"ok"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + })) + defer srv.Close() + + apiKey := "test-key" + thinking := false + err := newNovitaForTest(srv.URL).ChatStreamlyWithSender( + "deepseek/deepseek-v3.1", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Thinking: &thinking}, + func(*string, *string) error { return nil }) + if err != nil { + t.Fatalf("Stream: %v", err) + } + if got, ok := seen["enable_thinking"].(bool); !ok || got != false { + t.Errorf("stream enable_thinking=%v want false", seen["enable_thinking"]) + } +} + +func TestNovitaChatRequiresAPIKey(t *testing.T) { + _, err := newNovitaForTest("http://unused").ChatWithMessages("m", + []Message{{Role: "user", Content: "x"}}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("got %v", err) + } +} + +func TestNovitaChatRequiresMessages(t *testing.T) { + apiKey := "test-key" + _, err := newNovitaForTest("http://unused").ChatWithMessages("m", nil, + &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "messages is empty") { + t.Errorf("got %v", err) + } +} + +func TestNovitaChatRejectsHTTPError(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/chat/completions", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"detail":"unauthorized"}`)) + }) + defer srv.Close() + + apiKey := "test-key" + _, err := newNovitaForTest(srv.URL).ChatWithMessages("m", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "401") { + t.Errorf("got %v", err) + } +} + +// Streaming: an SSE stream that emits ... inline in +// delta.content must surface reasoning chunks through the sender's +// second arg, and visible content through the first. +func TestNovitaStreamSplitsThinkTags(t *testing.T) { + // Simulate the realistic case where tags span deltas — split + // "" across two chunks, and split "" too. + srv := newNovitaSSEServer(t, "/openai/v1/chat/completions", + `data: {"choices":[{"index":0,"delta":{"role":"assistant"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"<"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"think>"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"Okay, "}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"compute. "}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"12"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"."},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + defer srv.Close() + + apiKey := "test-key" + var content, reasoning []string + err := newNovitaForTest(srv.URL).ChatStreamlyWithSender( + "qwen/qwen3-30b-a3b-fp8", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(c *string, r *string) error { + if c != nil && r != nil { + t.Errorf("sender called with both args non-nil") + } + if r != nil && *r != "" { + reasoning = append(reasoning, *r) + } + if c != nil && *c != "" && *c != "[DONE]" { + content = append(content, *c) + } + return nil + }) + if err != nil { + t.Fatalf("stream: %v", err) + } + gotContent := strings.Join(content, "") + gotReason := strings.Join(reasoning, "") + if gotContent != "12." { + t.Errorf("content=%q want %q", gotContent, "12.") + } + if gotReason != "Okay, compute. " { + t.Errorf("reasoning=%q", gotReason) + } +} + +// Streaming for a non-reasoning model that emits only content chunks +// must continue to work unchanged. +func TestNovitaStreamPureContent(t *testing.T) { + srv := newNovitaSSEServer(t, "/openai/v1/chat/completions", + `data: {"choices":[{"index":0,"delta":{"role":"assistant"}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"Hello "}}]}`+"\n"+ + `data: {"choices":[{"index":0,"delta":{"content":"world"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + defer srv.Close() + + apiKey := "test-key" + var chunks []string + var sawDone bool + err := newNovitaForTest(srv.URL).ChatStreamlyWithSender("meta-llama/llama-3.3-70b-instruct", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(c *string, _ *string) error { + if c == nil { + return nil + } + if *c == "[DONE]" { + sawDone = true + return nil + } + chunks = append(chunks, *c) + return nil + }) + if err != nil { + t.Fatalf("stream: %v", err) + } + if strings.Join(chunks, "") != "Hello world" { + t.Errorf("content=%v", chunks) + } + if !sawDone { + t.Error("expected [DONE] sentinel") + } +} + +func TestNovitaStreamRequiresSender(t *testing.T) { + apiKey := "test-key" + err := newNovitaForTest("http://unused").ChatStreamlyWithSender("m", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, nil) + if err == nil || !strings.Contains(err.Error(), "sender is required") { + t.Errorf("got %v", err) + } +} + +func TestNovitaStreamRejectsExplicitFalse(t *testing.T) { + apiKey := "test-key" + stream := false + err := newNovitaForTest("http://unused").ChatStreamlyWithSender("m", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Stream: &stream}, + func(*string, *string) error { return nil }) + if err == nil || !strings.Contains(err.Error(), "stream must be true") { + t.Errorf("got %v", err) + } +} + +func TestNovitaListModelsHappyPath(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"id": "meta-llama/llama-3.3-70b-instruct"}, + {"id": "qwen/qwen3-30b-a3b-fp8"}, + {"id": "deepseek/deepseek-v4-pro"}, + }, + }) + }) + defer srv.Close() + + apiKey := "test-key" + ids, err := newNovitaForTest(srv.URL).ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if len(ids) != 3 { + t.Errorf("ids=%v", ids) + } +} + +func TestNovitaCheckConnection(t *testing.T) { + srv := newNovitaServer(t, "/openai/v1/models", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{{"id": "x"}}}) + }) + defer srv.Close() + + apiKey := "test-key" + if err := newNovitaForTest(srv.URL).CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil { + t.Errorf("CheckConnection: %v", err) + } +} + +func TestNovitaEmbedReturnsNoSuchMethod(t *testing.T) { + m := "x" + _, err := newNovitaForTest("http://unused").Embed(&m, []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("got %v", err) + } +} + +func TestNovitaRerankReturnsNoSuchMethod(t *testing.T) { + m := "x" + _, err := newNovitaForTest("http://unused").Rerank(&m, "q", []string{"a"}, &APIConfig{}, &RerankConfig{TopN: 1}) + if err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("got %v", err) + } +} + +func TestNovitaBalanceReturnsNoSuchMethod(t *testing.T) { + if _, err := newNovitaForTest("http://unused").Balance(&APIConfig{}); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("got %v", err) + } +} + +func TestNovitaAudioOCRReturnNoSuchMethod(t *testing.T) { + m := "x" + v := newNovitaForTest("http://unused") + if _, err := v.TranscribeAudio(&m, &m, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("TranscribeAudio: %v", err) + } + if _, err := v.AudioSpeech(&m, &m, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("AudioSpeech: %v", err) + } + if _, err := v.OCRFile(&m, nil, &m, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("OCRFile: %v", err) + } +} + +// TestNovitaBaseURLTrimsTrailingSlash pins the fix for a `//`-in-path +// bug a tenant could hit by configuring a baseURL like +// "https://api.novita.ai/v3/openai/". Every URL the driver builds via +// fmt.Sprintf("%s/%s", base, suffix) would then produce a double +// slash. baseURLForRegion now trims the trailing "/" so all three +// endpoint builders (Chat, Stream, ListModels) emit clean paths. +func TestNovitaBaseURLTrimsTrailingSlash(t *testing.T) { + cases := []struct { + name string + path string + method string + invoke func(n *NovitaModel, apiKey string) error + urlSuffix URLSuffix + respBody string + respHeaders map[string]string + }{ + { + name: "Chat", + path: "/openai/v1/chat/completions", + method: http.MethodPost, + urlSuffix: URLSuffix{Chat: "openai/v1/chat/completions"}, + invoke: func(n *NovitaModel, apiKey string) error { + _, err := n.ChatWithMessages("m", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil) + return err + }, + respBody: `{"choices":[{"message":{"content":"ok"}}]}`, + }, + { + name: "ListModels", + path: "/openai/v1/models", + method: http.MethodGet, + urlSuffix: URLSuffix{Models: "openai/v1/models"}, + invoke: func(n *NovitaModel, apiKey string) error { + _, err := n.ListModels(&APIConfig{ApiKey: &apiKey}) + return err + }, + respBody: `{"data":[]}`, + }, + { + name: "Stream", + path: "/openai/v1/chat/completions", + method: http.MethodPost, + urlSuffix: URLSuffix{Chat: "openai/v1/chat/completions"}, + invoke: func(n *NovitaModel, apiKey string) error { + return n.ChatStreamlyWithSender("m", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(*string, *string) error { return nil }) + }, + respBody: `data: {"choices":[{"index":0,"delta":{"content":"hi"},"finish_reason":"stop"}]}` + "\n" + + `data: [DONE]` + "\n", + respHeaders: map[string]string{"Content-Type": "text/event-stream"}, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // The load-bearing assertion: path is the clean + // "/openai/v1/chat/completions" or "/openai/v1/models", never "//chat/...". + if r.URL.Path != tc.path { + t.Errorf("path=%q want %q (double-slash bug?)", r.URL.Path, tc.path) + return + } + if r.Method != tc.method { + t.Errorf("method=%s want %s", r.Method, tc.method) + return + } + for k, v := range tc.respHeaders { + w.Header().Set(k, v) + } + _, _ = io.WriteString(w, tc.respBody) + })) + defer srv.Close() + + // Configure baseURL WITH a trailing slash on purpose. + n := NewNovitaModel( + map[string]string{"default": srv.URL + "/"}, + tc.urlSuffix, + ) + apiKey := "test-key" + if err := tc.invoke(n, apiKey); err != nil { + t.Fatalf("%s: %v", tc.name, err) + } + }) + } +} diff --git a/internal/entity/models/nvidia.go b/internal/entity/models/nvidia.go index 4fd6a9b3206..998196253c5 100644 --- a/internal/entity/models/nvidia.go +++ b/internal/entity/models/nvidia.go @@ -3,6 +3,7 @@ package models import ( "bufio" "bytes" + "context" "encoding/json" "fmt" "io" @@ -329,12 +330,254 @@ func (n *NvidiaModel) ChatStreamlyWithSender(modelName string, messages []Messag return scanner.Err() } -func (n NvidiaModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("no such method") +type nvidiaEmbeddingResponse struct { + Data []struct { + Index int `json:"index"` + Embedding []float64 `json:"embedding"` + } `json:"data"` } +func (n NvidiaModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := n.BaseURL[region] + if baseURL == "" { + baseURL = n.BaseURL["default"] + } + if baseURL == "" { + return nil, fmt.Errorf("nvidia: no base URL configured for region %q", region) + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), n.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + "input_type": "query", + "encoding_format": "float", + "truncate": "END", + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := n.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Nvidia embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed nvidiaEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) + } + + return embeddings, nil +} + +// nvidiaRerankRequest mirrors the NIM /ranking request shape: +// query is an object with a "text" field, passages is an array of +// objects each with a "text" field. truncate=END matches the Python +// NvidiaRerank reference at rag/llm/rerank_model.py. +type nvidiaRerankRequest struct { + Model string `json:"model"` + Query nvidiaRerankText `json:"query"` + Passages []nvidiaRerankText `json:"passages"` + Truncate string `json:"truncate,omitempty"` + TopN int `json:"top_n"` +} + +type nvidiaRerankText struct { + Text string `json:"text"` +} + +// nvidiaRerankResponse maps the NIM rankings array. Each entry pairs +// the original passage index with a logit score; the caller uses the +// index to restore original input order. +type nvidiaRerankResponse struct { + Rankings []struct { + Index int `json:"index"` + Logit float64 `json:"logit"` + } `json:"rankings"` +} + +// Rerank scores documents against the query using an NVIDIA NIM +// reranking model. Mirrors the Python NvidiaRerank class in +// rag/llm/rerank_model.py for payload shape (passages/query/logit). +// Defaults top_n to len(documents) so the API returns a score per +// input; callers may shrink it via RerankConfig.TopN, in which case +// only the top RerankConfig.TopN entries come back. Returned +// RerankResult entries are in the API's ranking order; callers that +// need original-input order should sort by Index. Same return-shape +// contract as the Aliyun and ZhipuAI Rerank drivers. func (n NvidiaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("no such method") + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := n.BaseURL[region] + if baseURL == "" { + baseURL = n.BaseURL["default"] + } + if baseURL == "" { + return nil, fmt.Errorf("nvidia: no base URL configured for region %q", region) + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), n.URLSuffix.Rerank) + + topN := len(documents) + if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN { + topN = rerankConfig.TopN + } + + passages := make([]nvidiaRerankText, len(documents)) + for i, doc := range documents { + passages[i] = nvidiaRerankText{Text: doc} + } + + reqBody := nvidiaRerankRequest{ + Model: *modelName, + Query: nvidiaRerankText{Text: query}, + Passages: passages, + Truncate: "END", + TopN: topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := n.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Nvidia rerank API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed nvidiaRerankResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Rankings))} + for _, r := range parsed.Rankings { + if r.Index < 0 || r.Index >= len(documents) { + return nil, fmt.Errorf("unexpected rerank index %d for %d inputs", r.Index, len(documents)) + } + rerankResponse.Data = append(rerankResponse.Data, RerankResult{ + Index: r.Index, + RelevanceScore: r.Logit, + }) + } + + return &rerankResponse, nil +} + +// TranscribeAudio transcribe audio +func (n *NvidiaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", n.Name()) +} + +func (z *NvidiaModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (n *NvidiaModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", n.Name()) +} + +func (z *NvidiaModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *NvidiaModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *NvidiaModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) } // ListModels calls /v1/models on the configured NVIDIA NIM base URL @@ -419,3 +662,11 @@ func (n NvidiaModel) CheckConnection(apiConfig *APIConfig) error { _, err := n.ListModels(apiConfig) return err } + +func (z *NvidiaModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *NvidiaModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/nvidia_rerank_test.go b/internal/entity/models/nvidia_rerank_test.go new file mode 100644 index 00000000000..c92249bfbb6 --- /dev/null +++ b/internal/entity/models/nvidia_rerank_test.go @@ -0,0 +1,195 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newNvidiaRerankServer(t *testing.T, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + // Use t.Errorf + return inside the handler goroutine; t.Fatalf would + // only Goexit the handler goroutine and the test would silently pass. + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + return + } + if r.URL.Path != "/ranking" { + t.Errorf("expected path=/ranking, got %s", r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("expected Content-Type=application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("invalid JSON body: %v\n%s", err, string(raw)) + return + } + handler(t, body, w) + })) +} + +func newNvidiaModelForTest(baseURL string) *NvidiaModel { + return NewNvidiaModel( + map[string]string{"default": baseURL}, + URLSuffix{Rerank: "ranking"}, + ) +} + +func TestNvidiaRerankHappyPath(t *testing.T) { + srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "nvidia/nv-rerankqa-mistral-4b-v3" { + t.Errorf("expected model=nvidia/nv-rerankqa-mistral-4b-v3, got %v", body["model"]) + } + query, ok := body["query"].(map[string]interface{}) + if !ok || query["text"] != "What is RAPTOR?" { + t.Errorf("expected query.text=What is RAPTOR?, got %v", body["query"]) + } + passages, ok := body["passages"].([]interface{}) + if !ok || len(passages) != 3 { + t.Errorf("expected 3 passages, got %v", body["passages"]) + return + } + if body["truncate"] != "END" { + t.Errorf("expected truncate=END, got %v", body["truncate"]) + } + if body["top_n"] != float64(3) { + t.Errorf("expected top_n=3 (matching len(documents)), got %v", body["top_n"]) + } + // Return rankings out of input order to verify Index preservation. + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "rankings": []map[string]interface{}{ + {"index": 2, "logit": 9.5}, + {"index": 0, "logit": 4.25}, + {"index": 1, "logit": 7.8}, + }, + }) + }) + defer srv.Close() + + model := newNvidiaModelForTest(srv.URL) + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + resp, err := model.Rerank( + &modelName, + "What is RAPTOR?", + []string{"doc-zero", "doc-one", "doc-two"}, + &APIConfig{ApiKey: &apiKey}, + &RerankConfig{}, + ) + if err != nil { + t.Fatalf("Rerank failed: %v", err) + } + if len(resp.Data) != 3 { + t.Fatalf("expected 3 results, got %d", len(resp.Data)) + } + want := map[int]float64{0: 4.25, 1: 7.8, 2: 9.5} + for _, r := range resp.Data { + if got, ok := want[r.Index]; !ok || got != r.RelevanceScore { + t.Errorf("unexpected result Index=%d RelevanceScore=%v", r.Index, r.RelevanceScore) + } + } +} + +func TestNvidiaRerankTopNClamp(t *testing.T) { + srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["top_n"] != float64(2) { + t.Errorf("expected top_n clamp to RerankConfig.TopN=2, got %v", body["top_n"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"rankings": []map[string]interface{}{}}) + }) + defer srv.Close() + + model := newNvidiaModelForTest(srv.URL) + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + if _, err := model.Rerank( + &modelName, "q", + []string{"a", "b", "c", "d"}, + &APIConfig{ApiKey: &apiKey}, + &RerankConfig{TopN: 2}, + ); err != nil { + t.Fatalf("Rerank failed: %v", err) + } +} + +func TestNvidiaRerankEmptyDocuments(t *testing.T) { + model := newNvidiaModelForTest("http://unused") + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + resp, err := model.Rerank(&modelName, "q", nil, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err != nil { + t.Fatalf("expected nil error for empty documents, got %v", err) + } + if len(resp.Data) != 0 { + t.Errorf("expected empty Data, got %d entries", len(resp.Data)) + } +} + +func TestNvidiaRerankRequiresAPIKey(t *testing.T) { + model := newNvidiaModelForTest("http://unused") + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + _, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestNvidiaRerankRequiresModelName(t *testing.T) { + model := newNvidiaModelForTest("http://unused") + apiKey := "test-key" + _, err := model.Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } +} + +func TestNvidiaRerankRejectsHTTPError(t *testing.T) { + srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"error":"unauthorized"}`)) + }) + defer srv.Close() + + model := newNvidiaModelForTest(srv.URL) + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + _, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "Nvidia rerank API error") { + t.Errorf("expected API error, got %v", err) + } +} + +func TestNvidiaRerankRejectsOutOfRangeIndex(t *testing.T) { + srv := newNvidiaRerankServer(t, func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "rankings": []map[string]interface{}{ + {"index": 5, "logit": 1.0}, // out of range for 2-input request + }, + }) + }) + defer srv.Close() + + model := newNvidiaModelForTest(srv.URL) + apiKey := "test-key" + modelName := "nvidia/nv-rerankqa-mistral-4b-v3" + _, err := model.Rerank(&modelName, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "unexpected rerank index") { + t.Errorf("expected out-of-range error, got %v", err) + } +} diff --git a/internal/entity/models/ollama.go b/internal/entity/models/ollama.go index 4e8e42ad0de..d95e9e8c734 100644 --- a/internal/entity/models/ollama.go +++ b/internal/entity/models/ollama.go @@ -3,6 +3,7 @@ package models import ( "bufio" "bytes" + "context" "encoding/json" "fmt" "io" @@ -359,14 +360,119 @@ func (o *OllamaModel) ChatStreamlyWithSender(modelName string, messages []Messag return scanner.Err() } -func (o *OllamaModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("no such method") +func (o *OllamaModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := o.BaseURL[region] + if baseURL == "" { + baseURL = o.BaseURL["default"] + } + if baseURL == "" { + return nil, fmt.Errorf("missing base URL: please configure the local access address for Ollama (e.g., http://127.0.0.1:11434/v1)") + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), o.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Ollama embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed openaiEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) + } + + return embeddings, nil } func (o *OllamaModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("no such method") } +// TranscribeAudio transcribe audio +func (o *OllamaModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OllamaModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (o *OllamaModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OllamaModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *OllamaModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *OllamaModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + func (o *OllamaModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" @@ -445,3 +551,11 @@ func (o *OllamaModel) CheckConnection(apiConfig *APIConfig) error { _, err := o.ListModels(apiConfig) return err } + +func (z *OllamaModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *OllamaModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/openai.go b/internal/entity/models/openai.go index 1adbb35cbc0..c19751231e3 100644 --- a/internal/entity/models/openai.go +++ b/internal/entity/models/openai.go @@ -403,12 +403,105 @@ func (z *OpenAIModel) ChatStreamlyWithSender(modelName string, messages []Messag return nil } -// Encode encodes a list of texts into embeddings. OpenAI does expose -// embedding endpoints (text-embedding-3-* and text-embedding-ada-002), -// but this initial driver intentionally leaves embedding support -// unimplemented. A follow-up PR can add it. -func (z *OpenAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("not implemented") +type openaiEmbeddingResponse struct { + Data []openrouterEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` + Usage openrouterUsage `json:"usage"` +} + +type openaiEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type openaiUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Embed turns a list of texts into embedding vectors using the +// OpenAI /v1/embeddings endpoint (e.g. text-embedding-3-small, +// text-embedding-3-large, text-embedding-ada-002). The output has +// one vector per input, in the same order the inputs were given. +func (z *OpenAIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := z.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, z.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("OpenAI embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed openaiEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) + } + + return embeddings, nil } // ListModels returns the list of model ids visible to the API key. @@ -500,3 +593,39 @@ func (z *OpenAIModel) CheckConnection(apiConfig *APIConfig) error { func (z *OpenAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (o *OpenAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OpenAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (o *OpenAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OpenAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *OpenAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *OpenAIModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *OpenAIModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *OpenAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/openrouter.go b/internal/entity/models/openrouter.go index 505af9ee6ac..c42493d0109 100644 --- a/internal/entity/models/openrouter.go +++ b/internal/entity/models/openrouter.go @@ -109,8 +109,10 @@ func (o *OpenRouterModel) ChatWithMessages(modelName string, messages []Message, reqBody["do_sample"] = *chatModelConfig.DoSample } - reqBody["reasoning"] = map[string]interface{}{ - "effort": "low", + if chatModelConfig.Effort != nil { + reqBody["reasoning"] = map[string]interface{}{ + "effort": chatModelConfig.Effort, + } } } @@ -351,9 +353,30 @@ func (o *OpenRouterModel) ChatStreamlyWithSender(modelName string, messages []Me return scanner.Err() } -func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +type openrouterEmbeddingResponse struct { + Data []openrouterEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` + Usage openrouterUsage `json:"usage"` +} + +type openrouterEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type openrouterUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + +func (o *OpenRouterModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { if len(texts) == 0 { - return [][]float64{}, nil + return []EmbeddingData{}, nil + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") } var region = "default" @@ -368,6 +391,10 @@ func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *A "input": texts, } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + jsonData, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("failed to marshal request: %w", err) @@ -398,52 +425,17 @@ func (o *OpenRouterModel) Encode(modelName *string, texts []string, apiConfig *A return nil, fmt.Errorf("OpenRouter embedding API error: status %d, body: %s", resp.StatusCode, string(body)) } - var result map[string]interface{} - if err = json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to decode response: %w", err) - } - - dataObj, ok := result["data"].([]interface{}) - if !ok || len(dataObj) == 0 { - return nil, fmt.Errorf("OpenRouter embedding response contains no data: %s", string(body)) + var parsed openrouterEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) } - embeddings := make([][]float64, len(texts)) - - for _, item := range dataObj { - dataMap, ok := item.(map[string]interface{}) - if !ok { - continue - } - - indexFloat, ok := dataMap["index"].(float64) - if !ok { - continue - } - index := int(indexFloat) - - if index < 0 || index >= len(texts) { - continue - } - - embeddingSlice, ok := dataMap["embedding"].([]interface{}) - if !ok { - continue - } - - embedding := make([]float64, len(embeddingSlice)) - for j, v := range embeddingSlice { - switch val := v.(type) { - case float64: - embedding[j] = val - case float32: - embedding[j] = float64(val) - default: - return nil, fmt.Errorf("unexpected embedding value type") - } - } - - embeddings[index] = embedding + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) } return embeddings, nil @@ -539,6 +531,91 @@ func (o *OpenRouterModel) Rerank(modelName *string, query string, documents []st return &rerankResponse, nil } +// TranscribeAudio transcribe audio +func (o *OpenRouterModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *OpenRouterModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (o *OpenRouterModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("OpenRouter API key is missing") + } + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("text content is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.TTS) + + // OpenRouter:response Audio bytes stream + reqBody := map[string]interface{}{ + "model": modelName, + "input": audioContent, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["response_format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("OpenRouter API error: %s, body: %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil +} + +func (z *OpenRouterModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *OpenRouterModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *OpenRouterModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + func (o *OpenRouterModel) ListModels(apiConfig *APIConfig) ([]string, error) { var region = "default" if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { @@ -575,7 +652,7 @@ func (o *OpenRouterModel) ListModels(apiConfig *APIConfig) ([]string, error) { } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API requestssss failed with status %d: %s : %s", resp.StatusCode, string(body), url) + return nil, fmt.Errorf("API request failed with status %d: %s : %s", resp.StatusCode, string(body)) } // Parse response @@ -651,3 +728,11 @@ func (o *OpenRouterModel) CheckConnection(apiConfig *APIConfig) error { _, err := o.Balance(apiConfig) return err } + +func (z *OpenRouterModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *OpenRouterModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/paddleocr.go b/internal/entity/models/paddleocr.go new file mode 100644 index 00000000000..25445ca38da --- /dev/null +++ b/internal/entity/models/paddleocr.go @@ -0,0 +1,300 @@ +package models + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "mime/multipart" + "net/http" + "strings" + "time" +) + +type PaddleOCRModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewPaddleOCRModel(baseURL map[string]string, urlSuffix URLSuffix) *PaddleOCRModel { + return &PaddleOCRModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + }, + } +} + +func (p PaddleOCRModel) NewInstance(baseURL map[string]string) ModelDriver { + return &PaddleOCRModel{ + BaseURL: baseURL, + URLSuffix: p.URLSuffix, + httpClient: &http.Client{ + Timeout: 120 * time.Second, + Transport: &http.Transport{ + MaxIdleConns: 100, + MaxIdleConnsPerHost: 10, + IdleConnTimeout: 90 * time.Second, + DisableCompression: false, + }, + }, + } +} + +func (p *PaddleOCRModel) Name() string { + return "paddle_ocr" +} + +func (p *PaddleOCRModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("no such method", p.Name()) +} + +type paddleSubmitResponse struct { + Data struct { + JobId string `json:"jobId"` + } `json:"data"` +} + +type paddlePollResponse struct { + Data struct { + State string `json:"state"` + ErrorMsg string `json:"errorMsg"` + ResultUrl struct { + JsonUrl string `json:"jsonUrl"` + } `json:"resultUrl"` + } `json:"data"` +} + +type paddleJsonlLine struct { + Result struct { + LayoutParsingResults []struct { + Markdown struct { + Text string `json:"text"` + } `json:"markdown"` + } `json:"layoutParsingResults"` + } `json:"result"` +} + +func (p *PaddleOCRModel) OCRFile(modelName *string, content []byte, fileURL *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + if (content == nil || len(content) == 0) && (fileURL == nil || *fileURL == "") { + return nil, fmt.Errorf("content and fileURL cannot be both empty") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", p.BaseURL[region], p.URLSuffix.OCR) + + optionalPayload := map[string]bool{ + "useDocOrientationClassify": false, + "useDocUnwarping": false, + "useChartRecognition": false, + } + optBytes, _ := json.Marshal(optionalPayload) + + var req *http.Request + var err error + + if fileURL != nil && strings.HasPrefix(*fileURL, "http") { + reqData := map[string]interface{}{ + "fileUrl": *fileURL, + "model": *modelName, + "optionalPayload": optionalPayload, + } + jsonData, err := json.Marshal(reqData) + if err != nil { + return nil, fmt.Errorf("failed to marshal json: %w", err) + } + req, err = http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + req.Header.Set("Content-Type", "application/json") + } else { + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + + _ = writer.WriteField("model", *modelName) + _ = writer.WriteField("optionalPayload", string(optBytes)) + + part, err := writer.CreateFormFile("file", "document.pdf") + if err != nil { + return nil, fmt.Errorf("failed to create form file: %w", err) + } + part.Write(content) + writer.Close() + + req, err = http.NewRequest("POST", url, body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + } + + req.Header.Set("Authorization", fmt.Sprintf("bearer %s", *apiConfig.ApiKey)) + + resp, err := p.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to submit job: %w", err) + } + defer resp.Body.Close() + + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("submit job failed: %s", string(respBody)) + } + + var submitResp paddleSubmitResponse + if err := json.Unmarshal(respBody, &submitResp); err != nil { + return nil, fmt.Errorf("failed to parse submit response: %w", err) + } + + jobId := submitResp.Data.JobId + if jobId == "" { + return nil, fmt.Errorf("failed to get jobId from response") + } + + pollUrl := fmt.Sprintf("%s/%s", url, jobId) + var jsonlUrl string + + for { + time.Sleep(3 * time.Second) + + pollReq, _ := http.NewRequest("GET", pollUrl, nil) + pollReq.Header.Set("Authorization", fmt.Sprintf("bearer %s", *apiConfig.ApiKey)) + + pollResp, err := p.httpClient.Do(pollReq) + if err != nil { + return nil, fmt.Errorf("failed to poll job status: %w", err) + } + + pollBody, _ := io.ReadAll(pollResp.Body) + pollResp.Body.Close() + + if pollResp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("poll job failed: %s", string(pollBody)) + } + + var pollData paddlePollResponse + if err = json.Unmarshal(pollBody, &pollData); err != nil { + return nil, fmt.Errorf("failed to parse poll response: %w", err) + } + + // end if 'done' or 'failed' + state := pollData.Data.State + if state == "done" { + jsonlUrl = pollData.Data.ResultUrl.JsonUrl + break + } else if state == "failed" { + return nil, fmt.Errorf("ocr job failed on server: %s", pollData.Data.ErrorMsg) + } + } + + if jsonlUrl == "" { + return nil, fmt.Errorf("job done but jsonl url is empty") + } + + resReq, err := http.NewRequest("GET", jsonlUrl, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request for jsonl: %w", err) + } + + resResp, err := p.httpClient.Do(resReq) + if err != nil { + return nil, fmt.Errorf("failed to download jsonl result: %w", err) + } + defer resResp.Body.Close() + + if resResp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("failed to download jsonl, status: %d", resResp.StatusCode) + } + + var fullMarkdown strings.Builder + scanner := bufio.NewScanner(resResp.Body) + + for scanner.Scan() { + line := strings.TrimSpace(scanner.Text()) + if line == "" { + continue + } + + var lineData paddleJsonlLine + if err := json.Unmarshal([]byte(line), &lineData); err != nil { + continue + } + + for _, layoutRes := range lineData.Result.LayoutParsingResults { + fullMarkdown.WriteString(layoutRes.Markdown.Text) + fullMarkdown.WriteString("\n\n") + } + } + + if err = scanner.Err(); err != nil { + return nil, fmt.Errorf("error reading jsonl: %w", err) + } + + extractedText := strings.TrimSpace(fullMarkdown.String()) + + return &OCRFileResponse{Text: &extractedText}, nil +} + +func (p *PaddleOCRModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) CheckConnection(apiConfig *APIConfig) error { + return fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} + +func (p *PaddleOCRModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("no such method", p.Name()) +} diff --git a/internal/entity/models/replicate.go b/internal/entity/models/replicate.go new file mode 100644 index 00000000000..0757b832507 --- /dev/null +++ b/internal/entity/models/replicate.go @@ -0,0 +1,611 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const replicatePollInterval = time.Second + +type ReplicateModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewReplicateModel(baseURL map[string]string, urlSuffix URLSuffix) *ReplicateModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &ReplicateModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (r *ReplicateModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewReplicateModel(baseURL, r.URLSuffix) +} + +func (r *ReplicateModel) Name() string { + return "replicate" +} + +type replicatePredictionURLs struct { + Get string `json:"get"` + Stream string `json:"stream"` +} + +type replicatePrediction struct { + ID string `json:"id"` + Status string `json:"status"` + Output interface{} `json:"output"` + Error interface{} `json:"error"` + URLs replicatePredictionURLs `json:"urls"` +} + +type replicateModelsResponse struct { + Results []struct { + Owner string `json:"owner"` + Name string `json:"name"` + } `json:"results"` +} + +type replicateSSEEvent struct { + event string + data string +} + +func (r *ReplicateModel) baseURLForRegion(region string) (string, error) { + base, ok := r.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("replicate: no base URL configured for region %q", region) + } + return strings.TrimSuffix(base, "/"), nil +} + +func (r *ReplicateModel) endpoint(apiConfig *APIConfig, suffix string) (string, error) { + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := r.baseURLForRegion(region) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", baseURL, suffix), nil +} + +func replicateUsesVersionEndpoint(modelName string) bool { + name := strings.TrimSpace(modelName) + return !strings.Contains(name, "/") || strings.Contains(name, ":") +} + +func (r *ReplicateModel) predictionEndpoint(apiConfig *APIConfig, modelName string) (string, string, error) { + if replicateUsesVersionEndpoint(modelName) { + endpoint, err := r.endpoint(apiConfig, r.URLSuffix.Chat) + return endpoint, modelName, err + } + + parts := strings.Split(modelName, "/") + if len(parts) != 2 || parts[0] == "" || parts[1] == "" { + return "", "", fmt.Errorf("replicate: official model name must be owner/name") + } + + modelsPrefix := strings.TrimSuffix(r.URLSuffix.Models, "models") + if modelsPrefix == "" { + modelsPrefix = "v1/" + } + officialSuffix := fmt.Sprintf("%smodels/%s/%s/predictions", + modelsPrefix, + url.PathEscape(parts[0]), + url.PathEscape(parts[1]), + ) + endpoint, err := r.endpoint(apiConfig, officialSuffix) + return endpoint, "", err +} + +func replicateMessageContent(content interface{}) string { + switch v := content.(type) { + case string: + return v + default: + b, err := json.Marshal(v) + if err != nil { + return fmt.Sprint(v) + } + return string(b) + } +} + +func replicatePromptFromMessages(messages []Message) (string, string) { + var systemParts []string + var promptParts []string + nonSystemCount := 0 + for _, msg := range messages { + content := replicateMessageContent(msg.Content) + if msg.Role == "system" { + systemParts = append(systemParts, content) + continue + } + nonSystemCount++ + if nonSystemCount == 1 && msg.Role == "user" && len(messages) == len(systemParts)+1 { + promptParts = append(promptParts, content) + continue + } + promptParts = append(promptParts, fmt.Sprintf("%s: %s", msg.Role, content)) + } + return strings.Join(promptParts, "\n"), strings.Join(systemParts, "\n\n") +} + +func replicateInputFromMessages(messages []Message, chatModelConfig *ChatConfig) map[string]interface{} { + prompt, systemPrompt := replicatePromptFromMessages(messages) + input := map[string]interface{}{ + "prompt": prompt, + } + if systemPrompt != "" { + input["system_prompt"] = systemPrompt + } + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + input["max_new_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + input["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + input["top_p"] = *chatModelConfig.TopP + } + // Replicate model inputs are model-specific. Forward only the + // common prompt-model fields above; Stop is intentionally + // omitted because upstream behavior is undefined for many + // hosted models. + } + return input +} + +func replicateOutputToString(output interface{}) (string, error) { + switch v := output.(type) { + case nil: + return "", nil + case string: + return v, nil + case []interface{}: + var b strings.Builder + for _, item := range v { + text, err := replicateOutputToString(item) + if err != nil { + return "", err + } + b.WriteString(text) + } + return b.String(), nil + case map[string]interface{}: + raw, err := json.Marshal(v) + if err != nil { + return "", err + } + return string(raw), nil + default: + return fmt.Sprint(v), nil + } +} + +func (r *ReplicateModel) createPrediction(ctx context.Context, url string, version string, input map[string]interface{}, stream bool, apiKey string, preferWait bool) (*replicatePrediction, error) { + body := map[string]interface{}{ + "input": input, + "stream": stream, + } + if version != "" { + body["version"] = version + } + + jsonData, err := json.Marshal(body) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + if preferWait { + req.Header.Set("Prefer", "wait=60") + } + + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) + } + + var prediction replicatePrediction + if err = json.Unmarshal(bodyBytes, &prediction); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if prediction.Error != nil { + return nil, fmt.Errorf("replicate: upstream error: %v", prediction.Error) + } + return &prediction, nil +} + +func replicatePredictionDone(status string) bool { + return replicatePredictionSucceeded(status) || status == "failed" || status == "canceled" +} + +func replicatePredictionSucceeded(status string) bool { + return status == "successful" +} + +func (r *ReplicateModel) getPrediction(ctx context.Context, url string, apiKey string) (*replicatePrediction, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var prediction replicatePrediction + if err = json.Unmarshal(body, &prediction); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if prediction.Error != nil { + return nil, fmt.Errorf("replicate: upstream error: %v", prediction.Error) + } + return &prediction, nil +} + +func (r *ReplicateModel) waitForPrediction(ctx context.Context, prediction *replicatePrediction, apiKey string) (*replicatePrediction, error) { + if prediction == nil { + return nil, fmt.Errorf("replicate: empty prediction response") + } + if replicatePredictionDone(prediction.Status) { + return prediction, nil + } + if prediction.URLs.Get == "" { + return nil, fmt.Errorf("replicate: prediction is %q and no polling URL was returned", prediction.Status) + } + + ticker := time.NewTicker(replicatePollInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return nil, fmt.Errorf("replicate: prediction did not finish before timeout: %w", ctx.Err()) + case <-ticker.C: + next, err := r.getPrediction(ctx, prediction.URLs.Get, apiKey) + if err != nil { + return nil, err + } + if replicatePredictionDone(next.Status) { + return next, nil + } + } + } +} + +func (r *ReplicateModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if strings.TrimSpace(modelName) == "" { + return nil, fmt.Errorf("model name is required") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + url, version, err := r.predictionEndpoint(apiConfig, modelName) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + prediction, err := r.createPrediction(ctx, url, version, replicateInputFromMessages(messages, chatModelConfig), false, *apiConfig.ApiKey, true) + if err != nil { + return nil, err + } + prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey) + if err != nil { + return nil, err + } + if !replicatePredictionSucceeded(prediction.Status) { + return nil, fmt.Errorf("replicate: prediction ended with status %q", prediction.Status) + } + + answer, err := replicateOutputToString(prediction.Output) + if err != nil { + return nil, fmt.Errorf("failed to parse prediction output: %w", err) + } + reasonContent := "" + return &ChatResponse{Answer: &answer, ReasonContent: &reasonContent}, nil +} + +func (r *ReplicateModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + if strings.TrimSpace(modelName) == "" { + return fmt.Errorf("model name is required") + } + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + if chatModelConfig != nil && chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + url, version, err := r.predictionEndpoint(apiConfig, modelName) + if err != nil { + return err + } + + prediction, err := r.createPrediction(context.Background(), url, version, replicateInputFromMessages(messages, chatModelConfig), true, *apiConfig.ApiKey, false) + if err != nil { + return err + } + if prediction.URLs.Stream == "" { + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + prediction, err = r.waitForPrediction(ctx, prediction, *apiConfig.ApiKey) + if err != nil { + return err + } + answer, err := replicateOutputToString(prediction.Output) + if err != nil { + return fmt.Errorf("failed to parse prediction output: %w", err) + } + if answer != "" { + if err := sender(&answer, nil); err != nil { + return err + } + } + endOfStream := "[DONE]" + return sender(&endOfStream, nil) + } + + return r.readPredictionStream(prediction.URLs.Stream, *apiConfig.ApiKey, sender) +} + +func (r *ReplicateModel) readPredictionStream(url string, apiKey string, sender func(*string, *string) error) error { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, url, nil) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey)) + req.Header.Set("Accept", "text/event-stream") + + resp, err := r.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + current := replicateSSEEvent{} + sawDone := false + for scanner.Scan() { + line := scanner.Text() + if line == "" { + done, err := dispatchReplicateSSEEvent(current, sender) + if err != nil { + return err + } + if done { + sawDone = true + break + } + current = replicateSSEEvent{} + continue + } + if strings.HasPrefix(line, "event:") { + current.event = strings.TrimSpace(line[6:]) + } + if strings.HasPrefix(line, "data:") { + if current.data != "" { + current.data += "\n" + } + data := line[5:] + if strings.HasPrefix(data, " ") { + data = data[1:] + } + current.data += data + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawDone && (current.event != "" || current.data != "") { + done, err := dispatchReplicateSSEEvent(current, sender) + if err != nil { + return err + } + sawDone = done + } + if !sawDone { + return fmt.Errorf("replicate: stream ended before done event") + } + + endOfStream := "[DONE]" + return sender(&endOfStream, nil) +} + +func dispatchReplicateSSEEvent(event replicateSSEEvent, sender func(*string, *string) error) (bool, error) { + switch event.event { + case "output", "": + if event.data == "" { + return false, nil + } + return false, sender(&event.data, nil) + case "error": + return false, fmt.Errorf("replicate: upstream stream error: %s", event.data) + case "done": + return true, nil + default: + return false, nil + } +} + +func (r *ReplicateModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + url, err := r.endpoint(apiConfig, r.URLSuffix.Models) + if err != nil { + return nil, err + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := r.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result replicateModelsResponse + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models := make([]string, 0, len(result.Results)) + for _, model := range result.Results { + if model.Owner != "" && model.Name != "" { + models = append(models, fmt.Sprintf("%s/%s", model.Owner, model.Name)) + } + } + return models, nil +} + +func (r *ReplicateModel) CheckConnection(apiConfig *APIConfig) error { + _, err := r.ListModels(apiConfig) + return err +} + +func (r *ReplicateModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} + +func (r *ReplicateModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", r.Name()) +} diff --git a/internal/entity/models/replicate_test.go b/internal/entity/models/replicate_test.go new file mode 100644 index 00000000000..d9eb1efd6ba --- /dev/null +++ b/internal/entity/models/replicate_test.go @@ -0,0 +1,321 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newReplicateForTest(baseURL string) *ReplicateModel { + return NewReplicateModel( + map[string]string{"default": baseURL}, + URLSuffix{Chat: "v1/predictions", Models: "v1/models"}, + ) +} + +func TestReplicateName(t *testing.T) { + if got := newReplicateForTest("http://unused").Name(); got != "replicate" { + t.Errorf("Name()=%q", got) + } +} + +func TestReplicateFactory(t *testing.T) { + driver, err := NewModelFactory().CreateModelDriver("Replicate", map[string]string{"default": "http://unused"}, URLSuffix{}) + if err != nil { + t.Fatalf("CreateModelDriver: %v", err) + } + if _, ok := driver.(*ReplicateModel); !ok { + t.Fatalf("driver type=%T, want *ReplicateModel", driver) + } +} + +func TestReplicatePromptFromMessages(t *testing.T) { + prompt, system := replicatePromptFromMessages([]Message{ + {Role: "system", Content: "be terse"}, + {Role: "user", Content: "hello"}, + {Role: "assistant", Content: "hi"}, + {Role: "user", Content: map[string]interface{}{"text": "again"}}, + }) + if system != "be terse" { + t.Errorf("system=%q", system) + } + want := "user: hello\nassistant: hi\nuser: {\"text\":\"again\"}" + if prompt != want { + t.Errorf("prompt=%q want %q", prompt, want) + } +} + +func TestReplicatePredictionEndpoint(t *testing.T) { + m := newReplicateForTest("https://api.example.test") + + endpoint, version, err := m.predictionEndpoint(&APIConfig{}, "meta/meta-llama-3-70b-instruct") + if err != nil { + t.Fatalf("official endpoint: %v", err) + } + if endpoint != "https://api.example.test/v1/models/meta/meta-llama-3-70b-instruct/predictions" { + t.Errorf("official endpoint=%q", endpoint) + } + if version != "" { + t.Errorf("official version=%q want empty", version) + } + + endpoint, version, err = m.predictionEndpoint(&APIConfig{}, "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa") + if err != nil { + t.Fatalf("version endpoint: %v", err) + } + if endpoint != "https://api.example.test/v1/predictions" { + t.Errorf("version endpoint=%q", endpoint) + } + if version != "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" { + t.Errorf("version=%q", version) + } +} + +func TestReplicateOfficialChatHappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models/meta/meta-llama-3-70b-instruct/predictions" { + t.Errorf("path=%s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q", got) + } + if got := r.Header.Get("Prefer"); got != "wait=60" { + t.Errorf("Prefer=%q", got) + } + raw, _ := io.ReadAll(r.Body) + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("body: %v", err) + return + } + if body["version"] != nil { + t.Errorf("official model requests must not send version=%v", body["version"]) + } + if body["stream"] != false { + t.Errorf("stream=%v", body["stream"]) + } + input := body["input"].(map[string]interface{}) + if input["prompt"] != "hello" { + t.Errorf("prompt=%v", input["prompt"]) + } + if input["system_prompt"] != "be helpful" { + t.Errorf("system_prompt=%v", input["system_prompt"]) + } + if input["max_new_tokens"] != float64(128) { + t.Errorf("max_new_tokens=%v", input["max_new_tokens"]) + } + // Stop is deliberately filtered out because Replicate model + // inputs are model-specific and upstream support is undefined. + if input["stop"] != nil { + t.Errorf("unexpected stop=%v", input["stop"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "successful", + "output": []string{"hel", "lo"}, + }) + })) + defer srv.Close() + + apiKey := "test-key" + maxTokens := 128 + stop := []string{"END"} + resp, err := newReplicateForTest(srv.URL).ChatWithMessages( + "meta/meta-llama-3-70b-instruct", + []Message{{Role: "system", Content: "be helpful"}, {Role: "user", Content: "hello"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &maxTokens, Stop: &stop}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if *resp.Answer != "hello" { + t.Errorf("Answer=%q", *resp.Answer) + } + if *resp.ReasonContent != "" { + t.Errorf("ReasonContent=%q", *resp.ReasonContent) + } +} + +func TestReplicateCommunityChatUsesVersionEndpoint(t *testing.T) { + const version = "replicate/hello-world:5c7d5dc6dd8bf75c1acaa8565735e7986bc5b66206b55cca93cb72c9bf15ccaa" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/predictions" { + t.Errorf("path=%s", r.URL.Path) + } + raw, _ := io.ReadAll(r.Body) + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("body: %v", err) + return + } + if body["version"] != version { + t.Errorf("version=%v", body["version"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "successful", + "output": "ok", + }) + })) + defer srv.Close() + + apiKey := "test-key" + resp, err := newReplicateForTest(srv.URL).ChatWithMessages( + version, + []Message{{Role: "user", Content: "hello"}}, + &APIConfig{ApiKey: &apiKey}, nil, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if *resp.Answer != "ok" { + t.Errorf("Answer=%q", *resp.Answer) + } +} + +func TestReplicateChatPollsUntilSucceeded(t *testing.T) { + var getCount int + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q", got) + } + switch r.URL.Path { + case "/v1/models/meta/meta-llama-3-70b-instruct/predictions": + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "processing", + "urls": map[string]string{ + "get": "http://" + r.Host + "/v1/predictions/p1", + }, + }) + case "/v1/predictions/p1": + getCount++ + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "successful", + "output": "done", + }) + default: + t.Errorf("unexpected path=%s", r.URL.Path) + } + })) + defer srv.Close() + + apiKey := "test-key" + resp, err := newReplicateForTest(srv.URL).ChatWithMessages( + "meta/meta-llama-3-70b-instruct", + []Message{{Role: "user", Content: "hello"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if getCount != 1 { + t.Errorf("getCount=%d", getCount) + } + if *resp.Answer != "done" { + t.Errorf("Answer=%q", *resp.Answer) + } +} + +func TestReplicateStreamHappyPath(t *testing.T) { + var streamURL string + streamSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Accept"); got != "text/event-stream" { + t.Errorf("Accept=%q", got) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, "event: output\n") + _, _ = io.WriteString(w, "data: Hello\n\n") + _, _ = io.WriteString(w, "event: output\n") + _, _ = io.WriteString(w, "data: world\n\n") + _, _ = io.WriteString(w, "event: done\n") + _, _ = io.WriteString(w, "data: {}\n\n") + })) + defer streamSrv.Close() + streamURL = streamSrv.URL + + apiSrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models/meta/meta-llama-3-70b-instruct/predictions" { + t.Errorf("path=%s", r.URL.Path) + } + raw, _ := io.ReadAll(r.Body) + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("body: %v", err) + return + } + if body["stream"] != true { + t.Errorf("stream=%v", body["stream"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "status": "starting", + "urls": map[string]string{ + "stream": streamURL, + }, + }) + })) + defer apiSrv.Close() + + apiKey := "test-key" + var chunks []string + err := newReplicateForTest(apiSrv.URL).ChatStreamlyWithSender( + "meta/meta-llama-3-70b-instruct", + []Message{{Role: "user", Content: "hello"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(c *string, _ *string) error { + if c != nil { + chunks = append(chunks, *c) + } + return nil + }) + if err != nil { + t.Fatalf("ChatStreamlyWithSender: %v", err) + } + if strings.Join(chunks, "") != "Hello world[DONE]" { + t.Errorf("chunks=%q", strings.Join(chunks, "")) + } +} + +func TestReplicateListModelsAndCheckConnection(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" { + t.Errorf("path=%s", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("Authorization=%q", got) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "results": []map[string]string{ + {"owner": "meta", "name": "meta-llama-3-70b-instruct"}, + {"owner": "replicate", "name": "hello-world"}, + }, + }) + })) + defer srv.Close() + + apiKey := "test-key" + model := newReplicateForTest(srv.URL) + models, err := model.ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if strings.Join(models, ",") != "meta/meta-llama-3-70b-instruct,replicate/hello-world" { + t.Errorf("models=%v", models) + } + if err := model.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil { + t.Fatalf("CheckConnection: %v", err) + } +} + +func TestReplicateUnsupportedMethods(t *testing.T) { + m := newReplicateForTest("http://unused") + if _, err := m.Embed(nil, nil, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Embed error=%v", err) + } + if _, err := m.Rerank(nil, "", nil, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank error=%v", err) + } + if _, err := m.Balance(nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Balance error=%v", err) + } +} diff --git a/internal/entity/models/siliconflow.go b/internal/entity/models/siliconflow.go index f3c658662cb..f726e3db998 100644 --- a/internal/entity/models/siliconflow.go +++ b/internal/entity/models/siliconflow.go @@ -22,7 +22,10 @@ import ( "encoding/json" "fmt" "io" + "mime/multipart" "net/http" + "os" + "path/filepath" "ragflow/internal/common" "strconv" "strings" @@ -218,7 +221,7 @@ func (z *SiliconflowModel) ChatStreamlyWithSender(modelName string, messages []M region = *apiConfig.Region } - url := fmt.Sprintf("%s/chat/completions", z.BaseURL[region]) + url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Chat) // Convert messages to API format apiMessages := make([]map[string]interface{}, len(messages)) @@ -368,10 +371,40 @@ func (z *SiliconflowModel) ChatStreamlyWithSender(modelName string, messages []M return scanner.Err() } -// Encode encodes a list of texts into embeddings -func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +type siliconflowEmbeddingResponse struct { + Object string `json:"object"` + Model string `json:"model"` + Data []siliconflowEmbeddingData `json:"data"` + Usage siliconflowUsage `json:"usage"` +} + +type siliconflowEmbeddingData struct { + Object string `json:"object"` + Embedding []float64 `json:"embedding"` + Index int `json:"index"` +} + +type siliconflowUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// siliconflowMaxBatchSize is the per-request input limit documented at +// https://docs.siliconflow.cn/en/api-reference/embeddings/create-embeddings. +const siliconflowMaxBatchSize = 32 + +// Embed embeds a list of texts into embeddings +func (s *SiliconflowModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { if len(texts) == 0 { - return [][]float64{}, nil + return []EmbeddingData{}, nil + } + if len(texts) > siliconflowMaxBatchSize { + return nil, fmt.Errorf("siliconflow supports a maximum of %d inputs per request", siliconflowMaxBatchSize) + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") } var region = "default" @@ -386,79 +419,53 @@ func (s *SiliconflowModel) Encode(modelName *string, texts []string, apiConfig * apiKey = *apiConfig.ApiKey } - embeddings := make([][]float64, len(texts)) - - for i, text := range texts { - reqBody := map[string]interface{}{ - "model": modelName, - "input": text, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } - - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - if apiKey != "" { - req.Header.Set("Authorization", "Bearer "+apiKey) - } - - resp, err := s.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } + reqBody := map[string]interface{}{ + "model": modelName, + "input": texts, + } - body, err := io.ReadAll(resp.Body) - resp.Body.Close() + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body)) - } + req.Header.Set("Content-Type", "application/json") + if apiKey != "" { + req.Header.Set("Authorization", "Bearer "+apiKey) + } - // Parse response - var result map[string]interface{} - if err = json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } - data, ok := result["data"].([]interface{}) - if !ok || len(data) == 0 { - return nil, fmt.Errorf("no data in response") - } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() - firstData, ok := data[0].(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("invalid data format") - } + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } - embeddingSlice, ok := firstData["embedding"].([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid embedding format") - } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("SILICONFLOW API error: %s, body: %s", resp.Status, string(body)) + } - embedding := make([]float64, len(embeddingSlice)) - for j, v := range embeddingSlice { - switch val := v.(type) { - case float64: - embedding[j] = val - case float32: - embedding[j] = float64(val) - default: - return nil, fmt.Errorf("unexpected embedding value type") - } - } + var parsed siliconflowEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } - embeddings[i] = embedding + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) } return embeddings, nil @@ -657,11 +664,16 @@ func (s *SiliconflowModel) Rerank(modelName *string, query string, documents []s apiKey = *apiConfig.ApiKey } + var topN = rerankConfig.TopN + if rerankConfig.TopN == 0 { + topN = len(documents) + } + reqBody := SiliconflowRerankRequest{ Model: *modelName, Query: query, Documents: documents, - TopN: rerankConfig.TopN, + TopN: topN, ReturnDocuments: false, MaxChunksPerDoc: 1024, OverlapTokens: 80, @@ -711,3 +723,275 @@ func (s *SiliconflowModel) Rerank(modelName *string, query string, documents []s } return &rerankResponse, nil } + +// TranscribeAudio transcribe audio +func (o *SiliconflowModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + if file == nil || *file == "" { + return nil, fmt.Errorf("file is missing") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.ASR) + + // multipart body + var body bytes.Buffer + writer := multipart.NewWriter(&body) + + // open audio file + audioFile, err := os.Open(*file) + if err != nil { + return nil, fmt.Errorf("failed to open audio file: %w", err) + } + defer audioFile.Close() + + // create multipart file field + part, err := writer.CreateFormFile( + "file", + filepath.Base(*file), + ) + if err != nil { + return nil, fmt.Errorf("failed to create multipart file: %w", err) + } + + // copy file content + if _, err = io.Copy(part, audioFile); err != nil { + return nil, fmt.Errorf("failed to copy audio data: %w", err) + } + + // model field + if err := writer.WriteField("model", *modelName); err != nil { + return nil, fmt.Errorf("failed to write model field: %w", err) + } + + // extra params + if asrConfig != nil && asrConfig.Params != nil { + for key, value := range asrConfig.Params { + + var val string + + switch v := value.(type) { + case string: + val = v + case bool: + val = strconv.FormatBool(v) + case int: + val = strconv.Itoa(v) + case int64: + val = strconv.FormatInt(v, 10) + case float32: + val = strconv.FormatFloat(float64(v), 'f', -1, 32) + case float64: + val = strconv.FormatFloat(v, 'f', -1, 64) + default: + val = fmt.Sprintf("%v", v) + } + + if err = writer.WriteField(key, val); err != nil { + return nil, fmt.Errorf("failed to write field %s: %w", key, err) + } + } + } + + if err = writer.Close(); err != nil { + return nil, fmt.Errorf("failed to close multipart writer: %w", err) + } + + // build request + req, err := http.NewRequest("POST", url, &body) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Content-Type", writer.FormDataContentType()) + req.Header.Set("Accept", "application/json") + + // send request + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("SiliconFlow ASR error: %s - %s", resp.Status, string(respBody)) + } + + // SiliconFlow response + var result struct { + Text string `json:"text"` + } + + if err = json.Unmarshal(respBody, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response: %w, body=%s", err, string(respBody)) + } + + return &ASRResponse{Text: result.Text}, nil +} + +func (z *SiliconflowModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (o *SiliconflowModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("audio content is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", o.BaseURL[region], o.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": *audioContent, + "stream": false, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["response_format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := o.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s - %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil +} + +func (z *SiliconflowModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("SiliconFlow API key is missing") + } + + if audioContent == nil || *audioContent == "" { + return fmt.Errorf("audio content is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": *audioContent, + "stream": true, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["response_format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := z.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("SiliconFlow stream TTS API error: %d, body: %s", resp.StatusCode, string(body)) + } + + buf := make([]byte, 32*1024) + + for { + n, err := resp.Body.Read(buf) + if n > 0 { + chunk := string(buf[:n]) + if errSend := sender(&chunk, nil); errSend != nil { + return errSend + } + } + + if err != nil { + if err == io.EOF { + break + } + return fmt.Errorf("error reading SiliconFlow binary audio stream: %w", err) + } + } + + return nil +} + +// OCRFile OCR file +func (m *SiliconflowModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *SiliconflowModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *SiliconflowModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *SiliconflowModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/stepfun.go b/internal/entity/models/stepfun.go new file mode 100644 index 00000000000..a4cce4bb9eb --- /dev/null +++ b/internal/entity/models/stepfun.go @@ -0,0 +1,661 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// StepFunModel implements ModelDriver for StepFun (阶跃星辰). +// +// StepFun exposes an OpenAI-compatible REST API at https://api.stepfun.com/v1 +// (chat completions at /chat/completions, list models at /models). The wire +// shape matches OpenAI closely enough that the chat path here is a direct +// port of the OpenAI driver. +type StepFunModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewStepFunModel creates a new StepFun model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +// +// The Client itself has no Timeout. http.Client.Timeout would also +// cap the time spent reading the response body, which would cut off +// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming +// callers wrap each request with context.WithTimeout instead. +func NewStepFunModel(baseURL map[string]string, urlSuffix URLSuffix) *StepFunModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &StepFunModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +/* + +RAGFlow(user)> tts with 'fnlp/MOSS-TTSD-v0.5@test@siliconflow' text 'He who desires but acts not, breeds pestilence.' play format 'wav' param '{"voice": "fnlp/MOSS-TTSD-v0.5:alex"}' +SUCCESS +RAGFlow(user)> stream tts with 'fnlp/MOSS-TTSD-v0.5@test@siliconflow' text 'He who desires but acts not, breeds pestilence.' play format 'wav' param '{"voice": "fnlp/MOSS-TTSD-v0.5:claire"}' +SUCCESS + +*/ + +func (s *StepFunModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewStepFunModel(baseURL, s.URLSuffix) +} + +func (s *StepFunModel) Name() string { + return "stepfun" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. This makes a misconfigured region fail +// fast with a clear message, instead of silently producing a relative +// URL that the HTTP transport then rejects. +func (s *StepFunModel) baseURLForRegion(region string) (string, error) { + base, ok := s.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("stepfun: no base URL configured for region %q", region) + } + return base, nil +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (s *StepFunModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := s.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, s.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + // Note: do NOT propagate chatModelConfig.Stream into the request body + // here. ChatWithMessages parses a single JSON response, so stream must + // always be off for this code path. + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + emptyReason := "" + return &ChatResponse{ + Answer: &content, + ReasonContent: &emptyReason, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. The StepFun SSE stream uses the same shape as OpenAI: +// "data:" lines carrying JSON events, with a final "[DONE]" line. +func (s *StepFunModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := s.baseURLForRegion(region) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, s.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + // Refuse to run if the caller explicitly asked for stream=false. + // The body of this method only knows how to read SSE, so a + // non-SSE JSON response would be parsed as if it were a stream + // and produce no chunks. Better to fail clearly. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // SSE streams are long-lived. We rely on the transport's + // ResponseHeaderTimeout to cap the connection-establishment phase + // instead of attaching a hard deadline here. + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := s.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: bump the scanner buffer from the 64KB default to 1MB + // so we never silently truncate a long data: line. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + + if data == "[DONE]" { + sawTerminal = true + break + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("stepfun: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +// Embed is left as a stub. StepFun has not advertised a public embeddings +// endpoint in the API reference linked from the umbrella issue, so any real +// implementation belongs in a follow-up only after the endpoint is verified. +func (s *StepFunModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("not implemented") +} + +// ListModels returns the list of model ids visible to the API key. +func (s *StepFunModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := s.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, s.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid models list format") + } + + models := make([]string, 0) + for _, model := range data { + modelMap, ok := model.(map[string]interface{}) + if !ok { + continue + } + modelName, ok := modelMap["id"].(string) + if !ok { + continue + } + models = append(models, modelName) + } + + return models, nil +} + +// Balance is not exposed by the StepFun API, so this returns "no such method". +func (s *StepFunModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +// CheckConnection runs a lightweight ListModels call to verify the API key. +func (s *StepFunModel) CheckConnection(apiConfig *APIConfig) error { + _, err := s.ListModels(apiConfig) + if err != nil { + return err + } + return nil +} + +// Rerank calculates similarity scores between query and documents. StepFun +// does not expose a public rerank API, so this returns "no such method". +func (s *StepFunModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method") +} + +// TranscribeAudio transcribe audio +func (s *StepFunModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", s.Name()) +} + +func (s *StepFunModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", s.Name()) +} + +// AudioSpeech convert text to audio +func (s *StepFunModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + // TODO Test it + if audioContent == nil || *audioContent == "" { + return nil, fmt.Errorf("audio content is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", s.BaseURL[region], s.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": *audioContent, + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["response_format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := s.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%s - %s", resp.Status, string(body)) + } + + return &TTSResponse{Audio: body}, nil +} + +// AudioSpeechWithSender for Streaming TTS +func (s *StepFunModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + // TODO Test it + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("StepFun API key is missing") + } + + if audioContent == nil || *audioContent == "" { + return fmt.Errorf("audio content is empty") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", s.BaseURL[region], s.URLSuffix.TTS) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": *audioContent, + "stream_format": "sse", + } + + if ttsConfig != nil && ttsConfig.Params != nil { + for key, value := range ttsConfig.Params { + reqBody[key] = value + } + } + if ttsConfig != nil && ttsConfig.Format != "" { + reqBody["response_format"] = ttsConfig.Format + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := s.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("StepFun stream TTS API error: %d - %s", resp.StatusCode, string(body)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 8*1024*1024) + + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data: ") { + continue + } + + dataStr := strings.TrimSpace(line[6:]) + // [DONE] + if dataStr == "" || dataStr == "[DONE]" { + continue + } + + // Parse + var event struct { + Type string `json:"type"` + Audio string `json:"audio"` + } + + if err := json.Unmarshal([]byte(dataStr), &event); err != nil { + continue + } + + if event.Type == "speech.audio.error" { + return fmt.Errorf("StepFun stream encountered an error during generation") + } + + // Extract the Base64 string containing the audio and decode it + if event.Type == "speech.audio.delta" && event.Audio != "" { + audioBytes, err := base64.StdEncoding.DecodeString(event.Audio) + if err == nil && len(audioBytes) > 0 { + chunk := string(audioBytes) + if errSend := sender(&chunk, nil); errSend != nil { + return errSend + } + } + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("error reading StepFun stream: %w", err) + } + + return nil +} + +// OCRFile OCR file +func (z *StepFunModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +// ParseFile parse file +func (z *StepFunModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *StepFunModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *StepFunModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/togetherai.go b/internal/entity/models/togetherai.go new file mode 100644 index 00000000000..7e9219180e9 --- /dev/null +++ b/internal/entity/models/togetherai.go @@ -0,0 +1,430 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +type TogetherAIModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewTogetherAIModel(baseURL map[string]string, urlSuffix URLSuffix) *TogetherAIModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &TogetherAIModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (t *TogetherAIModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewTogetherAIModel(baseURL, t.URLSuffix) +} + +func (t *TogetherAIModel) Name() string { + return "togetherai" +} + +func (t *TogetherAIModel) baseURLForRegion(region string) (string, error) { + base, ok := t.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("togetherai: no base URL configured for region %q", region) + } + return strings.TrimSuffix(base, "/"), nil +} + +type togetherAIReasoningOptions struct { + Enabled bool `json:"enabled"` +} + +func (t *TogetherAIModel) chatPayload(modelName string, messages []Message, stream bool, chatModelConfig *ChatConfig) map[string]interface{} { + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": stream, + } + + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + if chatModelConfig.Thinking != nil { + reqBody["reasoning"] = togetherAIReasoningOptions{ + Enabled: *chatModelConfig.Thinking, + } + } + if chatModelConfig.Effort != nil && strings.Contains(strings.ToLower(modelName), "gpt-oss") { + reqBody["reasoning_effort"] = *chatModelConfig.Effort + } + } + + return reqBody +} + +func (t *TogetherAIModel) chatURL(apiConfig *APIConfig) (string, error) { + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := t.baseURLForRegion(region) + if err != nil { + return "", err + } + return fmt.Sprintf("%s/%s", baseURL, t.URLSuffix.Chat), nil +} + +type togetherAIChatMessage struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + Reasoning string `json:"reasoning"` +} + +type togetherAIChatChoice struct { + Message togetherAIChatMessage `json:"message"` + Delta togetherAIChatMessage `json:"delta"` + FinishReason string `json:"finish_reason"` +} + +type togetherAIChatResponse struct { + Choices []togetherAIChatChoice `json:"choices"` + Error interface{} `json:"error"` + FinishReason string `json:"finish_reason"` +} + +func (t *TogetherAIModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if strings.TrimSpace(modelName) == "" { + return nil, fmt.Errorf("model name is required") + } + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + url, err := t.chatURL(apiConfig) + if err != nil { + return nil, err + } + + jsonData, err := json.Marshal(t.chatPayload(modelName, messages, false, chatModelConfig)) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result togetherAIChatResponse + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if result.Error != nil { + return nil, fmt.Errorf("togetherai: upstream error: %v", result.Error) + } + if len(result.Choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + content := result.Choices[0].Message.Content + reasonContent := result.Choices[0].Message.ReasoningContent + if reasonContent == "" { + reasonContent = result.Choices[0].Message.Reasoning + } + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +const togetherAIStreamTimeout = 10 * time.Minute + +func (t *TogetherAIModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + if strings.TrimSpace(modelName) == "" { + return fmt.Errorf("model name is required") + } + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + if chatModelConfig != nil && chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + url, err := t.chatURL(apiConfig) + if err != nil { + return err + } + + jsonData, err := json.Marshal(t.chatPayload(modelName, messages, true, chatModelConfig)) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // ResponseHeaderTimeout caps the initial header wait. This context + // also caps the body-read phase so a stalled SSE stream cannot hold + // the caller's goroutine and connection indefinitely. + ctx, cancel := context.WithTimeout(context.Background(), togetherAIStreamTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + req.Header.Set("Accept", "text/event-stream") + + resp, err := t.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + if data == "[DONE]" { + sawTerminal = true + break + } + + var event togetherAIChatResponse + if err = json.Unmarshal([]byte(data), &event); err != nil { + return fmt.Errorf("togetherai: invalid SSE event: %w", err) + } + if event.Error != nil { + return fmt.Errorf("togetherai: upstream stream error: %v", event.Error) + } + if len(event.Choices) == 0 { + continue + } + + choice := event.Choices[0] + if choice.Delta.ReasoningContent != "" { + if err := sender(nil, &choice.Delta.ReasoningContent); err != nil { + return err + } + } + if choice.Delta.Reasoning != "" { + if err := sender(nil, &choice.Delta.Reasoning); err != nil { + return err + } + } + if choice.Delta.Content != "" { + if err := sender(&choice.Delta.Content, nil); err != nil { + return err + } + } + if choice.FinishReason != "" || event.FinishReason != "" { + sawTerminal = true + break + } + } + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("togetherai: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + return sender(&endOfStream, nil) +} + +type togetherAIModelInfo struct { + ID string `json:"id"` +} + +func (t *TogetherAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := t.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, t.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := t.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result []togetherAIModelInfo + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models := make([]string, 0, len(result)) + for _, model := range result { + if model.ID != "" { + models = append(models, model.ID) + } + } + return models, nil +} + +func (t *TogetherAIModel) CheckConnection(apiConfig *APIConfig) error { + _, err := t.ListModels(apiConfig) + return err +} + +func (t *TogetherAIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", t.Name()) +} + +func (t *TogetherAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", t.Name()) +} diff --git a/internal/entity/models/togetherai_test.go b/internal/entity/models/togetherai_test.go new file mode 100644 index 00000000000..aecdf20c95b --- /dev/null +++ b/internal/entity/models/togetherai_test.go @@ -0,0 +1,277 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newTogetherAIServer(t *testing.T, handler func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if got := r.Header.Get("Content-Type"); !strings.HasPrefix(got, "application/json") { + t.Errorf("expected Content-Type to start with application/json, got %q", got) + return + } + var body map[string]interface{} + if r.Method == http.MethodPost { + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v\nraw=%s", err, string(raw)) + return + } + } + handler(t, r, body, w) + })) +} + +func newTogetherAIForTest(baseURL string) *TogetherAIModel { + return NewTogetherAIModel( + map[string]string{"default": baseURL}, + URLSuffix{Chat: "chat/completions", Models: "models"}, + ) +} + +func TestTogetherAIName(t *testing.T) { + if got := newTogetherAIForTest("http://unused").Name(); got != "togetherai" { + t.Errorf("Name()=%q", got) + } +} + +func TestTogetherAIFactory(t *testing.T) { + driver, err := NewModelFactory().CreateModelDriver("TogetherAI", map[string]string{"default": "http://unused"}, URLSuffix{}) + if err != nil { + t.Fatalf("CreateModelDriver: %v", err) + } + if _, ok := driver.(*TogetherAIModel); !ok { + t.Fatalf("driver type=%T, want *TogetherAIModel", driver) + } +} + +func TestTogetherAIChatHappyPath(t *testing.T) { + srv := newTogetherAIServer(t, func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if r.URL.Path != "/chat/completions" { + t.Errorf("path=%s", r.URL.Path) + } + if body["model"] != "openai/gpt-oss-20b" { + t.Errorf("model=%v", body["model"]) + } + if body["stream"] != false { + t.Errorf("stream=%v want false", body["stream"]) + } + if body["reasoning_effort"] != "high" { + t.Errorf("reasoning_effort=%v", body["reasoning_effort"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "content": "pong", + "reasoning": "thinking", + }, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + mt := 32 + temp := 0.3 + topP := 0.9 + stop := []string{"END"} + effort := "high" + resp, err := newTogetherAIForTest(srv.URL).ChatWithMessages( + "openai/gpt-oss-20b", + []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &mt, Temperature: &temp, TopP: &topP, Stop: &stop, Effort: &effort}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if *resp.Answer != "pong" { + t.Errorf("Answer=%q", *resp.Answer) + } + if *resp.ReasonContent != "thinking" { + t.Errorf("ReasonContent=%q", *resp.ReasonContent) + } +} + +func TestTogetherAIChatForwardsReasoningEnabled(t *testing.T) { + srv := newTogetherAIServer(t, func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "Qwen/Qwen3.5-9B" { + t.Errorf("model=%v", body["model"]) + } + reasoning, ok := body["reasoning"].(map[string]interface{}) + if !ok { + t.Fatalf("reasoning=%T, want object", body["reasoning"]) + } + if reasoning["enabled"] != false { + t.Errorf("reasoning.enabled=%v, want false", reasoning["enabled"]) + } + if _, ok := body["reasoning_effort"]; ok { + t.Errorf("reasoning_effort should not be sent for non-GPT-OSS model: %v", body["reasoning_effort"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "choices": []map[string]interface{}{{ + "message": map[string]interface{}{ + "content": "pong", + }, + }}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + thinking := false + resp, err := newTogetherAIForTest(srv.URL).ChatWithMessages( + "Qwen/Qwen3.5-9B", + []Message{{Role: "user", Content: "ping"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Thinking: &thinking}, + ) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if *resp.Answer != "pong" { + t.Errorf("Answer=%q", *resp.Answer) + } +} + +func TestTogetherAIChatRequiresModelName(t *testing.T) { + apiKey := "test-key" + _, err := newTogetherAIForTest("http://unused").ChatWithMessages("", []Message{{Role: "user", Content: "x"}}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } +} + +func TestTogetherAIStreamHappyPath(t *testing.T) { + srv := newTogetherAIServer(t, func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if r.URL.Path != "/chat/completions" { + t.Errorf("path=%s", r.URL.Path) + } + if body["stream"] != true { + t.Errorf("stream=%v want true", body["stream"]) + } + if got := r.Header.Get("Accept"); got != "text/event-stream" { + t.Errorf("Accept=%q", got) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"reasoning":"think "}}]}`+"\n"+ + `data: {"choices":[{"delta":{"content":"Hello"}}]}`+"\n"+ + `data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}]}`+"\n", + ) + }) + defer srv.Close() + + apiKey := "test-key" + var content []string + var reasoning []string + err := newTogetherAIForTest(srv.URL).ChatStreamlyWithSender( + "meta-llama/Llama-3.3-70B-Instruct-Turbo", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(c *string, r *string) error { + if c != nil { + content = append(content, *c) + } + if r != nil { + reasoning = append(reasoning, *r) + } + return nil + }, + ) + if err != nil { + t.Fatalf("ChatStreamlyWithSender: %v", err) + } + if strings.Join(content, "") != "Hello world[DONE]" { + t.Errorf("content=%q", strings.Join(content, "")) + } + if strings.Join(reasoning, "") != "think " { + t.Errorf("reasoning=%q", strings.Join(reasoning, "")) + } +} + +func TestTogetherAIStreamStopsOnRootFinishReason(t *testing.T) { + srv := newTogetherAIServer(t, func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter) { + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"content":"Done"}}],"finish_reason":"stop"}`+"\n", + ) + }) + defer srv.Close() + + apiKey := "test-key" + var chunks []string + err := newTogetherAIForTest(srv.URL).ChatStreamlyWithSender( + "meta-llama/Llama-3.3-70B-Instruct-Turbo", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{ApiKey: &apiKey}, nil, + func(c *string, _ *string) error { + if c != nil { + chunks = append(chunks, *c) + } + return nil + }, + ) + if err != nil { + t.Fatalf("ChatStreamlyWithSender: %v", err) + } + if strings.Join(chunks, "") != "Done[DONE]" { + t.Errorf("chunks=%q", strings.Join(chunks, "")) + } +} + +func TestTogetherAIListModelsAndCheckConnection(t *testing.T) { + srv := newTogetherAIServer(t, func(t *testing.T, r *http.Request, body map[string]interface{}, w http.ResponseWriter) { + if r.Method != http.MethodGet { + t.Errorf("method=%s", r.Method) + } + if r.URL.Path != "/models" { + t.Errorf("path=%s", r.URL.Path) + } + _ = json.NewEncoder(w).Encode([]map[string]interface{}{ + {"id": "openai/gpt-oss-20b"}, + {"id": "meta-llama/Llama-3.3-70B-Instruct-Turbo"}, + }) + }) + defer srv.Close() + + apiKey := "test-key" + model := newTogetherAIForTest(srv.URL) + models, err := model.ListModels(&APIConfig{ApiKey: &apiKey}) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if strings.Join(models, ",") != "openai/gpt-oss-20b,meta-llama/Llama-3.3-70B-Instruct-Turbo" { + t.Errorf("models=%v", models) + } + if err := model.CheckConnection(&APIConfig{ApiKey: &apiKey}); err != nil { + t.Fatalf("CheckConnection: %v", err) + } +} + +func TestTogetherAIUnsupportedMethods(t *testing.T) { + m := newTogetherAIForTest("http://unused") + if _, err := m.Embed(nil, nil, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Embed error=%v", err) + } + if _, err := m.Rerank(nil, "", nil, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank error=%v", err) + } + if _, err := m.Balance(nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Balance error=%v", err) + } +} diff --git a/internal/entity/models/types.go b/internal/entity/models/types.go index 250e41bc51a..12534851046 100644 --- a/internal/entity/models/types.go +++ b/internal/entity/models/types.go @@ -17,21 +17,34 @@ type ModelDriver interface { Name() string - // ChatWithMessages sends multiple messages with role and content + // ChatWithMessages sends multiple messages synchronously ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) - // ChatStreamlyWithSender sends messages and streams response via sender function (best performance, no channel) - // messages accepts []Message which supports multimodal content (e.g., [{"type": "text", "text": "..."}, {"type": "image_url", "image_url": {"url": "..."}}]) + // ChatStreamlyWithSender sends multiple messages asynchronously ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error - // Encode encodes a list of texts into embeddings - Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) + // Embed a list of texts into embeddings + Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) // Rerank calculates similarity scores between query and texts Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) + // TranscribeAudio transcribe audio + TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) + TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error + // AudioSpeech convert text to audio + AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) + AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error + // OCRFile OCR file + OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) + // ParseFile parse file + ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) // ListModels List supported models ListModels(apiConfig *APIConfig) ([]string, error) Balance(apiConfig *APIConfig) (map[string]interface{}, error) CheckConnection(apiConfig *APIConfig) error + + ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) + + ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) } type ChatResponse struct { @@ -39,14 +52,9 @@ type ChatResponse struct { ReasonContent *string `json:"reason_content"` } -type EmbeddingResult struct { - Index int `json:"index"` - Dimension int `json:"dimension"` - //Embedding []float64 `json:"embedding"` -} - -type EmbeddingResponse struct { - Data []EmbeddingResult `json:"data"` +type EmbeddingData struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` } type RerankResult struct { @@ -58,17 +66,53 @@ type RerankResponse struct { Data []RerankResult `json:"data"` } +type ASRResponse struct { + Text string `json:"text"` +} + +type TTSResponse struct { + Audio []byte `json:"audio"` +} + +type OCRFileResponse struct { + Text *string `json:"text"` +} + +type ParseFileResponse struct { + TaskID string `json:"task_id"` +} + +type ListTaskStatus struct { + TaskID string `json:"task_id"` + Status string `json:"status"` +} + +type TaskSegment struct { + Index int `json:"index"` + Content string `json:"content"` +} + +type TaskResponse struct { + Segments []TaskSegment `json:"segments"` +} + // URLSuffix represents the URL suffixes for different API endpoints type URLSuffix struct { - Chat string `json:"chat"` - AsyncChat string `json:"async_chat"` - AsyncResult string `json:"async_result"` - Embedding string `json:"embedding"` - Rerank string `json:"rerank"` - Models string `json:"models"` - Balance string `json:"balance"` - Files string `json:"files"` - Status string `json:"status"` + Chat string `json:"chat"` + AsyncChat string `json:"async_chat"` + AsyncResult string `json:"async_result"` + Embedding string `json:"embedding"` + Rerank string `json:"rerank"` + TTS string `json:"tts"` + ASR string `json:"asr"` + OCR string `json:"ocr"` + DocumentParse string `json:"doc_parse"` + Models string `json:"models"` + Balance string `json:"balance"` + Files string `json:"files"` + Status string `json:"status"` + Tasks string `json:"tasks"` + Task string `json:"task"` } type ChatConfig struct { @@ -98,6 +142,21 @@ type RerankConfig struct { TopN int } +type ASRConfig struct { + Params map[string]interface{} `json:"params"` +} + +type TTSConfig struct { + Format string `json:"format"` + Params map[string]interface{} `json:"params"` +} + +type OCRConfig struct { +} + +type ParseFileConfig struct { +} + // EmbeddingModel wraps a ModelDriver with embedding-specific configuration type EmbeddingModel struct { ModelDriver ModelDriver diff --git a/internal/entity/models/upstage.go b/internal/entity/models/upstage.go new file mode 100644 index 00000000000..1b492b60e2b --- /dev/null +++ b/internal/entity/models/upstage.go @@ -0,0 +1,622 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// UpstageModel implements ModelDriver for Upstage (Solar models). +// +// Upstage exposes an OpenAI-compatible REST API at +// https://api.upstage.ai/v1 (chat completions at /chat/completions, list +// models at /models, embeddings at /embeddings). The wire shape matches +// OpenAI closely enough that the chat path here is a direct port of the +// OpenAI driver. The legacy /v1/solar/* paths still work but the canonical +// base is /v1. +type UpstageModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewUpstageModel creates a new Upstage model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +// +// The Client itself has no Timeout. http.Client.Timeout would also +// cap the time spent reading the response body, which would cut off +// long-lived SSE streams in ChatStreamlyWithSender. Non-streaming +// callers wrap each request with context.WithTimeout instead. +func NewUpstageModel(baseURL map[string]string, urlSuffix URLSuffix) *UpstageModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &UpstageModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (u *UpstageModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewUpstageModel(baseURL, u.URLSuffix) +} + +func (u *UpstageModel) Name() string { + return "upstage" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. This makes a misconfigured region fail +// fast with a clear message, instead of silently producing a relative +// URL that the HTTP transport then rejects. +func (u *UpstageModel) baseURLForRegion(region string) (string, error) { + base, ok := u.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("upstage: no base URL configured for region %q", region) + } + return base, nil +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (u *UpstageModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := u.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, u.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + } + + // Note: do NOT propagate chatModelConfig.Stream into the request body + // here. ChatWithMessages parses a single JSON response, so stream must + // always be off for this code path. + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + // Upstage Solar reasoning models (solar-pro2 and the upcoming + // solar-pro3) accept reasoning_effort=low|medium|high to trade + // latency for chain-of-thought depth, matching the OpenAI + // o-series shape. ChatConfig.Effort is the canonical carrier. + if chatModelConfig.Effort != nil && *chatModelConfig.Effort != "" { + reqBody["reasoning_effort"] = *chatModelConfig.Effort + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := u.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok || len(choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid choice format") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("invalid message format") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + + // Upstage Solar reasoning models (solar-pro3, solar-pro2 with + // reasoning_effort >= medium) return the chain-of-thought in a + // `reasoning` field on the message. Pass it through when present + // so callers that opted into reasoning can show it. Absent or + // non-string means no reasoning was emitted — leave it empty. + reasonContent := "" + if r, ok := messageMap["reasoning"].(string); ok { + reasonContent = r + } + + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams the response via the +// sender function. The Upstage SSE stream uses the same shape as OpenAI: +// "data:" lines carrying JSON events, with a final "[DONE]" line. +func (u *UpstageModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return fmt.Errorf("api key is required") + } + + var region = "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := u.baseURLForRegion(region) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, u.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + } + + if chatModelConfig != nil { + // Refuse to run if the caller explicitly asked for stream=false. + // The body of this method only knows how to read SSE, so a + // non-SSE JSON response would be parsed as if it were a stream + // and produce no chunks. Better to fail clearly. + if chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + // reasoning_effort: same as the non-streaming path above. + if chatModelConfig.Effort != nil && *chatModelConfig.Effort != "" { + reqBody["reasoning_effort"] = *chatModelConfig.Effort + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + // SSE streams are long-lived. We rely on the transport's + // ResponseHeaderTimeout to cap the connection-establishment phase + // instead of attaching a hard deadline here. + req, err := http.NewRequestWithContext(context.Background(), "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := u.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: bump the scanner buffer from the 64KB default to 1MB + // so we never silently truncate a long data: line. + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + line := scanner.Text() + + if !strings.HasPrefix(line, "data:") { + continue + } + + data := strings.TrimSpace(line[5:]) + + if data == "[DONE]" { + sawTerminal = true + break + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("upstage: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + if err := sender(&endOfStream, nil); err != nil { + return err + } + + return nil +} + +type upstageEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type upstageEmbeddingResponse struct { + Data []upstageEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` +} + +// Embed turns a list of texts into embedding vectors using the Upstage +// /v1/solar/embeddings endpoint (solar-embedding-1-large-query for queries, +// solar-embedding-1-large-passage for passages). The output has one vector +// per input, in the same order the inputs were given. +func (u *UpstageModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := u.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, u.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := u.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Upstage embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed upstageEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Reorder by the reported index so the output always lines up with + // the input texts, even if the upstream API ever returns items out + // of order. A nil slot at the end indicates the upstream did not + // return an embedding for that input. + embeddings := make([]EmbeddingData, len(texts)) + filled := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("upstage: response index %d out of range for %d inputs", item.Index, len(texts)) + } + if filled[item.Index] { + // A malformed response that repeats the same index would + // silently overwrite the earlier vector. Fail loudly so + // the caller never uses ambiguous output. + return nil, fmt.Errorf("upstage: duplicate embedding index %d in response", item.Index) + } + embeddings[item.Index] = EmbeddingData{ + Embedding: item.Embedding, + Index: item.Index, + } + filled[item.Index] = true + } + for i, ok := range filled { + if !ok { + return nil, fmt.Errorf("upstage: missing embedding for input index %d", i) + } + } + + return embeddings, nil +} + +// ListModels returns the list of model ids visible to the API key. +func (u *UpstageModel) ListModels(apiConfig *APIConfig) ([]string, error) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := u.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, u.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := u.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + data, ok := result["data"].([]interface{}) + if !ok { + return nil, fmt.Errorf("invalid models list format") + } + + models := make([]string, 0) + for _, model := range data { + modelMap, ok := model.(map[string]interface{}) + if !ok { + continue + } + modelName, ok := modelMap["id"].(string) + if !ok { + continue + } + models = append(models, modelName) + } + + return models, nil +} + +// Balance is not exposed by the Upstage API, so this returns "no such method". +func (u *UpstageModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("no such method") +} + +// CheckConnection runs a lightweight ListModels call to verify the API key. +func (u *UpstageModel) CheckConnection(apiConfig *APIConfig) error { + _, err := u.ListModels(apiConfig) + if err != nil { + return err + } + return nil +} + +// Rerank calculates similarity scores between query and documents. Upstage +// does not expose a public rerank API, so this returns "no such method". +func (u *UpstageModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("no such method") +} + +// TranscribeAudio transcribe audio +func (z *UpstageModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *UpstageModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (z *UpstageModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *UpstageModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (z *UpstageModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +// ParseFile parse file +func (z *UpstageModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *UpstageModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *UpstageModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/upstage_test.go b/internal/entity/models/upstage_test.go new file mode 100644 index 00000000000..cb651df94af --- /dev/null +++ b/internal/entity/models/upstage_test.go @@ -0,0 +1,271 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newUpstageForTest(baseURL string) *UpstageModel { + return NewUpstageModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "chat/completions", + Models: "models", + Embedding: "embeddings", + }, + ) +} + +// ---------- reasoning_effort / reasoning field ---------- + +func TestUpstageChatPropagatesReasoningEffort(t *testing.T) { + // Per https://console.upstage.ai/api/docs/for-agents/raw, Upstage + // Solar models accept `reasoning_effort: minimal|low|medium|high`. + // ChatConfig.Effort is the canonical carrier; this test asserts it + // flows into the wire body verbatim. + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + effort := "high" + _, err := u.ChatWithMessages("solar-pro2", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Effort: &effort}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if got, ok := seen["reasoning_effort"].(string); !ok || got != "high" { + t.Errorf("reasoning_effort=%v want \"high\"", seen["reasoning_effort"]) + } +} + +func TestUpstageChatOmitsReasoningEffortWhenUnset(t *testing.T) { + // If the caller does not opt in, the field must NOT be sent. Sending + // "minimal" by default would silently change behavior for downstream + // proxies that treat a present field differently from an absent one. + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + _, err := u.ChatWithMessages("solar-pro2", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{}, // no Effort + ) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if _, present := seen["reasoning_effort"]; present { + t.Errorf("reasoning_effort should be absent when Effort is unset, got %v", seen["reasoning_effort"]) + } +} + +func TestUpstageStreamPropagatesReasoningEffort(t *testing.T) { + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"content":"hi"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + effort := "medium" + err := u.ChatStreamlyWithSender("solar-pro2", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{Effort: &effort}, + func(*string, *string) error { return nil }, + ) + if err != nil { + t.Fatalf("Stream: %v", err) + } + if got, ok := seen["reasoning_effort"].(string); !ok || got != "medium" { + t.Errorf("stream reasoning_effort=%v want \"medium\"", seen["reasoning_effort"]) + } +} + +func TestUpstageChatExtractsReasoningField(t *testing.T) { + // Per the Upstage docs: when reasoning_effort is high|medium for + // solar-pro3 (or high for solar-pro2), the response's + // choices[0].message includes a `reasoning` field. The driver must + // pass it through as ChatResponse.ReasonContent. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"choices":[{"message":{ + "content":"15% of 80 is **12**.", + "reasoning":"15/100 = 0.15; 0.15 * 80 = 12" + }}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + resp, err := u.ChatWithMessages("solar-pro3", + []Message{{Role: "user", Content: "What is 15% of 80?"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "15/100 = 0.15; 0.15 * 80 = 12" { + t.Errorf("ReasonContent=%v want the reasoning trace", resp.ReasonContent) + } + if resp.Answer == nil || *resp.Answer != "15% of 80 is **12**." { + t.Errorf("Answer=%v", resp.Answer) + } +} + +func TestUpstageChatHandlesAbsentReasoning(t *testing.T) { + // Models without reasoning (solar-mini, syn-pro) or low-effort + // requests return no `reasoning` field. The driver must leave + // ReasonContent empty without crashing. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + resp, err := u.ChatWithMessages("solar-mini", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "" { + t.Errorf("ReasonContent=%v want empty string for no-reasoning response", resp.ReasonContent) + } + if resp.Answer == nil || *resp.Answer != "ok" { + t.Errorf("Answer=%v want ok", resp.Answer) + } +} + +// Ensure the same JSON shape used by the maintainer's docs (per +// https://console.upstage.ai/api/chat) round-trips through the request +// body for both streaming and non-streaming paths. +func TestUpstageRequestBodyMatchesSolarAPIShape(t *testing.T) { + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + mt := 256 + temp := 0.7 + topP := 0.9 + stop := []string{"END"} + effort := "high" + _, err := u.ChatWithMessages("solar-pro2", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &apiKey}, + &ChatConfig{MaxTokens: &mt, Temperature: &temp, TopP: &topP, Stop: &stop, Effort: &effort}) + if err != nil { + t.Fatalf("Chat: %v", err) + } + want := map[string]interface{}{ + "model": "solar-pro2", + "stream": false, + "max_tokens": float64(256), + "temperature": 0.7, + "top_p": 0.9, + "reasoning_effort": "high", + } + for k, v := range want { + if got, ok := seen[k]; !ok { + t.Errorf("missing key %q in body", k) + } else if !strings.HasPrefix(k, "stop") && got != v { + t.Errorf("body[%q]=%v want %v", k, got, v) + } + } + if stopArr, ok := seen["stop"].([]interface{}); !ok || len(stopArr) != 1 || stopArr[0] != "END" { + t.Errorf("body[stop]=%v want [END]", seen["stop"]) + } + if _, ok := seen["messages"].([]interface{}); !ok { + t.Errorf("body[messages] missing or wrong type") + } +} + +// ---------- Embed: duplicate / out-of-range / reorder ---------- + +func TestUpstageEmbedRejectsDuplicateIndex(t *testing.T) { + // A malformed upstream that repeats data[*].index would silently + // overwrite the earlier vector; the driver must fail loudly instead. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"data":[ + {"embedding":[1],"index":0}, + {"embedding":[2],"index":0}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + model := "solar-embedding-1-large-passage" + _, err := u.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") { + t.Errorf("expected duplicate-index error, got %v", err) + } +} + +func TestUpstageEmbedRejectsOutOfRangeIndex(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"data":[{"embedding":[1],"index":7}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + model := "solar-embedding-1-large-passage" + _, err := u.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestUpstageEmbedHappyPathReordersByIndex(t *testing.T) { + // Upstream returns vectors in shuffled order; driver must realign. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"data":[ + {"embedding":[2],"index":2}, + {"embedding":[0],"index":0}, + {"embedding":[1],"index":1}]}`) + })) + defer srv.Close() + + u := newUpstageForTest(srv.URL) + apiKey := "test-key" + model := "solar-embedding-1-large-passage" + vecs, err := u.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + for i, v := range vecs { + if v.Index != i || v.Embedding[0] != float64(i) { + t.Errorf("slot %d = %+v, want index=%d embedding=[%d]", i, v, i, i) + } + } +} diff --git a/internal/entity/models/vllm.go b/internal/entity/models/vllm.go index 97ade07d1ea..c8dbcdab7c0 100644 --- a/internal/entity/models/vllm.go +++ b/internal/entity/models/vllm.go @@ -19,6 +19,7 @@ package models import ( "bufio" "bytes" + "context" "encoding/json" "fmt" "io" @@ -378,8 +379,93 @@ func (z *VllmModel) ChatStreamlyWithSender(modelName string, messages []Message, } // Encode encodes a list of texts into embeddings -func (z *VllmModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { - return nil, fmt.Errorf("not implemented") +type vllmEmbeddingResponse struct { + Data []struct { + Index int `json:"index"` + Embedding []float64 `json:"embedding"` + } `json:"data"` +} + +// Embed embeds a list of texts into embeddings +func (z *VllmModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := z.BaseURL[region] + if baseURL == "" { + baseURL = z.BaseURL["default"] + } + if baseURL == "" { + return nil, fmt.Errorf("missing base URL: please configure the local access address for vLLM (e.g., http://127.0.0.1:8000/v1)") + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("vLLM embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed vllmEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + var embeddings []EmbeddingData + for _, dataElem := range parsed.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) + } + + return embeddings, nil } func (z *VllmModel) ListModels(apiConfig *APIConfig) ([]string, error) { @@ -461,7 +547,154 @@ func (z *VllmModel) CheckConnection(apiConfig *APIConfig) error { return err } -// Rerank calculates similarity scores between query and documents +// vllmRerankRequest mirrors vLLM's Jina/Cohere-compatible /v1/rerank +// payload. Unlike NVIDIA NIM (which wraps each passage as {text: "..."}), +// vLLM accepts documents as a flat []string. +type vllmRerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopN int `json:"top_n"` +} + +// vllmRerankResponse maps the Jina-style results array. The `document` +// field is intentionally ignored — callers reconstruct text from the +// original input via Index. +type vllmRerankResponse struct { + Results []struct { + Index int `json:"index"` + RelevanceScore float64 `json:"relevance_score"` + } `json:"results"` +} + +// Rerank scores documents against the query using a vLLM rerank model +// served at /v1/rerank (stable since vLLM v0.7). Mirrors the contract +// of NvidiaModel.Rerank: defaults top_n to len(documents) so callers +// get a score per input, shrinks to RerankConfig.TopN only when set +// and smaller. Returned RerankResult entries are in the API's ranking +// order; callers that need original-input order sort by Index. The +// Authorization header is sent only when APIConfig.ApiKey is non-empty, +// matching the existing Embed/ListModels behaviour for this local +// driver. func (z *VllmModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { - return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := z.BaseURL[region] + if baseURL == "" { + baseURL = z.BaseURL["default"] + } + if baseURL == "" { + return nil, fmt.Errorf("missing base URL: please configure the local access address for vLLM (e.g., http://127.0.0.1:8000/v1)") + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Rerank) + + topN := len(documents) + if rerankConfig != nil && rerankConfig.TopN > 0 && rerankConfig.TopN < topN { + topN = rerankConfig.TopN + } + + reqBody := vllmRerankRequest{ + Model: *modelName, + Query: query, + Documents: documents, + TopN: topN, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("vLLM rerank API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed vllmRerankResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + rerankResponse := RerankResponse{Data: make([]RerankResult, 0, len(parsed.Results))} + for _, r := range parsed.Results { + if r.Index < 0 || r.Index >= len(documents) { + return nil, fmt.Errorf("unexpected rerank index %d for %d inputs", r.Index, len(documents)) + } + rerankResponse.Data = append(rerankResponse.Data, RerankResult{ + Index: r.Index, + RelevanceScore: r.RelevanceScore, + }) + } + + return &rerankResponse, nil +} + +// TranscribeAudio transcribe audio +func (o *VllmModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *VllmModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (o *VllmModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *VllmModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *VllmModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *VllmModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *VllmModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *VllmModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) } diff --git a/internal/entity/models/vllm_rerank_test.go b/internal/entity/models/vllm_rerank_test.go new file mode 100644 index 00000000000..42fda948c2f --- /dev/null +++ b/internal/entity/models/vllm_rerank_test.go @@ -0,0 +1,209 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newVllmRerankServer(t *testing.T, expectAuth string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Errorf("expected POST, got %s", r.Method) + return + } + if r.URL.Path != "/rerank" { + t.Errorf("expected path=/rerank, got %s", r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != expectAuth { + t.Errorf("expected Authorization=%q, got %q", expectAuth, got) + return + } + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("expected Content-Type=application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("failed to read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("invalid JSON body: %v\n%s", err, string(raw)) + return + } + handler(t, body, w) + })) +} + +func newVllmModelForTest(baseURL string) *VllmModel { + return NewVllmModel( + map[string]string{"default": baseURL}, + URLSuffix{Rerank: "rerank"}, + ) +} + +func TestVllmRerankHappyPath(t *testing.T) { + srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "BAAI/bge-reranker-v2-m3" { + t.Errorf("expected model=BAAI/bge-reranker-v2-m3, got %v", body["model"]) + } + if body["query"] != "What is RAPTOR?" { + t.Errorf("expected query=What is RAPTOR?, got %v", body["query"]) + } + // vLLM differs from NVIDIA: documents is a flat []string, not [{text}]. + docs, ok := body["documents"].([]interface{}) + if !ok || len(docs) != 3 { + t.Errorf("expected 3 documents, got %v", body["documents"]) + return + } + for i, want := range []string{"doc-zero", "doc-one", "doc-two"} { + if docs[i] != want { + t.Errorf("documents[%d]=%v, want %s", i, docs[i], want) + } + } + if body["top_n"] != float64(3) { + t.Errorf("expected top_n=3 (matching len(documents)), got %v", body["top_n"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "results": []map[string]interface{}{ + {"index": 2, "relevance_score": 0.95}, + {"index": 0, "relevance_score": 0.42}, + {"index": 1, "relevance_score": 0.78}, + }, + }) + }) + defer srv.Close() + + model := newVllmModelForTest(srv.URL) + apiKey := "test-key" + modelName := "BAAI/bge-reranker-v2-m3" + resp, err := model.Rerank( + &modelName, + "What is RAPTOR?", + []string{"doc-zero", "doc-one", "doc-two"}, + &APIConfig{ApiKey: &apiKey}, + &RerankConfig{}, + ) + if err != nil { + t.Fatalf("Rerank failed: %v", err) + } + if len(resp.Data) != 3 { + t.Fatalf("expected 3 results, got %d", len(resp.Data)) + } + want := map[int]float64{0: 0.42, 1: 0.78, 2: 0.95} + for _, r := range resp.Data { + if got, ok := want[r.Index]; !ok || got != r.RelevanceScore { + t.Errorf("unexpected result Index=%d RelevanceScore=%v", r.Index, r.RelevanceScore) + } + } +} + +func TestVllmRerankTopNClamp(t *testing.T) { + srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["top_n"] != float64(2) { + t.Errorf("expected top_n clamp to RerankConfig.TopN=2, got %v", body["top_n"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"results": []map[string]interface{}{}}) + }) + defer srv.Close() + + model := newVllmModelForTest(srv.URL) + apiKey := "test-key" + modelName := "BAAI/bge-reranker-v2-m3" + if _, err := model.Rerank( + &modelName, "q", + []string{"a", "b", "c", "d"}, + &APIConfig{ApiKey: &apiKey}, + &RerankConfig{TopN: 2}, + ); err != nil { + t.Fatalf("Rerank failed: %v", err) + } +} + +func TestVllmRerankEmptyDocuments(t *testing.T) { + model := newVllmModelForTest("http://unused") + apiKey := "test-key" + modelName := "BAAI/bge-reranker-v2-m3" + resp, err := model.Rerank(&modelName, "q", nil, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err != nil { + t.Fatalf("expected nil error for empty documents, got %v", err) + } + if len(resp.Data) != 0 { + t.Errorf("expected empty Data, got %d entries", len(resp.Data)) + } +} + +// vLLM is a local driver; the Authorization header must be omitted when +// no APIConfig.ApiKey is configured. This diverges from the NVIDIA driver +// which requires an API key. +func TestVllmRerankWithoutAPIKey(t *testing.T) { + srv := newVllmRerankServer(t, "", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "results": []map[string]interface{}{ + {"index": 0, "relevance_score": 0.5}, + }, + }) + }) + defer srv.Close() + + model := newVllmModelForTest(srv.URL) + modelName := "BAAI/bge-reranker-v2-m3" + resp, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{}, &RerankConfig{}) + if err != nil { + t.Fatalf("Rerank failed without api key: %v", err) + } + if len(resp.Data) != 1 || resp.Data[0].Index != 0 { + t.Errorf("unexpected response: %+v", resp) + } +} + +func TestVllmRerankRequiresModelName(t *testing.T) { + model := newVllmModelForTest("http://unused") + apiKey := "test-key" + _, err := model.Rerank(nil, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } +} + +func TestVllmRerankRejectsHTTPError(t *testing.T) { + srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte(`{"error":"boom"}`)) + }) + defer srv.Close() + + model := newVllmModelForTest(srv.URL) + apiKey := "test-key" + modelName := "BAAI/bge-reranker-v2-m3" + _, err := model.Rerank(&modelName, "q", []string{"a"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "vLLM rerank API error") { + t.Errorf("expected API error, got %v", err) + } +} + +func TestVllmRerankRejectsOutOfRangeIndex(t *testing.T) { + srv := newVllmRerankServer(t, "Bearer test-key", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "results": []map[string]interface{}{ + {"index": 5, "relevance_score": 0.9}, + }, + }) + }) + defer srv.Close() + + model := newVllmModelForTest(srv.URL) + apiKey := "test-key" + modelName := "BAAI/bge-reranker-v2-m3" + _, err := model.Rerank(&modelName, "q", []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err == nil || !strings.Contains(err.Error(), "unexpected rerank index") { + t.Errorf("expected out-of-range error, got %v", err) + } +} diff --git a/internal/entity/models/volcengine.go b/internal/entity/models/volcengine.go index 8b5670756dc..8f3133aa416 100644 --- a/internal/entity/models/volcengine.go +++ b/internal/entity/models/volcengine.go @@ -406,10 +406,35 @@ func (z *VolcEngine) ChatStreamlyWithSender(modelName string, messages []Message return scanner.Err() } -// Encode encodes a list of texts into embeddings -func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +type volcengineEmbeddingResponse struct { + Created int64 `json:"created"` + Data volcengineEmbeddingData `json:"data"` + ID string `json:"id"` + Model string `json:"model"` + Object string `json:"object"` + Usage volcengineUsage `json:"usage"` +} + +type volcengineEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` +} + +type volcengineUsage struct { + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` + PromptTokensDetails *volcenginePromptTokensDetails `json:"prompt_tokens_details,omitempty"` +} + +type volcenginePromptTokensDetails struct { + ImageTokens int `json:"image_tokens"` + TextTokens int `json:"text_tokens"` +} + +// Embed embeds a list of texts into embeddings +func (z *VolcEngine) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { if len(texts) == 0 { - return [][]float64{}, nil + return []EmbeddingData{}, nil } var region = "default" @@ -419,7 +444,7 @@ func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APICon url := fmt.Sprintf("%s/%s", z.BaseURL[region], z.URLSuffix.Embedding) - embeddings := make([][]float64, len(texts)) + var embeddings []EmbeddingData for i, text := range texts { @@ -466,25 +491,15 @@ func (z *VolcEngine) Encode(modelName *string, texts []string, apiConfig *APICon return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) } - // Volcengine multimodal embedding response - type VolcengineEmbeddingResponse struct { - Data struct { - Embedding []float64 `json:"embedding"` - Object string `json:"object"` - } `json:"data"` - } - - var result VolcengineEmbeddingResponse - - if err = json.Unmarshal(body, &result); err != nil { + var parsed volcengineEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } - if len(result.Data.Embedding) == 0 { - return nil, fmt.Errorf("empty embedding in response") - } - - embeddings[i] = result.Data.Embedding + var embeddingData EmbeddingData + embeddingData.Index = i + embeddingData.Embedding = parsed.Data.Embedding + embeddings = append(embeddings, embeddingData) } return embeddings, nil @@ -495,10 +510,91 @@ func (z *VolcEngine) Rerank(modelName *string, query string, documents []string, return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } -func (z *VolcEngine) ListModels(apiConfig *APIConfig) ([]string, error) { +// TranscribeAudio transcribe audio +func (o *VolcEngine) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *VolcEngine) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (o *VolcEngine) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *VolcEngine) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *VolcEngine) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *VolcEngine) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } +func (z *VolcEngine) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL := z.BaseURL[region] + if baseURL == "" { + baseURL = z.BaseURL["default"] + } + if baseURL == "" { + return nil, fmt.Errorf("volcengine: no base URL configured for region %q", region) + } + + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Models) + + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } + + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("VolcEngine models API error: %s, body: %s", resp.Status, string(body)) + } + + var modelList DSModelList + if err = json.Unmarshal(body, &modelList); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models := make([]string, 0, len(modelList.Models)) + for _, model := range modelList.Models { + modelName := model.ID + if model.OwnedBy != "" { + modelName = model.ID + "@" + model.OwnedBy + } + models = append(models, modelName) + } + + return models, nil +} + func (z *VolcEngine) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { return nil, fmt.Errorf("%s, no such method", z.Name()) } @@ -536,3 +632,11 @@ func (z *VolcEngine) CheckConnection(apiConfig *APIConfig) error { return nil } + +func (z *VolcEngine) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *VolcEngine) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/voyage.go b/internal/entity/models/voyage.go new file mode 100644 index 00000000000..c17a0e80815 --- /dev/null +++ b/internal/entity/models/voyage.go @@ -0,0 +1,389 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// VoyageModel implements ModelDriver for Voyage AI. +// +// Voyage AI exposes a focused REST API at https://api.voyageai.com/v1 +// with embedding (/embeddings) and reranking (/rerank) only — no chat, +// no streaming, no /v1/models, no balance. This driver covers Embed +// and Rerank with real implementations and returns "no such method" +// for every other ModelDriver method. +// +// Wire shape, captured live: +// +// Embed response: {object, data:[{object,embedding,index,text}], model, usage} +// Rerank response: {object, data:[{relevance_score,index}], model, usage} +// +// Rerank uses top_k as the request param name (not top_n like +// Aliyun/SiliconFlow); the driver translates RerankConfig.TopN to +// top_k on the wire. +type VoyageModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +// NewVoyageModel creates a new Voyage AI model instance. +// +// We clone http.DefaultTransport so we keep Go's defaults for +// ProxyFromEnvironment, DialContext (with KeepAlive), HTTP/2, +// TLSHandshakeTimeout, and ExpectContinueTimeout, and only override +// the connection-pool fields we care about. +func NewVoyageModel(baseURL map[string]string, urlSuffix URLSuffix) *VoyageModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &VoyageModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (v *VoyageModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewVoyageModel(baseURL, v.URLSuffix) +} + +func (v *VoyageModel) Name() string { + return "voyage" +} + +// baseURLForRegion returns the base URL for the given region, or an +// error if no entry exists. Single-region for Voyage but kept here +// for consistency with other drivers. +func (v *VoyageModel) baseURLForRegion(region string) (string, error) { + base, ok := v.BaseURL[region] + if !ok || base == "" { + return "", fmt.Errorf("voyage: no base URL configured for region %q", region) + } + return base, nil +} + +type voyageEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Object string `json:"object"` + Index int `json:"index"` +} + +type voyageEmbeddingResponse struct { + Object string `json:"object"` + Data []voyageEmbeddingData `json:"data"` + Model string `json:"model"` +} + +// Embed turns a list of texts into embedding vectors using the +// Voyage AI /v1/embeddings endpoint. Output is one vector per input, +// in the same order the inputs were given. +func (v *VoyageModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := v.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), v.URLSuffix.Embedding) + + reqBody := map[string]interface{}{ + "model": *modelName, + "input": texts, + } + + // Voyage's Matryoshka models (voyage-3.5, voyage-3.5-lite, + // voyage-3-large, voyage-code-3) accept output_dimension to + // truncate the vector. The wire param is output_dimension + // (singular) per https://docs.voyageai.com/reference/embeddings-api; + // passing "dimensions" or "output_dimensions" gets rejected with + // HTTP 400, so it's worth matching the docs spelling exactly. + if embeddingConfig != nil && embeddingConfig.Dimension > 0 { + reqBody["output_dimension"] = embeddingConfig.Dimension + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := v.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Voyage embeddings API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed voyageEmbeddingResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Reorder by the reported index so the output always lines up with + // the input texts. Reject duplicates (silent overwrite would hide + // a malformed response) and out-of-range indices (silent panic on + // slice growth would mask the bug). + embeddings := make([]EmbeddingData, len(texts)) + filled := make([]bool, len(texts)) + for _, item := range parsed.Data { + if item.Index < 0 || item.Index >= len(texts) { + return nil, fmt.Errorf("voyage: response index %d out of range for %d inputs", item.Index, len(texts)) + } + if filled[item.Index] { + return nil, fmt.Errorf("voyage: duplicate embedding index %d in response", item.Index) + } + embeddings[item.Index] = EmbeddingData{ + Embedding: item.Embedding, + Index: item.Index, + } + filled[item.Index] = true + } + for i, ok := range filled { + if !ok { + return nil, fmt.Errorf("voyage: missing embedding for input index %d", i) + } + } + + return embeddings, nil +} + +type voyageRerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopK int `json:"top_k"` +} + +type voyageRerankResponse struct { + Object string `json:"object"` + Data []struct { + RelevanceScore float64 `json:"relevance_score"` + Index int `json:"index"` + } `json:"data"` + Model string `json:"model"` +} + +// Rerank calculates similarity scores between a query and a list of +// documents using Voyage AI's /v1/rerank endpoint. Unlike many other +// rerank APIs that use `top_n`, Voyage uses `top_k` as the request +// parameter; the driver translates RerankConfig.TopN -> top_k. +func (v *VoyageModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + if len(documents) == 0 { + return &RerankResponse{}, nil + } + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + + region := "default" + if apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + baseURL, err := v.baseURLForRegion(region) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), v.URLSuffix.Rerank) + + topK := len(documents) + if rerankConfig != nil && rerankConfig.TopN > 0 { + topK = rerankConfig.TopN + } + + reqBody := voyageRerankRequest{ + Model: *modelName, + Query: query, + Documents: documents, + TopK: topK, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := v.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("Voyage rerank API error: %s, body: %s", resp.Status, string(body)) + } + + var parsed voyageRerankResponse + if err = json.Unmarshal(body, &parsed); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // Match Embed's defensive posture: rerank only returns top_k of + // len(documents) results, but a duplicate index would still be + // a malformed response and should fail loudly. + rerankResponse := &RerankResponse{} + seen := make(map[int]bool, len(parsed.Data)) + for _, r := range parsed.Data { + if r.Index < 0 || r.Index >= len(documents) { + return nil, fmt.Errorf("voyage: rerank result index %d out of range for %d documents", r.Index, len(documents)) + } + if seen[r.Index] { + return nil, fmt.Errorf("voyage: duplicate rerank index %d in response", r.Index) + } + seen[r.Index] = true + rerankResponse.Data = append(rerankResponse.Data, RerankResult{ + Index: r.Index, + RelevanceScore: r.RelevanceScore, + }) + } + + return rerankResponse, nil +} + +// ListModels is not exposed by the Voyage AI API. The docs at +// https://docs.voyageai.com publish embeddings and rerank endpoints +// only; /v1/models is not documented (live-confirmed: 404). The +// shipped catalog lives in conf/models/voyage.json; this driver +// method does not invent a fake one. +func (v *VoyageModel) ListModels(apiConfig *APIConfig) ([]string, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +// CheckConnection is not exposed by the Voyage AI API. With no +// documented /models or /health endpoint, the only way to verify +// credentials is to burn an embedding or rerank call against the +// tenant's quota — which is what this method exists to avoid. +// Return the documented sentinel rather than pretend. +func (v *VoyageModel) CheckConnection(apiConfig *APIConfig) error { + return fmt.Errorf("%s, no such method", v.Name()) +} + +// ChatWithMessages is not exposed by the Voyage AI API. +func (v *VoyageModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", v.Name()) +} + +// Balance is not exposed by the Voyage AI API. +func (v *VoyageModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +// TranscribeAudio / AudioSpeech / OCRFile: Voyage does not host any of +// these surfaces. +func (v *VoyageModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", v.Name()) +} + +func (v *VoyageModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", v.Name()) +} + +// ParseFile parse file +func (z *VoyageModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *VoyageModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *VoyageModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/voyage_test.go b/internal/entity/models/voyage_test.go new file mode 100644 index 00000000000..255915bf98a --- /dev/null +++ b/internal/entity/models/voyage_test.go @@ -0,0 +1,399 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func newVoyageServer(t *testing.T, expectedPath string, handler func(t *testing.T, body map[string]interface{}, w http.ResponseWriter)) *httptest.Server { + t.Helper() + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != expectedPath { + t.Errorf("expected path=%s, got %s", expectedPath, r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer test-key" { + t.Errorf("expected Authorization=Bearer test-key, got %q", got) + return + } + if got := r.Header.Get("Content-Type"); got != "application/json" { + t.Errorf("expected Content-Type=application/json, got %q", got) + return + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + var body map[string]interface{} + if err := json.Unmarshal(raw, &body); err != nil { + t.Errorf("unmarshal: %v\nraw=%s", err, string(raw)) + return + } + handler(t, body, w) + })) +} + +func newVoyageForTest(baseURL string) *VoyageModel { + return NewVoyageModel( + map[string]string{"default": baseURL}, + URLSuffix{Embedding: "v1/embeddings", Rerank: "v1/rerank"}, + ) +} + +func TestVoyageName(t *testing.T) { + if got := newVoyageForTest("http://unused").Name(); got != "voyage" { + t.Errorf("Name()=%q, want %q", got, "voyage") + } +} + +func TestVoyageEmbedHappyPath(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["model"] != "voyage-3.5" { + t.Errorf("model=%v", body["model"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "object": "list", + "data": []map[string]interface{}{ + {"object": "embedding", "embedding": []float64{0.1, 0.2}, "index": 0}, + {"object": "embedding", "embedding": []float64{0.3, 0.4}, "index": 1}, + {"object": "embedding", "embedding": []float64{0.5, 0.6}, "index": 2}, + }, + "model": "voyage-3.5", + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + vecs, err := v.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + if len(vecs) != 3 { + t.Fatalf("len=%d want 3", len(vecs)) + } + if vecs[1].Embedding[0] != 0.3 || vecs[1].Index != 1 { + t.Errorf("vecs[1]=%+v", vecs[1]) + } +} + +// TestVoyageEmbedPropagatesOutputDimension pins the docs-spelled +// param name. Voyage 400s on any other key (live-verified — sending +// "dimensions" returns "Argument 'dimensions' is not supported by our +// API"), so this name matters and must not regress. +func TestVoyageEmbedPropagatesOutputDimension(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if got, ok := body["output_dimension"].(float64); !ok || got != 256 { + t.Errorf("output_dimension=%v want 256", body["output_dimension"]) + } + for _, wrong := range []string{"dimensions", "output_dimensions", "dimension"} { + if _, present := body[wrong]; present { + t.Errorf("must not send %q (Voyage rejects unknown fields)", wrong) + } + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{{"embedding": []float64{0.1}, "index": 0}}, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"x"}, &APIConfig{ApiKey: &apiKey}, + &EmbeddingConfig{Dimension: 256}) + if err != nil { + t.Fatalf("Embed: %v", err) + } +} + +// And when Dimension is zero/unset, the field MUST be absent — Voyage +// would default the vector length, but only if we don't send the key +// at all (sending output_dimension: 0 is a 400). +func TestVoyageEmbedOmitsOutputDimensionWhenUnset(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if _, present := body["output_dimension"]; present { + t.Errorf("output_dimension must be absent when Dimension is unset, got %v", body["output_dimension"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{{"embedding": []float64{0.1}, "index": 0}}, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } +} + +func TestVoyageEmbedReordersByIndex(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{2}, "index": 2}, + {"embedding": []float64{0}, "index": 0}, + {"embedding": []float64{1}, "index": 1}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + vecs, err := v.Embed(&model, []string{"a", "b", "c"}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil { + t.Fatalf("Embed: %v", err) + } + for i, vec := range vecs { + if vec.Index != i || vec.Embedding[0] != float64(i) { + t.Errorf("slot %d=%+v", i, vec) + } + } +} + +func TestVoyageEmbedEmptyInputShortCircuits(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + t.Error("Embed([]) made an unexpected HTTP call") + })) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + vecs, err := v.Embed(&model, []string{}, &APIConfig{ApiKey: &apiKey}, nil) + if err != nil || len(vecs) != 0 { + t.Errorf("Embed([])=(%v,%v)", vecs, err) + } +} + +func TestVoyageEmbedRequiresAPIKey(t *testing.T) { + v := newVoyageForTest("http://unused") + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"a"}, &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "api key is required") { + t.Errorf("expected api-key error, got %v", err) + } +} + +func TestVoyageEmbedRequiresModelName(t *testing.T) { + v := newVoyageForTest("http://unused") + apiKey := "test-key" + _, err := v.Embed(nil, []string{"a"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "model name is required") { + t.Errorf("expected model-name error, got %v", err) + } +} + +func TestVoyageEmbedRejectsDuplicateIndex(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + {"embedding": []float64{2}, "index": 0}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "duplicate embedding index 0") { + t.Errorf("expected duplicate error, got %v", err) + } +} + +func TestVoyageEmbedRejectsOutOfRangeIndex(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 7}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestVoyageEmbedRejectsMissingSlot(t *testing.T) { + srv := newVoyageServer(t, "/v1/embeddings", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"embedding": []float64{1}, "index": 0}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "voyage-3.5" + _, err := v.Embed(&model, []string{"a", "b"}, &APIConfig{ApiKey: &apiKey}, nil) + if err == nil || !strings.Contains(err.Error(), "missing embedding for input index 1") { + t.Errorf("expected missing-slot error, got %v", err) + } +} + +func TestVoyageRerankHappyPath(t *testing.T) { + srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + // Voyage's request key is top_k (not top_n). + if body["top_k"] != float64(3) { + t.Errorf("top_k=%v want 3", body["top_k"]) + } + if body["query"] != "x" { + t.Errorf("query=%v", body["query"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "object": "list", + "data": []map[string]interface{}{ + {"relevance_score": 0.8, "index": 2}, + {"relevance_score": 0.5, "index": 0}, + {"relevance_score": 0.3, "index": 1}, + }, + "model": "rerank-2", + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "rerank-2" + resp, err := v.Rerank(&model, "x", []string{"a", "b", "c"}, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 3}) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 3 { + t.Fatalf("len=%d want 3", len(resp.Data)) + } + want := map[int]float64{0: 0.5, 1: 0.3, 2: 0.8} + for _, r := range resp.Data { + if got, ok := want[r.Index]; !ok || got != r.RelevanceScore { + t.Errorf("unexpected result index=%d score=%v", r.Index, r.RelevanceScore) + } + } +} + +func TestVoyageRerankTopKDefaultsToLenDocuments(t *testing.T) { + srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, body map[string]interface{}, w http.ResponseWriter) { + if body["top_k"] != float64(4) { + t.Errorf("top_k=%v want 4 (len(documents))", body["top_k"]) + } + _ = json.NewEncoder(w).Encode(map[string]interface{}{"data": []map[string]interface{}{}}) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "rerank-2" + _, err := v.Rerank(&model, "x", []string{"a", "b", "c", "d"}, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{}) + if err != nil { + t.Fatalf("Rerank: %v", err) + } +} + +func TestVoyageRerankEmptyDocuments(t *testing.T) { + v := newVoyageForTest("http://unused") + apiKey := "test-key" + model := "rerank-2" + resp, err := v.Rerank(&model, "x", nil, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 0}) + if err != nil { + t.Fatalf("Rerank: %v", err) + } + if len(resp.Data) != 0 { + t.Errorf("expected empty Data, got %d", len(resp.Data)) + } +} + +func TestVoyageRerankRejectsOutOfRangeIndex(t *testing.T) { + srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"relevance_score": 0.9, "index": 7}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "rerank-2" + _, err := v.Rerank(&model, "x", []string{"a", "b"}, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 2}) + if err == nil || !strings.Contains(err.Error(), "out of range") { + t.Errorf("expected out-of-range error, got %v", err) + } +} + +func TestVoyageRerankRejectsDuplicateIndex(t *testing.T) { + // A duplicate index would silently overwrite an earlier slot, which + // is the same failure mode Embed already guards against. Make sure + // Rerank fails loudly too. + srv := newVoyageServer(t, "/v1/rerank", func(t *testing.T, _ map[string]interface{}, w http.ResponseWriter) { + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{ + {"relevance_score": 0.9, "index": 0}, + {"relevance_score": 0.8, "index": 0}, + }, + }) + }) + defer srv.Close() + + v := newVoyageForTest(srv.URL) + apiKey := "test-key" + model := "rerank-2" + _, err := v.Rerank(&model, "x", []string{"a", "b"}, + &APIConfig{ApiKey: &apiKey}, &RerankConfig{TopN: 2}) + if err == nil || !strings.Contains(err.Error(), "duplicate rerank index 0") { + t.Errorf("expected duplicate-index error, got %v", err) + } +} + +// TestVoyageEmbedTrimsTrailingSlashInBaseURL guards against a +// misconfigured baseURL ending in "/" producing a double-slash path +// (e.g. `.../v1//embeddings`). Rerank already trims, so Embed must +// trim too; CodeRabbit flagged the inconsistency. +func TestVoyageEmbedTrimsTrailingSlashInBaseURL(t *testing.T) { + var sawPath string + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sawPath = r.URL.Path + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "data": []map[string]interface{}{{"embedding": []float64{1}, "index": 0}}, + }) + })) + defer srv.Close() + + v := NewVoyageModel( + map[string]string{"default": srv.URL + "/"}, // trailing slash + URLSuffix{Embedding: "v1/embeddings", Rerank: "v1/rerank"}, + ) + apiKey := "test-key" + model := "voyage-3.5" + if _, err := v.Embed(&model, []string{"x"}, &APIConfig{ApiKey: &apiKey}, nil); err != nil { + t.Fatalf("Embed: %v", err) + } + if sawPath != "/v1/embeddings" { + t.Errorf("path=%q want %q (no double slash)", sawPath, "/v1/embeddings") + } +} diff --git a/internal/entity/models/xai.go b/internal/entity/models/xai.go index 96617320cf9..b19f93ca7dc 100644 --- a/internal/entity/models/xai.go +++ b/internal/entity/models/xai.go @@ -397,9 +397,9 @@ func (z *XAIModel) ChatStreamlyWithSender(modelName string, messages []Message, return nil } -// Encode encodes a list of texts into embeddings. xAI does not expose a +// Embed embeds a list of texts into embeddings. xAI does not expose a // public embedding API yet, so this is left unimplemented. -func (z *XAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +func (z *XAIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { return nil, fmt.Errorf("not implemented") } @@ -492,3 +492,39 @@ func (z *XAIModel) CheckConnection(apiConfig *APIConfig) error { func (z *XAIModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { return nil, fmt.Errorf("%s, Rerank not implemented", z.Name()) } + +// TranscribeAudio transcribe audio +func (o *XAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *XAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (o *XAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *XAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *XAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *XAIModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *XAIModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *XAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/models/xinference.go b/internal/entity/models/xinference.go new file mode 100644 index 00000000000..d8f8fa39f54 --- /dev/null +++ b/internal/entity/models/xinference.go @@ -0,0 +1,476 @@ +// +// Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package models + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "time" +) + +// xinferenceStreamIdleTimeout bounds how long a stream can go without +// receiving any SSE line. Self-hosted models can be slow, but a stream +// that stays silent for a full minute is more useful as a surfaced error +// than as a stuck goroutine. +var xinferenceStreamIdleTimeout = 60 * time.Second + +// XinferenceModel implements ModelDriver for Xinference chat models. +// +// Xinference exposes an OpenAI-compatible API under /v1. The +// tenant may configure either the root endpoint (http://127.0.0.1:9997) +// or the OpenAI-compatible endpoint (http://127.0.0.1:9997/v1); the +// driver normalizes both to the root endpoint before adding URLSuffix +// values that match Xinference docs, such as v1/chat/completions. +// Authentication is optional: no-auth deployments ignore API keys, while +// auth-enabled deployments require Authorization: Bearer . +type XinferenceModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +type xinferenceChatChoice struct { + Message struct { + Content string `json:"content"` + ReasoningContent string `json:"reasoning_content"` + Reasoning string `json:"reasoning"` + Thinking string `json:"thinking"` + } `json:"message"` +} + +type xinferenceChatResponse struct { + Choices []xinferenceChatChoice `json:"choices"` +} + +type xinferenceModelListResponse struct { + Data []struct { + ID string `json:"id"` + } `json:"data"` +} + +// NewXinferenceModel creates a new Xinference model instance. +func NewXinferenceModel(baseURL map[string]string, urlSuffix URLSuffix) *XinferenceModel { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.MaxIdleConns = 100 + transport.MaxIdleConnsPerHost = 10 + transport.IdleConnTimeout = 90 * time.Second + transport.DisableCompression = false + transport.ResponseHeaderTimeout = 60 * time.Second + + return &XinferenceModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Transport: transport, + }, + } +} + +func (x *XinferenceModel) NewInstance(baseURL map[string]string) ModelDriver { + return NewXinferenceModel(baseURL, x.URLSuffix) +} + +func (x *XinferenceModel) Name() string { + return "xinference" +} + +func (x *XinferenceModel) baseURLForRegion(region string) (string, error) { + if base, ok := x.BaseURL[region]; ok && strings.TrimSpace(base) != "" { + return normalizeXinferenceBaseURL(base), nil + } + if base, ok := x.BaseURL["default"]; ok && strings.TrimSpace(base) != "" { + return normalizeXinferenceBaseURL(base), nil + } + return "", fmt.Errorf("xinference: missing base URL, configure the Xinference endpoint (e.g., http://127.0.0.1:9997 or http://127.0.0.1:9997/v1)") +} + +func normalizeXinferenceBaseURL(base string) string { + trimmed := strings.TrimRight(strings.TrimSpace(base), "/") + if trimmed == "" { + return trimmed + } + if strings.HasSuffix(trimmed, "/v1") { + return strings.TrimSuffix(trimmed, "/v1") + } + return trimmed +} + +func setXinferenceAuth(req *http.Request, apiConfig *APIConfig) { + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return + } + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) +} + +func xinferenceRegion(apiConfig *APIConfig) string { + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + return *apiConfig.Region + } + return "default" +} + +func xinferenceReasoningFromStrings(reasoningContent string, reasoning string, thinking string) string { + switch { + case reasoningContent != "": + return reasoningContent + case reasoning != "": + return reasoning + case thinking != "": + return thinking + default: + return "" + } +} + +func xinferenceReasoningFromMap(value map[string]interface{}) string { + for _, field := range []string{"reasoning_content", "reasoning", "thinking"} { + if text, ok := value[field].(string); ok && text != "" { + return text + } + } + return "" +} + +func buildXinferenceChatBody(modelName string, messages []Message, stream bool, chatModelConfig *ChatConfig) map[string]interface{} { + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": stream, + } + + if chatModelConfig != nil { + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + if chatModelConfig.Stop != nil { + reqBody["stop"] = *chatModelConfig.Stop + } + } + + return reqBody +} + +// ChatWithMessages sends multiple messages with roles and returns the response. +func (x *XinferenceModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + baseURL, err := x.baseURLForRegion(xinferenceRegion(apiConfig)) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, x.URLSuffix.Chat) + + reqBody := buildXinferenceChatBody(modelName, messages, false, chatModelConfig) + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + setXinferenceAuth(req, apiConfig) + + resp, err := x.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result xinferenceChatResponse + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + if len(result.Choices) == 0 { + return nil, fmt.Errorf("no choices in response") + } + + content := result.Choices[0].Message.Content + reasonContent := xinferenceReasoningFromStrings( + result.Choices[0].Message.ReasoningContent, + result.Choices[0].Message.Reasoning, + result.Choices[0].Message.Thinking, + ) + + return &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + }, nil +} + +// ChatStreamlyWithSender sends messages and streams response via sender. +func (x *XinferenceModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig, sender func(*string, *string) error) error { + if sender == nil { + return fmt.Errorf("sender is required") + } + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + if chatModelConfig != nil && chatModelConfig.Stream != nil && !*chatModelConfig.Stream { + return fmt.Errorf("stream must be true in ChatStreamlyWithSender") + } + + baseURL, err := x.baseURLForRegion(xinferenceRegion(apiConfig)) + if err != nil { + return err + } + url := fmt.Sprintf("%s/%s", baseURL, x.URLSuffix.Chat) + + reqBody := buildXinferenceChatBody(modelName, messages, true, chatModelConfig) + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + req.Header.Set("Content-Type", "application/json") + setXinferenceAuth(req, apiConfig) + + resp, err := x.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + lastActive := time.Now() + var lastActiveMu sync.Mutex + done := make(chan struct{}) + defer close(done) + go func() { + ticker := time.NewTicker(xinferenceStreamIdleTimeout / 4) + defer ticker.Stop() + for { + select { + case <-done: + return + case now := <-ticker.C: + lastActiveMu.Lock() + idle := now.Sub(lastActive) + lastActiveMu.Unlock() + if idle >= xinferenceStreamIdleTimeout { + cancel() + return + } + } + } + }() + + scanner := bufio.NewScanner(resp.Body) + scanner.Buffer(make([]byte, 64*1024), 1024*1024) + sawTerminal := false + for scanner.Scan() { + lastActiveMu.Lock() + lastActive = time.Now() + lastActiveMu.Unlock() + + line := scanner.Text() + if !strings.HasPrefix(line, "data:") { + continue + } + data := strings.TrimSpace(line[5:]) + if data == "[DONE]" { + sawTerminal = true + break + } + + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + if delta, ok := firstChoice["delta"].(map[string]interface{}); ok { + if reasoning := xinferenceReasoningFromMap(delta); reasoning != "" { + if err := sender(nil, &reasoning); err != nil { + return err + } + } + if content, ok := delta["content"].(string); ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + } + + if finishReason, ok := firstChoice["finish_reason"].(string); ok && finishReason != "" { + sawTerminal = true + break + } + } + + if err := scanner.Err(); err != nil { + if ctx.Err() != nil { + return fmt.Errorf("xinference: stream idle for more than %s, aborted", xinferenceStreamIdleTimeout) + } + return fmt.Errorf("failed to scan response body: %w", err) + } + if !sawTerminal { + return fmt.Errorf("xinference: stream ended before [DONE] or finish_reason") + } + + endOfStream := "[DONE]" + return sender(&endOfStream, nil) +} + +func (x *XinferenceModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + return nil, fmt.Errorf("%s, no such method", x.Name()) +} + +func (x *XinferenceModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + return nil, fmt.Errorf("%s, no such method", x.Name()) +} + +func (x *XinferenceModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", x.Name()) +} + +func (x *XinferenceModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", x.Name()) +} + +func (x *XinferenceModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, asrConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", x.Name()) +} + +func (x *XinferenceModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", x.Name()) +} + +func (x *XinferenceModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", x.Name()) +} + +func (x *XinferenceModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", x.Name()) +} + +// ListModels returns the model IDs exposed by Xinference's OpenAI-compatible +// /v1/models endpoint. +func (x *XinferenceModel) ListModels(apiConfig *APIConfig) ([]string, error) { + baseURL, err := x.baseURLForRegion(xinferenceRegion(apiConfig)) + if err != nil { + return nil, err + } + url := fmt.Sprintf("%s/%s", baseURL, x.URLSuffix.Models) + + ctx, cancel := context.WithTimeout(context.Background(), nonStreamCallTimeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + setXinferenceAuth(req, apiConfig) + + resp, err := x.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } + + var result xinferenceModelListResponse + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models := make([]string, 0, len(result.Data)) + for _, model := range result.Data { + if model.ID != "" { + models = append(models, model.ID) + } + } + return models, nil +} + +func (x *XinferenceModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + return nil, fmt.Errorf("%s, no such method", x.Name()) +} + +func (x *XinferenceModel) CheckConnection(apiConfig *APIConfig) error { + _, err := x.ListModels(apiConfig) + return err +} + +func (x *XinferenceModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", x.Name()) +} + +func (x *XinferenceModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", x.Name()) +} diff --git a/internal/entity/models/xinference_test.go b/internal/entity/models/xinference_test.go new file mode 100644 index 00000000000..af3179ea0ed --- /dev/null +++ b/internal/entity/models/xinference_test.go @@ -0,0 +1,313 @@ +package models + +import ( + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func newXinferenceForTest(baseURL string) *XinferenceModel { + return NewXinferenceModel( + map[string]string{"default": baseURL}, + URLSuffix{ + Chat: "v1/chat/completions", + Models: "v1/models", + }, + ) +} + +func withXinferenceIdleTimeout(t *testing.T, d time.Duration) { + t.Helper() + original := xinferenceStreamIdleTimeout + xinferenceStreamIdleTimeout = d + t.Cleanup(func() { + xinferenceStreamIdleTimeout = original + }) +} + +func TestXinferenceName(t *testing.T) { + x := newXinferenceForTest("http://unused") + if got := x.Name(); got != "xinference" { + t.Errorf("Name()=%q, want %q", got, "xinference") + } +} + +func TestNormalizeXinferenceBaseURL(t *testing.T) { + cases := []struct { + in string + want string + }{ + {"http://127.0.0.1:9997", "http://127.0.0.1:9997"}, + {"http://127.0.0.1:9997/", "http://127.0.0.1:9997"}, + {"http://127.0.0.1:9997/v1", "http://127.0.0.1:9997"}, + {" http://127.0.0.1:9997/v1/ ", "http://127.0.0.1:9997"}, + } + for _, tc := range cases { + if got := normalizeXinferenceBaseURL(tc.in); got != tc.want { + t.Errorf("normalizeXinferenceBaseURL(%q)=%q, want %q", tc.in, got, tc.want) + } + } +} + +func TestXinferenceFactoryRoute(t *testing.T) { + driver, err := NewModelFactory().CreateModelDriver("xinference", map[string]string{"default": "http://unused"}, URLSuffix{}) + if err != nil { + t.Fatalf("CreateModelDriver: %v", err) + } + if driver.Name() != "xinference" { + t.Errorf("driver.Name()=%q, want xinference", driver.Name()) + } +} + +func TestXinferenceChatHappyPathNormalizesBaseURLAndOmitsEmptyAuth(t *testing.T) { + var seen map[string]interface{} + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("path=%s, want /v1/chat/completions", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "" { + t.Errorf("expected no Authorization header, got %q", got) + } + raw, err := io.ReadAll(r.Body) + if err != nil { + t.Errorf("read body: %v", err) + return + } + if err := json.Unmarshal(raw, &seen); err != nil { + t.Errorf("unmarshal request: %v", err) + return + } + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"pong"}}]}`) + })) + defer srv.Close() + + x := newXinferenceForTest(srv.URL) + maxTokens := 32 + temp := 0.2 + resp, err := x.ChatWithMessages("qwen2.5-instruct", + []Message{{Role: "user", Content: "ping"}}, + &APIConfig{}, + &ChatConfig{MaxTokens: &maxTokens, Temperature: &temp}) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.Answer == nil || *resp.Answer != "pong" { + t.Fatalf("Answer=%v, want pong", resp.Answer) + } + if seen["stream"] != false { + t.Errorf("stream=%v, want false", seen["stream"]) + } + if seen["max_tokens"] != float64(32) { + t.Errorf("max_tokens=%v, want 32", seen["max_tokens"]) + } + if seen["temperature"] != 0.2 { + t.Errorf("temperature=%v, want 0.2", seen["temperature"]) + } +} + +func TestXinferenceChatSendsAuthHeaderWhenKeyProvided(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if got := r.Header.Get("Authorization"); got != "Bearer sk-test" { + t.Errorf("Authorization=%q, want Bearer sk-test", got) + } + _, _ = io.WriteString(w, `{"choices":[{"message":{"content":"ok"}}]}`) + })) + defer srv.Close() + + x := newXinferenceForTest(srv.URL + "/v1") + key := "sk-test" + _, err := x.ChatWithMessages("qwen2.5-instruct", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{ApiKey: &key}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } +} + +func TestXinferenceChatExtractsReasoningFields(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = io.WriteString(w, `{"choices":[{"message":{ + "content":"12", + "reasoning_content":"0.15 * 80 = 12" + }}]}`) + })) + defer srv.Close() + + x := newXinferenceForTest(srv.URL) + resp, err := x.ChatWithMessages("qwen3", + []Message{{Role: "user", Content: "15% of 80?"}}, + &APIConfig{}, nil) + if err != nil { + t.Fatalf("ChatWithMessages: %v", err) + } + if resp.ReasonContent == nil || *resp.ReasonContent != "0.15 * 80 = 12" { + t.Errorf("ReasonContent=%v", resp.ReasonContent) + } +} + +func TestXinferenceStreamHappyPath(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/chat/completions" { + t.Errorf("path=%s", r.URL.Path) + } + var seen map[string]interface{} + raw, _ := io.ReadAll(r.Body) + _ = json.Unmarshal(raw, &seen) + if seen["stream"] != true { + t.Errorf("stream=%v, want true", seen["stream"]) + } + w.Header().Set("Content-Type", "text/event-stream") + _, _ = io.WriteString(w, + `data: {"choices":[{"delta":{"reasoning_content":"step. "}}]}`+"\n"+ + `data: {"choices":[{"delta":{"content":"Hello"}}]}`+"\n"+ + `data: {"choices":[{"delta":{"content":" world"},"finish_reason":"stop"}]}`+"\n"+ + `data: [DONE]`+"\n", + ) + })) + defer srv.Close() + + x := newXinferenceForTest(srv.URL) + var content []string + var reasoning []string + var sawDone bool + err := x.ChatStreamlyWithSender("qwen2.5-instruct", + []Message{{Role: "user", Content: "hi"}}, + &APIConfig{}, nil, + func(c *string, r *string) error { + if r != nil && *r != "" { + reasoning = append(reasoning, *r) + } + if c != nil && *c == "[DONE]" { + sawDone = true + } + if c != nil && *c != "" && *c != "[DONE]" { + content = append(content, *c) + } + return nil + }) + if err != nil { + t.Fatalf("ChatStreamlyWithSender: %v", err) + } + if strings.Join(reasoning, "") != "step. " { + t.Errorf("reasoning=%q", strings.Join(reasoning, "")) + } + if strings.Join(content, "") != "Hello world" { + t.Errorf("content=%q", strings.Join(content, "")) + } + if !sawDone { + t.Error("expected [DONE] callback") + } +} + +func TestXinferenceStreamRejectsFalseStreamConfig(t *testing.T) { + x := newXinferenceForTest("http://unused") + stream := false + err := x.ChatStreamlyWithSender("qwen2.5-instruct", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, + &ChatConfig{Stream: &stream}, + func(*string, *string) error { return nil }) + if err == nil || !strings.Contains(err.Error(), "stream must be true") { + t.Errorf("expected stream-must-be-true error, got %v", err) + } +} + +func TestXinferenceStreamCancelsOnIdle(t *testing.T) { + withXinferenceIdleTimeout(t, 200*time.Millisecond) + + hold := make(chan struct{}) + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/event-stream") + w.WriteHeader(http.StatusOK) + if f, ok := w.(http.Flusher); ok { + _, _ = io.WriteString(w, `data: {"choices":[{"delta":{"content":"hi"}}]}`+"\n") + f.Flush() + } + select { + case <-hold: + case <-r.Context().Done(): + } + })) + t.Cleanup(srv.Close) + t.Cleanup(func() { close(hold) }) + + x := newXinferenceForTest(srv.URL) + err := x.ChatStreamlyWithSender("qwen2.5-instruct", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil, + func(*string, *string) error { return nil }) + if err == nil || !strings.Contains(err.Error(), "stream idle") { + t.Errorf("expected stream-idle error, got %v", err) + } +} + +func TestXinferenceListModelsAndCheckConnection(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/v1/models" { + t.Errorf("path=%s, want /v1/models", r.URL.Path) + } + if got := r.Header.Get("Authorization"); got != "Bearer sk-test" { + t.Errorf("Authorization=%q, want Bearer sk-test", got) + } + _, _ = io.WriteString(w, `{"object":"list","data":[{"id":"qwen2.5-instruct"},{"id":"custom-chat"}]}`) + })) + defer srv.Close() + + x := newXinferenceForTest(srv.URL) + key := "sk-test" + apiConfig := &APIConfig{ApiKey: &key} + models, err := x.ListModels(apiConfig) + if err != nil { + t.Fatalf("ListModels: %v", err) + } + if strings.Join(models, ",") != "qwen2.5-instruct,custom-chat" { + t.Errorf("models=%v", models) + } + if err := x.CheckConnection(apiConfig); err != nil { + t.Fatalf("CheckConnection: %v", err) + } +} + +func TestXinferenceMissingBaseURLFailsClearly(t *testing.T) { + x := NewXinferenceModel(map[string]string{}, URLSuffix{Chat: "v1/chat/completions"}) + _, err := x.ChatWithMessages("qwen2.5-instruct", + []Message{{Role: "user", Content: "x"}}, + &APIConfig{}, nil) + if err == nil || !strings.Contains(err.Error(), "missing base URL") { + t.Errorf("expected missing-base-URL error, got %v", err) + } +} + +func TestXinferenceUnsupportedMethodsReturnNoSuchMethod(t *testing.T) { + x := newXinferenceForTest("http://unused") + model := "qwen2.5-instruct" + + if _, err := x.Embed(&model, []string{"x"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Embed: expected no such method, got %v", err) + } + if _, err := x.Rerank(&model, "q", []string{"d"}, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Rerank: expected no such method, got %v", err) + } + if _, err := x.Balance(&APIConfig{}); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("Balance: expected no such method, got %v", err) + } + if _, err := x.TranscribeAudio(&model, nil, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("TranscribeAudio: expected no such method, got %v", err) + } + if err := x.TranscribeAudioWithSender(&model, nil, &APIConfig{}, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("TranscribeAudioWithSender: expected no such method, got %v", err) + } + if _, err := x.AudioSpeech(&model, nil, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("AudioSpeech: expected no such method, got %v", err) + } + if err := x.AudioSpeechWithSender(&model, nil, &APIConfig{}, nil, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("AudioSpeechWithSender: expected no such method, got %v", err) + } + if _, err := x.OCRFile(&model, nil, nil, &APIConfig{}, nil); err == nil || !strings.Contains(err.Error(), "no such method") { + t.Errorf("OCRFile: expected no such method, got %v", err) + } +} diff --git a/internal/entity/models/xunfei.go b/internal/entity/models/xunfei.go new file mode 100644 index 00000000000..e1b89b535d6 --- /dev/null +++ b/internal/entity/models/xunfei.go @@ -0,0 +1,452 @@ +package models + +import ( + "bufio" + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "ragflow/internal/common" + "strings" + "time" +) + +type XunFeiModel struct { + BaseURL map[string]string + URLSuffix URLSuffix + httpClient *http.Client +} + +func NewXunFeiModel(baseURL map[string]string, urlSuffix URLSuffix) *XunFeiModel { + return &XunFeiModel{ + BaseURL: baseURL, + URLSuffix: urlSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (x *XunFeiModel) NewInstance(baseURL map[string]string) ModelDriver { + return &XunFeiModel{ + BaseURL: baseURL, + URLSuffix: x.URLSuffix, + httpClient: &http.Client{ + Timeout: time.Second * 120, + Transport: &http.Transport{ + MaxIdleConns: 10, + MaxIdleConnsPerHost: 100, + IdleConnTimeout: time.Second * 90, + DisableCompression: false, + }, + }, + } +} + +func (x *XunFeiModel) Name() string { + return "xunfei" +} + +func (x *XunFeiModel) ChatWithMessages(modelName string, messages []Message, apiConfig *APIConfig, chatModelConfig *ChatConfig) (*ChatResponse, error) { + if len(messages) == 0 { + return nil, fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", x.BaseURL[region], x.URLSuffix.Chat) + + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": false, + "temperature": 1, + } + + if chatModelConfig != nil { + if chatModelConfig.Temperature != nil { + reqBody["temperature"] = *chatModelConfig.Temperature + } + + if chatModelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *chatModelConfig.MaxTokens + } + + if chatModelConfig.Stream != nil { + reqBody["stream"] = *chatModelConfig.Stream + } + + if chatModelConfig.TopP != nil { + reqBody["top_p"] = *chatModelConfig.TopP + } + + if chatModelConfig.Thinking != nil { + if *chatModelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request body: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := x.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response body: %w", err) + } + + // Parse Response + var result map[string]interface{} + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to unmarshal response body: %w", err) + } + + choices, ok := result["choices"].([]interface{}) + if !ok { + return nil, fmt.Errorf("no choices in response") + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("no choices in response") + } + + messageMap, ok := firstChoice["message"].(map[string]interface{}) + if !ok { + return nil, fmt.Errorf("no message in response") + } + + content, ok := messageMap["content"].(string) + if !ok { + return nil, fmt.Errorf("no message in response") + } + + var reasonContent string + if chatModelConfig != nil && chatModelConfig.Thinking != nil && *chatModelConfig.Thinking { + reasonContent, ok = messageMap["reasoning_content"].(string) + if !ok { + return nil, fmt.Errorf("invalid content format") + } + if reasonContent != "" && reasonContent[0] == '\n' { + reasonContent = reasonContent[1:] + } + } + + chatResponse := &ChatResponse{ + Answer: &content, + ReasonContent: &reasonContent, + } + + return chatResponse, nil +} + +func (x *XunFeiModel) ChatStreamlyWithSender(modelName string, messages []Message, apiConfig *APIConfig, modelConfig *ChatConfig, sender func(*string, *string) error) error { + if len(messages) == 0 { + return fmt.Errorf("messages is empty") + } + + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", x.BaseURL[region], x.URLSuffix.Chat) + + // Convert messages to API format + apiMessages := make([]map[string]interface{}, len(messages)) + for i, msg := range messages { + apiMessages[i] = map[string]interface{}{ + "role": msg.Role, + "content": msg.Content, + } + } + + reqBody := map[string]interface{}{ + "model": modelName, + "messages": apiMessages, + "stream": true, + "temperature": 1, + } + + if modelConfig != nil { + if modelConfig.Stream != nil { + reqBody["stream"] = *modelConfig.Stream + } + + if modelConfig.MaxTokens != nil { + reqBody["max_tokens"] = *modelConfig.MaxTokens + } + + if modelConfig.Temperature != nil { + reqBody["temperature"] = *modelConfig.Temperature + } + + if modelConfig.TopP != nil { + reqBody["top_p"] = *modelConfig.TopP + } + + if modelConfig.Stop != nil { + reqBody["stop"] = *modelConfig.Stop + } + + if modelConfig.Thinking != nil { + if *modelConfig.Thinking { + reqBody["thinking"] = map[string]interface{}{ + "type": "enabled", + } + } else { + reqBody["thinking"] = map[string]interface{}{ + "type": "disabled", + } + } + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := x.httpClient.Do(req) + if err != nil { + return fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return fmt.Errorf("invalid status code: %d, body: %s", resp.StatusCode, string(body)) + } + + // SSE parsing: read line by line + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + common.Info(line) + + // SSE data line starts with "data:" + if !strings.HasPrefix(line, "data:") { + continue + } + + // Extract JSON after "data:" + data := strings.TrimSpace(line[5:]) + + // [DONE] marks the end of stream + if data == "[DONE]" { + break + } + + // Parse the JSON event + var event map[string]interface{} + if err = json.Unmarshal([]byte(data), &event); err != nil { + continue + } + + choices, ok := event["choices"].([]interface{}) + if !ok || len(choices) == 0 { + continue + } + + firstChoice, ok := choices[0].(map[string]interface{}) + if !ok { + continue + } + + delta, ok := firstChoice["delta"].(map[string]interface{}) + if !ok { + continue + } + + reasoningContent, ok := delta["reasoning_content"].(string) + if ok && reasoningContent != "" { + if err := sender(nil, &reasoningContent); err != nil { + return err + } + } + + content, ok := delta["content"].(string) + if ok && content != "" { + if err := sender(&content, nil); err != nil { + return err + } + } + + finishReason, ok := firstChoice["finish_reason"].(string) + if ok && finishReason != "" { + break + } + } + + // Send [DONE] marker for OpenAI compatibility + endOfStream := "[DONE]" + if err = sender(&endOfStream, nil); err != nil { + return err + } + + return scanner.Err() +} + +func (x *XunFeiModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) Rerank(modelName *string, query string, documents []string, apiConfig *APIConfig, rerankConfig *RerankConfig) (*RerankResponse, error) { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) ListModels(apiConfig *APIConfig) ([]string, error) { + var region = "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } + + url := fmt.Sprintf("%s/%s", x.BaseURL[region], x.URLSuffix.Models) + + // Build request body + reqBody := map[string]interface{}{} + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("GET", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + + resp, err := x.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s : %s", resp.StatusCode, string(body)) + } + + // Parse response + var result map[string]interface{} + if err = json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + // convert result["data"] to []map[string]interface{} + models := make([]string, 0) + for _, model := range result["data"].([]interface{}) { + modelMap := model.(map[string]interface{}) + modelName := modelMap["id"].(string) + models = append(models, modelName) + } + + return models, nil +} + +func (x *XunFeiModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) CheckConnection(apiConfig *APIConfig) error { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + //TODO implement me + panic("implement me") +} + +func (x *XunFeiModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + //TODO implement me + panic("implement me") +} diff --git a/internal/entity/models/zhipu-ai.go b/internal/entity/models/zhipu-ai.go index 98bd5a7a52e..d90d63559be 100644 --- a/internal/entity/models/zhipu-ai.go +++ b/internal/entity/models/zhipu-ai.go @@ -157,7 +157,7 @@ func (z *ZhipuAIModel) ChatWithMessages(modelName string, messages []Message, ap // Parse response var result map[string]interface{} - if err := json.Unmarshal(body, &result); err != nil { + if err = json.Unmarshal(body, &result); err != nil { return nil, fmt.Errorf("failed to parse response: %w", err) } @@ -362,95 +362,157 @@ func (z *ZhipuAIModel) ChatStreamlyWithSender(modelName string, messages []Messa return scanner.Err() } +type zhipuEmbeddingResponse struct { + Data []zhipuEmbeddingData `json:"data"` + Model string `json:"model"` + Object string `json:"object"` + Usage zhipuUsage `json:"usage"` +} + +type zhipuEmbeddingData struct { + Embedding []float64 `json:"embedding"` + Index int `json:"index"` + Object string `json:"object"` +} + +type zhipuUsage struct { + CompletionTokens int `json:"completion_tokens"` + PromptTokens int `json:"prompt_tokens"` + TotalTokens int `json:"total_tokens"` +} + // Encode encodes a list of texts into embeddings -func (z *ZhipuAIModel) Encode(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([][]float64, error) { +func (z *ZhipuAIModel) Embed(modelName *string, texts []string, apiConfig *APIConfig, embeddingConfig *EmbeddingConfig) ([]EmbeddingData, error) { + if len(texts) == 0 { + return []EmbeddingData{}, nil + } + + if apiConfig == nil || apiConfig.ApiKey == nil || *apiConfig.ApiKey == "" { + return nil, fmt.Errorf("api key is required") + } + + if modelName == nil || *modelName == "" { + return nil, fmt.Errorf("model name is required") + } + var region = "default" - if apiConfig.Region != nil { + if apiConfig.Region != nil && *apiConfig.Region != "" { region = *apiConfig.Region } url := fmt.Sprintf("%s/%s", strings.TrimSuffix(z.BaseURL[region], "/"), z.URLSuffix.Embedding) - embeddings := make([][]float64, len(texts)) + reqBody := map[string]interface{}{} + reqBody["model"] = modelName + reqBody["input"] = texts + if embeddingConfig.Dimension > 0 { + reqBody["dimensions"] = embeddingConfig.Dimension + } - for i, text := range texts { - reqBody := map[string]interface{}{} - reqBody["model"] = modelName - reqBody["input"] = text - if embeddingConfig.Dimension > 0 { - reqBody["dimensions"] = embeddingConfig.Dimension - } + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("failed to marshal request: %w", err) + } - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("failed to marshal request: %w", err) - } + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) + } - req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("failed to create request: %w", err) - } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } - resp, err := z.httpClient.Do(req) - if err != nil { - return nil, fmt.Errorf("failed to send request: %w", err) - } + body, err := io.ReadAll(resp.Body) + resp.Body.Close() - body, err := io.ReadAll(resp.Body) - resp.Body.Close() + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } - if err != nil { - return nil, fmt.Errorf("failed to read response: %w", err) - } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) + } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API request failed with status %d: %s", resp.StatusCode, string(body)) - } + // Parse response + var zhipuResp zhipuEmbeddingResponse + if err = json.Unmarshal(body, &zhipuResp); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } - // Parse response - var result map[string]interface{} - if err = json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("failed to parse response: %w", err) - } + var embeddings []EmbeddingData + for _, dataElem := range zhipuResp.Data { + var embeddingData EmbeddingData + embeddingData.Embedding = dataElem.Embedding + embeddingData.Index = dataElem.Index + embeddings = append(embeddings, embeddingData) + } - data, ok := result["data"].([]interface{}) - if !ok || len(data) == 0 { - return nil, fmt.Errorf("no data in response") - } + return embeddings, nil +} - firstData, ok := data[0].(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("invalid data format") - } +func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { + region := "default" + if apiConfig != nil && apiConfig.Region != nil && *apiConfig.Region != "" { + region = *apiConfig.Region + } - embeddingSlice, ok := firstData["embedding"].([]interface{}) - if !ok { - return nil, fmt.Errorf("invalid embedding format") + baseURL := z.BaseURL["default"] + if region != "default" { + if regional, ok := z.BaseURL[region]; ok && regional != "" { + baseURL = regional } + } + if baseURL == "" { + return nil, fmt.Errorf("zhipu-ai: no base URL configured for default region") + } - embedding := make([]float64, len(embeddingSlice)) - for j, v := range embeddingSlice { - switch val := v.(type) { - case float64: - embedding[j] = val - case float32: - embedding[j] = float64(val) - default: - return nil, fmt.Errorf("unexpected embedding value type") - } - } + url := fmt.Sprintf("%s/%s", strings.TrimSuffix(baseURL, "/"), z.URLSuffix.Models) - embeddings[i] = embedding + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create request: %w", err) } - return embeddings, nil -} + req.Header.Set("Content-Type", "application/json") + if apiConfig != nil && apiConfig.ApiKey != nil && *apiConfig.ApiKey != "" { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiConfig.ApiKey)) + } -func (z *ZhipuAIModel) ListModels(apiConfig *APIConfig) ([]string, error) { - return nil, fmt.Errorf("%s, no such method", z.Name()) + resp, err := z.httpClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("failed to read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("ZhipuAI models API error: %s, body: %s", resp.Status, string(body)) + } + + var modelList DSModelList + if err = json.Unmarshal(body, &modelList); err != nil { + return nil, fmt.Errorf("failed to parse response: %w", err) + } + + models := make([]string, 0, len(modelList.Models)) + for _, model := range modelList.Models { + modelName := model.ID + if model.OwnedBy != "" { + modelName = model.ID + "@" + model.OwnedBy + } + models = append(models, modelName) + } + + return models, nil } func (z *ZhipuAIModel) Balance(apiConfig *APIConfig) (map[string]interface{}, error) { @@ -604,3 +666,39 @@ func (z *ZhipuAIModel) Rerank(modelName *string, query string, documents []strin return &rerankResponse, nil } + +// TranscribeAudio transcribe audio +func (o *ZhipuAIModel) TranscribeAudio(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig) (*ASRResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *ZhipuAIModel) TranscribeAudioWithSender(modelName *string, file *string, apiConfig *APIConfig, asrConfig *ASRConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// AudioSpeech convert text to audio +func (o *ZhipuAIModel) AudioSpeech(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig) (*TTSResponse, error) { + return nil, fmt.Errorf("%s, no such method", o.Name()) +} + +func (z *ZhipuAIModel) AudioSpeechWithSender(modelName *string, audioContent *string, apiConfig *APIConfig, ttsConfig *TTSConfig, sender func(*string, *string) error) error { + return fmt.Errorf("%s, no such method", z.Name()) +} + +// OCRFile OCR file +func (m *ZhipuAIModel) OCRFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, ocrConfig *OCRConfig) (*OCRFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", m.Name()) +} + +// ParseFile parse file +func (z *ZhipuAIModel) ParseFile(modelName *string, content []byte, url *string, apiConfig *APIConfig, parseFileConfig *ParseFileConfig) (*ParseFileResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *ZhipuAIModel) ListTasks(apiConfig *APIConfig) ([]ListTaskStatus, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} + +func (z *ZhipuAIModel) ShowTask(taskID string, apiConfig *APIConfig) (*TaskResponse, error) { + return nil, fmt.Errorf("%s, no such method", z.Name()) +} diff --git a/internal/entity/skill_search.go b/internal/entity/skill_search.go index 3a31dfb486e..011499bbcbb 100644 --- a/internal/entity/skill_search.go +++ b/internal/entity/skill_search.go @@ -56,8 +56,7 @@ type SkillSearchConfig struct { TenantRerankID *int64 `gorm:"column:tenant_rerank_id" json:"tenant_rerank_id,omitempty"` TopK int64 `gorm:"column:top_k;default:10" json:"top_k"` IndexVersion string `gorm:"column:index_version;size:32;default:'1.0.0'" json:"index_version"` - CreateTime *int64 `gorm:"column:create_time" json:"create_time,omitempty"` - UpdateTime *time.Time `gorm:"column:update_time" json:"update_time,omitempty"` + BaseModel } // TableName returns the table name for SkillSearchConfig model @@ -90,7 +89,7 @@ func (s *SkillSearchConfig) ToMap() map[string]interface{} { result["create_time"] = s.CreateTime } if s.UpdateTime != nil { - result["update_time"] = s.UpdateTime.Format("2006-01-02 15:04:05") + result["update_time"] = time.UnixMilli(*s.UpdateTime).Format("2006-01-02 15:04:05") } return result diff --git a/internal/entity/skill_space.go b/internal/entity/skill_space.go index 0e90a398171..1df53a9197b 100644 --- a/internal/entity/skill_space.go +++ b/internal/entity/skill_space.go @@ -36,8 +36,7 @@ type SkillSpace struct { RerankID string `gorm:"column:rerank_id;size:128" json:"rerank_id"` TopK int `gorm:"column:top_k;default:10" json:"top_k"` Status string `gorm:"column:status;size:1;default:1" json:"status"` - CreateTime *int64 `gorm:"column:create_time" json:"create_time,omitempty"` - UpdateTime *time.Time `gorm:"column:update_time" json:"update_time,omitempty"` + BaseModel } // TableName returns the table name for SkillSpace model @@ -83,7 +82,7 @@ func (s *SkillSpace) ToMap() map[string]interface{} { result["create_time"] = s.CreateTime } if s.UpdateTime != nil { - result["update_time"] = s.UpdateTime.Format("2006-01-02 15:04:05") + result["update_time"] = time.UnixMilli(*s.UpdateTime).Format("2006-01-02 15:04:05") } return result diff --git a/internal/entity/system.go b/internal/entity/system.go index 831bb7397f9..49541ec3ef8 100644 --- a/internal/entity/system.go +++ b/internal/entity/system.go @@ -16,18 +16,13 @@ package entity -import "time" - // SystemSettings system settings model type SystemSettings struct { - Name string `gorm:"column:name;primaryKey;size:128" json:"name"` - Source string `gorm:"column:source;size:32;not null" json:"source"` - DataType string `gorm:"column:data_type;size:32;not null" json:"data_type"` - Value string `gorm:"column:value;type:longtext;not null" json:"value"` - CreateTime *int64 `gorm:"column:create_time" json:"create_time"` - CreateDate *time.Time `gorm:"column:create_date" json:"create_date"` - UpdateTime *int64 `gorm:"column:update_time" json:"update_time"` - UpdateDate *time.Time `gorm:"column:update_date" json:"update_date"` + Name string `gorm:"column:name;primaryKey;size:128" json:"name"` + Source string `gorm:"column:source;size:32;not null" json:"source"` + DataType string `gorm:"column:data_type;size:32;not null" json:"data_type"` + Value string `gorm:"column:value;type:longtext;not null" json:"value"` + BaseModel } // TableName specify table name diff --git a/internal/handler/chunk.go b/internal/handler/chunk.go index 207edfee488..8159ce05961 100644 --- a/internal/handler/chunk.go +++ b/internal/handler/chunk.go @@ -92,7 +92,7 @@ func (h *ChunkHandler) RetrievalTest(c *gin.Context) { }) return } - if req.KbID == nil { + if req.Datasets == nil { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, "message": "kb_id is required", @@ -100,52 +100,10 @@ func (h *ChunkHandler) RetrievalTest(c *gin.Context) { return } - // Validate kb_id type: string or []string - switch v := req.KbID.(type) { - case string: - if v == "" { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "kb_id cannot be empty string", - }) - return - } - case []interface{}: - // Convert to []string - var kbIDs []string - for _, item := range v { - if str, ok := item.(string); ok && str != "" { - kbIDs = append(kbIDs, str) - } else { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "kb_id array must contain non-empty strings", - }) - return - } - } - if len(kbIDs) == 0 { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "kb_id array cannot be empty", - }) - return - } - // Convert back to interface{} for service - req.KbID = kbIDs - case []string: - // Already correct type - if len(v) == 0 { - c.JSON(http.StatusBadRequest, gin.H{ - "code": 400, - "message": "kb_id array cannot be empty", - }) - return - } - default: + if len(req.Datasets) == 0 { c.JSON(http.StatusBadRequest, gin.H{ "code": 400, - "message": "kb_id must be string or array of strings", + "message": "kb_id array cannot be empty", }) return } diff --git a/internal/handler/datasets.go b/internal/handler/datasets.go index a1768e63fb0..250e0ea208d 100644 --- a/internal/handler/datasets.go +++ b/internal/handler/datasets.go @@ -18,7 +18,11 @@ package handler import ( "encoding/json" + "fmt" "net/http" + "ragflow/internal/engine" + "ragflow/internal/engine/types" + "sort" "strconv" "strings" @@ -30,7 +34,7 @@ import ( // DatasetsHandler handles the RESTful dataset endpoints. type DatasetsHandler struct { - datasetsService *service.DatasetsService + datasetsService *service.DatasetService } type listDatasetsExt struct { @@ -40,7 +44,7 @@ type listDatasetsExt struct { } // NewDatasetsHandler creates a new datasets handler. -func NewDatasetsHandler(datasetsService *service.DatasetsService) *DatasetsHandler { +func NewDatasetsHandler(datasetsService *service.DatasetService) *DatasetsHandler { return &DatasetsHandler{datasetsService: datasetsService} } @@ -142,6 +146,27 @@ func (h *DatasetsHandler) CreateDataset(c *gin.Context) { }) } +// GetDataset handles GET /api/v1/datasets/:dataset_id. +func (h *DatasetsHandler) GetDataset(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := c.Param("dataset_id") + result, code, err := h.datasetsService.GetDataset(datasetID, user.ID) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeSuccess, + "data": result, + }) +} + // DeleteDatasets handles DELETE /api/v1/datasets. func (h *DatasetsHandler) DeleteDatasets(c *gin.Context) { user, errorCode, errorMessage := GetUser(c) @@ -177,3 +202,243 @@ func (h *DatasetsHandler) DeleteDatasets(c *gin.Context) { "data": result, }) } + +// GetKnowledgeGraph handles GET /api/v1/datasets/:dataset_id/graph. +func (h *DatasetsHandler) GetKnowledgeGraph(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { + jsonError(c, common.CodeDataError, "dataset_id is required") + return + } + + dataset, code, err := h.datasetsService.GetDataset(datasetID, user.ID) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + tenantID, _ := dataset["tenant_id"].(string) + if tenantID == "" { + jsonError(c, common.CodeDataError, "tenant_id is required") + return + } + + docEngine := engine.Get() + if docEngine == nil { + jsonError(c, common.CodeServerError, "Document engine is not initialized") + return + } + + indexName := fmt.Sprintf("ragflow_%s", tenantID) + exists, err := docEngine.ChunkStoreExists(c.Request.Context(), indexName, datasetID) + if err != nil { + jsonError(c, common.CodeServerError, err.Error()) + return + } + + result := gin.H{ + "graph": map[string]interface{}{}, + "mind_map": map[string]interface{}{}, + } + if !exists { + jsonResponse(c, common.CodeSuccess, result, "success") + return + } + + searchResult, err := docEngine.Search(c.Request.Context(), &types.SearchRequest{ + IndexNames: []string{indexName}, + KbIDs: []string{datasetID}, + Offset: 0, + Limit: 1, + SelectFields: []string{"content_with_weight", "knowledge_graph_kwd"}, + Filter: map[string]interface{}{ + "kb_id": []string{datasetID}, + "knowledge_graph_kwd": []string{"graph"}, + }, + }) + if err != nil { + jsonError(c, common.CodeServerError, err.Error()) + return + } + if searchResult == nil || len(searchResult.Chunks) == 0 { + jsonResponse(c, common.CodeSuccess, result, "success") + return + } + + chunk := searchResult.Chunks[0] + graphType := firstStringValue(chunk["knowledge_graph_kwd"]) + contentWithWeight, _ := chunk["content_with_weight"].(string) + if strings.TrimSpace(contentWithWeight) == "" { + jsonResponse(c, common.CodeSuccess, result, "success") + return + } + + var graphData map[string]interface{} + if err := json.Unmarshal([]byte(contentWithWeight), &graphData); err != nil { + jsonResponse(c, common.CodeSuccess, result, "success") + return + } + if len(graphData) == 0 { + jsonResponse(c, common.CodeSuccess, result, "success") + return + } + + if graphType == "" { + graphType = "graph" + } + if graphType == "graph" { + sortKnowledgeGraph(graphData) + result["graph"] = graphData + } else { + result[graphType] = graphData + } + + jsonResponse(c, common.CodeSuccess, result, "success") +} + +// DeleteKnowledgeGraph handles DELETE /api/v1/datasets/:dataset_id/graph. +func (h *DatasetsHandler) DeleteKnowledgeGraph(c *gin.Context) { + user, errorCode, errorMessage := GetUser(c) + if errorCode != common.CodeSuccess { + jsonError(c, errorCode, errorMessage) + return + } + + datasetID := strings.TrimSpace(c.Param("dataset_id")) + if datasetID == "" { + jsonError(c, common.CodeDataError, "dataset_id is required") + return + } + + dataset, code, err := h.datasetsService.GetDataset(datasetID, user.ID) + if err != nil { + jsonError(c, code, err.Error()) + return + } + + tenantID, _ := dataset["tenant_id"].(string) + if tenantID == "" { + jsonError(c, common.CodeDataError, "tenant_id is required") + return + } + + docEngine := engine.Get() + if docEngine == nil { + jsonError(c, common.CodeServerError, "Document engine is not initialized") + return + } + + indexName := fmt.Sprintf("ragflow_%s", tenantID) + if _, err := docEngine.DeleteChunks(c.Request.Context(), map[string]interface{}{ + "knowledge_graph_kwd": []string{"graph", "subgraph", "entity", "relation", "community_report"}, + }, indexName, datasetID); err != nil { + jsonError(c, common.CodeServerError, err.Error()) + return + } + + jsonResponse(c, common.CodeSuccess, true, "success") +} + +func firstStringValue(value interface{}) string { + switch v := value.(type) { + case string: + return strings.TrimSpace(v) + case []string: + if len(v) > 0 { + return strings.TrimSpace(v[0]) + } + case []interface{}: + for _, item := range v { + if s, ok := item.(string); ok { + s = strings.TrimSpace(s) + if s != "" { + return s + } + } + } + } + return "" +} + +func sortKnowledgeGraph(graphData map[string]interface{}) { + nodes := mapSlice(graphData["nodes"]) + if len(nodes) > 0 { + sort.Slice(nodes, func(i, j int) bool { + return numericValue(nodes[i]["pagerank"]) > numericValue(nodes[j]["pagerank"]) + }) + if len(nodes) > 256 { + nodes = nodes[:256] + } + graphData["nodes"] = nodes + } + + edges := mapSlice(graphData["edges"]) + if len(edges) > 0 { + nodeIDSet := make(map[string]struct{}, len(nodes)) + for _, node := range nodes { + if id, ok := node["id"].(string); ok { + nodeIDSet[id] = struct{}{} + } + } + filteredEdges := make([]map[string]interface{}, 0, len(edges)) + for _, edge := range edges { + source, _ := edge["source"].(string) + target, _ := edge["target"].(string) + if source == "" || target == "" || source == target { + continue + } + if _, ok := nodeIDSet[source]; !ok { + continue + } + if _, ok := nodeIDSet[target]; !ok { + continue + } + filteredEdges = append(filteredEdges, edge) + } + sort.Slice(filteredEdges, func(i, j int) bool { + return numericValue(filteredEdges[i]["weight"]) > numericValue(filteredEdges[j]["weight"]) + }) + if len(filteredEdges) > 128 { + filteredEdges = filteredEdges[:128] + } + graphData["edges"] = filteredEdges + } +} + +func mapSlice(value interface{}) []map[string]interface{} { + raw, ok := value.([]interface{}) + if !ok { + return nil + } + result := make([]map[string]interface{}, 0, len(raw)) + for _, item := range raw { + if m, ok := item.(map[string]interface{}); ok { + result = append(result, m) + } + } + return result +} + +func numericValue(value interface{}) float64 { + switch v := value.(type) { + case float64: + return v + case float32: + return float64(v) + case int: + return float64(v) + case int64: + return float64(v) + case json.Number: + f, _ := v.Float64() + return f + default: + return 0 + } +} diff --git a/internal/handler/document.go b/internal/handler/document.go index a4152c07dc8..9b3307a724a 100644 --- a/internal/handler/document.go +++ b/internal/handler/document.go @@ -21,8 +21,10 @@ import ( "fmt" "net/http" "ragflow/internal/common" + "ragflow/internal/entity" "strconv" "strings" + "time" "github.com/gin-gonic/gin" @@ -32,12 +34,14 @@ import ( // DocumentHandler document handler type DocumentHandler struct { documentService *service.DocumentService + datasetService *service.DatasetService } // NewDocumentHandler create document handler -func NewDocumentHandler(documentService *service.DocumentService) *DocumentHandler { +func NewDocumentHandler(documentService *service.DocumentService, datasetService *service.DatasetService) *DocumentHandler { return &DocumentHandler{ documentService: documentService, + datasetService: datasetService, } } @@ -198,35 +202,22 @@ func (h *DocumentHandler) DeleteDocument(c *gin.Context) { } // ListDocuments document list -// @Summary Document List -// @Description Get paginated document list -// @Tags documents -// @Accept json -// @Produce json -// @Param page query int false "page number" default(1) -// @Param page_size query int false "items per page" default(10) -// @Success 200 {object} map[string]interface{} -// @Router /api/v1/document/list [post] + func (h *DocumentHandler) ListDocuments(c *gin.Context) { - _, errorCode, errorMessage := GetUser(c) - if errorCode != common.CodeSuccess { - jsonError(c, errorCode, errorMessage) - return - } - kbID := c.Query("kb_id") - if kbID == "" { - c.JSON(http.StatusOK, gin.H{ - "code": 1, - "message": "Lack of KB ID", - "data": false, - }) + datasetID := c.Param("dataset_id") + pageStr := c.Query("page") + pageSizeStr := c.Query("page_size") + page, _ := strconv.Atoi(pageStr) + pageSize, _ := strconv.Atoi(pageSizeStr) + + userID := c.GetString("user_id") + + if !h.datasetService.Accessible(datasetID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization.") return } - page, _ := strconv.Atoi(c.DefaultQuery("page", "1")) - pageSize, _ := strconv.Atoi(c.DefaultQuery("page_size", "10")) - if page < 1 { page = 1 } @@ -235,7 +226,7 @@ func (h *DocumentHandler) ListDocuments(c *gin.Context) { } // Use kbID to filter documents - documents, total, err := h.documentService.ListDocumentsByKBID(kbID, page, pageSize) + documents, total, err := h.documentService.ListDocumentsByDatasetID(datasetID, page, pageSize) if err != nil { c.JSON(http.StatusOK, gin.H{ "code": 1, @@ -252,15 +243,7 @@ func (h *DocumentHandler) ListDocuments(c *gin.Context) { metaFields = make(map[string]interface{}) } - docs = append(docs, map[string]interface{}{ - "id": doc.ID, - "name": doc.Name, - "size": doc.Size, - "type": doc.Type, - "status": doc.Status, - "created_at": doc.CreatedAt, - "meta_fields": metaFields, - }) + docs = append(docs, mapDocumentListItem(doc, metaFields)) } c.JSON(http.StatusOK, gin.H{ @@ -273,6 +256,104 @@ func (h *DocumentHandler) ListDocuments(c *gin.Context) { }) } +func mapDocumentListItem(doc *entity.DocumentListItem, metaFields map[string]interface{}) map[string]interface{} { + item := map[string]interface{}{ + "id": doc.ID, + "dataset_id": doc.KbID, + "name": stringValue(doc.Name), + "thumbnail": stringValue(doc.Thumbnail), + "size": doc.Size, + "type": doc.Type, + "created_by": doc.CreatedBy, + "location": stringValue(doc.Location), + "token_count": doc.TokenNum, + "chunk_count": doc.ChunkNum, + "progress": doc.Progress, + "progress_msg": stringValue(doc.ProgressMsg), + "process_begin_at": formatTimePtr(doc.ProcessBeginAt), + "process_duration": doc.ProcessDuration, + "suffix": doc.Suffix, + "run": mapRunStatus(doc.Run), + "status": stringValue(doc.Status), + "chunk_method": doc.ParserID, + "parser_id": doc.ParserID, + "pipeline_id": stringValue(doc.PipelineID), + "pipeline_name": stringValue(doc.PipelineName), + "nickname": stringValue(doc.Nickname), + "parser_config": decodeJSONMap(string(doc.ParserConfig)), + "meta_fields": metaFields, + "create_time": int64(0), + "create_date": "", + "update_time": int64(0), + "update_date": "", + } + + if doc.CreateTime != nil { + item["create_time"] = *doc.CreateTime + } + if doc.CreateDate != nil { + item["create_date"] = doc.CreateDate.Format("2006-01-02 15:04:05") + } + if doc.UpdateTime != nil { + item["update_time"] = *doc.UpdateTime + } + if doc.UpdateDate != nil { + item["update_date"] = doc.UpdateDate.Format("2006-01-02 15:04:05") + } + + return item +} + +func decodeJSONMap(raw string) map[string]interface{} { + if strings.TrimSpace(raw) == "" { + return map[string]interface{}{} + } + + var data map[string]interface{} + if err := json.Unmarshal([]byte(raw), &data); err != nil { + return map[string]interface{}{} + } + + return data +} + +func mapRunStatus(run *string) string { + if run == nil { + return "UNSTART" + } + + switch strings.TrimSpace(*run) { + case "0": + return "UNSTART" + case "1": + return "RUNNING" + case "2": + return "CANCEL" + case "3": + return "DONE" + case "4": + return "FAIL" + default: + return strings.TrimSpace(*run) + } +} + +func formatTimePtr(value *time.Time) string { + if value == nil { + return "" + } + + return value.Format("2006-01-02 15:04:05") +} + +func stringValue(value *string) string { + if value == nil { + return "" + } + + return *value +} + // GetDocumentsByAuthorID get documents by author ID // @Summary Get Author Documents // @Description Get paginated document list by author ID @@ -482,4 +563,37 @@ func (h *DocumentHandler) SetMeta(c *gin.Context) { "message": "success", "data": true, }) -} \ No newline at end of file +} + +type ParseDocumentRequest struct { + Documents []string `json:"documents" binding:"required"` + DatasetID string `json:"dataset_id" binding:"required"` +} + +func (h *DocumentHandler) ParseDocuments(c *gin.Context) { + var req ParseDocumentRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + userID := c.GetString("user_id") + + if !h.datasetService.Accessible(req.DatasetID, userID) { + jsonError(c, common.CodeAuthenticationError, "No authorization to access the dataset.") + return + } + + err := h.documentService.ParseDocuments(req.DatasetID, userID, req.Documents) + if err != nil { + jsonError(c, common.CodeExceptionError, err.Error()) + return + } + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + }) +} diff --git a/internal/handler/kb.go b/internal/handler/kb.go index 580e24fdcac..3debfc969b1 100644 --- a/internal/handler/kb.go +++ b/internal/handler/kb.go @@ -650,7 +650,7 @@ func (h *KnowledgebaseHandler) InsertDatasetFromFile(c *gin.Context) { // Get the document engine and insert docEngine := engine.Get() - result, err := docEngine.InsertDataset(c.Request.Context(), debugFormat.Chunks, debugFormat.TableNamePrefix, debugFormat.KnowledgebaseID) + result, err := docEngine.InsertChunks(c.Request.Context(), debugFormat.Chunks, debugFormat.TableNamePrefix, debugFormat.KnowledgebaseID) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{ "code": 500, diff --git a/internal/handler/memory.go b/internal/handler/memory.go index b8e04d06d84..745cd9ddcc6 100644 --- a/internal/handler/memory.go +++ b/internal/handler/memory.go @@ -194,7 +194,7 @@ func (h *MemoryHandler) CreateMemory(c *gin.Context) { // Return success response c.JSON(http.StatusOK, gin.H{ "code": common.CodeSuccess, - "message": "success", + "message": "success", "data": result, }) } @@ -293,7 +293,7 @@ func (h *MemoryHandler) UpdateMemory(c *gin.Context) { // Return success response c.JSON(http.StatusOK, gin.H{ "code": common.CodeSuccess, - "message": "success", + "message": "success", "data": result, }) } @@ -347,7 +347,7 @@ func (h *MemoryHandler) DeleteMemory(c *gin.Context) { // Return success response c.JSON(http.StatusOK, gin.H{ "code": common.CodeSuccess, - "message": "success", + "message": "success", "data": nil, }) } @@ -436,7 +436,7 @@ func (h *MemoryHandler) ListMemories(c *gin.Context) { // Return success response c.JSON(http.StatusOK, gin.H{ "code": common.CodeSuccess, - "message": "success", + "message": "success", "data": result, }) } @@ -490,7 +490,7 @@ func (h *MemoryHandler) GetMemoryConfig(c *gin.Context) { // Return success response c.JSON(http.StatusOK, gin.H{ "code": common.CodeSuccess, - "message": "success", + "message": "success", "data": result, }) } diff --git a/internal/handler/providers.go b/internal/handler/providers.go index 758919f406b..9f7b238fc6e 100644 --- a/internal/handler/providers.go +++ b/internal/handler/providers.go @@ -423,6 +423,91 @@ func (h *ProviderHandler) CheckProviderConnection(c *gin.Context) { }) } +func (h *ProviderHandler) ListTasks(c *gin.Context) { + providerName := c.Param("provider_name") + if providerName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + instanceName := c.Param("instance_name") + if instanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + userID := c.GetString("user_id") + + // Get tenant ID from user + listTaskResponse, errorCode, err := h.modelProviderService.ListTasks(providerName, instanceName, userID) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": errorCode, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": listTaskResponse, + }) +} + +func (h *ProviderHandler) ShowTask(c *gin.Context) { + providerName := c.Param("provider_name") + if providerName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + instanceName := c.Param("instance_name") + if instanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + taskID := c.Param("task_id") + if taskID == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Task id is required", + }) + return + } + + userID := c.GetString("user_id") + + // Get tenant ID from user + taskResponse, errorCode, err := h.modelProviderService.ShowTask(providerName, instanceName, taskID, userID) + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": errorCode, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "message": "success", + "data": taskResponse, + }) +} + type AlterProviderInstanceRequest struct { LLMName string `json:"llm_name" binding:"required"` } @@ -950,7 +1035,7 @@ func (h *ProviderHandler) EmbedText(c *gin.Context) { } // Non-stream response - var response *models.EmbeddingResponse + var response []models.EmbeddingData var errorCode common.ErrorCode var err error @@ -966,7 +1051,7 @@ func (h *ProviderHandler) EmbedText(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ "code": 0, - "data": response.Data, + "data": response, "message": "success", }) } @@ -1047,3 +1132,391 @@ func (h *ProviderHandler) RerankDocument(c *gin.Context) { "message": "success", }) } + +type TranscribeAudioRequest struct { + ProviderName *string `json:"provider_name"` + InstanceName *string `json:"instance_name"` + ModelName *string `json:"model_name"` + File *string `json:"file"` + Language []string `json:"language"` + Prompt int `json:"prompt"` + Stream bool `json:"stream"` + ASRConfig *models.ASRConfig `json:"asr_config"` +} + +func (h *ProviderHandler) TranscribeAudio(c *gin.Context) { + var req TranscribeAudioRequest + if err := c.ShouldBindJSON(&req); err != nil { + println("JSON bind error: %v (type: %T)", err, err) + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + + userID := c.GetString("user_id") + + apiConfig := models.APIConfig{ + ApiKey: nil, + Region: nil, + } + + asrConfig := models.ASRConfig{} + if req.ASRConfig != nil { + asrConfig = *req.ASRConfig + } + + // Check if it's a stream request + if req.Stream { + // Set SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Writer.WriteHeader(http.StatusOK) + c.Writer.Flush() + + // Create sender function that writes directly to response + sender := func(content, reasoningContent *string) error { + // Check for [DONE] marker (OpenAI compatible) + if content != nil { + if *content == "[DONE]" { + c.SSEvent("done", "[DONE]") + return nil + } + message := fmt.Sprintf("[MESSAGE]%s", *content) + c.SSEvent("message", message) + c.Writer.Flush() + } + + if reasoningContent != nil { + message := fmt.Sprintf("[REASONING]%s", *reasoningContent) + c.SSEvent("message", message) + c.Writer.Flush() + } + + //logger.Info(data) + return nil + } + + // Stream response using sender function (best performance, no channel) + errorCode, err := h.modelProviderService.TranscribeAudioStream(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.File, &apiConfig, &asrConfig, sender) + + if errorCode != common.CodeSuccess { + c.SSEvent("error", err.Error()) + } + return + } + + // Non-stream response + var response *models.ASRResponse + var errorCode common.ErrorCode + var err error + + response, errorCode, err = h.modelProviderService.TranscribeAudio(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.File, &apiConfig, &asrConfig) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": errorCode, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": response, + "message": "success", + }) +} + +type AudioSpeechRequest struct { + ProviderName *string `json:"provider_name"` + InstanceName *string `json:"instance_name"` + ModelName *string `json:"model_name"` + Text *string `json:"text"` + Stream bool `json:"stream"` + TTSConfig *models.TTSConfig `json:"tts_config"` +} + +func (h *ProviderHandler) AudioSpeech(c *gin.Context) { + var req AudioSpeechRequest + if err := c.ShouldBindJSON(&req); err != nil { + println("JSON bind error: %v (type: %T)", err, err) + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + + userID := c.GetString("user_id") + + apiConfig := models.APIConfig{ + ApiKey: nil, + Region: nil, + } + + ttsConfig := models.TTSConfig{} + if req.TTSConfig != nil { + ttsConfig = *req.TTSConfig + } + + // Check if it's a stream request + if req.Stream { + // Set SSE headers + c.Header("Content-Type", "text/event-stream") + c.Header("Cache-Control", "no-cache") + c.Header("Connection", "keep-alive") + c.Writer.WriteHeader(http.StatusOK) + c.Writer.Flush() + + // Create sender function that writes directly to response + sender := func(content, reasoningContent *string) error { + // Check for [DONE] marker (OpenAI compatible) + if content != nil { + if *content == "[DONE]" { + c.SSEvent("done", "[DONE]") + return nil + } + message := fmt.Sprintf("[MESSAGE]%s", *content) + c.SSEvent("message", message) + c.Writer.Flush() + } + + if reasoningContent != nil { + message := fmt.Sprintf("[REASONING]%s", *reasoningContent) + c.SSEvent("message", message) + c.Writer.Flush() + } + + //logger.Info(data) + return nil + } + + // Stream response using sender function (best performance, no channel) + errorCode, err := h.modelProviderService.AudioSpeechStream(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Text, &apiConfig, &ttsConfig, sender) + + if errorCode != common.CodeSuccess { + c.SSEvent("error", err.Error()) + } + return + } + + // Non-stream response + var response *models.TTSResponse + var errorCode common.ErrorCode + var err error + + response, errorCode, err = h.modelProviderService.AudioSpeech(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Text, &apiConfig, &ttsConfig) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": errorCode, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": response, + "message": "success", + }) +} + +type OCRFileRequest struct { + ProviderName *string `json:"provider_name"` + InstanceName *string `json:"instance_name"` + ModelName *string `json:"model_name"` + Content []byte `json:"content"` + URL *string `json:"url"` +} + +func (h *ProviderHandler) OCRFile(c *gin.Context) { + var req OCRFileRequest + if err := c.ShouldBindJSON(&req); err != nil { + println("JSON bind error: %v (type: %T)", err, err) + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + + userID := c.GetString("user_id") + + apiConfig := models.APIConfig{ + ApiKey: nil, + Region: nil, + } + + OCRConfig := models.OCRConfig{} + + // Non-stream response + var response *models.OCRFileResponse + var errorCode common.ErrorCode + var err error + + response, errorCode, err = h.modelProviderService.OCRFile(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Content, req.URL, &apiConfig, &OCRConfig) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": errorCode, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": response, + "message": "success", + }) +} + +type ParseFileRequest struct { + ProviderName *string `json:"provider_name"` + InstanceName *string `json:"instance_name"` + ModelName *string `json:"model_name"` + Content []byte `json:"content"` + URL *string `json:"url"` +} + +func (h *ProviderHandler) ParseFile(c *gin.Context) { + var req ParseFileRequest + if err := c.ShouldBindJSON(&req); err != nil { + println("JSON bind error: %v (type: %T)", err, err) + c.JSON(http.StatusOK, gin.H{ + "code": common.CodeBadRequest, + "message": err.Error(), + }) + return + } + + if req.ProviderName == nil || *req.ProviderName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Provider name is required", + }) + return + } + + if req.InstanceName == nil || *req.InstanceName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Instance name is required", + }) + return + } + + if req.ModelName == nil || *req.ModelName == "" { + c.JSON(http.StatusBadRequest, gin.H{ + "code": 400, + "message": "Model name is required", + }) + return + } + + userID := c.GetString("user_id") + + apiConfig := models.APIConfig{ + ApiKey: nil, + Region: nil, + } + + parseFileConfig := models.ParseFileConfig{} + + // Non-stream response + var response *models.ParseFileResponse + var errorCode common.ErrorCode + var err error + + response, errorCode, err = h.modelProviderService.ParseFile(*req.ProviderName, *req.InstanceName, *req.ModelName, userID, req.Content, req.URL, &apiConfig, &parseFileConfig) + + if err != nil { + c.JSON(http.StatusOK, gin.H{ + "code": errorCode, + "message": err.Error(), + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "code": 0, + "data": response, + "message": "success", + }) +} diff --git a/internal/router/router.go b/internal/router/router.go index 97c9b90984c..5b8a840e7a1 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -50,7 +50,6 @@ func NewRouter( documentHandler *handler.DocumentHandler, datasetsHandler *handler.DatasetsHandler, systemHandler *handler.SystemHandler, - knowledgebaseHandler *handler.KnowledgebaseHandler, chunkHandler *handler.ChunkHandler, llmHandler *handler.LLMHandler, chatHandler *handler.ChatHandler, @@ -63,23 +62,22 @@ func NewRouter( providerHandler *handler.ProviderHandler, ) *Router { return &Router{ - authHandler: authHandler, - userHandler: userHandler, - tenantHandler: tenantHandler, - documentHandler: documentHandler, - datasetsHandler: datasetsHandler, - systemHandler: systemHandler, - knowledgebaseHandler: knowledgebaseHandler, - chunkHandler: chunkHandler, - llmHandler: llmHandler, - chatHandler: chatHandler, - chatSessionHandler: chatSessionHandler, - connectorHandler: connectorHandler, - searchHandler: searchHandler, - fileHandler: fileHandler, - memoryHandler: memoryHandler, - skillSearchHandler: skillSearchHandler, - providerHandler: providerHandler, + authHandler: authHandler, + userHandler: userHandler, + tenantHandler: tenantHandler, + documentHandler: documentHandler, + datasetsHandler: datasetsHandler, + systemHandler: systemHandler, + chunkHandler: chunkHandler, + llmHandler: llmHandler, + chatHandler: chatHandler, + chatSessionHandler: chatSessionHandler, + connectorHandler: connectorHandler, + searchHandler: searchHandler, + fileHandler: fileHandler, + memoryHandler: memoryHandler, + skillSearchHandler: skillSearchHandler, + providerHandler: providerHandler, } } @@ -144,6 +142,10 @@ func (r *Router) Setup(engine *gin.Engine) { users.GET("/me", r.userHandler.Info) // User settings endpoint users.PATCH("/me", r.userHandler.Setting) + // User tenant info endpoint + users.GET("/me/models", r.tenantHandler.TenantInfo) + // User set tenant info endpoint + users.PATCH("/me/models", r.userHandler.SetTenantInfo) } tenants := v1.Group("/tenants") @@ -151,6 +153,8 @@ func (r *Router) Setup(engine *gin.Engine) { tenants.GET("", r.tenantHandler.TenantList) } + v1.GET("/tenant/list", r.tenantHandler.TenantList) + // Document routes documents := v1.Group("/documents") { @@ -159,6 +163,7 @@ func (r *Router) Setup(engine *gin.Engine) { documents.GET("/:id", r.documentHandler.GetDocumentByID) documents.PUT("/:id", r.documentHandler.UpdateDocument) documents.DELETE("/:id", r.documentHandler.DeleteDocument) + documents.POST("/parse", r.documentHandler.ParseDocuments) } // Chat routes @@ -173,9 +178,15 @@ func (r *Router) Setup(engine *gin.Engine) { datasets := v1.Group("/datasets") { datasets.GET("", r.datasetsHandler.ListDatasets) + datasets.GET("/:dataset_id", r.datasetsHandler.GetDataset) + datasets.GET("/:dataset_id/graph", r.datasetsHandler.GetKnowledgeGraph) + datasets.DELETE("/:dataset_id/graph", r.datasetsHandler.DeleteKnowledgeGraph) datasets.POST("", r.datasetsHandler.CreateDataset) datasets.DELETE("", r.datasetsHandler.DeleteDatasets) datasets.POST("/search", r.chunkHandler.RetrievalTest) + + // Dataset documents + datasets.GET("/:dataset_id/documents", r.documentHandler.ListDocuments) } // Search routes @@ -195,6 +206,7 @@ func (r *Router) Setup(engine *gin.Engine) { file.DELETE("", r.fileHandler.DeleteFiles) file.POST("/move", r.fileHandler.MoveFiles) file.GET("/:id/ancestors", r.fileHandler.GetFileAncestors) + file.GET("/:id/parent", r.fileHandler.GetParentFolder) file.GET("/:id", r.fileHandler.Download) } @@ -262,6 +274,8 @@ func (r *Router) Setup(engine *gin.Engine) { provider.GET("/:provider_name/instances/:instance_name", r.providerHandler.ShowProviderInstance) provider.GET("/:provider_name/instances/:instance_name/balance", r.providerHandler.ShowInstanceBalance) provider.GET("/:provider_name/instances/:instance_name/connection", r.providerHandler.CheckProviderConnection) + provider.GET("/:provider_name/instances/:instance_name/tasks", r.providerHandler.ListTasks) + provider.GET("/:provider_name/instances/:instance_name/tasks/:task_id", r.providerHandler.ShowTask) provider.PUT("/:provider_name/instances/:instance_name", r.providerHandler.AlterProviderInstance) provider.DELETE("/:provider_name/instances", r.providerHandler.DropProviderInstance) provider.GET("/:provider_name/instances/:instance_name/models", r.providerHandler.ListInstanceModels) @@ -271,6 +285,10 @@ func (r *Router) Setup(engine *gin.Engine) { v1.POST("/chat/completions", r.providerHandler.ChatToModel) v1.POST("/embeddings", r.providerHandler.EmbedText) v1.POST("/rerank", r.providerHandler.RerankDocument) + v1.POST("/audio/transcriptions", r.providerHandler.TranscribeAudio) + v1.POST("/audio/speech", r.providerHandler.AudioSpeech) + v1.POST("/file/ocr", r.providerHandler.OCRFile) + v1.POST("/file/parse", r.providerHandler.ParseFile) } model := v1.Group("/models") @@ -279,6 +297,11 @@ func (r *Router) Setup(engine *gin.Engine) { model.PATCH("/", r.tenantHandler.SetModels) } + connector := v1.Group("/connectors") + { + connector.GET("/", r.connectorHandler.ListConnectors) + } + system := v1.Group("/system") { system.GET("/configs", r.systemHandler.GetConfigs) diff --git a/internal/server/config.go b/internal/server/config.go index 25f1b41876c..27e97b24720 100644 --- a/internal/server/config.go +++ b/internal/server/config.go @@ -723,6 +723,23 @@ func FromConfigFile(configPath string) error { } } + if v.IsSet("minio_0") { + minioConfig := v.Sub("minio_0") + if minioConfig != nil { + if globalConfig.StorageEngine.Minio == nil { + globalConfig.StorageEngine.Minio = &MinioConfig{ + Host: minioConfig.GetString("host"), + User: minioConfig.GetString("user"), + Password: minioConfig.GetString("password"), + Secure: minioConfig.GetBool("secure"), + PrefixPath: minioConfig.GetString("prefix_path"), + Verify: minioConfig.GetBool("verify"), + Bucket: minioConfig.GetString("bucket"), + } + } + } + } + if v.IsSet("s3") { s3Config := v.Sub("s3") if s3Config != nil { diff --git a/internal/service/api_token.go b/internal/service/api_token.go index 9f44d740199..667610ae1e5 100644 --- a/internal/service/api_token.go +++ b/internal/service/api_token.go @@ -20,7 +20,6 @@ import ( "ragflow/internal/dao" "ragflow/internal/entity" "ragflow/internal/utility" - "time" ) // TokenResponse token response @@ -67,9 +66,6 @@ type CreateAPITokenRequest struct { func (s *SystemService) CreateAPIToken(tenantID string, req *CreateAPITokenRequest) (*TokenResponse, error) { APITokenDAO := dao.NewAPITokenDAO() - now := time.Now().Unix() - nowDate := time.Now() - // Generate token and beta values // token: "ragflow-" + secrets.token_urlsafe(32) APIToken := utility.GenerateAPIToken() @@ -81,8 +77,6 @@ func (s *SystemService) CreateAPIToken(tenantID string, req *CreateAPITokenReque Token: APIToken, Beta: &betaAPIKey, } - APITokenData.CreateDate = &nowDate - APITokenData.CreateTime = &now if err := APITokenDAO.Create(APITokenData); err != nil { return nil, err diff --git a/internal/service/chat.go b/internal/service/chat.go index f386d727997..060bd3cc566 100644 --- a/internal/service/chat.go +++ b/internal/service/chat.go @@ -21,7 +21,6 @@ import ( "fmt" "ragflow/internal/entity" "strings" - "time" "unicode/utf8" "github.com/google/uuid" @@ -453,10 +452,6 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo newID = newID[:32] } - // Get current time - now := time.Now().Truncate(time.Second) - createTime := now.UnixMilli() - // Set default language language := "English" @@ -480,10 +475,6 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo KBIDs: kbIDsJSON, Status: strPtr("1"), } - chat.CreateTime = &createTime - chat.CreateDate = &now - chat.UpdateTime = &createTime - chat.UpdateDate = &now if err := s.chatDAO.Create(chat); err != nil { return nil, errors.New("Fail to new a chat") @@ -498,9 +489,6 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo }, nil } - // Update existing chat - also update update_time - now := time.Now().Truncate(time.Second) - updateTime := now.UnixMilli() updateData := map[string]interface{}{ "name": name, "description": description, @@ -515,8 +503,6 @@ func (s *ChatService) SetDialog(userID string, req *SetDialogRequest) (*SetDialo "similarity_threshold": similarityThreshold, "vector_similarity_weight": vectorSimilarityWeight, "kb_ids": kbIDsJSON, - "update_time": updateTime, - "update_date": now, } if err := s.chatDAO.UpdateByID(req.DialogID, updateData); err != nil { diff --git a/internal/service/chat_session.go b/internal/service/chat_session.go index 206b6e76b43..50402f9d7c2 100644 --- a/internal/service/chat_session.go +++ b/internal/service/chat_session.go @@ -79,8 +79,6 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe updates := map[string]interface{}{ "name": name, "user_id": userID, - "update_time": time.Now().UnixMilli(), - "update_date": time.Now(), } if err := s.chatSessionDAO.UpdateByID(req.SessionID, updates); err != nil { @@ -118,9 +116,6 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe } } - now := time.Now().Truncate(time.Second) - createTime := time.Now().UnixMilli() - // Create initial message - store as JSON object with messages array messagesObj := map[string]interface{}{ "messages": []map[string]interface{}{ @@ -144,10 +139,6 @@ func (s *ChatSessionService) SetChatSession(userID string, req *SetChatSessionRe UserID: &userID, Reference: referenceJSON, } - session.CreateTime = &createTime - session.CreateDate = &now - session.UpdateTime = &createTime - session.UpdateDate = &now if err := s.chatSessionDAO.Create(session); err != nil { return nil, errors.New("Fail to create a chat session") @@ -459,8 +450,6 @@ func (s *ChatSessionService) updateSessionMessages(session *entity.ChatSession, updates := map[string]interface{}{ "message": messagesJSON, "reference": referenceJSON, - "update_time": time.Now().UnixMilli(), - "update_date": time.Now(), } s.chatSessionDAO.UpdateByID(session.ID, updates) } diff --git a/internal/service/chunk.go b/internal/service/chunk.go index c2ce08d4e5b..5cc3dcf97c3 100644 --- a/internal/service/chunk.go +++ b/internal/service/chunk.go @@ -63,7 +63,7 @@ func NewChunkService() *ChunkService { // RetrievalTestRequest retrieval test request type RetrievalTestRequest struct { - KbID interface{} `json:"kb_id" binding:"required"` // string or []string + Datasets []string `json:"dataset_ids" binding:"required"` // string or []string Question string `json:"question" binding:"required"` Page *int `json:"page,omitempty"` Size *int `json:"size,omitempty"` @@ -105,7 +105,7 @@ type RetrievalTestResponse struct { // 7. knowledge graph retrieval (not implemented) // 8. Apply retrieval by children to group child chunks under parent chunks func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) (*RetrievalTestResponse, error) { - common.Info("RetrievalTest started", zap.String("userID", userID), zap.Any("kbID", req.KbID), zap.String("question", req.Question)) + common.Info("RetrievalTest started", zap.String("userID", userID), zap.Any("kbID", req.Datasets), zap.String("question", req.Question)) common.Debug(fmt.Sprintf("RetrievalTest request:\n"+ " kbID=%v\n"+ @@ -120,7 +120,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( " rerankID=%v\n"+ " keyword=%v\n"+ " similarityThreshold=%v, vectorSimilarityWeight=%v", - req.KbID, req.Question, + req.Datasets, req.Question, ptrString(req.Page), ptrString(req.Size), req.DocIDs, ptrString(req.UseKG), ptrString(req.TopK), req.CrossLanguages, ptrString(req.SearchID), req.Filter, @@ -134,20 +134,6 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( ctx := context.Background() - // Determine kb_id list and check permission for each kb_id - var kbIDs []string - switch v := req.KbID.(type) { - case string: - kbIDs = []string{v} - case []string: - kbIDs = v - default: - return nil, fmt.Errorf("kb_id must be string or array of strings") - } - if len(kbIDs) == 0 { - return nil, fmt.Errorf("kb_id cannot be empty") - } - tenants, err := s.userTenantDAO.GetByUserID(userID) if err != nil { return nil, fmt.Errorf("failed to get user tenants: %w", err) @@ -159,13 +145,13 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( var tenantIDs []string var kbRecords []*entity.Knowledgebase - for _, kbID := range kbIDs { + for _, datasetID := range req.Datasets { found := false for _, tenant := range tenants { - kb, err := s.kbDAO.GetByIDAndTenantID(kbID, tenant.TenantID) + kb, err := s.kbDAO.GetByIDAndTenantID(datasetID, tenant.TenantID) if err == nil && kb != nil { common.Debug("Found knowledge base in database", - zap.String("kbID", kbID), + zap.String("datasetID", datasetID), zap.String("tenantID", tenant.TenantID), zap.String("kbName", kb.Name), zap.String("embdID", kb.EmbdID)) @@ -227,7 +213,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( } } - // If no chatID from search_config, or chatModel not found, use tenant default + // If no chatID from search_config, or chatModel not found, use tenant default if chatModelForFilter == nil { tenantSvc := NewTenantService() modelName, err := tenantSvc.GetDefaultModelName(tenantIDs[0], entity.ModelTypeChat) @@ -253,7 +239,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( if filter != nil { // Get flattened metadata metadataSvc := NewMetadataService() - flattedMeta, err := metadataSvc.GetFlattedMetaByKBs(kbIDs) + flattedMeta, err := metadataSvc.GetFlattedMetaByKBs(req.Datasets) if err != nil { common.Warn("Failed to get flatted metadata", zap.Error(err)) } else { @@ -393,7 +379,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( retrievalReq := &nlp.RetrievalRequest{ TenantIDs: tenantIDs, Question: modifiedQuestion, - KbIDs: kbIDs, + KbIDs: req.Datasets, DocIDs: docIDs, Page: getPageNum(req.Page, 1), PageSize: getPageSize(req.Size, 30), @@ -427,7 +413,7 @@ func (s *ChunkService) RetrievalTest(req *RetrievalTestRequest, userID string) ( delete(filteredChunks[i], "vector") } - common.Info("RetrievalTest completed", zap.String("userID", userID), zap.Any("kbID", req.KbID), zap.String("question", req.Question), zap.Int64("chunkCount", int64(len(filteredChunks)))) + common.Info("RetrievalTest completed", zap.String("userID", userID), zap.Any("kbID", req.Datasets), zap.String("question", req.Question), zap.Int64("chunkCount", int64(len(filteredChunks)))) return &RetrievalTestResponse{ Chunks: filteredChunks, @@ -904,7 +890,7 @@ func (s *ChunkService) UpdateChunk(req *UpdateChunkRequest, userID string) error "id": req.ChunkID, } - err = s.docEngine.UpdateDataset(ctx, condition, d, indexName, req.DatasetID) + err = s.docEngine.UpdateChunks(ctx, condition, d, indexName, req.DatasetID) if err != nil { return fmt.Errorf("failed to update chunk: %w", err) } @@ -984,7 +970,7 @@ func (s *ChunkService) RemoveChunks(req *RemoveChunksRequest, userID string) (in return 0, fmt.Errorf("either chunk_ids or delete_all must be provided") } - deletedCount, err := s.docEngine.Delete(ctx, condition, indexName, doc.KbID) + deletedCount, err := s.docEngine.DeleteChunks(ctx, condition, indexName, doc.KbID) if err != nil { return 0, fmt.Errorf("failed to delete chunks: %w", err) } diff --git a/internal/service/datasets.go b/internal/service/dataset.go similarity index 91% rename from internal/service/datasets.go rename to internal/service/dataset.go index 271f457a20d..19be7425258 100644 --- a/internal/service/datasets.go +++ b/internal/service/dataset.go @@ -22,7 +22,6 @@ import ( "fmt" "ragflow/internal/entity" "strings" - "time" "github.com/google/uuid" "gorm.io/gorm" @@ -58,17 +57,21 @@ var ( datasetChunkMethodErrorMessage = "Input should be 'naive', 'book', 'email', 'laws', 'manual', 'one', 'paper', 'picture', 'presentation', 'qa', 'resume', 'table' or 'tag'" ) -// DatasetsService implements the RESTful dataset APIs from dataset_api.py. -type DatasetsService struct { +// DatasetService implements the RESTful dataset APIs from dataset_api.py. +type DatasetService struct { kbDAO *dao.KnowledgebaseDAO + documentDAO *dao.DocumentDAO + connectorDAO *dao.ConnectorDAO tenantDAO *dao.TenantDAO tenantLLMDAO *dao.TenantLLMDAO } -// NewDatasetsService creates a new datasets service. -func NewDatasetsService() *DatasetsService { - return &DatasetsService{ +// NewDatasetService creates a new datasets service. +func NewDatasetService() *DatasetService { + return &DatasetService{ kbDAO: dao.NewKnowledgebaseDAO(), + documentDAO: dao.NewDocumentDAO(), + connectorDAO: dao.NewConnectorDAO(), tenantDAO: dao.NewTenantDAO(), tenantLLMDAO: dao.NewTenantLLMDAO(), } @@ -105,7 +108,7 @@ type CreateDatasetRequest struct { } // ListDatasets lists datasets with pagination and filtering. -func (s *DatasetsService) ListDatasets(id, name string, page, pageSize int, orderby string, desc bool, keywords string, ownerIDs []string, parserID, userID string) ([]map[string]interface{}, int64, common.ErrorCode, error) { +func (s *DatasetService) ListDatasets(id, name string, page, pageSize int, orderby string, desc bool, keywords string, ownerIDs []string, parserID, userID string) ([]map[string]interface{}, int64, common.ErrorCode, error) { id = strings.TrimSpace(id) if id != "" { normalizedID, err := normalizeDatasetUUID1(id) @@ -187,7 +190,7 @@ func (s *DatasetsService) ListDatasets(id, name string, page, pageSize int, orde } // CreateDataset creates a new dataset. -func (s *DatasetsService) CreateDataset(req *CreateDatasetRequest, tenantID string) (map[string]interface{}, common.ErrorCode, error) { +func (s *DatasetService) CreateDataset(req *CreateDatasetRequest, tenantID string) (map[string]interface{}, common.ErrorCode, error) { if !isValidString(req.Name) { return nil, common.CodeDataError, errors.New("Dataset name must be string.") } @@ -392,8 +395,6 @@ func (s *DatasetsService) CreateDataset(req *CreateDatasetRequest, tenantID stri return nil, common.CodeServerError, errors.New("Internal server error") } - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) status := string(entity.StatusValid) // Deduplicate name within tenant duplicateName, err := common.DuplicateName(func(n, tid string) bool { @@ -416,10 +417,6 @@ func (s *DatasetsService) CreateDataset(req *CreateDatasetRequest, tenantID stri EmbdID: embdID, Status: &status, } - kb.CreateTime = &now - kb.UpdateTime = &now - kb.CreateDate = &nowDate - kb.UpdateDate = &nowDate if description != nil { kb.Description = description @@ -444,7 +441,7 @@ func (s *DatasetsService) CreateDataset(req *CreateDatasetRequest, tenantID stri } // DeleteDatasets deletes multiple datasets. -func (s *DatasetsService) DeleteDatasets(ids []string, deleteAll bool, tenantID string) (map[string]interface{}, common.ErrorCode, error) { +func (s *DatasetService) DeleteDatasets(ids []string, deleteAll bool, tenantID string) (map[string]interface{}, common.ErrorCode, error) { normalizedIDs := make([]string, 0, len(ids)) seenIDs := make(map[string]struct{}, len(ids)) @@ -523,7 +520,51 @@ func (s *DatasetsService) DeleteDatasets(ids []string, deleteAll bool, tenantID }, common.CodeSuccess, nil } -func (s *DatasetsService) deleteDataset(tenantID string, kb *entity.Knowledgebase) error { +// GetDataset gets a single dataset with its size and linked connectors. +func (s *DatasetService) GetDataset(datasetID, userID string) (map[string]interface{}, common.ErrorCode, error) { + datasetID = strings.TrimSpace(datasetID) + if datasetID == "" { + return nil, common.CodeDataError, errors.New("Lack of \"Dataset ID\"") + } + + normalizedID, err := normalizeDatasetUUID1(datasetID) + if err != nil { + return nil, common.CodeDataError, err + } + datasetID = normalizedID + + if !s.kbDAO.Accessible(datasetID, userID) { + return nil, common.CodeDataError, fmt.Errorf("User '%s' lacks permission for dataset '%s'", userID, datasetID) + } + + kb, err := s.kbDAO.GetByID(datasetID) + if err != nil || kb == nil { + return nil, common.CodeDataError, errors.New("Invalid Dataset ID") + } + + data := datasetToMap(kb) + + size, err := s.documentDAO.SumSizeByDatasetID(datasetID) + if err != nil { + return nil, common.CodeServerError, errors.New("Database operation failed") + } + data["size"] = size + + connectors, err := s.connectorDAO.ListByDatasetID(datasetID) + if err != nil { + return nil, common.CodeServerError, errors.New("Database operation failed") + } + data["connectors"] = connectors + + return data, common.CodeSuccess, nil +} + +// Accessible checks if a knowledge base is accessible by a user +func (s *DatasetService) Accessible(kbID, userID string) bool { + return s.kbDAO.Accessible(kbID, userID) +} + +func (s *DatasetService) deleteDataset(tenantID string, kb *entity.Knowledgebase) error { return dao.DB.Transaction(func(tx *gorm.DB) error { var documents []entity.Document if err := tx.Where("kb_id = ?", kb.ID).Find(&documents).Error; err != nil { @@ -670,7 +711,7 @@ func normalizeDatasetUUID1(id string) (string, error) { return strings.ReplaceAll(parsedUUID.String(), "-", ""), nil } -func (s *DatasetsService) verifyEmbeddingAvailability(embdID string, tenantID string) (bool, string) { +func (s *DatasetService) verifyEmbeddingAvailability(embdID string, tenantID string) (bool, string) { modelName, _, provider, err := parseModelName(embdID) if err != nil { return false, "Embedding model identifier must follow @ format" diff --git a/internal/service/document.go b/internal/service/document.go index 9c1fa0a2912..aeef7fb0048 100644 --- a/internal/service/document.go +++ b/internal/service/document.go @@ -175,17 +175,17 @@ func (s *DocumentService) ListDocuments(page, pageSize int) ([]*DocumentResponse return responses, total, nil } -// ListDocumentsByKBID list documents by knowledge base ID -func (s *DocumentService) ListDocumentsByKBID(kbID string, page, pageSize int) ([]*DocumentResponse, int64, error) { +// ListDocumentsByDatasetID list documents by knowledge base ID +func (s *DocumentService) ListDocumentsByDatasetID(kbID string, page, pageSize int) ([]*entity.DocumentListItem, int64, error) { offset := (page - 1) * pageSize documents, total, err := s.documentDAO.ListByKBID(kbID, offset, pageSize) if err != nil { return nil, 0, err } - responses := make([]*DocumentResponse, len(documents)) + responses := make([]*entity.DocumentListItem, len(documents)) for i, doc := range documents { - responses[i] = s.toResponse(doc) + responses[i] = doc } return responses, total, nil @@ -207,6 +207,13 @@ func (s *DocumentService) GetDocumentsByAuthorID(authorID, page, pageSize int) ( return responses, total, nil } +func (s *DocumentService) ParseDocuments(datasetID, userID string, docIDs []string) error { + // create document parse id + // save to task table + // send to message queue + return nil +} + // toResponse convert model.Document to DocumentResponse func (s *DocumentService) toResponse(doc *entity.Document) *DocumentResponse { createdAt := "" @@ -223,7 +230,12 @@ func (s *DocumentService) toResponse(doc *entity.Document) *DocumentResponse { } updatedAt := "" if doc.UpdateTime != nil { - updatedAt = time.Unix(*doc.UpdateTime, 0).Format("2006-01-02 15:04:05") + // Accept both historical second-based values and current millisecond-based values. + ts := *doc.UpdateTime + if ts > 1000000000000 { + ts /= 1000 + } + updatedAt = time.Unix(ts, 0).Format("2006-01-02 15:04:05") } return &DocumentResponse{ ID: doc.ID, diff --git a/internal/service/file.go b/internal/service/file.go index 662d50010c4..27ee9e8c8ac 100644 --- a/internal/service/file.go +++ b/internal/service/file.go @@ -662,7 +662,7 @@ func (s *FileService) deleteDocumentFromEngine(ctx context.Context, doc *entity. reqCtx, cancel := context.WithTimeout(ctx, 300*time.Second) defer cancel() condition := map[string]interface{}{"doc_id": doc.ID} - if _, err := docEngine.Delete(reqCtx, condition, indexName, doc.KbID); err != nil { + if _, err := docEngine.DeleteChunks(reqCtx, condition, indexName, doc.KbID); err != nil { return fmt.Errorf("delete document from engine: %w", err) } return nil diff --git a/internal/service/kb.go b/internal/service/kb.go index 77d25779267..5017b939383 100644 --- a/internal/service/kb.go +++ b/internal/service/kb.go @@ -27,7 +27,6 @@ import ( "ragflow/internal/utility" "strings" - "time" ) // KnowledgebaseService service class for managing dataset operations @@ -112,7 +111,7 @@ func (s *KnowledgebaseService) CreateDatasetInDocEngine(req *CreateDatasetTableR // Call document engine to create table // Full table name will be built as "{tableName}_{kb_id}" - err = s.docEngine.CreateDataset(context.Background(), tableName, req.KBID, vecSize, req.ParserID) + err = s.docEngine.CreateChunkStore(context.Background(), tableName, req.KBID, vecSize, req.ParserID) if err != nil { return nil, common.CodeServerError, fmt.Errorf("failed to create dataset: %w", err) } @@ -132,11 +131,9 @@ func (s *KnowledgebaseService) DeleteDatasetInDocEngine(kbID string) (common.Err return common.CodeDataError, fmt.Errorf("knowledge base not found: %s", kbID) } - // Build table name: ragflow__ - tableName := fmt.Sprintf("ragflow_%s_%s", kb.TenantID, kbID) - // Call document engine to delete table - err = s.docEngine.DropTable(context.Background(), tableName) + err = s.docEngine.DropChunkStore(context.Background(), fmt.Sprintf("ragflow_%s", kb.TenantID), kbID) + if err != nil { return common.CodeServerError, fmt.Errorf("failed to delete table: %w", err) } @@ -213,11 +210,6 @@ func (s *KnowledgebaseService) UpdateKB(req *UpdateKBRequest, userID string) (ma updates["parser_config"] = req.ParserConfig } - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) - updates["update_time"] = now - updates["update_date"] = nowDate - // Update in database if err := s.kbDAO.UpdateByID(req.KBID, updates); err != nil { return nil, common.CodeServerError, fmt.Errorf("failed to update knowledge base: %w", err) @@ -291,7 +283,7 @@ func (s *KnowledgebaseService) Accessible(kbID, userID string) bool { // RemoveTag removes a tag from documents in a dataset func (s *KnowledgebaseService) RemoveTag(condition map[string]interface{}, newValue map[string]interface{}, indexName, kbID string) error { - return s.docEngine.UpdateDataset(context.Background(), condition, newValue, indexName, kbID) + return s.docEngine.UpdateChunks(context.Background(), condition, newValue, indexName, kbID) } // GetByID retrieves a knowledge base by ID diff --git a/internal/service/memory.go b/internal/service/memory.go index 2ab7272b087..93face56a93 100644 --- a/internal/service/memory.go +++ b/internal/service/memory.go @@ -370,8 +370,6 @@ func (s *MemoryService) CreateMemory(tenantID string, req *CreateMemoryRequest) } memoryTypeInt := dao.CalculateMemoryType(uniqueMemoryTypes) - timestamp := time.Now().UnixMilli() - systemPrompt := PromptAssembler{}.AssembleSystemPrompt(uniqueMemoryTypes) newID := common.GenerateUUID() @@ -402,9 +400,6 @@ func (s *MemoryService) CreateMemory(tenantID string, req *CreateMemoryRequest) memory.TenantLLMID = &llmID } } - memory.CreateTime = ×tamp - memory.UpdateTime = ×tamp - if err := s.memoryDAO.Create(memory); err != nil { return nil, errors.New("could not create new memory") } @@ -504,7 +499,7 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda } if req.ForgettingPolicy != nil { - fp := ForgettingPolicy(strings.ToLower(*req.ForgettingPolicy)) + fp := ForgettingPolicy(strings.ToUpper(strings.TrimSpace(*req.ForgettingPolicy))) if !validForgettingPolicies[fp] { return nil, fmt.Errorf("forgetting policy '%s' is not supported", *req.ForgettingPolicy) } @@ -549,6 +544,108 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda return formatRetDataFromMemory(currentMemory), nil } + currentMemoryTypes := dao.GetMemoryTypeHuman(currentMemory.MemoryType) + normalizedCurrentMemoryTypes := normalizeMemoryTypes(currentMemoryTypes) + + filteredUpdateDict := make(map[string]interface{}, len(updateDict)) + for field, value := range updateDict { + switch field { + case "name": + currentName := strings.TrimSpace(currentMemory.Name) + requestName := strings.TrimSpace(fmt.Sprint(value)) + if currentName != requestName { + filteredUpdateDict[field] = value + } + case "permissions": + currentPermissions := strings.ToLower(strings.TrimSpace(currentMemory.Permissions)) + requestPermissions := strings.ToLower(strings.TrimSpace(fmt.Sprint(value))) + if currentPermissions != requestPermissions { + filteredUpdateDict[field] = value + } + case "llm_id": + currentLLMID := strings.TrimSpace(currentMemory.LLMID) + requestLLMID := strings.TrimSpace(fmt.Sprint(value)) + if currentLLMID != requestLLMID { + filteredUpdateDict[field] = value + } + case "embd_id": + currentEmbdID := strings.TrimSpace(currentMemory.EmbdID) + requestEmbdID := strings.TrimSpace(fmt.Sprint(value)) + if currentEmbdID != requestEmbdID { + filteredUpdateDict[field] = value + } + case "tenant_llm_id": + if currentMemory.TenantLLMID == nil || *currentMemory.TenantLLMID != value.(int64) { + filteredUpdateDict[field] = value + } + case "tenant_embd_id": + if currentMemory.TenantEmbdID == nil || *currentMemory.TenantEmbdID != value.(int64) { + filteredUpdateDict[field] = value + } + case "memory_type": + if types, ok := value.([]string); ok { + if !sameStringSet(normalizedCurrentMemoryTypes, normalizeMemoryTypes(types)) { + filteredUpdateDict[field] = value + } + } else { + filteredUpdateDict[field] = value + } + case "memory_size": + if currentMemory.MemorySize != value.(int64) { + filteredUpdateDict[field] = value + } + case "forgetting_policy": + currentForgettingPolicy := strings.ToUpper(strings.TrimSpace(currentMemory.ForgettingPolicy)) + requestForgettingPolicy := strings.ToUpper(strings.TrimSpace(fmt.Sprint(value))) + if currentForgettingPolicy != requestForgettingPolicy { + filteredUpdateDict[field] = value + } + case "temperature": + if currentMemory.Temperature != value.(float64) { + filteredUpdateDict[field] = value + } + case "avatar": + currentAvatar := "" + if currentMemory.Avatar != nil { + currentAvatar = *currentMemory.Avatar + } + if currentAvatar != fmt.Sprint(value) { + filteredUpdateDict[field] = value + } + case "description": + currentDescription := "" + if currentMemory.Description != nil { + currentDescription = *currentMemory.Description + } + if currentDescription != fmt.Sprint(value) { + filteredUpdateDict[field] = value + } + case "system_prompt": + currentSystemPrompt := "" + if currentMemory.SystemPrompt != nil { + currentSystemPrompt = *currentMemory.SystemPrompt + } + if currentSystemPrompt != fmt.Sprint(value) { + filteredUpdateDict[field] = value + } + case "user_prompt": + currentUserPrompt := "" + if currentMemory.UserPrompt != nil { + currentUserPrompt = *currentMemory.UserPrompt + } + if currentUserPrompt != fmt.Sprint(value) { + filteredUpdateDict[field] = value + } + default: + filteredUpdateDict[field] = value + } + } + updateDict = filteredUpdateDict + + if len(updateDict) == 0 { + return formatRetDataFromMemory(currentMemory), nil + } + memorySize := currentMemory.MemorySize notAllowedUpdate := []string{} for _, f := range []string{"tenant_embd_id", "embd_id", "memory_type"} { @@ -586,6 +683,45 @@ func (s *MemoryService) UpdateMemory(tenantID string, memoryID string, req *Upda return formatRetDataFromMemory(updatedMemory), nil } +func normalizeMemoryTypes(memoryTypes []string) []string { + normalized := make([]string, 0, len(memoryTypes)) + seen := make(map[string]struct{}, len(memoryTypes)) + for _, mt := range memoryTypes { + mt = strings.ToLower(strings.TrimSpace(mt)) + if mt == "" { + continue + } + if _, exists := seen[mt]; exists { + continue + } + seen[mt] = struct{}{} + normalized = append(normalized, mt) + } + return normalized +} + +func sameStringSet(a, b []string) bool { + if len(a) != len(b) { + return false + } + counts := make(map[string]int, len(a)) + for _, item := range a { + counts[item]++ + } + for _, item := range b { + counts[item]-- + if counts[item] < 0 { + return false + } + } + for _, count := range counts { + if count != 0 { + return false + } + } + return true +} + // DeleteMemory deletes a memory by ID // It also deletes associated message indexes before removing the memory record // diff --git a/internal/service/model_service.go b/internal/service/model_service.go index 1a107d4231e..9813010a282 100644 --- a/internal/service/model_service.go +++ b/internal/service/model_service.go @@ -25,7 +25,6 @@ import ( "ragflow/internal/entity" modelModule "ragflow/internal/entity/models" "strings" - "time" ) // parseModelName parses a composite model name in format "model@instance@provider" or "model@provider" @@ -45,6 +44,25 @@ func parseModelName(compositeName string) (modelName, instanceName, providerName } } +func newModelDriverForBaseURL(driver modelModule.ModelDriver, providerName, region, baseURL string) (modelModule.ModelDriver, error) { + if driver == nil { + return nil, fmt.Errorf("provider %s driver not found", providerName) + } + + if strings.TrimSpace(baseURL) == "" { + return driver, nil + } + + newDriver := driver.NewInstance(map[string]string{ + region: baseURL, + }) + if newDriver == nil { + return nil, fmt.Errorf("provider %s does not support custom base_url", providerName) + } + + return newDriver, nil +} + func NewModelProviderService() *ModelProviderService { return &ModelProviderService{ modelProviderDAO: dao.NewTenantModelProviderDAO(), @@ -88,20 +106,14 @@ func (m *ModelProviderService) AddModelProvider(providerName, userID string) (co return common.CodeServerError, errors.New("fail to get UUID") } - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) tenantModelProvider := &entity.TenantModelProvider{ ID: providerID, ProviderName: providerName, TenantID: tenantID, } - tenantModelProvider.CreateTime = &now - tenantModelProvider.UpdateTime = &now - tenantModelProvider.CreateDate = &nowDate - tenantModelProvider.UpdateDate = &nowDate err = m.modelProviderDAO.Create(tenantModelProvider) if err != nil { - return common.CodeServerError, errors.New("fail to create model provider") + return common.CodeServerError, fmt.Errorf("fail to create model provider: %s", err.Error()) } return common.CodeSuccess, nil } @@ -203,11 +215,10 @@ func (m *ModelProviderService) ListSupportedModels(providerName, instanceName, u // For local deployed models if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return nil, err } - - driver = driver.NewInstance(newURL) } return driver.ListModels(apiConfig) @@ -247,8 +258,6 @@ func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName } extraStr := string(extraByte) - now := time.Now().Unix() - nowDate := time.Now().Truncate(time.Second) tenantModelProvider := &entity.TenantModelInstance{ ID: instanceID, InstanceName: instanceName, @@ -257,10 +266,6 @@ func (m *ModelProviderService) CreateProviderInstance(providerName, instanceName Status: "enable", Extra: extraStr, } - tenantModelProvider.CreateTime = &now - tenantModelProvider.UpdateTime = &now - tenantModelProvider.CreateDate = &nowDate - tenantModelProvider.UpdateDate = &nowDate err = m.modelInstanceDAO.Create(tenantModelProvider) if err != nil { @@ -460,10 +465,10 @@ func (m *ModelProviderService) CheckProviderConnection(providerName, instanceNam driver := providerInfo.ModelDriver if baseURL, ok := extra["base_url"]; ok && baseURL != "" { - newURL := map[string]string{ - region: baseURL, + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return common.CodeServerError, err } - driver = driver.NewInstance(newURL) } err = driver.CheckConnection(apiConfig) @@ -473,6 +478,128 @@ func (m *ModelProviderService) CheckProviderConnection(providerName, instanceNam return common.CodeSuccess, nil } +func (m *ModelProviderService) ListTasks(providerName, instanceName, userID string) ([]modelModule.ListTaskStatus, common.ErrorCode, error) { + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, common.CodeServerError, err + } + + if len(tenants) == 0 { + return nil, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, common.CodeServerError, err + } + + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeServerError, fmt.Errorf("provider %s not found", providerName) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + apiConfig := &modelModule.APIConfig{ + ApiKey: nil, + Region: nil, + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + driver := providerInfo.ModelDriver + if baseURL, ok := extra["base_url"]; ok && baseURL != "" { + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return nil, common.CodeServerError, err + } + } + + var listTaskResponse []modelModule.ListTaskStatus + listTaskResponse, err = driver.ListTasks(apiConfig) + if err != nil { + return nil, common.CodeServerError, err + } + return listTaskResponse, common.CodeSuccess, nil +} + +func (m *ModelProviderService) ShowTask(providerName, instanceName, taskID, userID string) (*modelModule.TaskResponse, common.ErrorCode, error) { + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, common.CodeServerError, err + } + + if len(tenants) == 0 { + return nil, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, common.CodeServerError, err + } + + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeServerError, fmt.Errorf("provider %s not found", providerName) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + apiConfig := &modelModule.APIConfig{ + ApiKey: nil, + Region: nil, + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + driver := providerInfo.ModelDriver + if baseURL, ok := extra["base_url"]; ok && baseURL != "" { + driver, err = newModelDriverForBaseURL(driver, providerName, region, baseURL) + if err != nil { + return nil, common.CodeServerError, err + } + } + + var taskResponse *modelModule.TaskResponse + taskResponse, err = driver.ShowTask(taskID, apiConfig) + if err != nil { + return nil, common.CodeServerError, err + } + return taskResponse, common.CodeSuccess, nil +} + func (m *ModelProviderService) AlterProviderInstance(providerName, instanceName, newInstanceName, apiKey, userID string) (common.ErrorCode, error) { return common.CodeSuccess, nil } @@ -738,6 +865,10 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) } + if !model.ModelTypeMap["chat"] && !model.ModelTypeMap["vision"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", modelName, providerName)) + } + modelConfig.ModelClass = model.Class var extra map[string]string @@ -763,6 +894,9 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam } if modelInfo.Status == "active" { + if modelInfo.ModelType != "chat" && modelInfo.ModelType != "vision" { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", modelName, providerName)) + } // For local deployed models providerInfo := dao.GetModelProviderManager().FindProvider(providerName) if providerInfo == nil { @@ -781,10 +915,10 @@ func (m *ModelProviderService) ChatToModelWithMessages(providerName, instanceNam modelConfig.ModelClass = &providerInfo.Class - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.ChatResponse response, err = newProviderInfo.ChatWithMessages(modelName, messages, apiConfig, modelConfig) @@ -833,11 +967,16 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc return common.CodeNotFound, err } - _, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + var model *entity.Model = nil + model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) if err != nil { return common.CodeNotFound, err } + if !model.ModelTypeMap["chat"] && !model.ModelTypeMap["vision"] { + return common.CodeNotFound, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", modelName, providerName)) + } + var extra map[string]string err = json.Unmarshal([]byte(instance.Extra), &extra) if err != nil { @@ -857,6 +996,9 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc } if modelInfo.Status == "active" { + if modelInfo.ModelType != "chat" && modelInfo.ModelType != "vision" { + return common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a chat or multimodal model", modelName, providerName)) + } // For local deployed models providerInfo := dao.GetModelProviderManager().FindProvider(providerName) if providerInfo == nil { @@ -875,10 +1017,10 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc modelConfig.ModelClass = &providerInfo.Class - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) err = newProviderInfo.ChatStreamlyWithSender(modelName, messages, apiConfig, modelConfig, sender) if err != nil { @@ -891,7 +1033,7 @@ func (m *ModelProviderService) ChatToModelStreamWithSender(providerName, instanc } // EmbedText sends texts to the embedding model -func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, userID string, texts []string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.EmbeddingConfig) (*modelModule.EmbeddingResponse, common.ErrorCode, error) { +func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, userID string, texts []string, apiConfig *modelModule.APIConfig, modelConfig *modelModule.EmbeddingConfig) ([]modelModule.EmbeddingData, common.ErrorCode, error) { if apiConfig == nil { apiConfig = &modelModule.APIConfig{} } @@ -949,30 +1091,22 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - var embeddingList [][]float64 - embeddingList, err = providerInfo.ModelDriver.Encode(&modelName, texts, apiConfig, modelConfig) + var response []modelModule.EmbeddingData + response, err = providerInfo.ModelDriver.Embed(&modelName, texts, apiConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } - if embeddingList == nil { + if response == nil || len(response) == 0 { return nil, common.CodeServerError, errors.New("empty embed response") } - response := &modelModule.EmbeddingResponse{ - Data: make([]modelModule.EmbeddingResult, len(embeddingList)), - } - for i, embedding := range embeddingList { - response.Data[i] = modelModule.EmbeddingResult{ - Index: i, - Dimension: len(embedding), - //Embedding: embedding, - } - } - return response, common.CodeSuccess, nil } if modelInfo.Status == "active" { + if modelInfo.ModelType != "embedding" { + return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is an embedding model", modelName, providerName)) + } // For local deployed models providerInfo := dao.GetModelProviderManager().FindProvider(providerName) if providerInfo == nil { @@ -989,31 +1123,20 @@ func (m *ModelProviderService) EmbedText(providerName, instanceName, modelName, apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) - var embeddingList [][]float64 - embeddingList, err = newProviderInfo.Encode(&modelName, texts, apiConfig, modelConfig) + var response []modelModule.EmbeddingData + response, err = newProviderInfo.Embed(&modelName, texts, apiConfig, modelConfig) if err != nil { return nil, common.CodeServerError, err } - if embeddingList == nil { + if response == nil || len(response) == 0 { return nil, common.CodeServerError, errors.New("empty embed response") } - response := &modelModule.EmbeddingResponse{ - Data: make([]modelModule.EmbeddingResult, len(embeddingList)), - } - for i, embedding := range embeddingList { - response.Data[i] = modelModule.EmbeddingResult{ - Index: i, - Dimension: len(embedding), - //Embedding: embedding, - } - } - return response, common.CodeSuccess, nil } @@ -1066,7 +1189,7 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN } if !model.ModelTypeMap["rerank"] { - return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not an embedding model", providerName, modelName)) + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not a rerank model", providerName, modelName)) } var extra map[string]string @@ -1089,6 +1212,9 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN } if modelInfo.Status == "active" { + if modelInfo.ModelType != "rerank" { + return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a rerank model", modelName, providerName)) + } // For local deployed models providerInfo := dao.GetModelProviderManager().FindProvider(providerName) if providerInfo == nil { @@ -1105,10 +1231,10 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN apiConfig.Region = ®ion apiConfig.ApiKey = &instance.APIKey - newURL := map[string]string{ - region: extra["base_url"], + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err } - newProviderInfo := providerInfo.ModelDriver.NewInstance(newURL) var response *modelModule.RerankResponse response, err = newProviderInfo.Rerank(&modelName, query, documents, apiConfig, modelConfig) @@ -1122,6 +1248,636 @@ func (m *ModelProviderService) RerankDocument(providerName, instanceName, modelN return nil, common.CodeServerError, errors.New("model is disabled") } +// TranscribeAudio transcribe audio file to text +func (m *ModelProviderService) TranscribeAudio(providerName, instanceName, modelName, userID string, audioFile *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig) (*modelModule.ASRResponse, common.ErrorCode, error) { + if apiConfig == nil { + apiConfig = &modelModule.APIConfig{} + } + if asrConfig == nil { + asrConfig = &modelModule.ASRConfig{} + } + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, common.CodeServerError, err + } + + if len(tenants) == 0 { + return nil, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var model *entity.Model = nil + model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + + if !model.ModelTypeMap["asr"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not an ASR model", providerName, modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + var response *modelModule.ASRResponse + response, err = providerInfo.ModelDriver.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + if modelInfo.ModelType != "asr" { + return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is an ASR model", modelName, providerName)) + } + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err + } + + var response *modelModule.ASRResponse + response, err = newProviderInfo.TranscribeAudio(&modelName, audioFile, apiConfig, asrConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + return nil, common.CodeServerError, errors.New("model is disabled") +} + +// ChatToModelStreamWithSender streams chat response directly via sender function (best performance, no channel) +func (m *ModelProviderService) TranscribeAudioStream(providerName, instanceName, modelName, userID string, audioFile *string, apiConfig *modelModule.APIConfig, asrConfig *modelModule.ASRConfig, sender func(*string, *string) error) (common.ErrorCode, error) { + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return common.CodeServerError, err + } + + if len(tenants) == 0 { + return common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return common.CodeNotFound, err + } + + var model *entity.Model = nil + model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return common.CodeNotFound, err + } + if !model.ModelTypeMap["asr"] { + return common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not an ASR model", providerName, modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + err = providerInfo.ModelDriver.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender) + if err != nil { + return common.CodeServerError, err + } + + return common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + if modelInfo.ModelType != "asr" { + return common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is an ASR model", modelName, providerName)) + } + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return common.CodeServerError, err + } + + err = newProviderInfo.TranscribeAudioWithSender(&modelName, audioFile, apiConfig, asrConfig, sender) + if err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil + } + + return common.CodeServerError, errors.New("model is disabled") +} + +// TranscribeAudio transcribe audio file to text +func (m *ModelProviderService) AudioSpeech(providerName, instanceName, modelName, userID string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig) (*modelModule.TTSResponse, common.ErrorCode, error) { + if apiConfig == nil { + apiConfig = &modelModule.APIConfig{} + } + if ttsConfig == nil { + ttsConfig = &modelModule.TTSConfig{} + } + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, common.CodeServerError, err + } + + if len(tenants) == 0 { + return nil, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var model *entity.Model = nil + model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + + if !model.ModelTypeMap["tts"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not a TTS model", providerName, modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + var response *modelModule.TTSResponse + response, err = providerInfo.ModelDriver.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + if modelInfo.ModelType != "tts" { + return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a TTS model", modelName, providerName)) + } + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err + } + + var response *modelModule.TTSResponse + response, err = newProviderInfo.AudioSpeech(&modelName, audioContent, apiConfig, ttsConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + return nil, common.CodeServerError, errors.New("model is disabled") +} + +func (m *ModelProviderService) AudioSpeechStream(providerName, instanceName, modelName, userID string, audioContent *string, apiConfig *modelModule.APIConfig, ttsConfig *modelModule.TTSConfig, sender func(*string, *string) error) (common.ErrorCode, error) { + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return common.CodeServerError, err + } + + if len(tenants) == 0 { + return common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return common.CodeNotFound, err + } + + var model *entity.Model = nil + model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return common.CodeNotFound, err + } + + if !model.ModelTypeMap["tts"] { + return common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not a TTS model", providerName, modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + err = providerInfo.ModelDriver.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender) + if err != nil { + return common.CodeServerError, err + } + + return common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + if modelInfo.ModelType != "tts" { + return common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a TTS model", modelName, providerName)) + } + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return common.CodeServerError, err + } + + err = newProviderInfo.AudioSpeechWithSender(&modelName, audioContent, apiConfig, ttsConfig, sender) + if err != nil { + return common.CodeServerError, err + } + return common.CodeSuccess, nil + } + + return common.CodeServerError, errors.New("model is disabled") +} + +func (m *ModelProviderService) OCRFile(providerName, instanceName, modelName, userID string, content []byte, url *string, apiConfig *modelModule.APIConfig, ocrConfig *modelModule.OCRConfig) (*modelModule.OCRFileResponse, common.ErrorCode, error) { + if apiConfig == nil { + apiConfig = &modelModule.APIConfig{} + } + if ocrConfig == nil { + ocrConfig = &modelModule.OCRConfig{} + } + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, common.CodeServerError, err + } + + if len(tenants) == 0 { + return nil, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var model *entity.Model = nil + model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + + if !model.ModelTypeMap["ocr"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not an OCR model", providerName, modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + var response *modelModule.OCRFileResponse + response, err = providerInfo.ModelDriver.OCRFile(&modelName, content, url, apiConfig, ocrConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + if modelInfo.ModelType != "ocr" { + return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is an OCR model", modelName, providerName)) + } + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err + } + + var response *modelModule.OCRFileResponse + response, err = newProviderInfo.OCRFile(&modelName, content, url, apiConfig, ocrConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + return nil, common.CodeServerError, errors.New("model is disabled") +} + +func (m *ModelProviderService) ParseFile(providerName, instanceName, modelName, userID string, content []byte, url *string, apiConfig *modelModule.APIConfig, parseFileConfig *modelModule.ParseFileConfig) (*modelModule.ParseFileResponse, common.ErrorCode, error) { + if apiConfig == nil { + apiConfig = &modelModule.APIConfig{} + } + if parseFileConfig == nil { + parseFileConfig = &modelModule.ParseFileConfig{} + } + + // Get tenant ID from user + tenants, err := m.userTenantDAO.GetByUserIDAndRole(userID, "owner") + if err != nil { + return nil, common.CodeServerError, err + } + + if len(tenants) == 0 { + return nil, common.CodeNotFound, errors.New("user has no tenants") + } + + tenantID := tenants[0].TenantID + + // Check if provider exists + provider, err := m.modelProviderDAO.GetByTenantIDAndProviderName(tenantID, providerName) + if err != nil { + return nil, common.CodeServerError, err + } + + instance, err := m.modelInstanceDAO.GetByProviderIDAndInstanceName(provider.ID, instanceName) + if err != nil { + return nil, common.CodeServerError, err + } + + modelInfo, err := m.modelDAO.GetModelByProviderIDAndInstanceIDAndModelName(provider.ID, instance.ID, modelName) + if err != nil { + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var model *entity.Model = nil + model, err = dao.GetModelProviderManager().GetModelByName(providerName, modelName) + if err != nil { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s not found", providerName, modelName)) + } + + if !model.ModelTypeMap["doc_parse"] { + return nil, common.CodeNotFound, errors.New(fmt.Sprintf("provider %s model %s is not a Document Parse model", providerName, modelName)) + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + var response *modelModule.ParseFileResponse + response, err = providerInfo.ModelDriver.ParseFile(&modelName, content, url, apiConfig, parseFileConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + if modelInfo.Status == "active" { + if modelInfo.ModelType != "doc_parse" { + return nil, common.CodeServerError, errors.New(fmt.Sprintf("expect model %s@%s is a Document Parse model", modelName, providerName)) + } + // For local deployed models + providerInfo := dao.GetModelProviderManager().FindProvider(providerName) + if providerInfo == nil { + return nil, common.CodeNotFound, errors.New("provider not found") + } + + var extra map[string]string + err = json.Unmarshal([]byte(instance.Extra), &extra) + if err != nil { + return nil, common.CodeServerError, err + } + + region := extra["region"] + apiConfig.Region = ®ion + apiConfig.ApiKey = &instance.APIKey + + newProviderInfo, err := newModelDriverForBaseURL(providerInfo.ModelDriver, providerName, region, extra["base_url"]) + if err != nil { + return nil, common.CodeServerError, err + } + + var response *modelModule.ParseFileResponse + response, err = newProviderInfo.ParseFile(&modelName, content, url, apiConfig, parseFileConfig) + if err != nil { + return nil, common.CodeServerError, err + } + if response == nil { + return nil, common.CodeServerError, errors.New("empty chat response") + } + + return response, common.CodeSuccess, nil + } + + return nil, common.CodeServerError, errors.New("model is disabled") +} + // GetEmbeddingModel returns an EmbeddingModel wrapper for the given tenant func (m *ModelProviderService) GetEmbeddingModel(tenantID, compositeModelName string) (*modelModule.EmbeddingModel, error) { driver, modelName, apiConfig, maxTokens, err := m.getModelConfig(tenantID, compositeModelName) diff --git a/internal/service/model_service_test.go b/internal/service/model_service_test.go new file mode 100644 index 00000000000..7b6e138c4ce --- /dev/null +++ b/internal/service/model_service_test.go @@ -0,0 +1,129 @@ +package service + +import ( + "strings" + "testing" + + modelModule "ragflow/internal/entity/models" +) + +type stubModelDriver struct { + modelModule.ModelDriver + newInstance func(map[string]string) modelModule.ModelDriver +} + +var _ modelModule.ModelDriver = (*stubModelDriver)(nil) + +func (s *stubModelDriver) NewInstance(baseURL map[string]string) modelModule.ModelDriver { + if s.newInstance != nil { + return s.newInstance(baseURL) + } + return s +} + +func (s *stubModelDriver) Name() string { + return "stub" +} + +func TestNewModelDriverForBaseURLPreservesEmptyRegion(t *testing.T) { + expected := &stubModelDriver{} + var gotBaseURL map[string]string + driver := &stubModelDriver{ + newInstance: func(baseURL map[string]string) modelModule.ModelDriver { + gotBaseURL = baseURL + return expected + }, + } + + got, err := newModelDriverForBaseURL(driver, "stub", "", "http://localhost:1234") + if err != nil { + t.Fatalf("newModelDriverForBaseURL returned error: %v", err) + } + if got != expected { + t.Fatalf("expected returned driver %p, got %p", expected, got) + } + if gotBaseURL[""] != "http://localhost:1234" { + t.Fatalf("expected empty-region base URL, got %#v", gotBaseURL) + } + if _, ok := gotBaseURL["default"]; ok { + t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL) + } +} + +func TestNewModelDriverForBaseURLUsesProvidedRegion(t *testing.T) { + var gotBaseURL map[string]string + driver := &stubModelDriver{ + newInstance: func(baseURL map[string]string) modelModule.ModelDriver { + gotBaseURL = baseURL + return &stubModelDriver{} + }, + } + + _, err := newModelDriverForBaseURL(driver, "stub", "cn-hangzhou", "http://localhost:5678") + if err != nil { + t.Fatalf("newModelDriverForBaseURL returned error: %v", err) + } + if gotBaseURL["cn-hangzhou"] != "http://localhost:5678" { + t.Fatalf("expected regional base URL, got %#v", gotBaseURL) + } + if _, ok := gotBaseURL["default"]; ok { + t.Fatalf("unexpected default region key in base URL map: %#v", gotBaseURL) + } +} + +func TestNewModelDriverForBaseURLSkipsEmptyBaseURL(t *testing.T) { + for _, baseURL := range []string{"", " "} { + t.Run(baseURL, func(t *testing.T) { + called := false + driver := &stubModelDriver{ + newInstance: func(map[string]string) modelModule.ModelDriver { + called = true + return nil + }, + } + + got, err := newModelDriverForBaseURL(driver, "deepseek", "default", baseURL) + if err != nil { + t.Fatalf("newModelDriverForBaseURL returned error: %v", err) + } + if got != driver { + t.Fatalf("expected original driver %p, got %p", driver, got) + } + if called { + t.Fatal("expected empty base URL to skip NewInstance") + } + }) + } +} + +func TestNewModelDriverForBaseURLRejectsNilInstance(t *testing.T) { + driver := &stubModelDriver{ + newInstance: func(map[string]string) modelModule.ModelDriver { + return nil + }, + } + + got, err := newModelDriverForBaseURL(driver, "deepseek", "default", "http://localhost:1234") + if err == nil { + t.Fatal("expected nil NewInstance result to return an error") + } + if got != nil { + t.Fatalf("expected nil driver on error, got %T", got) + } + if !strings.Contains(err.Error(), "deepseek") || !strings.Contains(err.Error(), "custom base_url") { + t.Fatalf("expected provider-specific custom base_url error, got %v", err) + } +} + +func TestNewModelDriverForBaseURLRejectsNilDriver(t *testing.T) { + got, err := newModelDriverForBaseURL(nil, "deepseek", "default", "http://localhost:1234") + if err == nil { + t.Fatal("expected nil driver to return an error") + } + if got != nil { + t.Fatalf("expected nil driver on error, got %T", got) + } + if !strings.Contains(err.Error(), "driver not found") { + t.Fatalf("expected driver not found error, got %v", err) + } +} diff --git a/internal/service/nlp/retrieval.go b/internal/service/nlp/retrieval.go index 27545711206..4cfd197f89c 100644 --- a/internal/service/nlp/retrieval.go +++ b/internal/service/nlp/retrieval.go @@ -607,12 +607,15 @@ func (s *RetrievalService) Search(ctx context.Context, req *RetrievalSearchReque // GetVector computes query vector and returns MatchDenseExpr for hybrid search func (s *RetrievalService) GetVector(txt string, embModel *models.EmbeddingModel, topk int, similarity float64) (*types.MatchDenseExpr, error) { - embeddings, err := embModel.ModelDriver.Encode(embModel.ModelName, []string{txt}, embModel.APIConfig, nil) + embeddingConfig := &models.EmbeddingConfig{ + Dimension: 0, + } + embeddings, err := embModel.ModelDriver.Embed(embModel.ModelName, []string{txt}, embModel.APIConfig, embeddingConfig) if err != nil { return nil, err } - vector := embeddings[0] + vector := embeddings[0].Embedding vectorSize := len(vector) vectorColumnName := fmt.Sprintf("q_%d_vec", vectorSize) diff --git a/internal/service/skill_indexer.go b/internal/service/skill_indexer.go index ec36a7948e7..10d76d667dd 100644 --- a/internal/service/skill_indexer.go +++ b/internal/service/skill_indexer.go @@ -25,6 +25,7 @@ import ( "ragflow/internal/dao" "ragflow/internal/engine" "ragflow/internal/entity" + "ragflow/internal/entity/models" "ragflow/internal/storage" "ragflow/internal/tokenizer" "strings" @@ -170,7 +171,7 @@ func (s *SkillIndexerService) IndexSkill(ctx context.Context, tenantID, spaceID // For Infinity: ensure table exists with correct dimension BEFORE inserting if docEngine.GetType() == "infinity" { - exists, _ := docEngine.TableExists(ctx, indexName) + exists, _ := docEngine.ChunkStoreExists(ctx, indexName, "skill") if !exists { common.Info(fmt.Sprintf("Creating Infinity table with dimension %d", dimension)) if err := s.createIndexWithDimension(ctx, tenantID, spaceID, docEngine, embdID, dimension); err != nil { @@ -237,7 +238,8 @@ func (s *SkillIndexerService) BatchIndexSkills(ctx context.Context, tenantID, sp // Generate embeddings in batch common.Info(fmt.Sprintf("Generating embeddings for %d skills with embdID=%s", len(skills), embdID)) - vectors, err := s.generateEmbeddings(ctx, vectorTexts, embdID, tenantID) + var vectors []models.EmbeddingData + vectors, err = s.generateEmbeddings(ctx, vectorTexts, embdID, tenantID) if err != nil { common.Warn(fmt.Sprintf("Failed to generate embeddings: %v. Continuing with text-only index.", err)) vectors = nil // Continue without vectors @@ -250,7 +252,7 @@ func (s *SkillIndexerService) BatchIndexSkills(ctx context.Context, tenantID, sp if docEngine.GetType() == "infinity" { // For Infinity: must ensure table exists with correct dimension BEFORE inserting common.Info(fmt.Sprintf("Checking if index exists: %s", indexName)) - exists, err := docEngine.TableExists(ctx, indexName) + exists, err := docEngine.ChunkStoreExists(ctx, indexName, "skill") if err != nil { common.Warn(fmt.Sprintf("Error checking index existence: %v", err)) } @@ -311,7 +313,7 @@ func (s *SkillIndexerService) BatchIndexSkills(ctx context.Context, tenantID, sp // Add vector only if available if vectors != nil && i < len(vectors) { - doc[vectorField] = vectors[i] + doc[vectorField] = vectors[i].Embedding } else { common.Info(fmt.Sprintf("No vector for skill %s, creating text-only index", skill.ID)) // For Infinity: use zero vector as placeholder (table schema requires vector column) @@ -434,10 +436,10 @@ func (s *SkillIndexerService) ReindexAll(ctx context.Context, tenantID, spaceID // Delete existing index and recreate with new dimension (for both ES and Infinity) indexName := getSkillIndexName(tenantID, spaceID) - exists, _ := docEngine.TableExists(ctx, indexName) + exists, _ := docEngine.ChunkStoreExists(ctx, indexName, "skill") if exists { common.Info(fmt.Sprintf("ReindexAll: deleting existing index %s", indexName)) - if err := docEngine.DropTable(ctx, indexName); err != nil { + if err := docEngine.DropChunkStore(ctx, indexName, "skill"); err != nil { common.Warn(fmt.Sprintf("ReindexAll: failed to delete existing index: %v", err)) } } @@ -844,7 +846,7 @@ func (s *SkillIndexerService) InitializeIndex(ctx context.Context, tenantID, spa common.Info("Checking skill index existence", zap.String("indexName", indexName), zap.String("tenantID", tenantID), zap.String("spaceID", spaceID)) - exists, err := docEngine.TableExists(ctx, indexName) + exists, err := docEngine.ChunkStoreExists(ctx, indexName, "skill") if err != nil { common.Error("Failed to check index existence", err) return fmt.Errorf("failed to check index existence: %w", err) @@ -881,22 +883,22 @@ func (s *SkillIndexerService) createIndexWithDimension(ctx context.Context, tena // For Infinity: check if table exists and needs recreation (dimension mismatch) if docEngine.GetType() == "infinity" { - exists, err := docEngine.TableExists(ctx, indexName) + exists, err := docEngine.ChunkStoreExists(ctx, indexName, "skill") if err != nil { common.Warn(fmt.Sprintf("Error checking if index exists: %v", err)) } if exists { common.Info(fmt.Sprintf("Index exists, deleting for recreation with dimension %d", dimension), zap.String("indexName", indexName)) - if err := docEngine.DropTable(ctx, indexName); err != nil { + if err := docEngine.DropChunkStore(ctx, indexName, "skill"); err != nil { common.Warn(fmt.Sprintf("Failed to delete existing index: %v", err)) } } } - // Use the doc engine's CreateDataset method with skill-specific mapping + // Use the doc engine's CreateChunkStore method with skill-specific mapping // The mapping file is loaded from conf/skill_es_mapping.json or conf/skill_infinity_mapping.json - err := docEngine.CreateDataset(ctx, indexName, "skill", dimension, "") + err := docEngine.CreateChunkStore(ctx, indexName, "skill", dimension, "") if err != nil { common.Error("Failed to create skill index", err) return err @@ -932,20 +934,21 @@ func (s *SkillIndexerService) generateEmbedding(ctx context.Context, text, embdI } truncatedText := truncate(text, maxLen-10) - vectors, err := embeddingModel.ModelDriver.Encode(embeddingModel.ModelName, []string{truncatedText}, embeddingModel.APIConfig, nil) + var response []models.EmbeddingData + response, err = embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{truncatedText}, embeddingModel.APIConfig, nil) if err != nil { return nil, fmt.Errorf("failed to encode text: %w", err) } - if len(vectors) == 0 { + if len(response) == 0 { return nil, fmt.Errorf("embedding returned empty result") } - return vectors[0], nil + return response[0].Embedding, nil } // generateEmbeddings generates embeddings for multiple texts in batch // This is more efficient than calling generateEmbedding individually -func (s *SkillIndexerService) generateEmbeddings(ctx context.Context, texts []string, embdID, tenantID string) ([][]float64, error) { +func (s *SkillIndexerService) generateEmbeddings(ctx context.Context, texts []string, embdID, tenantID string) ([]models.EmbeddingData, error) { common.Info(fmt.Sprintf("generateEmbeddings called: texts=%d, embdID=%s, tenantID=%s", len(texts), embdID, tenantID)) if s.modelProvider == nil { @@ -975,18 +978,19 @@ func (s *SkillIndexerService) generateEmbeddings(ctx context.Context, texts []st common.Info(fmt.Sprintf("Encoding %d texts", len(truncatedTexts))) // Use batch encode API (consistent with Python's encode(texts: list)) - vectors, err := embeddingModel.ModelDriver.Encode(embeddingModel.ModelName, truncatedTexts, embeddingModel.APIConfig, nil) + var response []models.EmbeddingData + response, err = embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, truncatedTexts, embeddingModel.APIConfig, nil) if err != nil { common.Error(fmt.Sprintf("Failed to encode texts: %v", err), err) return nil, fmt.Errorf("failed to encode texts: %w", err) } - common.Info(fmt.Sprintf("Encoded successfully, got %d vectors", len(vectors))) - if len(vectors) > 0 { - common.Info(fmt.Sprintf("Vector dimension: %d", len(vectors[0]))) + common.Info(fmt.Sprintf("Encoded successfully, got %d vectors", len(response))) + if len(response) > 0 { + common.Info(fmt.Sprintf("Vector dimension: %d", len(response[0].Embedding))) } - return vectors, nil + return response, nil } // truncate truncates text to maxLen characters @@ -1021,16 +1025,17 @@ func (s *SkillIndexerService) getEmbeddingDimension(ctx context.Context, tenantI // Use simple test text like Python does: embedding_model.encode(["ok"]) testText := "ok" - vectors, err := embeddingModel.ModelDriver.Encode(embeddingModel.ModelName, []string{testText}, embeddingModel.APIConfig, nil) + var response []models.EmbeddingData + response, err = embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{testText}, embeddingModel.APIConfig, nil) if err != nil { return 0, fmt.Errorf("failed to encode test text: %w", err) } - if len(vectors) == 0 || len(vectors[0]) == 0 { + if len(response) == 0 || len(response[0].Embedding) == 0 { return 0, fmt.Errorf("embedding returned empty vector") } - dimension := len(vectors[0]) + dimension := len(response[0].Embedding) common.Info(fmt.Sprintf("Got embedding dimension from API: %d", dimension)) return dimension, nil } diff --git a/internal/service/skill_search.go b/internal/service/skill_search.go index c48d0f1314a..c6c86097a86 100644 --- a/internal/service/skill_search.go +++ b/internal/service/skill_search.go @@ -27,6 +27,7 @@ import ( "ragflow/internal/engine" "ragflow/internal/engine/types" "ragflow/internal/entity" + "ragflow/internal/entity/models" "ragflow/internal/utility" "strings" @@ -225,7 +226,7 @@ func (s *SkillSearchService) Search(ctx context.Context, req *SearchRequest, doc indexName := getSkillIndexName(req.TenantID, req.SpaceID) common.Debug("Searching skills", zap.String("indexName", indexName), zap.String("query", req.Query)) - indexExists, err := docEngine.TableExists(ctx, indexName) + indexExists, err := docEngine.ChunkStoreExists(ctx, indexName, "skill") if err != nil { common.Error("Failed to check index existence", err) return nil, common.CodeOperatingError, fmt.Errorf("failed to check index existence: %w", err) @@ -679,15 +680,16 @@ func (s *SkillSearchService) getEmbedding(ctx context.Context, text, embdID, ten } truncatedText := truncate(text, maxLen-10) - vectors, err := embeddingModel.ModelDriver.Encode(embeddingModel.ModelName, []string{truncatedText}, embeddingModel.APIConfig, nil) + var response []models.EmbeddingData + response, err = embeddingModel.ModelDriver.Embed(embeddingModel.ModelName, []string{truncatedText}, embeddingModel.APIConfig, nil) if err != nil { return nil, fmt.Errorf("failed to encode query: %w", err) } - if len(vectors) == 0 { + if len(response) == 0 { return nil, fmt.Errorf("embedding returned empty result") } - return vectors[0], nil + return response[0].Embedding, nil } // Helper functions diff --git a/internal/service/skill_space.go b/internal/service/skill_space.go index e40907fec4e..d325dd90028 100644 --- a/internal/service/skill_space.go +++ b/internal/service/skill_space.go @@ -126,8 +126,6 @@ func (s *SkillSpaceService) getSkillsFolderID(tenantID string) (string, error) { // Skills folder not found, create it common.Info("Creating skills folder", zap.String("tenant_id", tenantID)) folderID := generateSpaceID() - now := time.Now() - createTime := now.UnixMilli() folder := &entity.File{ ID: folderID, ParentID: rootFolder.ID, @@ -137,12 +135,6 @@ func (s *SkillSpaceService) getSkillsFolderID(tenantID string) (string, error) { Type: "folder", Size: 0, SourceType: "system", - BaseModel: entity.BaseModel{ - CreateTime: &createTime, - UpdateTime: &createTime, - CreateDate: &now, - UpdateDate: &now, - }, } if err := s.fileDAO.Create(folder); err != nil { @@ -218,8 +210,6 @@ func (s *SkillSpaceService) CreateSpace(req *CreateSpaceRequest) (map[string]int // Generate space ID and folder ID spaceID := generateSpaceID() folderID := generateSpaceID() - timestamp := time.Now().UnixMilli() - now := time.Now() // Create folder for the space under skills folder folder := &entity.File{ @@ -249,8 +239,6 @@ func (s *SkillSpaceService) CreateSpace(req *CreateSpaceRequest) (map[string]int RerankID: req.RerankID, TopK: 10, Status: "1", - CreateTime: ×tamp, - UpdateTime: &now, } if err := s.spaceDAO.Create(space); err != nil { @@ -535,7 +523,7 @@ func (s *SkillSpaceService) asyncDeleteSpace(spaceID, folderID, tenantID string, indexName := getSkillIndexName(tenantID, spaceID) common.Info("Async deleting space index", zap.String("index", indexName), zap.String("spaceID", spaceID)) deleteCtx, cancel := context.WithTimeout(context.Background(), 60*time.Second) - if err := docEngine.DropTable(deleteCtx, indexName); err != nil { + if err := docEngine.DropChunkStore(deleteCtx, indexName, "skill"); err != nil { common.Warn("Failed to delete space index during async delete", zap.String("index", indexName), zap.Error(err)) // Continue with other cleanup steps } else { diff --git a/internal/service/tenant.go b/internal/service/tenant.go index 54606f58eb8..08eee5d89aa 100644 --- a/internal/service/tenant.go +++ b/internal/service/tenant.go @@ -267,11 +267,8 @@ func (s *TenantService) GetTenantList(userID string) ([]*TenantListItem, error) // CreateMetadataInDocEngine creates the document metadata table for a tenant func (s *TenantService) CreateMetadataInDocEngine(tenantID string) (common.ErrorCode, error) { - // Build table name: ragflow_doc_meta_ - tableName := fmt.Sprintf("ragflow_doc_meta_%s", tenantID) - // Call document engine to create doc meta table - err := s.docEngine.CreateMetadata(context.Background(), tableName) + err := s.docEngine.CreateMetadataStore(context.Background(), tenantID) if err != nil { return common.CodeServerError, fmt.Errorf("failed to create metadata table: %w", err) } @@ -281,11 +278,8 @@ func (s *TenantService) CreateMetadataInDocEngine(tenantID string) (common.Error // DeleteMetadataInDocEngine deletes the document metadata table for a tenant func (s *TenantService) DeleteMetadataInDocEngine(tenantID string) (common.ErrorCode, error) { - // Build table name: ragflow_doc_meta_ - tableName := fmt.Sprintf("ragflow_doc_meta_%s", tenantID) - // Call document engine to delete doc meta table - err := s.docEngine.DropTable(context.Background(), tableName) + err := s.docEngine.DropMetadataStore(context.Background(), tenantID) if err != nil { return common.CodeServerError, fmt.Errorf("failed to delete doc meta table: %w", err) } diff --git a/internal/service/user.go b/internal/service/user.go index 6b117697c4d..dd6b9493048 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -150,14 +150,6 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error IsSuperuser: &isSuperuser, } - now := time.Now().Unix() - user.CreateTime = &now - user.UpdateTime = &now - nowDate := time.Now().Truncate(time.Second) - user.CreateDate = &nowDate - user.UpdateDate = &nowDate - user.LastLoginTime = &nowDate - tenantName := req.Nickname + "'s Kingdom" llmID := cfg.UserDefaultLLM.DefaultModels.ChatModel.Name @@ -192,11 +184,6 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error ParserIDs: "naive:General,Q&A:Q&A,manual:Manual,table:Table,paper:Research Paper,book:Book,laws:Laws,presentation:Presentation,picture:Picture,one:One,audio:Audio,email:Email,tag:Tag", Status: &status, } - tenant.CreateTime = &now - tenant.UpdateTime = &now - tenant.CreateDate = &nowDate - tenant.UpdateDate = &nowDate - userTenantID := utility.GenerateToken() userTenant := &entity.UserTenant{ ID: userTenantID, @@ -206,11 +193,6 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error InvitedBy: userID, Status: &status, } - userTenant.CreateTime = &now - userTenant.UpdateTime = &now - userTenant.CreateDate = &nowDate - userTenant.UpdateDate = &nowDate - fileID := utility.GenerateToken() rootFile := &entity.File{ ID: fileID, @@ -221,11 +203,6 @@ func (s *UserService) Register(req *RegisterRequest) (*entity.User, common.Error Type: "folder", Size: 0, } - rootFile.CreateTime = &now - rootFile.UpdateTime = &now - rootFile.CreateDate = &nowDate - rootFile.UpdateDate = &nowDate - tenantDAO := dao.NewTenantDAO() userTenantDAO := dao.NewUserTenantDAO() fileDAO := dao.NewFileDAO() @@ -298,13 +275,9 @@ func (s *UserService) Login(req *LoginRequest) (*entity.User, common.ErrorCode, // Generate new access token token := utility.GenerateToken() - if err := s.UpdateUserAccessToken(user, token); err != nil { - return nil, common.CodeServerError, fmt.Errorf("failed to update access token: %w", err) - } - - // Update timestamp - now := time.Now().Unix() - user.UpdateTime = &now + user.AccessToken = &token + now := time.Now().Truncate(time.Second) + user.LastLoginTime = &now if err := s.userDAO.Update(user); err != nil { return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err) } @@ -339,11 +312,9 @@ func (s *UserService) LoginByEmail(req *EmailLoginRequest) (*entity.User, common // Generate new access token token := utility.GenerateToken() user.AccessToken = &token + now := time.Now().Truncate(time.Second) + user.LastLoginTime = &now - now := time.Now().Unix() - user.UpdateTime = &now - now_date := time.Now().Truncate(time.Second) - user.UpdateDate = &now_date if err := s.userDAO.Update(user); err != nil { return nil, common.CodeServerError, fmt.Errorf("failed to update user: %w", err) } @@ -803,15 +774,19 @@ func (s *UserService) UpdateUserSettings(user *entity.User, req *UpdateSettingsR if req.Avatar != nil { // In Go version, avatar might be stored differently // For now, just update if field exists + user.Avatar = req.Avatar } if req.Language != nil { // Store language preference + user.Language = req.Language } if req.ColorSchema != nil { // Store color schema preference + user.ColorSchema = req.ColorSchema } if req.Timezone != nil { // Store timezone preference + user.Timezone = req.Timezone } // Save updated user diff --git a/mcp/server/server.py b/mcp/server/server.py index bc3a362901e..81c8da3f073 100644 --- a/mcp/server/server.py +++ b/mcp/server/server.py @@ -32,7 +32,7 @@ from starlette.middleware import Middleware from starlette.responses import JSONResponse, Response from starlette.routing import Mount, Route -from strenum import StrEnum +from enum import StrEnum class LaunchMode(StrEnum): diff --git a/memory/services/query.py b/memory/services/query.py index 0e97f1fc2b0..e2bce608b98 100644 --- a/memory/services/query.py +++ b/memory/services/query.py @@ -21,7 +21,7 @@ from common.doc_store.doc_store_base import MatchDenseExpr, MatchTextExpr from common.float_utils import get_float from rag.nlp import rag_tokenizer, term_weight, synonym - +from rag.utils.redis_conn import REDIS_CONN def get_vector(txt, emb_mdl, topk=10, similarity=0.1): if isinstance(similarity, str) and len(similarity) > 0: @@ -44,7 +44,7 @@ class MsgTextQuery(QueryBase): def __init__(self): self.tw = term_weight.Dealer() - self.syn = synonym.Dealer() + self.syn = synonym.Dealer(redis=REDIS_CONN.REDIS if REDIS_CONN.is_alive() else None) self.query_fields = [ "content" ] diff --git a/pyproject.toml b/pyproject.toml index 9c41642a04e..c864f7b0567 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,17 @@ [project] name = "ragflow" -version = "0.25.2" +version = "0.25.5" description = "[RAGFlow](https://ragflow.io/) is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding. It offers a streamlined RAG workflow for businesses of any scale, combining LLM (Large Language Models) to provide truthful question-answering capabilities, backed by well-founded citations from various complex formatted data." authors = [{ name = "Zhichang Yu", email = "yuzhichang@gmail.com" }] license-files = ["LICENSE"] readme = "README.md" -requires-python = ">=3.12,<3.15" +requires-python = ">=3.13,<3.15" dependencies = [ + # discord-py==2.3.2 unconditionally imports audioop in discord/player.py at module- + # load time. audioop was removed from the CPython stdlib in Python 3.13 (PEP 594), + # so any import of the discord package raises ImportError on Python 3.13 — even in + # tests that never use voice features. audioop-lts provides the module as a backport. + "audioop-lts>=0.2.1", "aiosmtplib>=5.0.0", "akshare>=1.15.78,<2.0.0", "anthropic==0.34.1", @@ -16,6 +21,7 @@ dependencies = [ "azure-storage-file-datalake==12.16.0", "beartype>=0.20.0,<1.0.0", "bio==1.7.1", + "boto3>=1.28.0", "boxsdk>=10.1.0", "captcha>=0.7.1", "chardet>=5.2.0,<6.0.0", @@ -49,7 +55,7 @@ dependencies = [ "groq==0.9.0", "grpcio-status==1.67.1", "html-text==0.6.2", - "infinity-sdk==0.7.0-dev6", + "infinity-sdk==0.7.0", "infinity-emb>=0.0.66,<0.0.67", "jira==3.10.5", "json-repair==0.35.0", @@ -73,7 +79,7 @@ dependencies = [ "opencv-python-headless==4.10.0.84", "opendal>=0.45.0,<0.46.0", "opensearch-py==2.7.1", - "ormsgpack==1.5.0", + "ormsgpack>=1.5.0", "pdfplumber==0.10.4", "pluginlib>=0.10.0", "psycopg2-binary>=2.9.11,<3.0.0", @@ -86,6 +92,7 @@ dependencies = [ "pypdf>=6.10.2", "python-calamine>=0.4.0", "python-docx>=1.1.2,<2.0.0", + "paramiko>=3.5.1", "python-pptx>=1.0.2,<2.0.0", # "pywencai>=0.13.1,<1.0.0", # Temporarily disabled: conflicts with agentrun-sdk (pydash>=8), needed for agent/tools/wencai.py "qianfan==0.4.6", @@ -100,12 +107,13 @@ dependencies = [ "ruamel-yaml>=0.18.6,<0.19.0", "scholarly==1.7.11", "selenium-wire==5.1.0", + "spacy==3.8.14", + "en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl", "slack-sdk==3.37.0", "socksio==1.0.0", "agentrun-sdk>=0.0.16,<1.0.0", "nest-asyncio>=1.6.0,<2.0.0", # Needed for agent/component/message.py "sqlglotrs==0.9.0", - "strenum==0.4.15", "tavily-python==0.5.1", "tencentcloud-sdk-python==3.0.1478", "tika==2.6.0", @@ -176,7 +184,6 @@ test = [ "pycryptodomex==3.20.0", "pytest-playwright>=0.7.2", "codecov>=2.1.13", - "tensorflow-cpu>=2.17.0", ] [tool.uv] @@ -184,11 +191,31 @@ constraint-dependencies = [ # CVE-2026-30922: Denial of Service via unbounded recursion in ASN.1 decoding (CVSS 7.5 HIGH) # pyasn1 < 0.6.3 is vulnerable; pulled in transitively via google-auth / rsa / pyasn1-modules "pyasn1>=0.6.3", + # Python 3.13 added pathlib.PurePath.parser as a public class attribute holding + # the posixpath/ntpath module. trio<0.26 introspects all Path class attributes to + # generate async forwards and raises TypeError on any non-callable attribute it + # encounters (fixed in trio 0.26 by skipping non-callables). Pulled in transitively + # via selenium-wire -> trio-websocket -> trio. + "trio>=0.26.0", ] +override-dependencies = [ + # moodlepy<=0.24.1 pins attrs<23.0.0, but trio>=0.26.0 requires attrs>=23.2.0. + # attrs 23.x is backward-compatible; moodlepy works fine at runtime with it. + "attrs>=23.2.0", +] +# trio 0.26+ (Python 3.13 compatible) is not yet on the Aliyun mirror. +# Mark PyPI as explicit so it is used only for packages listed in [tool.uv.sources]. +[[tool.uv.index]] +name = "pypi" +url = "https://pypi.org/simple" +explicit = true [[tool.uv.index]] url = "https://mirrors.aliyun.com/pypi/simple" +[tool.uv.sources] +trio = [{ index = "pypi" }] + [tool.setuptools] packages = [ 'agent', diff --git a/rag/app/audio.py b/rag/app/audio.py index 29ef625fad4..2741c91a906 100644 --- a/rag/app/audio.py +++ b/rag/app/audio.py @@ -35,8 +35,8 @@ def chunk(filename, binary, tenant_id, lang, callback=None, **kwargs): if not ext: raise RuntimeError("No extension detected.") - if ext not in [".da", ".wave", ".wav", ".mp3", ".wav", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", - ".realaudio", ".vqf", ".oggvorbis", ".aac", ".ape"]: + if ext not in [".da", ".wave", ".wav", ".mp3", ".aac", ".flac", ".ogg", ".aiff", ".au", ".midi", ".wma", + ".realaudio", ".vqf", ".oggvorbis", ".ape"]: raise RuntimeError(f"Extension {ext} is not supported yet.") tmp_path = "" diff --git a/rag/app/laws.py b/rag/app/laws.py index e2fe885ffa2..46829d23c2e 100644 --- a/rag/app/laws.py +++ b/rag/app/laws.py @@ -95,7 +95,7 @@ def __str__(self) -> str: class Pdf(PdfParser): def __init__(self): - self.model_speciess = ParserType.LAWS.value + self.model_species = ParserType.LAWS.value super().__init__() def __call__(self, filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, zoomin=3, callback=None): diff --git a/rag/app/manual.py b/rag/app/manual.py index b3f5f2edc17..c2e17aeb20d 100644 --- a/rag/app/manual.py +++ b/rag/app/manual.py @@ -32,7 +32,7 @@ class Pdf(PdfParser): def __init__(self): - self.model_speciess = ParserType.MANUAL.value + self.model_species = ParserType.MANUAL.value super().__init__() def __call__(self, filename, binary=None, from_page=0, to_page=MAXIMUM_PAGE_NUMBER, zoomin=3, callback=None): diff --git a/rag/app/naive.py b/rag/app/naive.py index f91e2a8f946..7bf4743e7db 100644 --- a/rag/app/naive.py +++ b/rag/app/naive.py @@ -131,6 +131,19 @@ def by_mineru( ocr_model_config = get_model_config_by_type_and_name(tenant_id, LLMType.OCR, mineru_llm_name) ocr_model = LLMBundle(tenant_id=tenant_id, model_config=ocr_model_config, lang=lang) pdf_parser = ocr_model.mdl + + # Closes #14869: when the tenant has an IMAGE2TEXT model + # configured, let the MinerU parser enrich image chunks with + # VLM-generated semantic descriptions (parity with deepdoc's + # VisionFigureParser). Best-effort — fall back silently if + # no vision model is available. + if "vision_model" not in kwargs: + try: + vision_model_config = get_tenant_default_model_by_type(tenant_id, LLMType.IMAGE2TEXT) + kwargs["vision_model"] = LLMBundle(tenant_id=tenant_id, model_config=vision_model_config, lang=lang) + except Exception as vlm_err: + logging.info(f"[MinerU] no IMAGE2TEXT model for tenant; skipping image VLM enhancement: {vlm_err}") + sections, tables = pdf_parser.parse_pdf( filepath=filename, binary=binary, diff --git a/rag/app/paper.py b/rag/app/paper.py index 82ddb8bc838..f578a5fc7a8 100644 --- a/rag/app/paper.py +++ b/rag/app/paper.py @@ -30,7 +30,7 @@ class Pdf(PdfParser): def __init__(self): - self.model_speciess = ParserType.PAPER.value + self.model_species = ParserType.PAPER.value super().__init__() def __call__(self, filename, binary=None, from_page=0, diff --git a/rag/app/table.py b/rag/app/table.py index ea553ca0f9d..5f4fabd527e 100644 --- a/rag/app/table.py +++ b/rag/app/table.py @@ -36,6 +36,7 @@ from deepdoc.parser import ExcelParser from common import settings +logger = logging.getLogger(__name__) class Excel(ExcelParser): def __call__(self, fnm, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, callback=None, **kwargs): @@ -49,11 +50,11 @@ def __call__(self, fnm, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMB res, fails, done = [], [], 0 rn = 0 flow_images = [] - pending_cell_images = [] tables = [] for sheet_name in wb.sheetnames: ws = wb[sheet_name] images = Excel._extract_images_from_worksheet(ws, sheetname=sheet_name) + pending_cell_images = [] if images: image_descriptions = vision_figure_parser_figure_xlsx_wrapper(images=images, callback=callback, **kwargs) @@ -372,6 +373,11 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, Every row in table will be treated as a chunk. """ + _pc0 = kwargs.get("parser_config") or {} + logger.debug(f"[TABLE_PARSER_DEBUG] parser_config keys: {list(_pc0.keys())}") + logger.debug(f"[TABLE_PARSER_DEBUG] table_column_mode: {_pc0.get('table_column_mode')}") + logger.debug(f"[TABLE_PARSER_DEBUG] table_column_roles: {_pc0.get('table_column_roles')}") + tbls = [] is_english = lang.lower() == "english" if re.search(r"\.xlsx?$", filename, re.IGNORECASE): @@ -435,6 +441,19 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, # Field type suffixes for database columns # Maps data types to their database field suffixes fields_map = {"text": "_tks", "int": "_long", "keyword": "_kwd", "float": "_flt", "datetime": "_dt", "bool": "_kwd"} + parser_config = kwargs.get("parser_config") or {} + if parser_config.get("table_column_mode") == "manual": + column_roles = parser_config.get("table_column_roles") or {} + else: + column_roles = {} + logger.debug( + f"[TABLE_PARSER_DEBUG] effective table_column_mode={parser_config.get('table_column_mode')!r}, " + f"column_roles keys={list(column_roles.keys())}" + ) + + # Pass 1: infer columns per sheet (multi-sheet Excel => multiple DataFrames). Merge field_map and + # table_column_names, then update KB once so the UI role selector sees all columns, not only the last sheet. + sheet_specs = [] for df in dfs: for n in ["id", "_id", "index", "idx"]: if n in df.columns: @@ -457,22 +476,64 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, txts.extend([str(c) for c in cln if c]) clmns_map = [(py_clmns[i].lower() + fields_map[clmn_tys[i]], str(clmns[i]).replace("_", " ")) for i in range(len(clmns))] - # For Infinity/OceanBase: Use original column names as keys since they're stored in chunk_data JSON - # For ES/OS: Use full field names with type suffixes (e.g., url_kwd, body_tks) + # field_map: only columns stored in chunk_data (metadata or both) — used for retrieval/SQL + stored_indices = [ + i for i in range(len(clmns)) + if column_roles.get(clmns[i], "both") in ("metadata", "both") + ] if settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE: - # For Infinity/OceanBase: key = original column name, value = display name - field_map = {py_clmns[i].lower(): str(clmns[i]).replace("_", " ") for i in range(len(clmns))} + field_map = { + py_clmns[i].lower(): str(clmns[i]).replace("_", " ") + for i in stored_indices + } else: - # For ES/OS: key = typed field name, value = display name - field_map = {k: v for k, v in clmns_map} - logging.debug(f"Field map: {field_map}") - KnowledgebaseService.update_parser_config(kwargs["kb_id"], {"field_map": field_map}) + field_map = { + clmns_map[i][0]: clmns_map[i][1] + for i in stored_indices + } + logging.debug(f"Field map (sheet): {field_map}") + sheet_specs.append( + { + "df": df, + "clmns": clmns, + "clmn_tys": clmn_tys, + "clmns_map": clmns_map, + "py_clmns": py_clmns, + "field_map": field_map, + } + ) + + merged_field_map = {} + merged_table_column_names = [] + seen_col = set() + for spec in sheet_specs: + merged_field_map.update(spec["field_map"]) + for col in spec["clmns"]: + if col not in seen_col: + seen_col.add(col) + merged_table_column_names.append(col) + + logging.debug(f"Field map (merged across sheets): {merged_field_map}") + kb_id = kwargs.get("kb_id") + if kb_id: + KnowledgebaseService.update_parser_config( + kb_id, + {"field_map": merged_field_map, "table_column_names": merged_table_column_names}, + ) - eng = lang.lower() == "english" # is_english(txts) + eng = lang.lower() == "english" # is_english(txts) + for spec in sheet_specs: + df = spec["df"] + clmns = spec["clmns"] + clmn_tys = spec["clmn_tys"] + clmns_map = spec["clmns_map"] + py_clmns = spec["py_clmns"] + _debug_row_idx = 0 for ii, row in df.iterrows(): + _debug_row_idx += 1 d = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))} - row_fields = [] - data_json = {} # For Infinity: Store all columns in a JSON object + text_fields = [] # indexing + both -> content_with_weight + stored = {} # metadata + both -> chunk_data (Infinity) or typed fields (ES) for j in range(len(clmns)): if row[clmns[j]] is None: continue @@ -480,27 +541,49 @@ def chunk(filename, binary=None, from_page=0, to_page=MAXIMUM_TASK_PAGE_NUMBER, continue if not isinstance(row[clmns[j]], pd.Series) and pd.isna(row[clmns[j]]): continue - # For Infinity/OceanBase: Store in chunk_data JSON column - # For Elasticsearch/OpenSearch: Store as individual fields with type suffixes - if settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE: - data_json[str(clmns[j])] = row[clmns[j]] - else: - fld = clmns_map[j][0] - d[fld] = row[clmns[j]] if clmn_tys[j] != "text" else rag_tokenizer.tokenize(row[clmns[j]]) - row_fields.append((clmns[j], row[clmns[j]])) - if not row_fields: + col_name = clmns[j] + role = column_roles.get(col_name, "both") + if _debug_row_idx == 1: + logger.debug(f"[TABLE_PARSER_DEBUG] Column '{col_name}' -> role '{role}'") + if role in ("indexing", "vectorize", "both"): + text_fields.append((col_name, row[col_name])) + if role in ("metadata", "both"): + if settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE: + stored[str(col_name)] = row[col_name] + else: + fld = clmns_map[j][0] + if clmn_tys[j] != "text": + stored[fld] = row[col_name] + else: + cell = row[col_name] + stored[fld] = rag_tokenizer.tokenize(cell) + raw_s = str(cell).strip() if cell is not None else "" + if raw_s: + stored[f"{py_clmns[j].lower()}_raw"] = raw_s + if not text_fields and not stored: continue - # Add the data JSON field to the document (for Infinity/OceanBase) if settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE: - d["chunk_data"] = data_json - # Format as a structured text for better LLM comprehension - # Format each field as "- Field Name: Value" on separate lines - formatted_text = "\n".join([f"- {field}: {value}" for field, value in row_fields]) + if stored: + d["chunk_data"] = stored + else: + d.update(stored) + formatted_text = "\n".join([f"- {field}: {value}" for field, value in text_fields]) if text_fields else "" tokenize(d, formatted_text, eng) + if _debug_row_idx == 1: + logger.debug( + f"[TABLE_PARSER_DEBUG] Chunk content_with_weight length: {len(d.get('content_with_weight', '') or '')}" + ) + _cd = d.get("chunk_data") + logger.debug( + f"[TABLE_PARSER_DEBUG] Chunk chunk_data keys: {list(_cd.keys()) if isinstance(_cd, dict) else 'N/A'}" + ) + if not (settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE): + _extra = [k for k in d if k not in ("docnm_kwd", "title_tks", "content_with_weight", "content_ltks", "content_sm_ltks")] + logger.debug(f"[TABLE_PARSER_DEBUG] Chunk ES extra field keys (sample): {_extra[:20]}") res.append(d) - if tbls: - doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))} - res.extend(tokenize_table(tbls, doc, is_english)) + if tbls: + doc = {"docnm_kwd": filename, "title_tks": rag_tokenizer.tokenize(re.sub(r"\.[a-zA-Z]+$", "", filename))} + res.extend(tokenize_table(tbls, doc, is_english)) callback(0.35, "") return res diff --git a/rag/flow/chunker/title_chunker/common.py b/rag/flow/chunker/title_chunker/common.py index 89981a83de5..0ca6549a960 100644 --- a/rag/flow/chunker/title_chunker/common.py +++ b/rag/flow/chunker/title_chunker/common.py @@ -73,25 +73,61 @@ async def invoke(self): def extract_line_records(self): - # Normalize all upstream payloads into an ordered record stream. - # Level resolution and chunk construction operate on this stream only, - # so strategy code does not depend on source-specific output layouts. + """ + Normalize all upstream input payloads into a unified ordered record stream. + All level resolution and chunk construction logic operates on this standard stream, + decoupling downstream chunking strategies from different upstream output formats. + """ + import logging + logger = logging.getLogger(__name__) + + payload = None + # Extract raw content payload based on upstream output format type if self.from_upstream.output_format == "markdown": payload = self.from_upstream.markdown_result or "" - return [{"text": line, "doc_type_kwd": "text", "img_id": None, "layout": "", PDF_POSITIONS_KEY: []} for line in payload.split("\n") if line] - - if self.from_upstream.output_format == "text": + elif self.from_upstream.output_format == "text": payload = self.from_upstream.text_result or "" - return [{"text": line, "doc_type_kwd": "text", "img_id": None, "layout": "", PDF_POSITIONS_KEY: []} for line in payload.split("\n") if line] - - if self.from_upstream.output_format == "html": + elif self.from_upstream.output_format == "html": payload = self.from_upstream.html_result or "" - return [{"text": line, "doc_type_kwd": "text", "img_id": None, "layout": "", PDF_POSITIONS_KEY: []} for line in payload.split("\n") if line] + + # Boundary robustness fix: explicit None check to distinguish `None` and empty string "" + # Prevents empty payload from unexpectedly falling through to structured chunk branch + if payload is not None: + lines = payload.split("\n") + input_line_count = len(lines) + + # Format-branched text processing to preserve original document semantics + # Plain text: perform full whitespace stripping and invalid empty line filtering + if self.from_upstream.output_format == "text": + clean_lines = [line.strip() for line in lines if line.strip()] + # Markdown & HTML: retain original indentation/spacing, only filter pure blank lines + else: + clean_lines = [line for line in lines if line.strip()] + + output_line_count = len(clean_lines) + # Production observability log: added format dimension per project coding guidelines + logger.info( + f"payload filter: format={self.from_upstream.output_format} before={input_line_count} after={output_line_count}" + ) + + return [ + { + "text": line, + "doc_type_kwd": "text", + "img_id": None, + "layout": "", + PDF_POSITIONS_KEY: [] + } + for line in clean_lines + ] + # Return empty array directly for null payload to block invalid branch fallthrough + return [] items = self.from_upstream.chunks if self.from_upstream.output_format == "chunks" else self.from_upstream.json_result return [ { - "text": str(item.get("text") or ""), + # Serialization fix: avoid None value being converted into literal "None" string + "text": item.get("text") or "", "doc_type_kwd": str(item.get("doc_type_kwd") or "text"), "img_id": item.get("img_id"), "layout": "{} {}".format(item.get("layout_type", ""), item.get("layoutno", "")).strip(), @@ -100,7 +136,6 @@ def extract_line_records(self): for item in items or [] ] - def extract_outlines(self): file = self.from_upstream.file or {} source = ( diff --git a/rag/graphrag/general/index.py b/rag/graphrag/general/index.py index da86fdc48e4..396f3aae0b1 100644 --- a/rag/graphrag/general/index.py +++ b/rag/graphrag/general/index.py @@ -29,6 +29,7 @@ from rag.graphrag.general.extractor import Extractor from rag.graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt from rag.graphrag.light.graph_extractor import GraphExtractor as LightKGExt +from rag.graphrag.ner.graph_extractor import GraphExtractor as NerKGExt from rag.graphrag.phase_markers import ( PHASE_COMMUNITY, PHASE_RESOLUTION, @@ -53,6 +54,40 @@ from common.doc_store.doc_store_base import OrderByExpr +DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE = 4096 + + +def _positive_int_config(config: dict, key: str, default: int) -> int: + value = config.get(key, default) + try: + value = int(value) + except (TypeError, ValueError): + logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default) + return default + if value < 512 or value > 8196: + logging.warning("Invalid GraphRAG config %s=%r, using default %s", key, value, default) + return default + return value + + +def _select_extractor(graphrag_config: dict): + """Return the extractor class matching ``graphrag_config["method"]``. + + Supported values: + - ``"general"`` – Microsoft GraphRAG LLM-based extractor (default in + earlier versions). + - ``"light"`` – LightRAG-style LLM-based extractor (the default when + *method* is omitted or unrecognised). + - ``"ner"`` – NER-based extractor using spaCy (no LLM + needed for entity / relation extraction itself). + """ + method = graphrag_config.get("method", "light") + if method == "general": + return GeneralKGExt + if method == "ner": + return NerKGExt + return LightKGExt + async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str): """Load a previously saved subgraph from the doc store. @@ -102,102 +137,6 @@ async def load_subgraph_from_store(tenant_id: str, kb_id: str, doc_id: str): return None -async def run_graphrag( - row: dict, - language, - with_resolution: bool, - with_community: bool, - chat_model, - embedding_model, - callback, -): - enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") - start = asyncio.get_running_loop().time() - tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"] - chunks = [] - for d in settings.retriever.chunk_list(doc_id, tenant_id, [kb_id], max_count=10000, fields=["content_with_weight", "doc_id"], sort_by_position=True): - chunks.append(d["content_with_weight"]) - - timeout_sec = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 - - try: - subgraph = await asyncio.wait_for( - generate_subgraph( - LightKGExt if "method" not in row["kb_parser_config"].get("graphrag", {}) - or row["kb_parser_config"]["graphrag"]["method"] != "general" - else GeneralKGExt, - tenant_id, - kb_id, - doc_id, - chunks, - language, - row["kb_parser_config"]["graphrag"].get("entity_types", []), - chat_model, - embedding_model, - callback, - ), - timeout=timeout_sec, - ) - except asyncio.TimeoutError: - logging.error("generate_subgraph timeout") - raise - - if not subgraph: - return - - graphrag_task_lock = RedisDistributedLock(f"graphrag_task_{kb_id}", lock_value=doc_id, timeout=1200) - await graphrag_task_lock.spin_acquire() - callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired") - - try: - subgraph_nodes = set(subgraph.nodes()) - new_graph = await merge_subgraph( - tenant_id, - kb_id, - doc_id, - subgraph, - embedding_model, - callback, - ) - assert new_graph is not None - - if not with_resolution and not with_community: - return - - if with_resolution: - await graphrag_task_lock.spin_acquire() - callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired") - await resolve_entities( - new_graph, - subgraph_nodes, - tenant_id, - kb_id, - doc_id, - chat_model, - embedding_model, - callback, - task_id=row["id"], - ) - if with_community: - await graphrag_task_lock.spin_acquire() - callback(msg=f"run_graphrag {doc_id} graphrag_task_lock acquired") - await extract_community( - new_graph, - tenant_id, - kb_id, - doc_id, - chat_model, - embedding_model, - callback, - task_id=row["id"], - ) - finally: - graphrag_task_lock.release() - now = asyncio.get_running_loop().time() - callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.") - return - - async def run_graphrag_for_kb( row: dict, doc_ids: list[str], @@ -215,6 +154,8 @@ async def run_graphrag_for_kb( enable_timeout_assertion = os.environ.get("ENABLE_TIMEOUT_ASSERTION") start = asyncio.get_running_loop().time() fields_for_chunks = ["content_with_weight", "doc_id"] + graphrag_config = kb_parser_config.get("graphrag", {}) + batch_chunk_token_size = _positive_int_config(graphrag_config, "batch_chunk_token_size", DEFAULT_GRAPHRAG_BATCH_CHUNK_TOKEN_SIZE) if not doc_ids: logging.info(f"Fetching all docs for {kb_id}") @@ -242,21 +183,20 @@ def load_doc_chunks(doc_id: str) -> list[str]: chunks = [] current_chunk = "" - # DEBUG: Obtener todos los chunks primero raw_chunks = list(settings.retriever.chunk_list( doc_id, tenant_id, [kb_id], - max_count=10000, # FIX: Aumentar límite para procesar todos los chunks fields=fields_for_chunks, sort_by_position=True, + retrieve_all=True )) - callback(msg=f"[DEBUG] chunk_list() returned {len(raw_chunks)} raw chunks for doc {doc_id}") + callback(msg=f"[GraphRAG] chunk_list returned {len(raw_chunks)} raw chunks for doc:{doc_id}") for d in raw_chunks: content = d["content_with_weight"] - if num_tokens_from_string(current_chunk + content) < 4096: + if num_tokens_from_string(current_chunk + content) < batch_chunk_token_size: current_chunk += content else: if current_chunk: @@ -268,16 +208,7 @@ def load_doc_chunks(doc_id: str) -> list[str]: return chunks - all_doc_chunks: dict[str, list[str]] = {} total_chunks = 0 - for doc_id in doc_ids: - chunks = load_doc_chunks(doc_id) - all_doc_chunks[doc_id] = chunks - total_chunks += len(chunks) - - if total_chunks == 0: - callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") - return {"ok_docs": [], "failed_docs": doc_ids, "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} semaphore = asyncio.Semaphore(max_parallel_docs) @@ -285,18 +216,13 @@ def load_doc_chunks(doc_id: str) -> list[str]: failed_docs: list[tuple[str, str]] = [] # (doc_id, error) async def build_one(doc_id: str): + nonlocal total_chunks + if has_canceled(row["id"]): callback(msg=f"Task {row['id']} cancelled, stopping execution.") raise TaskCanceledException(f"Task {row['id']} was cancelled") - chunks = all_doc_chunks.get(doc_id, []) - if not chunks: - callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.") - return - - kg_extractor = LightKGExt if ("method" not in kb_parser_config.get("graphrag", {}) or kb_parser_config["graphrag"]["method"] != "general") else GeneralKGExt - - deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 + kg_extractor = _select_extractor(graphrag_config) async with semaphore: # CHECKPOINT: bounded by semaphore so doc-store lookups respect max_parallel_docs @@ -306,6 +232,13 @@ async def build_one(doc_id: str): callback(msg=f"[GraphRAG] doc:{doc_id} subgraph found in store, skipping LLM extraction.") return try: + chunks = load_doc_chunks(doc_id) + total_chunks += len(chunks) + if not chunks: + callback(msg=f"[GraphRAG] doc:{doc_id} has no available chunks, skip generation.") + return + + deadline = max(120, len(chunks) * 60 * 10) if enable_timeout_assertion else 10000000000 msg = f"[GraphRAG] build_subgraph doc:{doc_id}" callback(msg=f"{msg} start (chunks={len(chunks)}, timeout={deadline}s)") @@ -356,6 +289,10 @@ async def build_one(doc_id: str): await asyncio.gather(*tasks, return_exceptions=True) raise + if total_chunks == 0 and not subgraphs: + callback(msg=f"[GraphRAG] kb:{kb_id} has no available chunks in all documents, skip.") + return {"ok_docs": [], "failed_docs": [(doc_id, "no available chunks") for doc_id in doc_ids], "total_docs": len(doc_ids), "total_chunks": 0, "seconds": 0.0} + if has_canceled(row["id"]): callback(msg=f"Task {row['id']} cancelled after document processing.") raise TaskCanceledException(f"Task {row['id']} was cancelled") diff --git a/rag/graphrag/ner/__init__.py b/rag/graphrag/ner/__init__.py new file mode 100644 index 00000000000..f65b1742496 --- /dev/null +++ b/rag/graphrag/ner/__init__.py @@ -0,0 +1,18 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +from .graph_extractor import GraphExtractor + +__all__ = ["GraphExtractor"] diff --git a/rag/graphrag/ner/graph_extractor.py b/rag/graphrag/ner/graph_extractor.py new file mode 100644 index 00000000000..67d97346c1f --- /dev/null +++ b/rag/graphrag/ner/graph_extractor.py @@ -0,0 +1,644 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +spaCy-based entity and relationship extractor for GraphRAG. + +Combines techniques from **LinearRAG** and **MGranRAG**: + +* **Entity extraction** uses MGranRAG's multi-pass stacking algorithm + (hyphen/apostrophe merging → capitalised-word merging → continuous + noun/number merging) combined with spaCy NER, then deduplicated via + ``ner_all_keywords``. +* **Relationship inference** follows LinearRAG's *relation-free* approach: + entities co-occurring in the same sentence (or nearby sentences) are + linked by implicit semantic edges whose description is the shared + sentence text (semantic bridging). Edge weights are optionally TF- + normalised. + +No LLM calls are needed for the extraction step itself. The LLM is only +used downstream (inherited from ``Extractor``) for merging / summarising +duplicate entity descriptions when the same entity appears in multiple +chunks. +""" + +import logging +from collections import defaultdict + +from rag.graphrag.general.extractor import Extractor +from rag.llm.chat_model import Base as CompletionLLM + +# --------------------------------------------------------------------------- +# spaCy model loading (lazy, module-level singleton) +# --------------------------------------------------------------------------- +_nlp = None +_nlp_model_name = "" + + +def _load_spacy_model(model_name: str = "en_core_web_sm"): + """Load (or return cached) spaCy language model. + + Automatically downloads the model if it is not yet installed. + """ + global _nlp, _nlp_model_name + if _nlp is not None and _nlp_model_name == model_name: + return _nlp + try: + import spacy + except ImportError: + raise ImportError( + "spaCy is required for the spacy GraphRAG method. " + "Install it with: pip install spacy && python -m spacy download en_core_web_sm" + ) + try: + _nlp = spacy.load(model_name) + logging.info("Loaded spaCy model '%s'", model_name) + except OSError: + logging.warning( + "spaCy model '%s' not found; downloading automatically …", model_name + ) + from spacy.cli import download as spacy_download + spacy_download(model_name) + _nlp = spacy.load(model_name) + logging.info("Downloaded and loaded spaCy model '%s'", model_name) + _nlp_model_name = model_name + return _nlp + + +# --------------------------------------------------------------------------- +# spaCy ↔ application entity-type mapping +# --------------------------------------------------------------------------- +# spaCy's built-in entity labels → the application-level types used by +# ``DEFAULT_ENTITY_TYPES``. Labels not listed here fall through to +# ``"category"``. +SPACY_TO_APP_ENTITY_TYPE: dict[str, str] = { + "PERSON": "person", + "ORG": "organization", + "GPE": "geo", + "LOC": "geo", + "FAC": "geo", + "EVENT": "event", + "PRODUCT": "category", + "WORK_OF_ART": "category", + "LAW": "category", + "LANGUAGE": "category", + "NORP": "category", + "MONEY": "category", + "QUANTITY": "category", + "TIME": "event", + "DATE": "event", +} + +# Labels to skip entirely (from LinearRAG: ordinals / cardinals are rarely +# useful as graph nodes). +_SKIP_SPACY_LABELS = {"ORDINAL", "CARDINAL"} + + +# --------------------------------------------------------------------------- +# MGranRAG-style multi-pass keyword extraction +# --------------------------------------------------------------------------- + +def _has_uppercase(text: str) -> bool: + return any(c.isupper() for c in text) + + +def _replace_word(word: str) -> str: + """Normalise spaces around hyphens and apostrophes (from MGranRAG).""" + return ( + word.replace(" - ", "-") + .replace(" -", "-") + .replace("- ", "-") + .replace(" 's", "'s") + .replace(" 'S", "'S") + ) + + +def extract_keywords(spacy_doc) -> set[str]: + """MGranRAG-style 3-pass stacking keyword extraction. + + Phase 1 — Hyphen / apostrophe merging: + Tokens connected by ``-`` or ``'s`` are merged into a single + phrase labelled ``NP`` (e.g. ``New-York``, ``cat's``). + + Phase 2 — Capitalised-word merging: + Consecutive tokens whose ``shape_`` contains ``X`` (i.e. start + with an uppercase letter) are merged. Function words (ADP, CCONJ, + DET, PART) between them are absorbed as well, producing phrases + like ``King of England``. Merged results are labelled ``NX`` + unless already ``PROPN``. + + Phase 3 — Continuous noun / number merging: + Consecutive tokens with POS in ``[PROPN, NOUN, NUM, NX, NP]`` + are merged and labelled ``NNN`` (unless already ``PROPN``). + + Finally, results with a trailing lowercase non-noun word are + truncated, and coordinating conjunctions (``and``, ``or``) inside a + merged phrase cause it to be split so that each proper noun is + extracted individually (e.g. ``Bob and Lucy`` → ``Bob``, ``Lucy``). + """ + # ── Phase 1: hyphen / apostrophe ────────────────────────────────── + f1_word: list[str] = [] + f1_shape: list[str] = [] + f1_pos: list[str] = [] + f1_pos_list: list[list[str]] = [] + f1_word_list: list[list[str]] = [] + + is_right = False + for token in spacy_doc: + if token.shape_ in ("'x", "-") and token.pos_ in ("PUNCT", "PART"): + if token.shape_ == "-": + is_right = True + if f1_word: + f1_word[-1] += token.text + f1_pos[-1] = "NP" + f1_pos_list[-1].append(token.pos_) + f1_word_list[-1].append(token.text) + elif is_right: + is_right = False + if f1_word: + f1_word[-1] += token.text + f1_pos[-1] = "NP" + f1_pos_list[-1].append(token.pos_) + f1_word_list[-1].append(token.text) + else: + f1_word.append(token.text) + f1_shape.append(token.shape_) + f1_pos.append(token.pos_) + f1_pos_list.append([token.pos_]) + f1_word_list.append([token.text]) + + # ── Phase 2: capitalised-word merging ─────────────────────────── + f2_word: list[str] = [] + f2_shape: list[str] = [] + f2_pos: list[str] = [] + f2_pos_list: list[list[str]] = [] + f2_word_list: list[list[str]] = [] + + for cur in range(len(f1_word)): + cw = f1_word[cur] + cs = f1_shape[cur] + cp = f1_pos[cur] + cpl = f1_pos_list[cur] + cwl = f1_word_list[cur] + + if "X" in cs or cp in ("ADP", "CCONJ", "DET", "PART"): + if f2_word and "X" in f2_shape[-1]: + # Merge with previous capitalised token. + f2_word[-1] += " " + cw + f2_shape[-1] += "X" + if f2_pos[-1] != "PROPN": + f2_pos[-1] = "NX" + f2_pos_list[-1].extend(cpl) + f2_word_list[-1].extend(cwl) + else: + f2_word.append(cw) + f2_shape.append(cs + "Start" if "X" in cs else cs) + f2_pos.append(cp) + f2_pos_list.append(cpl) + f2_word_list.append(cwl) + else: + f2_word.append(cw) + f2_shape.append(cs) + f2_pos.append(cp) + f2_pos_list.append(cpl) + f2_word_list.append(cwl) + + # ── Phase 3: continuous noun / number merging ─────────────────── + f3_word: list[str] = [] + f3_shape: list[str] = [] + f3_pos: list[str] = [] + f3_pos_list: list[list[str]] = [] + f3_word_list: list[list[str]] = [] + + _noun_pos = {"PROPN", "NOUN", "NUM", "NX", "NP"} + _noun_pos_ext = _noun_pos | {"NNN"} + + for cur in range(len(f2_word)): + cw = f2_word[cur] + cs = f2_shape[cur] + cp = f2_pos[cur] + cpl = f2_pos_list[cur] + cwl = f2_word_list[cur] + + if cp in _noun_pos: + if f3_word and f3_pos[-1] in _noun_pos_ext: + f3_word[-1] += " " + cw + f3_shape[-1] += "X" + if f3_pos[-1] != "PROPN": + f3_pos[-1] = "NNN" + f3_pos_list[-1].extend(cpl) + f3_word_list[-1].extend(cwl) + else: + f3_word.append(cw) + f3_shape.append(cs) + f3_pos.append(cp) + f3_pos_list.append(cpl) + f3_word_list.append(cwl) + else: + f3_word.append(cw) + f3_shape.append(cs) + f3_pos.append(cp) + f3_pos_list.append(cpl) + f3_word_list.append(cwl) + + # ── Final keyword collection ──────────────────────────────────── + keywords: set[str] = set() + for cur in range(len(f3_word)): + cw = f3_word[cur] + cp = f3_pos[cur] + cpl = f3_pos_list[cur] + cwl = f3_word_list[cur] + + if cp not in _noun_pos_ext: + continue + + # Truncate trailing lowercase non-noun / non-number words. + if cwl and not _has_uppercase(cwl[-1]) and cpl[-1] not in ( + "PROPN", + "NOUN", + "NUM", + "PART", + ): + for i in range(len(cpl) - 1, 0, -1): + if cpl[i] in ("PROPN", "NOUN", "NUM", "PART") or _has_uppercase( + cwl[i] + ): + break + word = _replace_word(" ".join(cwl[: i + 1])) + keywords.add(word) + else: + word = _replace_word(cw) + keywords.add(word) + + # Split on coordinating conjunctions (and/or) inside merged + # phrases so that individual proper nouns are also extracted + # (e.g. ``Bob and Lucy`` → ``Bob``, ``Lucy``). + if any(p in ("PROPN", "NOUN", "NUM") for p in cpl): + cur_kws: list[str] = [] + for pidx, pos in enumerate(cpl): + if pos == "CCONJ" and cwl[pidx] and cwl[pidx][0].islower(): + if cur_kws: + keywords.add(_replace_word(" ".join(cur_kws))) + cur_kws = [] + else: + cur_kws.append(cwl[pidx]) + if cur_kws: + keywords.add(_replace_word(" ".join(cur_kws))) + + return keywords + + +def get_ner(spacy_doc) -> dict[str, str]: + """Return ``{entity_text: spaCy_label}`` for all NER entities.""" + entities_dict: dict[str, str] = {} + for ent in spacy_doc.ents: + if ent.label_ in _SKIP_SPACY_LABELS: + continue + text = ent.text.strip() + for t in text.split("\n"): + t = t.strip() + if t: + entities_dict[t] = ent.label_ + return entities_dict + + +def ner_all_keywords(spacy_doc) -> set[str]: + """Combine rule-based keyword extraction with spaCy NER (MGranRAG). + + Returns the union of: + - keywords from the 3-pass stacking algorithm (``extract_keywords``) + - entity texts from spaCy NER (``get_ner``) + """ + keywords = extract_keywords(spacy_doc) + ner_dict = get_ner(spacy_doc) + return keywords.union(ner_dict.keys()) + + +# --------------------------------------------------------------------------- +# Main extractor class +# --------------------------------------------------------------------------- + +class GraphExtractor(Extractor): + """Extract entities and relationships using spaCy (no LLM calls). + + Entity extraction + MGranRAG's ``ner_all_keywords`` combines a 3-pass stacking + keyword algorithm with spaCy NER, yielding broader coverage than + NER alone (e.g. it catches compound nouns, hyphenated terms, and + multi-word proper nouns that NER might miss). + + Relationship inference + LinearRAG's *relation-free* semantic bridging: entities + co-occurring in the same sentence (or within + ``max_sentence_distance`` sentences) are linked by an implicit + edge. The edge description is the shared sentence text, which + provides natural language context without requiring an LLM. + + Optionally, edge weights are TF-normalised (LinearRAG): + ``weight = count(entity_in_chunk) / sum(all_entity_counts_in_chunk)``. + + The ``llm_invoker`` is only used downstream for merging / summarising + duplicate descriptions (inherited from ``Extractor``). + + Parameters + ---------- + llm_invoker : CompletionLLM + LLM handle (used only for description summarisation, not extraction). + language : str + Language hint. + entity_types : list[str] | None + Application-level entity types to keep. Entities whose mapped + type is not in this list are discarded. + spacy_model : str + Name of the spaCy model to load (default ``en_core_web_sm``). + max_sentence_distance : int + When inferring relationships, pair entities that co-occur within + the same sentence. If > 1, also pair entities in sentences whose + indices differ by at most this value. + relationship_strength : int + Default weight assigned to every inferred relationship when + ``use_tf_weight`` is ``False``. + use_tf_weight : bool + If ``True``, use TF-normalised weighting (LinearRAG-style) for + edge weights instead of the fixed ``relationship_strength``. + """ + + def __init__( + self, + llm_invoker: CompletionLLM, + language: str | None = "English", + entity_types: list[str] | None = None, + spacy_model: str = "en_core_web_sm", + max_sentence_distance: int = 1, + relationship_strength: int = 1, + use_tf_weight: bool = False, + ): + super().__init__(llm_invoker, language, entity_types) + self._spacy_model_name = spacy_model + self._max_sentence_distance = max_sentence_distance + self._relationship_strength = relationship_strength + self._use_tf_weight = use_tf_weight + # Eagerly load the model so import errors surface early. + self._nlp = _load_spacy_model(spacy_model) + + # ------------------------------------------------------------------ + # Public interface – called by ``Extractor.__call__`` + # ------------------------------------------------------------------ + + async def _process_single_content( + self, + chunk_key_dp: tuple[str, str], + chunk_seq: int, + num_chunks: int, + out_results, + task_id="", + ): + """Process one chunk through spaCy NER + keyword stacking + co-occurrence.""" + chunk_key = chunk_key_dp[0] + content = chunk_key_dp[1] + doc = self._nlp(content) + + # ── 1. Entity extraction (MGranRAG: ner_all_keywords) ──────── + # Build a mapping from keyword text → spaCy label (if available). + ner_label_map: dict[str, str] = get_ner(doc) + all_keywords = ner_all_keywords(doc) + + # For each keyword, determine its app-level entity type. + # - If the keyword matches a NER entity, use that label. + # - Otherwise, infer from POS heuristics. + ent_records: dict[str, dict] = {} # entity_name_upper → record + ent_by_sent: dict[int, list[dict]] = defaultdict(list) + + for kw in all_keywords: + kw_upper = kw.strip().upper() + if not kw_upper: + continue + + # Determine entity type. + spacy_label = ner_label_map.get(kw) + if spacy_label: + app_type = SPACY_TO_APP_ENTITY_TYPE.get(spacy_label, "category") + else: + app_type = self._infer_type_from_pos(doc, kw) + + if app_type not in self._entity_types_set: + continue + + # Determine which sentence this keyword belongs to. + sent_idx = self._keyword_sent_idx(doc, kw) + + # Description: use the containing sentence (LinearRAG semantic bridging). + #sent_text = self._keyword_sent_text(doc, kw) + + ent_record = dict( + entity_name=kw_upper, + entity_type=app_type.upper(), + description="", #sent_text or kw, + source_id=chunk_key, + ) + # A keyword may appear multiple times; keep the first. + if kw_upper not in ent_records: + ent_records[kw_upper] = ent_record + ent_by_sent[sent_idx].append(ent_record) + + maybe_nodes: dict[str, list[dict]] = defaultdict(list) + for name, rec in ent_records.items(): + maybe_nodes[name].append(rec) + + # ── 2. Relationship inference (LinearRAG: sentence co-occurrence) ─ + maybe_edges: dict[tuple, list[dict]] = defaultdict(list) + + # Pre-compute TF weights if needed (LinearRAG). + entity_tf: dict[str, float] = {} + if self._use_tf_weight: + total_count = sum( + content.upper().count(name) for name in ent_records + ) + for name in ent_records: + count = content.upper().count(name) + entity_tf[name] = count / total_count if total_count > 0 else 0.0 + + seen_pairs: set[tuple[str, str]] = set() + for si in sorted(ent_by_sent.keys()): + ents_in_range = list(ent_by_sent[si]) + # Expand with nearby sentences. + for offset in range(1, self._max_sentence_distance + 1): + for nb_si in (si + offset, si - offset): + if nb_si in ent_by_sent: + ents_in_range.extend(ent_by_sent[nb_si]) + # Deduplicate by entity name. + unique: dict[str, dict] = {} + for e in ents_in_range: + unique[e["entity_name"]] = e + ent_list = list(unique.values()) + + for a_idx in range(len(ent_list)): + for b_idx in range(a_idx + 1, len(ent_list)): + ea, eb = ent_list[a_idx], ent_list[b_idx] + pair = tuple(sorted([ea["entity_name"], eb["entity_name"]])) + if pair in seen_pairs: + continue + seen_pairs.add(pair) + + # Relationship description: shared sentence text + # (LinearRAG semantic bridging — the sentence is the + # semantic bridge between entities). + #desc = self._cooccurrence_description(doc, ea["entity_name"], eb["entity_name"]) + + # Edge weight: TF-normalised (LinearRAG) or fixed. + if self._use_tf_weight: + w = (entity_tf.get(ea["entity_name"], 0.0) + + entity_tf.get(eb["entity_name"], 0.0)) + weight = max(w, 0.01) + else: + weight = self._relationship_strength + + # Keywords for the edge: the two entity names. + edge_record = dict( + src_id=pair[0], + tgt_id=pair[1], + weight=weight, + description="", #desc, + keywords=[ea["entity_name"], eb["entity_name"]], + source_id=chunk_key, + ) + maybe_edges[pair].append(edge_record) + + token_count = len(doc) + out_results.append((dict(maybe_nodes), dict(maybe_edges), token_count)) + if self.callback: + self.callback( + 0.5 + 0.1 * len(out_results) / num_chunks, + msg=f"[spacy] Entities extraction of chunk {chunk_seq} " + f"{len(out_results)}/{num_chunks} done, " + f"{len(maybe_nodes)} nodes, {len(maybe_edges)} edges, " + f"{token_count} tokens.", + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + @property + def _entity_types_set(self) -> set[str]: + return {t.lower() for t in self._entity_types} + + @staticmethod + def _infer_type_from_pos(doc, keyword: str) -> str: + """Infer an application-level entity type from POS tags when the + keyword was found by the stacking algorithm but not by NER.""" + kw_upper = keyword.upper() + for token in doc: + if token.text.upper() == kw_upper or token.text.upper().startswith(kw_upper.split()[0]): + if token.pos_ == "PROPN": + return "person" + if token.pos_ == "NOUN": + return "category" + if token.pos_ == "NUM": + return "event" + break + # Fallback: check for uppercase → likely a named entity. + if _has_uppercase(keyword): + return "person" + return "category" + + @staticmethod + def _keyword_sent_idx(doc, keyword: str) -> int: + """Return the sentence index that contains *keyword*.""" + kw_lower = keyword.lower() + for i, sent in enumerate(doc.sents): + if kw_lower in sent.text.lower(): + return i + return 0 + + @staticmethod + def _keyword_sent_text(doc, keyword: str) -> str | None: + """Return the sentence text containing *keyword* (LinearRAG semantic bridging).""" + kw_lower = keyword.lower() + for sent in doc.sents: + if kw_lower in sent.text.lower(): + return sent.text.strip() + return None + + @staticmethod + def _cooccurrence_description(doc, head_name: str, tail_name: str) -> str: + """Derive a relationship description using sentence co-occurrence + (LinearRAG) with dependency-path enhancement as fallback. + + If both entities appear in the same sentence, that sentence is + used as the description (semantic bridging). Otherwise, try to + find a lowest common ancestor in the dependency tree. As a last + resort, return a generic statement. + """ + head_lower = head_name.lower() + tail_lower = tail_name.lower() + + # Primary: shared sentence text (LinearRAG semantic bridging). + for sent in doc.sents: + sent_lower = sent.text.lower() + if head_lower in sent_lower and tail_lower in sent_lower: + return sent.text.strip() + + # Fallback: dependency path via LCA. + head_tok = GraphExtractor._find_token_by_text(doc, head_name) + tail_tok = GraphExtractor._find_token_by_text(doc, tail_name) + if head_tok is not None and tail_tok is not None: + path_head = list(GraphExtractor._ancestor_path(head_tok)) + path_tail = list(GraphExtractor._ancestor_path(tail_tok)) + lca = None + for h in path_head: + for t in path_tail: + if h == t: + lca = h + break + if lca is not None: + break + if lca is not None and lca is not head_tok and lca is not tail_tok: + return f"{head_name} is related to {tail_name} via '{lca.lemma_}'" + + # Final fallback: nearby sentences. + head_sent = GraphExtractor._find_sent_for_text(doc, head_lower) + if head_sent is not None: + return head_sent.text.strip() + + return f"{head_name} is related to {tail_name}" + + @staticmethod + def _find_token_by_text(doc, ent_name: str): + """Return the head token of the first spaCy entity matching *ent_name*.""" + target = ent_name.upper() + for ent in doc.ents: + if ent.text.strip().upper() == target: + return ent.root + # Fallback: token-level match for keywords not in doc.ents. + for token in doc: + if token.text.strip().upper() == target: + return token + return None + + @staticmethod + def _find_sent_for_text(doc, text_lower: str): + """Return the first ``Span`` whose text contains *text_lower*.""" + for sent in doc.sents: + if text_lower in sent.text.lower(): + return sent + return None + + @staticmethod + def _ancestor_path(token): + """Yield *token* then each ancestor up to the root.""" + yield token + for anc in token.ancestors: + yield anc diff --git a/rag/llm/__init__.py b/rag/llm/__init__.py index 8d6db359ce6..4e30c9f91f0 100644 --- a/rag/llm/__init__.py +++ b/rag/llm/__init__.py @@ -19,7 +19,7 @@ import importlib import inspect -from strenum import StrEnum +from enum import StrEnum class SupportedLiteLLMProvider(StrEnum): diff --git a/rag/llm/chat_model.py b/rag/llm/chat_model.py index 45b81a6cc71..1d9612dff3f 100644 --- a/rag/llm/chat_model.py +++ b/rag/llm/chat_model.py @@ -28,7 +28,7 @@ import litellm import openai from openai import AsyncOpenAI, OpenAI -from strenum import StrEnum +from enum import StrEnum from common.misc_utils import thread_pool_exec from common.token_utils import num_tokens_from_string, total_token_count_from_response @@ -391,6 +391,10 @@ async def _exec_tool(tc): name = tc.function.name try: args = json_repair.loads(tc.function.arguments) + if not isinstance(args, dict): + raise TypeError( + f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}" + ) if hasattr(self.toolcall_session, "tool_call_async"): result = await self.toolcall_session.tool_call_async(name, args) else: @@ -493,6 +497,10 @@ async def _exec_tool(tc): name = tc.function.name try: args = json_repair.loads(tc.function.arguments) + if not isinstance(args, dict): + raise TypeError( + f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}" + ) if hasattr(self.toolcall_session, "tool_call_async"): result = await self.toolcall_session.tool_call_async(name, args) else: @@ -1495,7 +1503,11 @@ async def _exceptions_async(self, e, attempt): return msg def _verbose_tool_use(self, name, args, res): - return "" + json.dumps({"name": name, "args": args, "result": res}, ensure_ascii=False, indent=2) + "" + return "" + json.dumps( + {"name": name, "args": args, "result": str(res) if isinstance(res, Exception) else res}, + ensure_ascii=False, + indent=2, + ) + "" def _append_history(self, hist, tool_call, tool_res, reasoning_content=None): assistant_msg = { @@ -1604,6 +1616,8 @@ async def _exec_tool(tc): name = tc.function.name try: args = json_repair.loads(tc.function.arguments) + if not isinstance(args, dict): + raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}") if hasattr(self.toolcall_session, "tool_call_async"): result = await self.toolcall_session.tool_call_async(name, args) else: @@ -1720,6 +1734,8 @@ async def _exec_tool(tc): name = tc.function.name try: args = json_repair.loads(tc.function.arguments) + if not isinstance(args, dict): + raise TypeError(f"Tool arguments for {name} must be a JSON object, got {type(args).__name__}") if hasattr(self.toolcall_session, "tool_call_async"): result = await self.toolcall_session.tool_call_async(name, args) else: diff --git a/rag/llm/cv_model.py b/rag/llm/cv_model.py index 6c3e6e7a1ef..728f1677d2d 100644 --- a/rag/llm/cv_model.py +++ b/rag/llm/cv_model.py @@ -446,6 +446,7 @@ def _request(self, msg, stream, gen_conf=None): "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", }, + timeout=60, ) return response.json() @@ -1029,6 +1030,7 @@ def describe(self, image): "Authorization": f"Bearer {self.key}", }, json={"messages": self.prompt(b64)}, + timeout=60, ) response = response.json() return ( @@ -1046,6 +1048,7 @@ def _request(self, msg, gen_conf=None): "Authorization": f"Bearer {self.key}", }, json={"messages": msg, **gen_conf}, + timeout=60, ) return response.json() @@ -1276,14 +1279,67 @@ class RAGconCV(GptV4): _FACTORY_NAME = "RAGcon" def __init__(self, key, model_name, lang="Chinese", base_url="", **kwargs): - + if not base_url: base_url = "https://connect.ragcon.com/v1" - + # Initialize client self.client = OpenAI(api_key=key, base_url=base_url) self.async_client = AsyncOpenAI(api_key=key, base_url=base_url) self.model_name = model_name self.lang = lang - - Base.__init__(self, **kwargs) \ No newline at end of file + + Base.__init__(self, **kwargs) + + +class BedrockCV(Base): + _FACTORY_NAME = "Bedrock" + + def __init__(self, key, model_name, lang="Chinese", **kwargs): + self.model_name = f"bedrock/{model_name}" + self.lang = lang + self._parse_credentials(key) + Base.__init__(self, **kwargs) + + def _parse_credentials(self, key): + bedrock_key = json.loads(key) + self.auth_mode = bedrock_key.get("auth_mode", "") + self.aws_region = bedrock_key.get("bedrock_region", "us-east-1") + self.aws_ak = bedrock_key.get("bedrock_ak", "") + self.aws_sk = bedrock_key.get("bedrock_sk", "") + self.aws_role_arn = bedrock_key.get("aws_role_arn", "") + + def _get_aws_creds(self): + if self.auth_mode == "access_key_secret": + return { + "aws_region_name": self.aws_region, + "aws_access_key_id": self.aws_ak, + "aws_secret_access_key": self.aws_sk, + } + elif self.auth_mode == "iam_role": + import boto3 + sts_client = boto3.client("sts", region_name=self.aws_region) + resp = sts_client.assume_role(RoleArn=self.aws_role_arn, RoleSessionName="BedrockCVSession") + creds = resp["Credentials"] + return { + "aws_region_name": self.aws_region, + "aws_access_key_id": creds["AccessKeyId"], + "aws_secret_access_key": creds["SecretAccessKey"], + "aws_session_token": creds["SessionToken"], + } + else: + return {"aws_region_name": self.aws_region} + + def describe_with_prompt(self, image, prompt=None): + import litellm + b64 = self.image2base64(image) + messages = self.vision_llm_prompt(b64, prompt) + res = litellm.completion( + model=self.model_name, + messages=messages, + **self._get_aws_creds(), + ) + return res.choices[0].message.content.strip(), total_token_count_from_response(res) + + def describe(self, image): + return self.describe_with_prompt(image) \ No newline at end of file diff --git a/rag/llm/embedding_model.py b/rag/llm/embedding_model.py index 9fe1095527b..ccaa8339010 100644 --- a/rag/llm/embedding_model.py +++ b/rag/llm/embedding_model.py @@ -17,6 +17,7 @@ import os import threading from abc import ABC +from contextlib import contextmanager from urllib.parse import urljoin import dashscope @@ -32,6 +33,76 @@ import logging import base64 +logger = logging.getLogger(__name__) + + +def _dashscope_base_url_for_log(base_url: str) -> str: + """Log host/path only (no query string) so secrets in URLs are not printed.""" + return base_url.split("?", 1)[0].strip()[:256] + + +def _dashscope_native_http_api_url(base_url: str | None) -> str | None: + """ + Resolve the DashScope *native* HTTP API root for Tongyi-Qianwen (Qwen) text embeddings. + + RAGFlow often stores an OpenAI-compatible base URL (e.g. ``.../compatible-mode/v1``) for + the same provider. The ``dashscope`` Python SDK used by ``TextEmbedding.call`` does *not* + use that path; it expects ``https:///api/v1`` instead. + + Users outside mainland China are directed to the international endpoint + (``dashscope-intl.aliyuncs.com``); domestic traffic uses ``dashscope.aliyuncs.com``. + When ``base_url`` already points at the native API root (ends with ``/api/v1``), it is + returned unchanged so custom or regional deployments keep working. + """ + if not base_url: + return None + u = base_url.strip().rstrip("/") + safe = _dashscope_base_url_for_log(u) + if u.endswith("/api/v1"): + logger.debug("DashScope Tongyi-Qianwen embedding: using native API base as configured (%s)", safe) + return u + # International (Singapore) DashScope — required for overseas Tongyi-Qianwen accounts. + if "dashscope-intl.aliyuncs.com" in u: + resolved = "https://dashscope-intl.aliyuncs.com/api/v1" + logger.info( + "DashScope Tongyi-Qianwen embedding: mapped configured base_url to intl native API (%s -> %s)", + safe, + resolved, + ) + return resolved + # China mainland DashScope default host. + if "dashscope.aliyuncs.com" in u: + resolved = "https://dashscope.aliyuncs.com/api/v1" + logger.info( + "DashScope Tongyi-Qianwen embedding: mapped configured base_url to CN native API (%s -> %s)", + safe, + resolved, + ) + return resolved + logger.warning( + "DashScope Tongyi-Qianwen embedding: base_url is set but not recognized as a DashScope host; " + "using SDK default endpoint (%s)", + safe, + ) + return None + + +@contextmanager +def _dashscope_native_api_url_scope(url: str | None): + """ + Temporarily set ``dashscope.base_http_api_url`` for the duration of a single SDK call, + then restore the previous value. Narrows the window where concurrent threads see a mismatch. + """ + if not url: + yield + return + prev = getattr(dashscope, "base_http_api_url", None) + dashscope.base_http_api_url = url + try: + yield + finally: + dashscope.base_http_api_url = prev + class Base(ABC): def __init__(self, key, model_name, **kwargs): @@ -197,11 +268,21 @@ def __init__(self, key, model_name="Baichuan-Text-Embedding", base_url="https:// class QWenEmbed(Base): + """ + Embeddings for Alibaba Tongyi-Qianwen via the DashScope ``TextEmbedding`` API. + + ``base_url`` comes from the user's embedding-model configuration (often the same host + as the OpenAI-compatible chat endpoint). This class maps known DashScope hosts to the + native ``/api/v1`` base URL so international and China endpoints both work. + """ + _FACTORY_NAME = "Tongyi-Qianwen" - def __init__(self, key, model_name="text_embedding_v2", **kwargs): + def __init__(self, key, model_name="text_embedding_v2", base_url=None, **kwargs): self.key = key self.model_name = model_name + # Native API root for the SDK; None if base_url is absent or not a known DashScope host. + self._dashscope_http_api_url = _dashscope_native_http_api_url(base_url) def encode(self, texts: list): import time @@ -214,10 +295,12 @@ def encode(self, texts: list): texts = [truncate(t, 2048) for t in texts] for i in range(0, len(texts), batch_size): retry_max = 5 - resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") + with _dashscope_native_api_url_scope(self._dashscope_http_api_url): + resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") while (resp["output"] is None or resp["output"].get("embeddings") is None) and retry_max > 0: time.sleep(10) - resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") + with _dashscope_native_api_url_scope(self._dashscope_http_api_url): + resp = dashscope.TextEmbedding.call(model=self.model_name, input=texts[i : i + batch_size], api_key=self.key, text_type="document") retry_max -= 1 if retry_max == 0 and (resp["output"] is None or resp["output"].get("embeddings") is None): if resp.get("message"): @@ -237,7 +320,8 @@ def encode(self, texts: list): return np.array(res), token_count def encode_queries(self, text): - resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query") + with _dashscope_native_api_url_scope(self._dashscope_http_api_url): + resp = dashscope.TextEmbedding.call(model=self.model_name, input=text[:2048], api_key=self.key, text_type="query") try: return np.array(resp["output"]["embeddings"][0]["embedding"]), total_token_count_from_response(resp) except Exception as _e: @@ -409,7 +493,7 @@ def encode(self, texts: list[str | bytes], task="retrieval.passage"): data["task"] = task data["truncate"] = True - response = requests.post(self.base_url, headers=self.headers, json=data) + response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) try: res = response.json() for d in res["data"]: @@ -687,7 +771,7 @@ def encode(self, texts: list): "encoding_format": "float", "truncate": "END", } - response = requests.post(self.base_url, headers=self.headers, json=payload) + response = requests.post(self.base_url, headers=self.headers, json=payload, timeout=30) try: res = response.json() ress.extend([d["embedding"] for d in res["data"]]) @@ -827,7 +911,7 @@ def encode(self, texts: list): "input": texts_batch, "encoding_format": "float", } - response = requests.post(self.base_url, json=payload, headers=self.headers) + response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) try: res = response.json() ress.extend([d["embedding"] for d in res["data"]]) @@ -844,7 +928,7 @@ def encode_queries(self, text): "input": text, "encoding_format": "float", } - response = requests.post(self.base_url, json=payload, headers=self.headers) + response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) try: res = response.json() return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res) @@ -954,7 +1038,7 @@ def __init__(self, key, model_name, base_url=None, **kwargs): self.base_url = base_url or "http://127.0.0.1:8080" def encode(self, texts: list): - response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}) + response = requests.post(f"{self.base_url}/embed", json={"inputs": texts}, headers={"Content-Type": "application/json"}, timeout=30) if response.status_code == 200: embeddings = response.json() else: @@ -962,7 +1046,7 @@ def encode(self, texts: list): return np.array(embeddings), sum([num_tokens_from_string(text) for text in texts]) def encode_queries(self, text: str): - response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}) + response = requests.post(f"{self.base_url}/embed", json={"inputs": text}, headers={"Content-Type": "application/json"}, timeout=30) if response.status_code == 200: embedding = response.json()[0] return np.array(embedding), num_tokens_from_string(text) @@ -1163,7 +1247,7 @@ def encode(self, texts: list): "input": [[chunk] for chunk in batch], "encoding_format": "base64_int8", } - response = requests.post(url, headers=self.headers, json=payload) + response = requests.post(url, headers=self.headers, json=payload, timeout=30) try: res = response.json() for doc in res["data"]: @@ -1182,7 +1266,7 @@ def encode(self, texts: list): "input": batch, "encoding_format": "base64_int8", } - response = requests.post(url, headers=self.headers, json=payload) + response = requests.post(url, headers=self.headers, json=payload, timeout=30) try: res = response.json() for d in res["data"]: diff --git a/rag/llm/rerank_model.py b/rag/llm/rerank_model.py index ed569d6bdcf..99801e00a78 100644 --- a/rag/llm/rerank_model.py +++ b/rag/llm/rerank_model.py @@ -17,8 +17,9 @@ import logging from abc import ABC from urllib.parse import urljoin +from typing import Tuple, List +from http import HTTPStatus -import httpx import numpy as np import requests from yarl import URL @@ -28,21 +29,15 @@ class Base(ABC): def __init__(self, key, model_name, **kwargs): - """ - Abstract base class constructor. - Parameters are not stored; initialization is left to subclasses. - """ pass - def similarity(self, query: str, texts: list): + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: raise NotImplementedError("Please implement encode method!") @staticmethod def _normalize_rank(rank: np.ndarray) -> np.ndarray: - """ - Normalize rank values to the range 0 to 1. - Avoids division by zero if all ranks are identical. - """ + if rank.size == 0: + return rank min_rank = np.min(rank) max_rank = np.max(rank) @@ -58,17 +53,21 @@ class JinaRerank(Base): _FACTORY_NAME = "Jina" def __init__(self, key, model_name="jina-reranker-v2-base-multilingual", base_url="https://api.jina.ai/v1/rerank"): - self.base_url = "https://api.jina.ai/v1/rerank" + self.base_url = base_url or "https://api.jina.ai/v1/rerank" self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name - def similarity(self, query: str, texts: list): + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts) if texts else 0, dtype=float), 0 texts = [truncate(t, 8196) for t in texts] data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)} - res = requests.post(self.base_url, headers=self.headers, json=data).json() + response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) + response.raise_for_status() + res = response.json() rank = np.zeros(len(texts), dtype=float) try: - for d in res["results"]: + for d in res.get("results", []): rank[d["index"]] = d["relevance_score"] except Exception as _e: log_exception(_e, res) @@ -89,18 +88,20 @@ def __init__(self, key="x", model_name="", base_url=""): if key and key != "x": self.headers["Authorization"] = f"Bearer {key}" - def similarity(self, query: str, texts: list): - if len(texts) == 0: - return np.array([]), 0 + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts) if texts else 0, dtype=float), 0 pairs = [(query, truncate(t, 4096)) for t in texts] token_count = 0 for _, t in pairs: token_count += num_tokens_from_string(t) data = {"model": self.model_name, "query": query, "return_documents": "true", "return_len": "true", "documents": texts} - res = requests.post(self.base_url, headers=self.headers, json=data).json() + response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) + response.raise_for_status() + res = response.json() rank = np.zeros(len(texts), dtype=float) try: - for d in res["results"]: + for d in res.get("results", []): rank[d["index"]] = d["relevance_score"] except Exception as _e: log_exception(_e, res) @@ -118,8 +119,9 @@ def __init__(self, key, model_name, base_url): self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name.split("___")[0] - def similarity(self, query: str, texts: list): - # noway to config Ragflow , use fix setting + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 texts = [truncate(t, 500) for t in texts] data = { "model": self.model_name, @@ -130,16 +132,17 @@ def similarity(self, query: str, texts: list): token_count = 0 for t in texts: token_count += num_tokens_from_string(t) - res = requests.post(self.base_url, headers=self.headers, json=data).json() + response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) + response.raise_for_status() + res = response.json() rank = np.zeros(len(texts), dtype=float) try: - for d in res["results"]: + for d in res.get("results", []): rank[d["index"]] = d["relevance_score"] except Exception as _e: log_exception(_e, res) rank = Base._normalize_rank(rank) - return rank, token_count @@ -164,7 +167,9 @@ def __init__(self, key, model_name, base_url="https://ai.api.nvidia.com/v1/retri "Authorization": f"Bearer {key}", } - def similarity(self, query: str, texts: list): + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts]) data = { "model": self.model_name, @@ -173,10 +178,12 @@ def similarity(self, query: str, texts: list): "truncate": "END", "top_n": len(texts), } - res = requests.post(self.base_url, headers=self.headers, json=data).json() + response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) + response.raise_for_status() + res = response.json() rank = np.zeros(len(texts), dtype=float) try: - for d in res["rankings"]: + for d in res.get("rankings", []): rank[d["index"]] = d["logit"] except Exception as _e: log_exception(_e, res) @@ -189,8 +196,8 @@ class LmStudioRerank(Base): def __init__(self, key, model_name, base_url, **kwargs): pass - def similarity(self, query: str, texts: list): - raise NotImplementedError("The LmStudioRerank has not been implement") + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + raise NotImplementedError("The LmStudioRerank has not been implemented") class OpenAI_APIRerank(Base): @@ -205,8 +212,9 @@ def __init__(self, key, model_name, base_url): self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} self.model_name = model_name.split("___")[0] - def similarity(self, query: str, texts: list): - # noway to config Ragflow , use fix setting + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 texts = [truncate(t, 500) for t in texts] data = { "model": self.model_name, @@ -217,16 +225,17 @@ def similarity(self, query: str, texts: list): token_count = 0 for t in texts: token_count += num_tokens_from_string(t) - res = requests.post(self.base_url, headers=self.headers, json=data).json() + response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) + response.raise_for_status() + res = response.json() rank = np.zeros(len(texts), dtype=float) try: - for d in res["results"]: + for d in res.get("results", []): rank[d["index"]] = d["relevance_score"] except Exception as _e: log_exception(_e, res) rank = Base._normalize_rank(rank) - return rank, token_count @@ -236,14 +245,15 @@ class CoHereRerank(Base): def __init__(self, key, model_name, base_url=None): from cohere import Client - # Only pass base_url if it's a non-empty string, otherwise use default Cohere API endpoint - client_kwargs = {"api_key": key} + client_kwargs = {"api_key": key, "timeout": 30.0} if base_url and base_url.strip(): client_kwargs["base_url"] = base_url self.client = Client(**client_kwargs) self.model_name = model_name.split("___")[0] - def similarity(self, query: str, texts: list): + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 token_count = num_tokens_from_string(query) + sum([num_tokens_from_string(t) for t in texts]) res = self.client.rerank( model=self.model_name, @@ -267,8 +277,8 @@ class TogetherAIRerank(Base): def __init__(self, key, model_name, base_url, **kwargs): pass - def similarity(self, query: str, texts: list): - raise NotImplementedError("The api has not been implement") + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + raise NotImplementedError("The api has not been implemented") class SILICONFLOWRerank(Base): @@ -288,7 +298,9 @@ def __init__(self, key, model_name, base_url="https://api.siliconflow.cn/v1/rera "authorization": f"Bearer {key}", } - def similarity(self, query: str, texts: list): + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 payload = { "model": self.model_name, "query": query, @@ -298,18 +310,16 @@ def similarity(self, query: str, texts: list): "max_chunks_per_doc": 1024, "overlap_tokens": 80, } - response_raw = requests.post(self.base_url, json=payload, headers=self.headers) - response = response_raw.json() + response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) + response.raise_for_status() + res = response.json() rank = np.zeros(len(texts), dtype=float) try: - for d in response["results"]: + for d in res.get("results", []): rank[d["index"]] = d["relevance_score"] except Exception as _e: log_exception(_e, response) - return ( - rank, - total_token_count_from_response(response), - ) + return rank, total_token_count_from_response(res) class BaiduYiyanRerank(Base): @@ -321,10 +331,12 @@ def __init__(self, key, model_name, base_url=None): key = json.loads(key) ak = key.get("yiyan_ak", "") sk = key.get("yiyan_sk", "") - self.client = Reranker(ak=ak, sk=sk) + self.client = Reranker(ak=ak, sk=sk, request_timeout=30) self.model_name = model_name - def similarity(self, query: str, texts: list): + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 res = self.client.do( model=self.model_name, query=query, @@ -333,7 +345,7 @@ def similarity(self, query: str, texts: list): ).body rank = np.zeros(len(texts), dtype=float) try: - for d in res["results"]: + for d in res.get("results", []): rank[d["index"]] = d["relevance_score"] except Exception as _e: log_exception(_e, res) @@ -346,12 +358,12 @@ class VoyageRerank(Base): def __init__(self, key, model_name, base_url=None): import voyageai - self.client = voyageai.Client(api_key=key) + self.client = voyageai.Client(api_key=key, timeout=30.0) self.model_name = model_name - def similarity(self, query: str, texts: list): - if not texts: - return np.array([]), 0 + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts) if texts else 0, dtype=float), 0 rank = np.zeros(len(texts), dtype=float) res = self.client.rerank(query=query, documents=texts, model=self.model_name, top_k=len(texts)) @@ -368,28 +380,31 @@ class QWenRerank(Base): def __init__(self, key, model_name="gte-rerank", **kwargs): import dashscope - self.api_key = key self.model_name = dashscope.TextReRank.Models.gte_rerank if model_name is None else model_name + # Remove invalid global timeout, use official SDK per-request timeout parameter + self.request_timeout = 30.0 - def similarity(self, query: str, texts: list): - from http import HTTPStatus - + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 + import dashscope - # Build call parameters - call_kwargs = { - "api_key": self.api_key, - "model": self.model_name, - "query": query, - "documents": texts, - "top_n": len(texts) - } - # qwen3-rerank does not support return_documents parameter - if not self.model_name.startswith("qwen3-rerank"): - call_kwargs["return_documents"] = False - - resp = dashscope.TextReRank.call(**call_kwargs) + # Pass official request_timeout parameter to both API call branches + if self.model_name.startswith("qwen3-rerank"): + resp = dashscope.TextReRank.call( + api_key=self.api_key, model=self.model_name, + query=query, documents=texts, top_n=len(texts), + request_timeout=self.request_timeout + ) + else: + resp = dashscope.TextReRank.call( + api_key=self.api_key, model=self.model_name, + query=query, documents=texts, + top_n=len(texts), return_documents=False, + request_timeout=self.request_timeout + ) rank = np.zeros(len(texts), dtype=float) if resp.status_code == HTTPStatus.OK: @@ -407,16 +422,26 @@ class HuggingfaceRerank(Base): _FACTORY_NAME = "HuggingFace" @staticmethod - def post(query: str, texts: list, url="127.0.0.1"): + def post(query: str, texts: list, url: str = "http://127.0.0.1"): exc = None scores = [0 for _ in range(len(texts))] batch_size = 8 + # FIX: Robust URL construction to avoid duplicate "/rerank" path suffix + base_url = url.rstrip("/") + if not base_url.startswith(("http://", "https://")): + base_url = f"http://{base_url}" + # Only append "/rerank" when endpoint does not already end with it + endpoint = base_url if base_url.endswith("/rerank") else f"{base_url}/rerank" + for i in range(0, len(texts), batch_size): try: + # Fix: Add request timeout res = requests.post( - f"http://{url}/rerank", headers={"Content-Type": "application/json"}, json={"query": query, "texts": texts[i : i + batch_size], "raw_scores": False, "truncate": True} + endpoint, headers={"Content-Type": "application/json"}, + json={"query": query, "texts": texts[i:i+batch_size], "raw_scores": False, "truncate": True}, + timeout=30 ) - + res.raise_for_status() for o in res.json(): scores[o["index"] + i] = o["score"] except Exception as e: @@ -430,9 +455,9 @@ def __init__(self, key, model_name="BAAI/bge-reranker-v2-m3", base_url="http://1 self.model_name = model_name.split("___")[0] self.base_url = base_url - def similarity(self, query: str, texts: list) -> tuple[np.ndarray, int]: - if not texts: - return np.array([]), 0 + def similarity(self, query: str, texts: List) -> tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 token_count = 0 for t in texts: token_count += num_tokens_from_string(t) @@ -454,7 +479,10 @@ def __init__(self, key, model_name, base_url): "authorization": f"Bearer {key}", } - def similarity(self, query: str, texts: list): + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 + payload = { "model": self.model_name, "query": query, @@ -463,28 +491,22 @@ def similarity(self, query: str, texts: list): } try: - response = requests.post(self.base_url, json=payload, headers=self.headers) + response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=30) response.raise_for_status() response_json = response.json() rank = np.zeros(len(texts), dtype=float) - - token_count = 0 - for t in texts: - token_count += num_tokens_from_string(t) + token_count = sum(num_tokens_from_string(t) for t in texts) try: - for result in response_json["results"]: + for result in response_json.get("results", []): rank[result["index"]] = result["relevance_score"] except Exception as _e: log_exception(_e, response) - return ( - rank, - token_count, - ) + return (rank, token_count) - except httpx.HTTPStatusError as e: - raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {e.response.status_code} - {e.response.text}") + except requests.exceptions.RequestException as e: + raise ValueError(f"Error calling GPUStackRerank model {self.model_name}: {str(e)}") from e class NovitaRerank(JinaRerank): @@ -509,9 +531,25 @@ class Ai302Rerank(Base): _FACTORY_NAME = "302.AI" def __init__(self, key, model_name, base_url="https://api.302.ai/v1/rerank"): - if not base_url: - base_url = "https://api.302.ai/v1/rerank" - super().__init__(key, model_name, base_url) + self.base_url = base_url or "https://api.302.ai/v1/rerank" + self.headers = {"Content-Type": "application/json", "Authorization": f"Bearer {key}"} + self.model_name = model_name + + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 + texts = [truncate(t, 500) for t in texts] + data = {"model": self.model_name, "query": query, "documents": texts, "top_n": len(texts)} + response = requests.post(self.base_url, headers=self.headers, json=data, timeout=30) + response.raise_for_status() + res = response.json() + rank = np.zeros(len(texts), dtype=float) + try: + for d in res.get("results", []): + rank[d["index"]] = d["relevance_score"] + except Exception as _e: + log_exception(_e, res) + return rank, total_token_count_from_response(res) class JiekouAIRerank(JinaRerank): @@ -534,12 +572,6 @@ def __init__(self, key, model_name, base_url="https://futurmix.ai/v1/rerank"): class RAGconRerank(Base): - """ - RAGcon Rerank Provider - routes through LiteLLM proxy - - Assumes LiteLLM proxy supports /rerank endpoint. - Default Base URL: https://connect.ragcon.ai/v1 - """ _FACTORY_NAME = "RAGcon" def __init__(self, key, model_name, base_url=None, **kwargs): @@ -553,8 +585,10 @@ def __init__(self, key, model_name, base_url=None, **kwargs): self.model_name = model_name - def similarity(self, query: str, texts: list): - # noway to config Ragflow , use fix setting + def similarity(self, query: str, texts: List) -> Tuple[np.ndarray, int]: + if not query or not texts: + return np.zeros(len(texts), dtype=float), 0 + texts = [truncate(t, 500) for t in texts] data = { "model": self.model_name, @@ -562,17 +596,16 @@ def similarity(self, query: str, texts: list): "documents": texts, "top_n": len(texts), } - token_count = 0 - for t in texts: - token_count += num_tokens_from_string(t) - res = requests.post(self._base_url + "/rerank", headers=self.headers, json=data).json() + token_count = sum(num_tokens_from_string(t) for t in texts) + response = requests.post(self._base_url + "/rerank", headers=self.headers, json=data, timeout=30) + response.raise_for_status() + res = response.json() rank = np.zeros(len(texts), dtype=float) try: - for d in res["results"]: + for d in res.get("results", []): rank[d["index"]] = d["relevance_score"] except Exception as _e: log_exception(_e, res) rank = Base._normalize_rank(rank) - return rank, token_count diff --git a/rag/llm/sequence2txt_model.py b/rag/llm/sequence2txt_model.py index 563dd47fc14..4624a2911ad 100644 --- a/rag/llm/sequence2txt_model.py +++ b/rag/llm/sequence2txt_model.py @@ -195,7 +195,7 @@ def transcription(self, audio, language="zh", prompt=None, response_format="json files = {"file": (audio_file_name, audio_data, "audio/wav")} try: - response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload) + response = requests.post(f"{self.base_url}/v1/audio/transcriptions", files=files, data=payload, timeout=60) response.raise_for_status() result = response.json() @@ -377,6 +377,7 @@ def transcription(self, audio_path): data=payload, files=files, headers=headers, + timeout=60, ) body = response.json() if response.status_code == 200: diff --git a/rag/llm/tts_model.py b/rag/llm/tts_model.py index 94a81ceba2a..f37cd89c253 100644 --- a/rag/llm/tts_model.py +++ b/rag/llm/tts_model.py @@ -116,7 +116,8 @@ def _send_request(self, endpoint, payload, stream=True): url, headers=self.headers, json=payload, - stream=stream + stream=stream, + timeout=60, ) if response.status_code != 200: @@ -532,7 +533,8 @@ def tts(self, text, voice="English Female", stream=True): f"{self.base_url}/audio/speech", headers=self.headers, json=payload, - stream=stream + stream=stream, + timeout=60, ) if response.status_code != 200: diff --git a/rag/nlp/query.py b/rag/nlp/query.py index 2d50eea3431..db04eb37532 100644 --- a/rag/nlp/query.py +++ b/rag/nlp/query.py @@ -22,12 +22,13 @@ from common.query_base import QueryBase from common.doc_store.doc_store_base import MatchTextExpr from rag.nlp import rag_tokenizer, term_weight, synonym +from rag.utils.redis_conn import REDIS_CONN class FulltextQueryer(QueryBase): def __init__(self): self.tw = term_weight.Dealer() - self.syn = synonym.Dealer() + self.syn = synonym.Dealer(redis=REDIS_CONN.REDIS if REDIS_CONN.is_alive() else None) self.query_fields = [ "title_tks^10", "title_sm_tks^5", diff --git a/rag/nlp/search.py b/rag/nlp/search.py index 57b663400ef..e79671f04eb 100644 --- a/rag/nlp/search.py +++ b/rag/nlp/search.py @@ -180,7 +180,13 @@ async def search(self, req, idx_names: str | list[str], else: matchDense = await self.get_vector(qst, emb_mdl, topk, req.get("similarity", 0.1)) q_vec = matchDense.embedding_data - if not settings.DOC_ENGINE_INFINITY: + # ES path no longer fetches chunk vectors here. The clean + # cosine score is recovered later via a second KNN-only call + # in retrieval(); chunk vectors are fetched on demand for + # citations (see Dealer.fetch_chunk_vectors). OceanBase + # still relies on local rerank against chunk vectors, so + # keep pulling them for that backend. + if settings.DOC_ENGINE_OCEANBASE: src.append(f"q_{len(q_vec)}_vec") fusionExpr = FusionExpr("weighted_sum", topk, {"weights": "0.05,0.95"}) @@ -358,6 +364,113 @@ def _rank_feature_scores(self, query_rfea, search_res): rank_fea.append(nor / np.sqrt(denor) / q_denor) return np.array(rank_fea) * 10. + pageranks + async def _knn_scores(self, sres: "Dealer.SearchResult", + idx_names: str | list[str], + kb_ids: list[str]) -> dict[str, float]: + """ + Second-pass ES call that returns the cosine similarity between the + query embedding and each candidate chunk's embedding, filtered to the + chunk ids the original search already surfaced. We rely on ES to do + the vector math so the chunk vectors never leave the engine. + """ + if not sres.ids or not sres.query_vector: + return {} + dim = len(sres.query_vector) + matchDense = MatchDenseExpr( + f"q_{dim}_vec", + sres.query_vector, + "float", + "cosine", + len(sres.ids), + {"similarity": 0.0}, + ) + condition = {"id": list(sres.ids)} + res = await thread_pool_exec( + self.dataStore.search, + [], # no _source fields needed; we only want _id and _score + [], + condition, + [matchDense], + OrderByExpr(), + 0, + len(sres.ids), + idx_names, + kb_ids, + ) + return self.dataStore.get_scores(res) + + async def fetch_chunk_vectors(self, chunk_ids: list[str], + tenant_ids: str | list[str], + kb_ids: list[str], + dim: int) -> dict[str, list[float]]: + """ + Citation-time helper: fetch only the embedding vectors for an + explicit set of chunk ids. Used by callers that need to compute + answer-vs-chunk similarity locally (e.g. insert_citations) so the + main retrieval path can keep skipping vector transport. + """ + if not chunk_ids: + return {} + if isinstance(tenant_ids, str): + idx_names = [index_name(tid) for tid in tenant_ids.split(",")] + else: + idx_names = [index_name(tid) for tid in tenant_ids] + vec_field = f"q_{dim}_vec" + res = await thread_pool_exec( + self.dataStore.search, + [vec_field], + [], + {"id": list(chunk_ids)}, + [], + OrderByExpr(), + 0, + len(chunk_ids), + idx_names, + kb_ids, + ) + fields = self.dataStore.get_fields(res, [vec_field]) + out: dict[str, list[float]] = {} + zero = [0.0] * dim + for cid, doc in fields.items(): + v = doc.get(vec_field) + if isinstance(v, str): + v = [get_float(x) for x in v.split("\t")] + if not isinstance(v, list) or len(v) != dim: + v = zero + out[cid] = v + return out + + def rerank_with_knn(self, sres, query, knn_scores: dict[str, float], + tkweight=0.3, vtweight=0.7, + cfield="content_ltks", + rank_feature: dict | None = None): + """ + Merge ES-side KNN cosine similarity with locally computed term + similarity using the user-configured weights. Replaces the older + local-only rerank() for the ES path, which depended on shipping + chunk vectors back to the application. + """ + _, keywords = self.qryr.question(query) + + for i in sres.ids: + if isinstance(sres.field[i].get("important_kwd", []), str): + sres.field[i]["important_kwd"] = [sres.field[i]["important_kwd"]] + ins_tw = [] + for i in sres.ids: + content_ltks = list(OrderedDict.fromkeys(sres.field[i][cfield].split())) + title_tks = [t for t in sres.field[i].get("title_tks", "").split() if t] + question_tks = [t for t in sres.field[i].get("question_tks", "").split() if t] + important_kwd = sres.field[i].get("important_kwd", []) + tks = content_ltks + title_tks * 2 + important_kwd * 5 + question_tks * 6 + ins_tw.append(tks) + + tksim = np.array(self.qryr.token_similarity(keywords, ins_tw), dtype=np.float64) + vtsim = np.array([knn_scores.get(chunk_id, 0.0) for chunk_id in sres.ids], + dtype=np.float64) + rank_fea = self._rank_feature_scores(rank_feature, sres) + sim = tkweight * tksim + vtweight * vtsim + rank_fea + return sim, tksim, vtsim + def rerank(self, sres, query, tkweight=0.3, vtweight=0.7, cfield="content_ltks", rank_feature: dict | None = None @@ -491,7 +604,8 @@ async def retrieval( if isinstance(tenant_ids, str): tenant_ids = tenant_ids.split(",") - sres = await self.search(req, [index_name(tid) for tid in tenant_ids], kb_ids, embd_mdl, highlight, + idx_names = [index_name(tid) for tid in tenant_ids] + sres = await self.search(req, idx_names, kb_ids, embd_mdl, highlight, rank_feature=rank_feature) # Temporary retrieval-side guard: prune chunks whose parent document no # longer exists before reranking and returning results. @@ -516,8 +630,9 @@ async def retrieval( sim = [s if s is not None else 0.0 for s in sim] tsim = sim vsim = sim - else: - # ElasticSearch doesn't normalize each way score before fusion. + elif settings.DOC_ENGINE_OCEANBASE: + # OceanBase still returns chunk vectors in the result; use + # the historical local rerank that depends on them. sim, tsim, vsim = self.rerank( sres, question, @@ -525,6 +640,20 @@ async def retrieval( vector_similarity_weight, rank_feature=rank_feature, ) + else: + # ES path: ask ES for the clean cosine score via a second + # KNN-only call filtered by the candidate ids, then merge it + # with locally computed term similarity using the user's + # weight. Chunk vectors stay in the index. + knn_scores = await self._knn_scores(sres, idx_names, kb_ids) + sim, tsim, vsim = self.rerank_with_knn( + sres, + question, + knn_scores, + 1 - vector_similarity_weight, + vector_similarity_weight, + rank_feature=rank_feature, + ) sim_np = np.array(sim, dtype=np.float64) if sim_np.size == 0: @@ -559,6 +688,11 @@ async def retrieval( did = chunk.get("doc_id", "") position_int = chunk.get("position_int", []) + # Chunk vectors are no longer fetched during the main retrieval + # call. Fall back to whatever the chunk happens to carry (Infinity + # path) and otherwise emit a zero placeholder so the downstream + # shape stays stable. Citation callers refill this via + # Dealer.fetch_chunk_vectors when needed. d = { "chunk_id": id, "content_ltks": chunk["content_ltks"], @@ -619,7 +753,13 @@ def chunk_list(self, doc_id: str, tenant_id: str, kb_ids: list[str], max_count=1024, offset=0, fields=["docnm_kwd", "content_with_weight", "img_id"], - sort_by_position: bool = False): + sort_by_position: bool = False, + retrieve_all: bool = False): + """Return chunks for a document. + + By default, preserve the historical max_count cap. When retrieve_all is + True, keep paging until the doc store returns fewer rows than requested. + """ condition = {"doc_id": doc_id} fields_set = set(fields or []) @@ -637,8 +777,9 @@ def chunk_list(self, doc_id: str, tenant_id: str, res = [] bs = 128 - for p in range(offset, max_count, bs): - limit = min(bs, max_count - p) + p = offset + while retrieve_all or p < max_count: + limit = bs if retrieve_all else min(bs, max_count - p) if limit <= 0: break es_res = self.dataStore.search(fields, [], condition, [], orderBy, p, limit, index_name(tenant_id), @@ -651,6 +792,7 @@ def chunk_list(self, doc_id: str, tenant_id: str, chunk_count = len(dict_chunks) if chunk_count == 0 or chunk_count < limit: break + p += limit return res def all_tags(self, tenant_id: str, kb_ids: list[str], S=1000): @@ -781,6 +923,13 @@ def retrieval_by_children(self, chunks: list[dict], tenant_ids: list[str]): vector_size = 1024 for id, cks in mom_chunks.items(): chunk = self.dataStore.get(id, idx_nms[0], [ck["kb_id"] for ck in cks]) + if chunk is None: + logging.warning( + "Parent chunk '%s' not found in the index; falling back to %d child chunk(s).", + id, len(cks), + ) + chunks.extend(cks) + continue d = { "chunk_id": id, "content_ltks": " ".join([ck["content_ltks"] for ck in cks]), diff --git a/rag/prompts/generator.py b/rag/prompts/generator.py index ddf99251b57..fc4999dbe45 100644 --- a/rag/prompts/generator.py +++ b/rag/prompts/generator.py @@ -76,6 +76,10 @@ def count(): total += m["count"] return total + def trim_content(content, limit): + limit = max(0, limit) + return encoder.decode(encoder.encode(content)[:limit]) + c = count() if c < max_length: return c, msg @@ -90,16 +94,34 @@ def count(): ll = num_tokens_from_string(msg_[0]["content"]) ll2 = num_tokens_from_string(msg_[-1]["content"]) - if ll / (ll + ll2) > 0.8: - m = msg_[0]["content"] - m = encoder.decode(encoder.encode(m)[: max_length - ll2]) - msg[0]["content"] = m - return max_length, msg + total = ll + ll2 + if total <= 0: + logging.debug( + "message_fit_in degenerate token counts total=%s max_length=%s ll=%s ll2=%s preserved_roles=%s", + total, + max_length, + ll, + ll2, + [m.get("role") for m in msg], + ) + return 0, msg + + if len(msg) == 1: + msg[0]["content"] = trim_content(msg[0]["content"], max_length) + return count(), msg + + if ll / total > 0.8: + preserved_last = min(ll2, max_length) + msg[-1]["content"] = trim_content(msg_[-1]["content"], preserved_last) + remaining = max(0, max_length - preserved_last) + msg[0]["content"] = trim_content(msg_[0]["content"], remaining) + return count(), msg - m = msg_[-1]["content"] - m = encoder.decode(encoder.encode(m)[: max_length - ll2]) - msg[-1]["content"] = m - return max_length, msg + preserved_system = min(ll, max_length) + msg[0]["content"] = trim_content(msg_[0]["content"], preserved_system) + remaining = max(0, max_length - preserved_system) + msg[-1]["content"] = trim_content(msg_[-1]["content"], remaining) + return count(), msg def kb_prompt(kbinfos, max_tokens, hash_id=False): @@ -472,6 +494,28 @@ async def rank_memories_async(chat_mdl, goal: str, sub_goal: str, tool_call_summ async def gen_meta_filter(chat_mdl, meta_data: dict, query: str, constraints: dict = None) -> dict: + """Generate metadata filter conditions from a user query using an LLM. + + Args: + chat_mdl: LLM bundle for generating filters + meta_data: Dict of {key: set of values} - e.g. {"character": {"Caocao", "Liubei"}, "year": {2026}} + query: User question (e.g. "Caocao in 2026") + constraints: Optional dict of {key: operator} to constrain which op to use for a key + + Returns: + Dict with "logic" ("and"/"or") and "conditions" list. + Example return value: + { + "logic": "and", + "conditions": [ + {"key": "year", "value": "2026", "op": "="}, + {"key": "character", "value": "Caocao", "op": "="} + ] + } + + The LLM is prompted with the available metadata keys and values, and is asked to + generate filter conditions that match the user's query semantics. + """ meta_data_structure = {} for key, values in meta_data.items(): meta_data_structure[key] = list(values.keys()) if isinstance(values, dict) else values diff --git a/rag/raptor.py b/rag/raptor.py index e4017319b5b..a7f2c782d33 100644 --- a/rag/raptor.py +++ b/rag/raptor.py @@ -14,11 +14,13 @@ # limitations under the License. # import asyncio +from dataclasses import dataclass, field import logging import re import numpy as np import umap +from sklearn.cluster import AgglomerativeClustering from sklearn.mixture import GaussianMixture from api.db.services.task_service import has_canceled @@ -33,9 +35,127 @@ set_llm_cache, ) from common.misc_utils import thread_pool_exec +from rag.utils.raptor_utils import ( + AHC_CLUSTERING_METHOD, + GMM_CLUSTERING_METHOD, + PSI_TREE_BUILDER, + RAPTOR_TREE_BUILDER, + SUPPORTED_CLUSTERING_METHODS, + SUPPORTED_TREE_BUILDERS, +) + + +@dataclass +class _PsiTreeNode: + """Node used to represent the in-memory Psi merge tree.""" + + index: int + text: str = "" + embedding: np.ndarray | None = None + children: list["_PsiTreeNode"] = field(default_factory=list) + parent: "_PsiTreeNode | None" = None + + +class _PsiUnionFind: + """Build parent links for the Psi merge tree from ranked leaf pairs.""" + + def __init__(self, n: int): + """Initialize the union-find state for n leaf nodes.""" + self._rank = [0 for _ in range(n)] + self._parent_chains = [[] for _ in range(n)] + self._node_ids = [[i] for i in range(n)] + self._tree = [-1 for _ in range(max(1, 2 * n - 1))] + self._next_id = n + + @staticmethod + def _ordered_extend(target: list[int], values: list[int]): + """Append unseen values while preserving their original order.""" + for value in values: + if value not in target: + target.append(value) + + def _find(self, i: int) -> list[int]: + """Return the parent chain for a leaf, extending it lazily.""" + chain = self._parent_chains[i] + if not chain or (len(chain) == 1 and chain[0] == i): + return [i] + if chain[0] == i: + self._ordered_extend(chain, self._find(chain[1])) + else: + self._ordered_extend(chain, self._find(chain[0])) + return chain + + def _rank_bisect_right(self, chain: list[int], rank: int) -> int: + """Return the first chain index whose rank is greater than rank.""" + idx = 0 + while idx < len(chain) and self._rank[chain[idx]] <= rank: + idx += 1 + return idx + + def _build(self, i: int, j: int, insert_point: int | None = None): + """Record a merge edge in the compact parent array.""" + if insert_point is not None: + parent_ids = self._node_ids[insert_point] + parent_rank_idx = self._rank[i] + 1 + if parent_rank_idx >= len(parent_ids): + logging.warning( + "RAPTOR Psi union fallback: rank index %d is out of bounds for node %d with %d parent ids", + parent_rank_idx, + insert_point, + len(parent_ids), + ) + parent_rank_idx = len(parent_ids) - 1 + self._tree[self._node_ids[i][-1]] = parent_ids[parent_rank_idx] + return + self._tree[self._node_ids[i][-1]] = self._next_id + self._tree[self._node_ids[j][-1]] = self._next_id + self._node_ids[i].append(self._next_id) + self._next_id += 1 + + def union(self, i: int, j: int) -> bool: + """Merge two ranked leaves and return whether a new edge was added.""" + root_i = self._find(i)[-1] + root_j = self._find(j)[-1] + if root_i == root_j: + return False + + if self._rank[root_i] < self._rank[root_j]: + if not self._parent_chains[root_j]: + self._parent_chains[root_j].append(root_j) + chain = self._parent_chains[j] + higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_i]) + if higher_rank_idx >= len(chain): + higher_rank_idx = len(chain) - 1 + insert_point = chain[higher_rank_idx] + self._ordered_extend(self._parent_chains[root_i], chain[higher_rank_idx:]) + self._build(root_i, root_j, insert_point=insert_point) + elif self._rank[root_i] > self._rank[root_j]: + if not self._parent_chains[root_i]: + self._parent_chains[root_i].append(root_i) + chain = self._parent_chains[i] + higher_rank_idx = self._rank_bisect_right(chain, self._rank[root_j]) + if higher_rank_idx >= len(chain): + higher_rank_idx = len(chain) - 1 + insert_point = chain[higher_rank_idx] + self._ordered_extend(self._parent_chains[root_j], chain[higher_rank_idx:]) + self._build(root_j, root_i, insert_point=insert_point) + else: + if not self._parent_chains[root_i]: + self._parent_chains[root_i].append(root_i) + self._ordered_extend(self._parent_chains[root_j], self._parent_chains[i][-1:]) + self._rank[root_i] += 1 + self._build(root_i, root_j) + return True + + @property + def tree(self) -> list[int]: + """Return the compact child-to-parent array for constructed nodes.""" + return self._tree[:self._next_id] class RecursiveAbstractiveProcessing4TreeOrganizedRetrieval: + """Build RAPTOR summary layers with the classic or Psi tree strategy.""" + def __init__( self, max_cluster, @@ -45,7 +165,12 @@ def __init__( max_token=512, threshold=0.1, max_errors=3, + tree_builder=RAPTOR_TREE_BUILDER, + clustering_method=GMM_CLUSTERING_METHOD, + psi_exact_max_leaves=4096, + psi_bucket_size=1024, ): + """Configure RAPTOR summarization, clustering, and Psi limits.""" self._max_cluster = max_cluster self._llm_model = llm_model self._embd_model = embd_model @@ -54,8 +179,17 @@ def __init__( self._max_token = max_token self._max_errors = max(1, max_errors) self._error_count = 0 - + self._tree_builder = tree_builder or RAPTOR_TREE_BUILDER + if self._tree_builder not in SUPPORTED_TREE_BUILDERS: + raise ValueError(f"Unsupported RAPTOR tree builder: {self._tree_builder}") + self._clustering_method = clustering_method or GMM_CLUSTERING_METHOD + if self._clustering_method not in SUPPORTED_CLUSTERING_METHODS: + raise ValueError(f"Unsupported RAPTOR clustering method: {self._clustering_method}") + self._psi_exact_max_leaves = max(2, int(psi_exact_max_leaves or 4096)) + self._psi_bucket_size = min(max(2, int(psi_bucket_size or 1024)), self._psi_exact_max_leaves) + def _check_task_canceled(self, task_id: str, message: str = ""): + """Raise if the current document task was canceled.""" if task_id and has_canceled(task_id): log_msg = f"Task {task_id} cancelled during RAPTOR {message}." logging.info(log_msg) @@ -63,6 +197,7 @@ def _check_task_canceled(self, task_id: str, message: str = ""): @timeout(60 * 20) async def _chat(self, system, history, gen_conf): + """Call the configured LLM with caching and short retries.""" cached = await thread_pool_exec(get_llm_cache, self._llm_model.llm_name, system, history, gen_conf) if cached: return cached @@ -86,6 +221,7 @@ async def _chat(self, system, history, gen_conf): @timeout(20) async def _embedding_encode(self, txt): + """Encode text with the configured embedding model and cache result.""" response = await thread_pool_exec(get_embed_cache, self._embd_model.llm_name, txt) if response is not None: return response @@ -97,6 +233,7 @@ async def _embedding_encode(self, txt): return embds def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_id: str = ""): + """Choose the GMM cluster count with the lowest BIC score.""" max_clusters = min(self._max_cluster, len(embeddings)) n_clusters = np.arange(1, max_clusters) bics = [] @@ -109,57 +246,422 @@ def _get_optimal_clusters(self, embeddings: np.ndarray, random_state: int, task_ optimal_clusters = n_clusters[np.argmin(bics)] return optimal_clusters + def _get_clusters_ahc(self, embeddings: np.ndarray, task_id: str = "") -> np.ndarray: + """Cluster embeddings with Ward-linkage AHC and a dendrogram gap heuristic.""" + n = len(embeddings) + if n <= 1: + return np.zeros(n, dtype=int) + if n == 2: + return np.arange(n) + + self._check_task_canceled(task_id, "_get_clusters_ahc dendrogram") + full_clust = AgglomerativeClustering( + n_clusters=None, + distance_threshold=0, + compute_distances=True, + linkage="ward", + ) + full_clust.fit(embeddings) + + distances = full_clust.distances_ + if len(distances) > 1: + gaps = np.diff(distances) + max_gap_idx = int(np.argmax(gaps)) + n_clusters = max(1, min(n - max_gap_idx - 1, self._max_cluster)) + else: + n_clusters = max(1, min(n, self._max_cluster)) + if n_clusters <= 1: + logging.info("RAPTOR AHC: _get_clusters_ahc selected one cluster for %d embeddings", n) + return np.zeros(n, dtype=int) + + logging.info("RAPTOR AHC: _get_clusters_ahc selected n_clusters=%d for %d embeddings", n_clusters, n) + self._check_task_canceled(task_id, "_get_clusters_ahc fit") + clustering = AgglomerativeClustering(n_clusters=n_clusters, linkage="ward") + return clustering.fit_predict(embeddings) + + def _adjust_tree_nodes(self, embeddings: np.ndarray, labels: np.ndarray, max_iter: int = 5) -> np.ndarray: + """Refine AHC assignments by reassigning nodes to nearest centroids.""" + labels = labels.copy() + for _ in range(max_iter): + unique_labels = np.unique(labels) + if len(unique_labels) <= 1: + return labels + centroids = np.stack([embeddings[labels == lbl].mean(axis=0) for lbl in unique_labels]) + diffs = embeddings[:, np.newaxis, :] - centroids[np.newaxis, :, :] + sq_dists = (diffs**2).sum(axis=2) + new_label_indices = np.argmin(sq_dists, axis=1) + new_labels = unique_labels[new_label_indices] + if np.array_equal(new_labels, labels): + break + unique_new = np.unique(new_labels) + remap = {old: new for new, old in enumerate(unique_new)} + labels = np.array([remap[int(lbl)] for lbl in new_labels]) + return labels + + @timeout(60 * 20) + async def _summarize_texts(self, texts: list[str], callback=None, task_id: str = ""): + """Summarize a cluster and return text plus embedding when successful.""" + self._check_task_canceled(task_id, "summarization") + + len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) + cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) + try: + async with chat_limiter: + self._check_task_canceled(task_id, "before LLM call") + + cnt = await self._chat( + "You're a helpful assistant.", + [ + { + "role": "user", + "content": self._prompt.format(cluster_content=cluster_content), + } + ], + {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 + ) + cnt = re.sub( + "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", + "", + cnt, + ) + logging.debug(f"SUM: {cnt}") + + self._check_task_canceled(task_id, "before embedding") + + embds = await self._embedding_encode(cnt) + return cnt, embds + except TaskCanceledException: + raise + except Exception as exc: + self._error_count += 1 + warn_msg = f"[RAPTOR] Skip cluster ({len(texts)} chunks) due to error: {exc}" + logging.warning(warn_msg) + if callback: + callback(msg=warn_msg) + if self._error_count >= self._max_errors: + raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc + return None + + @staticmethod + def _root(node: _PsiTreeNode) -> _PsiTreeNode: + """Return the current root for a Psi tree node.""" + while node.parent is not None: + node = node.parent + return node + + def _rank_leaf_pairs(self, leaves: list[_PsiTreeNode]) -> np.ndarray: + """Rank all leaf pairs by original embedding-space cosine similarity.""" + node_embeddings = np.asarray([leaf.embedding for leaf in leaves], dtype=np.float64) + node_embeddings = self._normalize_embeddings(node_embeddings) + similarities = node_embeddings @ node_embeddings.T + lower = np.tril_indices(len(leaves), -1) + ordered = np.argsort(similarities[lower], axis=0)[::-1] + return np.stack([lower[0][ordered], lower[1][ordered]], axis=-1) + + @staticmethod + def _normalize_embeddings(node_embeddings: np.ndarray) -> np.ndarray: + """Normalize embeddings for cosine operations while tolerating zero vectors.""" + node_embeddings = np.asarray(node_embeddings, dtype=np.float64) + norms = np.linalg.norm(node_embeddings, axis=1, keepdims=True) + return node_embeddings / np.maximum(norms, 1e-12) + + def _split_psi_buckets(self, nodes: list[_PsiTreeNode]) -> list[list[_PsiTreeNode]]: + """Split large Psi inputs so exact pair ranking is bounded per bucket.""" + if len(nodes) <= self._psi_bucket_size: + return [nodes] + + node_embeddings = self._normalize_embeddings(np.asarray([node.embedding for node in nodes], dtype=np.float64)) + groups = [np.arange(len(nodes), dtype=int)] + buckets = [] + + while groups: + group = np.asarray(groups.pop(), dtype=int) + if len(group) <= self._psi_bucket_size: + buckets.append(group.tolist()) + continue + + fanout = min(max(2, int(np.ceil(len(group) / self._psi_bucket_size))), len(group), 32) + group_embeddings = node_embeddings[group] + center_idx = np.linspace(0, len(group_embeddings) - 1, num=fanout, dtype=int) + centers = group_embeddings[center_idx].copy() + + for _ in range(5): + labels = np.argmax(group_embeddings @ centers.T, axis=1) + for center_id in range(fanout): + mask = labels == center_id + if not np.any(mask): + continue + center = group_embeddings[mask].mean(axis=0) + norm = np.linalg.norm(center) + centers[center_id] = center / norm if norm > 0 else center + + labels = np.argmax(group_embeddings @ centers.T, axis=1) + split_groups = [group[labels == center_id].tolist() for center_id in range(fanout)] + split_groups = [bucket for bucket in split_groups if bucket] + if len(split_groups) <= 1: + split_groups = [ + group[start:start + self._psi_bucket_size].tolist() + for start in range(0, len(group), self._psi_bucket_size) + ] + groups.extend(split_groups) + + buckets = [bucket for bucket in buckets if bucket] + buckets.sort(key=lambda bucket: (len(bucket), bucket[0])) + return [[nodes[idx] for idx in bucket] for bucket in buckets] + + def _assign_prototype_embeddings(self, node: _PsiTreeNode) -> np.ndarray: + """Assign mean child embeddings to internal Psi nodes for bucket-level ranking.""" + if not node.children: + return np.asarray(node.embedding, dtype=np.float64) + embeddings = np.asarray([self._assign_prototype_embeddings(child) for child in node.children], dtype=np.float64) + node.embedding = embeddings.mean(axis=0) + return node.embedding + + @staticmethod + def _iter_nodes(root: _PsiTreeNode): + """Yield nodes in a Psi tree using a stack traversal.""" + stack = [root] + while stack: + node = stack.pop() + yield node + stack.extend(node.children) + + def _create_psi_parent(self, index: int, children: list[_PsiTreeNode]) -> _PsiTreeNode: + """Create a parent node and attach the provided children to it.""" + parent = _PsiTreeNode(index=index, children=children) + for child in children: + child.parent = parent + return parent + + def _rebalance_psi_tree(self, root: _PsiTreeNode, next_index: int) -> tuple[_PsiTreeNode, int]: + """Group oversized Psi tree nodes so fanout stays within max_cluster.""" + max_children = max(2, int(self._max_cluster or 2)) + + def rebalance(node: _PsiTreeNode): + """Recursively group children when a Psi node exceeds fanout.""" + nonlocal next_index + + for child in list(node.children): + rebalance(child) + + while len(node.children) > max_children: + original_children = len(node.children) + grouped_children = [] + for start in range(0, len(node.children), max_children): + batch = node.children[start:start + max_children] + if len(batch) == 1: + grouped_children.append(batch[0]) + batch[0].parent = node + else: + grouped_children.append(self._create_psi_parent(next_index, batch)) + grouped_children[-1].parent = node + next_index += 1 + node.children = grouped_children + logging.info( + "RAPTOR Psi rebalance: node=%s children=%d grouped_to=%d max_cluster=%d", + node.index, + original_children, + len(grouped_children), + max_children, + ) + + rebalance(root) + return self._root(root), next_index + + def _build_exact_psi_structure( + self, + nodes: list[_PsiTreeNode], + next_index: int, + task_id: str = "", + ) -> tuple[_PsiTreeNode, int, int]: + """Build an exact Psi subtree for a bounded node set.""" + if len(nodes) == 1: + return nodes[0], next_index, 0 + + ranked_pairs = self._rank_leaf_pairs(nodes) + union_find = _PsiUnionFind(len(nodes)) + merges = 0 + for left_idx, right_idx in ranked_pairs: + self._check_task_canceled(task_id, "Psi tree construction") + if union_find.union(int(left_idx), int(right_idx)): + merges += 1 + if merges == len(nodes) - 1: + break + + local_nodes = {idx: node for idx, node in enumerate(nodes)} + tree = union_find.tree + children_by_parent = {} + for child_idx, parent_idx in enumerate(tree): + if child_idx not in local_nodes: + local_nodes[child_idx] = _PsiTreeNode(index=next_index) + next_index += 1 + if parent_idx == -1: + continue + children_by_parent.setdefault(parent_idx, []).append(child_idx) + if parent_idx not in local_nodes: + local_nodes[parent_idx] = _PsiTreeNode(index=next_index) + next_index += 1 + + for parent_idx, child_indices in children_by_parent.items(): + parent = local_nodes[parent_idx] + parent.children = [local_nodes[child_idx] for child_idx in child_indices] + for child in parent.children: + child.parent = parent + + roots = [local_nodes[idx] for idx, parent_idx in enumerate(tree) if parent_idx == -1 and idx in local_nodes] + root = max(roots, key=lambda node: node.index) + return root, next_index, merges + + def _build_bucketed_psi_structure( + self, + nodes: list[_PsiTreeNode], + next_index: int, + task_id: str = "", + ) -> tuple[_PsiTreeNode, int, int]: + """Build large Psi trees by exact-ranking bounded buckets, then bucket roots.""" + buckets = self._split_psi_buckets(nodes) + logging.info( + "RAPTOR Psi bucketed build: nodes=%d buckets=%d bucket_size=%d exact_max_leaves=%d", + len(nodes), + len(buckets), + self._psi_bucket_size, + self._psi_exact_max_leaves, + ) + + bucket_roots = [] + merges = 0 + for bucket in buckets: + bucket_root, next_index, bucket_merges = self._build_psi_structure_from_nodes(bucket, next_index, task_id) + self._assign_prototype_embeddings(bucket_root) + bucket_roots.append(bucket_root) + merges += bucket_merges + + if len(bucket_roots) == 1: + return bucket_roots[0], next_index, merges + + root, next_index, root_merges = self._build_psi_structure_from_nodes(bucket_roots, next_index, task_id) + return root, next_index, merges + root_merges + + def _build_psi_structure_from_nodes( + self, + nodes: list[_PsiTreeNode], + next_index: int, + task_id: str = "", + ) -> tuple[_PsiTreeNode, int, int]: + """Build Psi structure exactly for small sets and bucket large sets.""" + if len(nodes) <= self._psi_exact_max_leaves: + return self._build_exact_psi_structure(nodes, next_index, task_id) + return self._build_bucketed_psi_structure(nodes, next_index, task_id) + + def _build_psi_structure(self, chunks, task_id: str = "") -> tuple[_PsiTreeNode, list[_PsiTreeNode]]: + """Build the Psi merge tree from original chunk embeddings.""" + leaves = [ + _PsiTreeNode(index=i, text=text, embedding=np.asarray(embd)) + for i, (text, embd) in enumerate(chunks) + ] + if len(leaves) == 1: + return leaves[0], leaves + + root, next_index, merges = self._build_psi_structure_from_nodes(leaves, len(leaves), task_id) + root, _ = self._rebalance_psi_tree(root, next_index) + logging.info( + "RAPTOR Psi tree built: leaves=%d merges=%d root_fanout=%d", + len(leaves), + merges, + len(root.children), + ) + return root, leaves + + @staticmethod + def _psi_layers(root: _PsiTreeNode) -> dict[int, list[_PsiTreeNode]]: + """Collect non-leaf Psi nodes by height for bottom-up summarization.""" + layers = {} + + def height(node: _PsiTreeNode) -> int: + """Return node height while collecting internal nodes by layer.""" + if not node.children: + return 0 + node_height = max(height(child) for child in node.children) + 1 + layers.setdefault(node_height, []).append(node) + return node_height + + height(root) + return layers + + async def _build_psi_layers(self, chunks, callback=None, task_id: str = ""): + """Materialize Psi tree layers as summary chunks.""" + layers = [(0, len(chunks))] + root, _ = self._build_psi_structure(chunks, task_id=task_id) + + for layer_idx, (_, nodes) in enumerate(sorted(self._psi_layers(root).items()), start=1): + layer_start = len(chunks) + + async def summarize_node(node: _PsiTreeNode): + """Summarize one Psi internal node if its children have text.""" + texts = [child.text for child in node.children if child.text] + if not texts: + logging.warning("RAPTOR Psi node %s skipped because it has no child text to summarize", node.index) + return None + result = await self._summarize_texts(texts, callback, task_id) + if result is None: + logging.warning("RAPTOR Psi node %s skipped because summarization failed", node.index) + return None + node.text, node.embedding = result + return node + + tasks = [asyncio.create_task(summarize_node(node)) for node in nodes] + try: + summarized_nodes = await asyncio.gather(*tasks, return_exceptions=False) + except Exception as e: + logging.error(f"Error in RAPTOR Psi tree processing: {e}") + for task in tasks: + task.cancel() + await asyncio.gather(*tasks, return_exceptions=True) + raise + + summarized_nodes = [node for node in summarized_nodes if node is not None] + for node in summarized_nodes: + chunks.append((node.text, node.embedding)) + + if len(chunks) > layer_start: + layers.append((layer_start, len(chunks))) + logging.info( + "RAPTOR Psi layer materialized: layer=%d nodes=%d summaries=%d", + layer_idx, + len(nodes), + len(chunks) - layer_start, + ) + if callback: + callback(msg="Build one Psi-RAG layer: {} -> {}".format(len(nodes), len(chunks) - layer_start)) + else: + logging.warning("RAPTOR Psi layer %d produced no summaries; stopping materialization", layer_idx) + break + + return chunks, layers + async def __call__(self, chunks, random_state, callback=None, task_id: str = ""): + """Build summary chunks and layer boundaries for RAPTOR retrieval.""" if len(chunks) <= 1: return [], [] chunks = [(s, a) for s, a in chunks if s and a is not None and len(a) > 0] + if len(chunks) <= 1: + return chunks, [(0, len(chunks))] + if self._tree_builder == PSI_TREE_BUILDER: + logging.info("RAPTOR: using %s tree builder for %d chunks", self._tree_builder, len(chunks)) + return await self._build_psi_layers(chunks, callback, task_id) + layers = [(0, len(chunks))] start, end = 0, len(chunks) @timeout(60 * 20) async def summarize(ck_idx: list[int]): + """Summarize one classic RAPTOR cluster into the chunk list.""" nonlocal chunks - self._check_task_canceled(task_id, "summarization") - texts = [chunks[i][0] for i in ck_idx] - len_per_chunk = int((self._llm_model.max_length - self._max_token) / len(texts)) - cluster_content = "\n".join([truncate(t, max(1, len_per_chunk)) for t in texts]) - try: - async with chat_limiter: - self._check_task_canceled(task_id, "before LLM call") - - cnt = await self._chat( - "You're a helpful assistant.", - [ - { - "role": "user", - "content": self._prompt.format(cluster_content=cluster_content), - } - ], - {"max_tokens": max(self._max_token, 512)}, # fix issue: #10235 - ) - cnt = re.sub( - "(······\n由于长度的原因,回答被截断了,要继续吗?|For the content length reason, it stopped, continue?)", - "", - cnt, - ) - logging.debug(f"SUM: {cnt}") - - self._check_task_canceled(task_id, "before embedding") - - embds = await self._embedding_encode(cnt) - chunks.append((cnt, embds)) - except TaskCanceledException: - raise - except Exception as exc: - self._error_count += 1 - warn_msg = f"[RAPTOR] Skip cluster ({len(ck_idx)} chunks) due to error: {exc}" - logging.warning(warn_msg) - if callback: - callback(msg=warn_msg) - if self._error_count >= self._max_errors: - raise RuntimeError(f"RAPTOR aborted after {self._error_count} errors. Last error: {exc}") from exc + result = await self._summarize_texts(texts, callback, task_id) + if result is not None: + chunks.append(result) while end - start > 1: self._check_task_canceled(task_id, "layer processing") @@ -167,8 +669,12 @@ async def summarize(ck_idx: list[int]): embeddings = [embd for _, embd in chunks[start:end]] if len(embeddings) == 2: await summarize([start, start + 1]) + produced = len(chunks) - end + if produced == 0: + logging.warning("RAPTOR layer produced no summaries; stopping materialization") + break if callback: - callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) + callback(msg="Cluster one layer: {} -> {}".format(end - start, produced)) layers.append((end, len(chunks))) start = end end = len(chunks) @@ -180,15 +686,37 @@ async def summarize(ck_idx: list[int]): n_components=min(12, len(embeddings) - 2), metric="cosine", ).fit_transform(embeddings) - n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id) + if self._clustering_method == AHC_CLUSTERING_METHOD: + logging.info("RAPTOR: using clustering_method=%s before _get_clusters_ahc", self._clustering_method) + raw_labels = self._get_clusters_ahc(reduced_embeddings, task_id=task_id) + raw_cluster_count = np.unique(raw_labels).size + logging.info("RAPTOR AHC: _get_clusters_ahc produced n_clusters=%d", raw_cluster_count) + if raw_cluster_count > 1: + adjusted = self._adjust_tree_nodes(reduced_embeddings, raw_labels) + adjusted_cluster_count = np.unique(adjusted).size + logging.info("RAPTOR AHC: _adjust_tree_nodes adjusted n_clusters=%d", adjusted_cluster_count) + else: + adjusted = raw_labels + logging.warning("RAPTOR AHC: _adjust_tree_nodes skipped because _get_clusters_ahc returned one cluster") + unique_labels = np.unique(adjusted) + label_map = {old: idx for idx, old in enumerate(unique_labels)} + lbls = [label_map[int(lbl)] for lbl in adjusted] + n_clusters = len(unique_labels) + else: + n_clusters = self._get_optimal_clusters(reduced_embeddings, random_state, task_id=task_id) + if n_clusters == 1: + lbls = [0 for _ in range(len(reduced_embeddings))] + else: + gm = GaussianMixture(n_components=n_clusters, random_state=random_state) + gm.fit(reduced_embeddings) + probs = gm.predict_proba(reduced_embeddings) + lbls = [np.where(prob > self._threshold)[0] for prob in probs] + lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] + if n_clusters == 1: lbls = [0 for _ in range(len(reduced_embeddings))] else: - gm = GaussianMixture(n_components=n_clusters, random_state=random_state) - gm.fit(reduced_embeddings) - probs = gm.predict_proba(reduced_embeddings) - lbls = [np.where(prob > self._threshold)[0] for prob in probs] - lbls = [lbl[0] if isinstance(lbl, np.ndarray) else lbl for lbl in lbls] + lbls = [int(lbl[0]) if isinstance(lbl, np.ndarray) else int(lbl) for lbl in lbls] tasks = [] for c in range(n_clusters): @@ -205,10 +733,21 @@ async def summarize(ck_idx: list[int]): await asyncio.gather(*tasks, return_exceptions=True) raise - assert len(chunks) - end == n_clusters, "{} vs. {}".format(len(chunks) - end, n_clusters) + produced = len(chunks) - end + assert produced <= n_clusters, "{} vs. {}".format(produced, n_clusters) + if produced < n_clusters: + logging.warning( + "RAPTOR layer produced %d/%d cluster summaries; skipped %d cluster(s) due to errors", + produced, + n_clusters, + n_clusters - produced, + ) + if produced == 0: + logging.warning("RAPTOR layer produced no summaries; stopping materialization") + break layers.append((end, len(chunks))) if callback: - callback(msg="Cluster one layer: {} -> {}".format(end - start, len(chunks) - end)) + callback(msg="Cluster one layer: {} -> {}".format(end - start, produced)) start = end end = len(chunks) diff --git a/rag/svr/sync_data_source.py b/rag/svr/sync_data_source.py index 9a60701e793..a5ba3958204 100644 --- a/rag/svr/sync_data_source.py +++ b/rag/svr/sync_data_source.py @@ -41,7 +41,9 @@ from api.db.services.document_service import DocumentService from api.db.services.knowledgebase_service import KnowledgebaseService from common import settings +from common.constants import ConnectorTaskType, FileSource, TaskStatus from common.config_utils import show_configs +from common.data_source.config import INDEX_BATCH_SIZE from common.data_source import ( BlobStorageConnector, RSSConnector, @@ -58,9 +60,8 @@ SeaFileConnector, RDBMSConnector, DingTalkAITableConnector, + RestAPIConnector, ) -from common.constants import FileSource, TaskStatus -from common.data_source.config import INDEX_BATCH_SIZE from common.data_source.models import ConnectorFailure, SeafileSyncScope from common.data_source.webdav_connector import WebDAVConnector from common.data_source.confluence_connector import ConfluenceConnector @@ -70,12 +71,11 @@ from common.data_source.gitlab_connector import GitlabConnector from common.data_source.bitbucket.connector import BitbucketConnector from common.data_source.interfaces import CheckpointOutputWrapper +from common.data_source.exceptions import ConnectorValidationError from common.log_utils import init_root_logger from common.signal_utils import start_tracemalloc_and_snapshot, stop_tracemalloc from common.versions import get_ragflow_version from box_sdk_gen import BoxOAuth, OAuthConfig, AccessToken -from collections import namedtuple - MAX_CONCURRENT_TASKS = int(os.environ.get("MAX_CONCURRENT_TASKS", "5")) task_limiter = asyncio.Semaphore(MAX_CONCURRENT_TASKS) @@ -155,30 +155,37 @@ async def __call__(self, task: dict): }) return - SyncLogsService.schedule(task["connector_id"], task["kb_id"], task["poll_range_start"]) + task_type = task.get("task_type", ConnectorTaskType.SYNC) + if task_type == ConnectorTaskType.SYNC: + SyncLogsService.schedule( + task["connector_id"], + task["kb_id"], + task.get("poll_range_start"), + task_type=ConnectorTaskType.SYNC, + ) + elif task_type == ConnectorTaskType.PRUNE and self.conf.get("sync_deleted_files"): + SyncLogsService.schedule( + task["connector_id"], + task["kb_id"], + task_type=ConnectorTaskType.PRUNE, + ) async def _run_task_logic(self, task: dict): + task_type = task.get("task_type", ConnectorTaskType.SYNC) + if task_type == ConnectorTaskType.PRUNE: + await self._run_prune_task_logic(task) + return + await self._run_sync_task_logic(task) + + async def _run_sync_task_logic(self, task: dict): """ Executes the core synchronization pipeline for a data source task. - - This method retrieves documents from the external source via the `_generate` method, - parses and upserts them into the Knowledge Base (KB), and handles stale document - reconciliation (sync deletion) if a remote snapshot (`file_list`) is provided. """ - generate_output = await self._generate(task) - # `_generate()` currently supports two outputs: - # 1. `document_batch_generator` - # 2. `(document_batch_generator, file_list)` - if isinstance(generate_output, tuple): - document_batch_generator, file_list = generate_output - else: - document_batch_generator = generate_output - file_list = None + document_batch_generator = await self._generate(task) failed_docs = 0 added_docs = 0 updated_docs = 0 - removed_docs = 0 next_update = datetime(1970, 1, 1, tzinfo=timezone.utc) source_type = f"{self.SOURCE_NAME}/{task['connector_id']}" existing_doc_ids = { @@ -213,6 +220,8 @@ async def _run_task_logic(self, task: dict): } if doc.metadata: d["metadata"] = doc.metadata + if getattr(doc, "fingerprint", None): + d["fingerprint"] = doc.fingerprint docs.append(d) try: @@ -248,34 +257,12 @@ async def _run_task_logic(self, task: dict): prefix = self._get_source_prefix() prefix = f"{prefix} " if prefix else "" next_update_info = self._format_window_boundary(next_update) - expects_deleted_file_snapshot = ( - task.get("reindex") != "1" - and task.get("poll_range_start") - and self.conf.get("sync_deleted_files") - ) - cleanup_errors = [] - if expects_deleted_file_snapshot and file_list is None: - logging.warning( - "%s deleted-file snapshot retrieval failed " - "(connector_id=%s, kb_id=%s)", - self.SOURCE_NAME, - task["connector_id"], - task["kb_id"], - ) - elif file_list is not None: - removed_docs, cleanup_errors = ConnectorService.cleanup_stale_documents_for_task( - task["id"], - task["connector_id"], - task["kb_id"], - task["tenant_id"], - file_list, - ) - total_changed_docs = added_docs + updated_docs + removed_docs + total_changed_docs = added_docs + updated_docs summary = ( f"{prefix}sync summary till {next_update_info}: " f"total={total_changed_docs}, added={added_docs}, " - f"updated={updated_docs}, deleted={removed_docs}" + f"updated={updated_docs}" ) if failed_docs > 0: summary = f"{summary}, skipped={failed_docs}" @@ -284,23 +271,159 @@ async def _run_task_logic(self, task: dict): if ( isinstance(self, _RDBMSBase) and failed_docs == 0 - and (not expects_deleted_file_snapshot or file_list is not None) - and not cleanup_errors ): self.connector.persist_sync_state() SyncLogsService.done(task["id"], task["connector_id"]) task["poll_range_start"] = next_update + async def _run_prune_task_logic(self, task: dict): + if not self.conf.get("sync_deleted_files"): + SyncLogsService.done(task["id"], task["connector_id"]) + return + + await self._initialize_for_prune(task) + + file_list = self._collect_prune_snapshot(task) + if file_list is None: + logging.warning( + "%s prune snapshot retrieval failed (connector_id=%s, kb_id=%s)", + self.SOURCE_NAME, + task["connector_id"], + task["kb_id"], + ) + SyncLogsService.done(task["id"], task["connector_id"]) + return + + removed_docs, cleanup_errors = ConnectorService.cleanup_stale_documents_for_task( + task["id"], + task["connector_id"], + task["kb_id"], + task["tenant_id"], + file_list, + ) + logging.info( + "%s prune summary: deleted=%s, errors=%s", + self.SOURCE_NAME, + removed_docs, + len(cleanup_errors), + ) + SyncLogsService.done(task["id"], task["connector_id"]) + async def _generate(self, task: dict): raise NotImplementedError def _get_source_prefix(self): return "" + async def _initialize_for_prune(self, task: dict): + await self._generate(task) + + def _get_prune_snapshot_kwargs(self, task: dict) -> dict[str, Any]: + return {} + + def _collect_prune_snapshot(self, task: dict): + if not getattr(self, "connector", None): + return None + if not hasattr(self.connector, "retrieve_all_slim_docs_perm_sync"): + return None + + file_list = [] + snapshot_kwargs = self._get_prune_snapshot_kwargs(task) + try: + for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(**snapshot_kwargs): + file_list.extend(slim_batch) + except TypeError: + for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): + file_list.extend(slim_batch) + except Exception: + logging.exception( + "%s prune snapshot failed (connector_id=%s, kb_id=%s)", + self.SOURCE_NAME, + task["connector_id"], + task["kb_id"], + ) + return None + return file_list + class _BlobLikeBase(SyncBase): DEFAULT_BUCKET_TYPE: str = "s3" + def _fingerprint_filtered_generator(self, task: dict): + """Generator that uses list_keys() + get_value() to skip unchanged objects. + + Pre-loads {doc_id: content_hash} for the connector's existing docs in + this KB, iterates the bucket via list_keys(), and only materializes a + Document (one GetObject call) when the listing fingerprint differs from + the persisted content_hash. Unchanged objects are skipped entirely -- + no download, no re-parse. + + Per-key fetch failures are counted and surfaced via SyncLogsService so + a partially failing sync (e.g. throttling, IAM regression mid-run) + doesn't silently report DONE while half the bucket is unreachable. + Connectors yielding KeyRecord(deleted=True) are skipped here -- actual + deletion reconciliation lives in the unified delete pass (PR-4). + """ + source_type = f"{self.SOURCE_NAME}/{task['connector_id']}" + existing_fingerprints = DocumentService.list_id_content_hash_map_by_kb_and_source_type( + task["kb_id"], source_type, + ) + + bypass_count = 0 + fetch_count = 0 + fail_count = 0 + batch = [] + for key_record in self.connector.list_keys(): + if key_record.deleted: + continue + + doc_id = hash128(key_record.key) + stored = existing_fingerprints.get(doc_id, "") + if key_record.fingerprint and stored and key_record.fingerprint == stored: + bypass_count += 1 + continue + + try: + doc = self.connector.get_value(key_record.key) + except Exception as ex: + fail_count += 1 + logging.exception( + "Failed to fetch %s from %s: %s", + key_record.key, + self.SOURCE_NAME, + ex, + ) + continue + + fetch_count += 1 + batch.append(doc) + if len(batch) >= self.connector.batch_size: + yield batch + batch = [] + + if batch: + yield batch + + log_msg = ( + "[%s] fingerprint sync: %d bypassed, %d fetched, %d failed " + "(connector_id=%s, kb_id=%s)" + ) + log_args = ( + self.SOURCE_NAME, + bypass_count, + fetch_count, + fail_count, + task["connector_id"], + task["kb_id"], + ) + # Use WARNING when any fetch failed so partial-bucket regressions + # (auth, throttling, IAM drift) surface without diving into the + # per-exception traces above. + if fail_count: + logging.warning(log_msg, *log_args) + else: + logging.info(log_msg, *log_args) + async def _generate(self, task: dict): bucket_type = self.conf.get("bucket_type", self.DEFAULT_BUCKET_TYPE) @@ -312,29 +435,18 @@ async def _generate(self, task: dict): self.connector.set_allow_images(self.conf.get("allow_images", False)) self.connector.load_credentials(self.conf["credentials"]) - file_list = None - document_batch_generator = ( - self.connector.load_from_state() - if task["reindex"] == "1" or not task["poll_range_start"] - else self.connector.poll_source( - task["poll_range_start"].timestamp(), - datetime.now(timezone.utc).timestamp(), - ) - ) - - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) + # Fingerprint-bypass path: skip GetObject for unchanged ETags. Disabled + # on full reindex (we want to re-fetch everything in that case). + use_fingerprint_path = task["reindex"] != "1" + if use_fingerprint_path: + document_batch_generator = self._fingerprint_filtered_generator(task) + else: + document_batch_generator = self.connector.load_from_state() _begin_info = ( - "totally" - if task["reindex"] == "1" or not task["poll_range_start"] - else "from {}".format(task["poll_range_start"]) + "fingerprint-bypass" + if use_fingerprint_path + else "full reindex" ) logging.info( @@ -345,7 +457,7 @@ async def _generate(self, task: dict): _begin_info, ) ) - return document_batch_generator, file_list + return document_batch_generator class S3(_BlobLikeBase): @@ -383,28 +495,11 @@ async def _generate(self, task: dict): return self.connector.load_from_state() end_time = datetime.now(timezone.utc).timestamp() - file_list = None - if self.conf.get("sync_deleted_files"): - logging.info( - "[RSS] Syncing deleted files via slim snapshot (connector_id=%s)", - task["connector_id"], - ) - snapshot_start = time.perf_counter() - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - logging.info( - "[RSS] Slim snapshot fetched %d docs in %.2f seconds", - len(file_list), - time.perf_counter() - snapshot_start, - ) document_generator = self.connector.poll_source( task["poll_range_start"].timestamp(), end_time, ) - if file_list is not None: - return document_generator, file_list return document_generator @@ -447,16 +542,11 @@ async def _generate(self, task: dict): credential_json=self.conf["credentials"]) self.connector.set_credentials_provider(credentials_provider) - file_list = None # Determine the time range for synchronization based on reindex or poll_range_start if task["reindex"] == "1" or not task["poll_range_start"]: start_time = 0.0 else: start_time = task["poll_range_start"].timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) end_time = datetime.now(timezone.utc).timestamp() @@ -502,7 +592,7 @@ def wrapper(): yield batch self.log_connection("Confluence", self.conf["wiki_base"], task) - return wrapper(), file_list + return wrapper() class Notion(SyncBase): @@ -511,7 +601,6 @@ class Notion(SyncBase): async def _generate(self, task: dict): self.connector = NotionConnector(root_page_id=self.conf["root_page_id"]) self.connector.load_credentials(self.conf["credentials"]) - file_list = None document_generator = ( self.connector.load_from_state() if task["reindex"] == "1" or not task["poll_range_start"] @@ -519,19 +608,10 @@ async def _generate(self, task: dict): datetime.now(timezone.utc).timestamp()) ) - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format( task["poll_range_start"]) self.log_connection("Notion", f"root({self.conf['root_page_id']})", task) - return document_generator, file_list + return document_generator class Discord(SyncBase): @@ -549,26 +629,17 @@ async def _generate(self, task: dict): batch_size=self.conf.get("batch_size", 1024), ) self.connector.load_credentials(self.conf["credentials"]) - file_list = None document_generator = ( self.connector.load_from_state() if task["reindex"] == "1" or not task["poll_range_start"] else self.connector.poll_source(task["poll_range_start"].timestamp(), datetime.now(timezone.utc).timestamp()) ) - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) _begin_info = "totally" if task["reindex"] == "1" or not task["poll_range_start"] else "from {}".format( task["poll_range_start"]) self.log_connection("Discord", f"servers({server_ids}), channel({channel_names})", task) - return document_generator, file_list + return document_generator class Gmail(SyncBase): @@ -607,8 +678,6 @@ async def _generate(self, task: dict): task["connector_id"], ) - file_list = None - # Decide between full reindex and incremental polling by time range. if task["reindex"] == "1" or not task.get("poll_range_start"): start_time = None @@ -628,17 +697,13 @@ async def _generate(self, task: dict): end_time = datetime.now(timezone.utc).timestamp() _begin_info = f"from {poll_start}" document_generator = self.connector.poll_source(start_time, end_time) - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) try: admin_email = self.connector.primary_admin_email except RuntimeError: admin_email = "unknown" self.log_connection("Gmail", f"as {admin_email}", task) - return document_generator, file_list + return document_generator class Dropbox(SyncBase): @@ -648,22 +713,16 @@ async def _generate(self, task: dict): self.connector = DropboxConnector(batch_size=self.conf.get("batch_size", INDEX_BATCH_SIZE)) self.connector.load_credentials(self.conf["credentials"]) poll_start = task["poll_range_start"] - file_list = None - if task["reindex"] == "1" or not poll_start: document_generator = self.connector.load_from_state() _begin_info = "totally" else: end_time = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) document_generator = self.connector.poll_source(poll_start.timestamp(), end_time) _begin_info = f"from {poll_start}" self.log_connection("Dropbox", "workspace", task) - return document_generator, file_list + return document_generator class GoogleDrive(SyncBase): @@ -697,8 +756,6 @@ async def _generate(self, task: dict): if new_credentials: self._persist_rotated_credentials(task["connector_id"], new_credentials) - file_list = None - # Capture end_time BEFORE the snapshot to prevent the ingestion race condition end_time = datetime.now(timezone.utc).timestamp() @@ -708,18 +765,6 @@ async def _generate(self, task: dict): else: start_time = task["poll_range_start"].timestamp() _begin_info = f"from {task['poll_range_start']}" - - if self.conf.get("sync_deleted_files"): - file_list = [] - SlimDoc = namedtuple('SlimDoc', ['id']) - - # Add observability timing so operators can track the O(N) cost - snapshot_start = time.perf_counter() - - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(SlimDoc(doc.id) for doc in slim_batch) - - logging.info("Slim snapshot fetched %d files in %.2f seconds", len(file_list), time.perf_counter() - snapshot_start) raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE try: @@ -765,7 +810,7 @@ def document_batches(): admin_email = "unknown" self.log_connection("Google Drive", f"as {admin_email}", task) - return document_batches(), file_list + return document_batches() def _persist_rotated_credentials(self, connector_id: str, credentials: dict[str, Any]) -> None: """Saves refreshed OAuth credentials back to the database configuration.""" @@ -808,17 +853,12 @@ async def _generate(self, task: dict): self.connector.load_credentials(credentials) self.connector.validate_connector_settings() - file_list = None if task["reindex"] == "1" or not task["poll_range_start"]: start_time = 0.0 _begin_info = "totally" else: start_time = task["poll_range_start"].timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) _begin_info = f"from {task['poll_range_start']}" end_time = datetime.now(timezone.utc).timestamp() @@ -877,7 +917,7 @@ def document_batches(): f"overlap_buffer_s={getattr(self.connector, 'time_buffer_seconds', connector_kwargs.get('time_buffer_seconds'))}" ), ) - return document_batches(), file_list + return document_batches() @staticmethod def _normalize_list(values: Any) -> list[str] | None: @@ -929,25 +969,11 @@ async def _generate(self, task: dict): self.connector.set_allow_images(self.conf.get("allow_images", False)) self.connector.load_credentials(self.conf["credentials"]) - file_list = None if task["reindex"] == "1" or not task["poll_range_start"]: document_batch_generator = self.connector.load_from_state() _begin_info = "totally" else: end_ts = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "WebDAV slim snapshot failed; continuing without stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task["connector_id"], - task["kb_id"], - ) - file_list = None document_batch_generator = self.connector.poll_source( task["poll_range_start"].timestamp(), end_ts, @@ -960,7 +986,7 @@ def wrapper(): for document_batch in document_batch_generator: yield document_batch - return wrapper(), file_list + return wrapper() class Moodle(SyncBase): @@ -976,7 +1002,6 @@ async def _generate(self, task: dict): # Determine the time range for synchronization based on reindex or poll_range_start poll_start = task.get("poll_range_start") - file_list = None if task["reindex"] == "1" or poll_start is None: document_generator = self.connector.load_from_state() @@ -988,20 +1013,6 @@ async def _generate(self, task: dict): # could be polled as new and at the same time be missing from # the slim list, which would mark it as stale and delete it. end_ts = datetime.now(timezone.utc).timestamp() - - if self.conf.get("sync_deleted_files"): - file_list = [] - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "Moodle slim snapshot failed; skipping stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task.get("connector_id"), - task.get("kb_id"), - ) - file_list = None document_generator = self.connector.poll_source( poll_start.timestamp(), end_ts, @@ -1009,7 +1020,7 @@ async def _generate(self, task: dict): _begin_info = f"from {poll_start}" self.log_connection("Moodle", self.conf["moodle_url"], task) - return document_generator, file_list + return document_generator class BOX(SyncBase): @@ -1037,23 +1048,18 @@ async def _generate(self, task: dict): self.connector.load_credentials(auth) poll_start = task["poll_range_start"] - file_list = None if task["reindex"] == "1" or poll_start is None: document_generator = self.connector.load_from_state() _begin_info = "totally" else: - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) document_generator = self.connector.poll_source( poll_start.timestamp(), datetime.now(timezone.utc).timestamp(), ) _begin_info = f"from {poll_start}" self.log_connection("Box", f"folder_id({self.conf['folder_id']})", task) - return document_generator, file_list + return document_generator class Airtable(SyncBase): @@ -1078,16 +1084,11 @@ async def _generate(self, task: dict): ) poll_start = task.get("poll_range_start") - file_list = None if task.get("reindex") == "1" or poll_start is None: document_generator = self.connector.load_from_state() _begin_info = "totally" else: - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) document_generator = self.connector.poll_source( poll_start.timestamp(), datetime.now(timezone.utc).timestamp(), @@ -1100,7 +1101,7 @@ async def _generate(self, task: dict): task, ) - return document_generator, file_list + return document_generator class Asana(SyncBase): SOURCE_NAME: str = FileSource.ASANA @@ -1120,17 +1121,12 @@ async def _generate(self, task: dict): ) poll_start = task.get("poll_range_start") - file_list = None if task.get("reindex") == "1" or not poll_start: document_generator = self.connector.load_from_state() _begin_info = "totally" else: end_time = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) document_generator = self.connector.poll_source( poll_start.timestamp(), end_time, @@ -1143,7 +1139,7 @@ async def _generate(self, task: dict): task, ) - return document_generator, file_list + return document_generator class Github(SyncBase): SOURCE_NAME: str = FileSource.GITHUB @@ -1157,8 +1153,8 @@ async def _generate(self, task: dict): self.connector = GithubConnector( repo_owner=self.conf.get("repository_owner"), repositories=self.conf.get("repository_name"), - include_prs=self.conf.get("include_pull_requests", False), - include_issues=self.conf.get("include_issues", False), + include_prs=self.conf.get("include_pull_requests", True), + include_issues=self.conf.get("include_issues", True), ) credentials = self.conf.get("credentials", {}) @@ -1169,15 +1165,10 @@ async def _generate(self, task: dict): {"github_access_token": credentials["github_access_token"]} ) - file_list = None if task.get("reindex") == "1" or not task.get("poll_range_start"): start_time = datetime.fromtimestamp(0, tz=timezone.utc) else: start_time = task.get("poll_range_start") - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) end_time = datetime.now(timezone.utc) @@ -1214,7 +1205,7 @@ def wrapper(): task, ) - return wrapper(), file_list + return wrapper() class IMAP(SyncBase): SOURCE_NAME: str = FileSource.IMAP @@ -1270,27 +1261,10 @@ async def _generate(self, task): task["connector_id"], ) - file_list = None - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync( - start=initial_sync_start, - end=end_time, - ): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "IMAP slim snapshot failed; continuing without stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task["connector_id"], - task["kb_id"], - ) - file_list = None + self._prune_snapshot_kwargs = { + "start": initial_sync_start, + "end": end_time, + } raw_batch_size = self.conf.get("sync_batch_size") or self.conf.get("batch_size") or INDEX_BATCH_SIZE try: @@ -1336,7 +1310,10 @@ def wrapper(): f"host({self.conf['imap_host']}) port({self.conf['imap_port']}) user({self.conf['credentials']['imap_username']}) folder({self.conf['imap_mailbox']})", task, ) - return wrapper(), file_list + return wrapper() + + def _get_prune_snapshot_kwargs(self, task: dict) -> dict[str, Any]: + return getattr(self, "_prune_snapshot_kwargs", {}) class Zendesk(SyncBase): @@ -1346,26 +1323,11 @@ async def _generate(self, task: dict): self.connector.load_credentials(self.conf["credentials"]) end_time = datetime.now(timezone.utc).timestamp() - file_list = None if task["reindex"] == "1" or not task.get("poll_range_start"): start_time = 0 _begin_info = "totally" else: start_time = task["poll_range_start"].timestamp() - if self.conf.get("sync_deleted_files"): - logging.info( - "[Zendesk] Syncing deleted files via slim snapshot (connector_id=%s)", - task.get("connector_id"), - ) - snapshot_start = time.perf_counter() - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - logging.info( - "[Zendesk] Slim snapshot fetched %d docs in %.2f seconds", - len(file_list), - time.perf_counter() - snapshot_start, - ) _begin_info = f"from {task['poll_range_start']}" raw_batch_size = ( @@ -1426,9 +1388,6 @@ def wrapper(): yield batch self.log_connection("Zendesk", f"subdomain({self.conf['credentials'].get('zendesk_subdomain')})", task) - - if file_list is not None: - return wrapper(), file_list return wrapper() @@ -1455,7 +1414,6 @@ async def _generate(self, task: dict): } ) - file_list = None if task["reindex"] == "1" or not task["poll_range_start"]: document_generator = self.connector.load_from_state() _begin_info = "totally" @@ -1469,13 +1427,9 @@ async def _generate(self, task: dict): poll_start.timestamp(), datetime.now(timezone.utc).timestamp() ) - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) _begin_info = "from {}".format(poll_start) self.log_connection("Gitlab", f"({self.conf['project_name']})", task) - return document_generator, file_list + return document_generator class Bitbucket(SyncBase): @@ -1494,17 +1448,12 @@ async def _generate(self, task: dict): "bitbucket_api_token": self.conf["credentials"].get("bitbucket_api_token"), } ) - file_list = None if task["reindex"] == "1" or not task["poll_range_start"]: start_time = datetime.fromtimestamp(0, tz=timezone.utc) _begin_info = "totally" else: start_time = task.get("poll_range_start") - if self.conf.get("sync_deleted_files"): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) _begin_info = f"from {start_time}" end_time = datetime.now(timezone.utc) @@ -1536,8 +1485,6 @@ def wrapper(): yield batch self.log_connection("Bitbucket", f"workspace({self.conf.get('workspace')})", task) - if file_list is not None: - return wrapper(), file_list return wrapper() @@ -1564,26 +1511,12 @@ async def _generate(self, task: dict): ) self.connector.load_credentials(conf["credentials"]) - file_list = None poll_start = task.get("poll_range_start") if task["reindex"] == "1" or poll_start is None: document_generator = self.connector.load_from_state() _begin_info = "totally" else: end_ts = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "SeaFile slim snapshot failed; continuing without stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task["connector_id"], - task["kb_id"], - ) - file_list = None document_generator = self.connector.poll_source( poll_start.timestamp(), end_ts, @@ -1598,7 +1531,7 @@ async def _generate(self, task: dict): extra += f" path={conf.get('sync_path')}" self.log_connection("SeaFile", f"{conf['seafile_url']} (scope={scope}{extra})", task) - return document_generator, file_list + return document_generator class DingTalkAITable(SyncBase): @@ -1631,33 +1564,12 @@ async def _generate(self, task: dict): ) poll_start = task.get("poll_range_start") - file_list = None if task.get("reindex") == "1" or poll_start is None: document_generator = self.connector.load_from_state() _begin_info = "totally" else: end_ts = datetime.now(timezone.utc).timestamp() - if self.conf.get("sync_deleted_files"): - file_list = [] - logging.info( - "DingTalk AI Table: fetching slim snapshot for stale-document reconciliation " - "(connector_id=%s, kb_id=%s, table_id=%s)", - task["connector_id"], - task["kb_id"], - self.conf.get("table_id"), - ) - try: - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - except Exception: - logging.exception( - "DingTalk AI Table slim snapshot failed; continuing without stale-document cleanup " - "(connector_id=%s, kb_id=%s)", - task["connector_id"], - task["kb_id"], - ) - file_list = None document_generator = self.connector.poll_source( poll_start.timestamp(), end_ts, @@ -1670,7 +1582,7 @@ async def _generate(self, task: dict): task, ) - return document_generator, file_list + return document_generator class _RDBMSBase(SyncBase): @@ -1700,16 +1612,6 @@ async def _generate(self, task: dict): self.connector.validate_connector_settings() self.connector.prepare_sync_state(task["connector_id"], self.conf) - file_list = None - if ( - task["reindex"] != "1" - and task["poll_range_start"] - and self.conf.get("sync_deleted_files") - ): - file_list = [] - for slim_batch in self.connector.retrieve_all_slim_docs_perm_sync(): - file_list.extend(slim_batch) - if task["reindex"] == "1" or not task["poll_range_start"]: document_generator = self.connector.load_from_state() _begin_info = "totally" @@ -1726,7 +1628,7 @@ async def _generate(self, task: dict): _begin_info = f"from {poll_start}" self.log_connection(self.LOG_NAME, f"{self.conf.get('host')}:{self.conf.get('database')}", task) - return document_generator, file_list + return document_generator class MySQL(_RDBMSBase): @@ -1743,6 +1645,33 @@ class PostgreSQL(_RDBMSBase): DEFAULT_PORT: int = 5432 +class REST_API(SyncBase): + SOURCE_NAME: str = FileSource.REST_API + + async def _generate(self, task: dict): + try: + cfg = RestAPIConnector.parse_storage_config(self.conf) + except ConnectorValidationError as exc: + raise ValueError(str(exc)) from exc + + self.connector = RestAPIConnector.from_parsed_config(cfg) + self.connector.load_credentials(self.conf.get("credentials") or {}) + + poll_start = task.get("poll_range_start") + if task.get("reindex") == "1" or poll_start is None: + document_generator = self.connector.load_from_state() + begin_info = "totally" + else: + document_generator = self.connector.poll_source( + poll_start.timestamp(), + datetime.now(timezone.utc).timestamp(), + ) + begin_info = f"from {poll_start}" + + logging.info("Connect to REST API: %s %s %s", self.conf.get("method", "GET"), self.conf.get("url"), begin_info) + return document_generator + + func_factory = { FileSource.RSS: RSS, FileSource.S3: S3, @@ -1773,6 +1702,7 @@ class PostgreSQL(_RDBMSBase): FileSource.MYSQL: MySQL, FileSource.POSTGRESQL: PostgreSQL, FileSource.DINGTALK_AI_TABLE: DingTalkAITable, + FileSource.REST_API: REST_API, } @@ -1780,14 +1710,17 @@ async def dispatch_tasks(): """Polls the database for pending synchronization tasks and dispatches them concurrently.""" while True: try: - list(SyncLogsService.list_sync_tasks()[0]) + SyncLogsService.list_due_sync_tasks() + SyncLogsService.list_due_prune_tasks() break except Exception as e: logging.warning(f"DB is not ready yet: {e}") await asyncio.sleep(3) + due_sync_tasks = SyncLogsService.list_due_sync_tasks() + due_prune_tasks = SyncLogsService.list_due_prune_tasks() tasks = [] - for task in SyncLogsService.list_sync_tasks()[0]: + for task in [*due_sync_tasks, *due_prune_tasks]: if task["poll_range_start"]: task["poll_range_start"] = task["poll_range_start"].astimezone(timezone.utc) if task["poll_range_end"]: diff --git a/rag/svr/task_executor.py b/rag/svr/task_executor.py index 4d563278424..e639ba6e46d 100644 --- a/rag/svr/task_executor.py +++ b/rag/svr/task_executor.py @@ -15,14 +15,17 @@ import time +start_ts = time.time() -from common.misc_utils import thread_pool_exec +# LiteLLM fetches a model cost map from GitHub during import unless this is set. +# Parser pods should not block startup on external network access. +import os +os.environ.setdefault("LITELLM_LOCAL_MODEL_COST_MAP", "True") # no internet, save about 10s -start_ts = time.time() +from common.misc_utils import thread_pool_exec import asyncio import socket -import concurrent # from beartype import BeartypeConf # from beartype.claw import beartype_all # <-- you didn't sign up for this # beartype_all(conf=BeartypeConf(violation_type=UserWarning)) # <-- emit warnings from all code @@ -37,10 +40,17 @@ from common.connection_utils import timeout from common.metadata_utils import turn2jsonschema, update_metadata_to from rag.utils.base64_image import image2id -from rag.utils.raptor_utils import should_skip_raptor, get_skip_reason +from rag.utils.raptor_utils import ( + collect_raptor_chunk_ids, + collect_raptor_methods, + get_raptor_clustering_method, + get_raptor_tree_builder, + get_skip_reason, + make_raptor_summary_chunk_id, + should_skip_raptor, +) from common.log_utils import init_root_logger from common.config_utils import show_configs -from rag.graphrag.general.index import run_graphrag_for_kb from rag.graphrag.utils import get_llm_cache, set_llm_cache, get_tags_from_cache, set_tags_to_cache from rag.prompts.generator import keyword_extraction, question_proposal, content_tagging, run_toc_from_text, \ gen_metadata @@ -71,7 +81,9 @@ from rag.app import laws, paper, presentation, manual, qa, table, book, resume, picture, naive, one, audio, \ email, tag from rag.nlp import search, rag_tokenizer, add_positions -from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor +from rag.raptor import ( + RAPTOR_TREE_BUILDER, +) from common.token_utils import num_tokens_from_string, truncate from rag.utils.redis_conn import REDIS_CONN, RedisDistributedLock from rag.graphrag.utils import chat_limiter @@ -79,9 +91,15 @@ from common.exceptions import TaskCanceledException from common import settings from common.constants import PAGERANK_FLD, TAG_FLD, SVR_CONSUMER_GROUP_NAME +from rag.utils.table_es_metadata import ( + aggregate_table_manual_doc_metadata, + merge_table_parser_config_from_kb, + table_parser_strip_doc_metadata_keys, +) BATCH_SIZE = 64 + FACTORY = { "general": naive, ParserType.NAIVE.value: naive, @@ -268,6 +286,16 @@ async def build_chunks(task, progress_callback): logging.exception("Chunking {}/{} got exception".format(task["location"], task["name"])) raise + # Table parser column roles / mode are stored on the dataset (KB) parser_config; + # chunk tasks carry document-level parser_config only — merge KB keys so manual roles apply. + parser_config_for_chunk = merge_table_parser_config_from_kb(task) + if task.get("parser_id", "").lower() == "table" and task.get("kb_parser_config"): + logging.debug( + "[TASK_EXECUTOR_DEBUG] table parser: merged KB keys into parser_config for chunk; " + f"mode={parser_config_for_chunk.get('table_column_mode')}, " + f"roles_keys={list((parser_config_for_chunk.get('table_column_roles') or {}).keys())}" + ) + try: async with chunk_limiter: cks = await thread_pool_exec( @@ -279,7 +307,7 @@ async def build_chunks(task, progress_callback): lang=task["language"], callback=progress_callback, kb_id=task["kb_id"], - parser_config=task["parser_config"], + parser_config=parser_config_for_chunk, tenant_id=task["tenant_id"], ) logging.info("Chunking({}) {}/{} done".format(timer() - st, task["location"], task["name"])) @@ -369,7 +397,7 @@ async def doc_keyword_extraction(chat_mdl, d, topn): cached = await keyword_extraction(chat_mdl, d["content_with_weight"], topn) set_llm_cache(chat_mdl.llm_name, d["content_with_weight"], cached, "keywords", {"topn": topn}) if cached: - d["important_kwd"] = cached.split(",") + d["important_kwd"] = [k for k in re.split(r"[,,;;、\r\n]+", cached) if k.strip()] d["important_tks"] = rag_tokenizer.tokenize(" ".join(d["important_kwd"])) return @@ -612,7 +640,8 @@ async def embedding(docs, mdl, parser_config=None, callback=None): if not c: c = d["content_with_weight"] c = re.sub(r"]{0,12})?>", " ", c) - if not c: + if not c.strip(): + logging.debug("embedding(): normalized whitespace-only chunk to placeholder 'None' (len=%d)", len(c)) c = "None" cnts.append(c) @@ -759,7 +788,7 @@ def batch_encode(txts): del ck["questions"] if "keywords" in ck: if "important_tks" not in ck: - ck["important_kwd"] = ck["keywords"].split(",") + ck["important_kwd"] = [k for k in re.split(r"[,,;;、\r\n]+", ck["keywords"]) if k.strip()] ck["important_tks"] = rag_tokenizer.tokenize(str(ck["keywords"])) del ck["keywords"] if "summary" in ck: @@ -802,61 +831,161 @@ def batch_encode(txts): dsl=str(pipeline)) -async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str) -> bool: - """Return True if RAPTOR chunks already exist for doc_id in the doc store. +RAPTOR_METHOD_SEARCH_LIMIT = 10000 - Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading - chunk cannot produce a false-negative result. Uses thread_pool_exec so - the blocking doc-store call does not stall the event loop. - """ + +async def get_raptor_chunk_field_map(doc_id: str, tenant_id: str, kb_id: str) -> dict: + """Return stored RAPTOR marker fields for a document.""" from common.doc_store.doc_store_base import OrderByExpr from rag.nlp import search as nlp_search - try: - condition = {"doc_id": doc_id, "raptor_kwd": ["raptor"]} + + async def search_fields(fields: list[str], condition: dict, order_by=None): + """Search chunk fields in the current knowledge base.""" res = await thread_pool_exec( settings.docStoreConn.search, - ["raptor_kwd"], [], condition, [], OrderByExpr(), - 0, 1, nlp_search.index_name(tenant_id), [kb_id] + fields, [], condition, [], order_by or OrderByExpr(), + 0, RAPTOR_METHOD_SEARCH_LIMIT, nlp_search.index_name(tenant_id), [kb_id] + ) + return settings.docStoreConn.get_fields(res, fields) + + primary = await search_fields(["raptor_kwd", "extra"], {"doc_id": doc_id, "raptor_kwd": ["raptor"]}) + if collect_raptor_chunk_ids(primary): + return primary + + try: + return await search_fields( + ["raptor_kwd", "extra"], + {"doc_id": doc_id}, + OrderByExpr().desc("create_timestamp_flt"), ) - field_map = settings.docStoreConn.get_fields(res, ["raptor_kwd"]) - found = bool(field_map) - if found: + except Exception: + logging.debug("RAPTOR fallback method lookup with extra field failed for doc %s", doc_id, exc_info=True) + return primary + + +async def get_raptor_chunk_methods(doc_id: str, tenant_id: str, kb_id: str) -> set[str]: + """Return the RAPTOR tree builders already stored for doc_id. + + Queries directly for raptor_kwd="raptor" rows so a non-RAPTOR leading + chunk cannot produce a false-negative result. Legacy summary chunks that + do not have method metadata are treated as the original RAPTOR builder. + """ + try: + field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id) + methods = collect_raptor_methods(field_map) + if methods: logging.info( - "Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s) already exist", - doc_id, tenant_id, kb_id, + "Checkpoint hit: RAPTOR chunks for doc %s (tenant=%s kb=%s methods=%s) already exist", + doc_id, tenant_id, kb_id, sorted(methods), ) else: logging.info( "Checkpoint miss: no RAPTOR chunks for doc %s (tenant=%s kb=%s)", doc_id, tenant_id, kb_id, ) - return found + return methods except Exception: logging.exception("Failed to check RAPTOR chunks for doc %s", doc_id) - return False + raise + + +async def has_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, tree_builder: str = RAPTOR_TREE_BUILDER) -> bool: + """Return whether doc_id already has summaries for tree_builder.""" + methods = await get_raptor_chunk_methods(doc_id, tenant_id, kb_id) + return tree_builder in methods + + +async def delete_raptor_chunks(doc_id: str, tenant_id: str, kb_id: str, keep_method: str | None = None): + """Delete RAPTOR summaries for doc_id, optionally preserving one method.""" + from rag.nlp import search as nlp_search + + if keep_method is None: + logging.info( + "delete_raptor_chunks: removing all RAPTOR summaries (doc=%s tenant=%s kb=%s)", + doc_id, tenant_id, kb_id, + ) + await thread_pool_exec( + settings.docStoreConn.delete, + {"doc_id": doc_id, "raptor_kwd": ["raptor"]}, + nlp_search.index_name(tenant_id), + kb_id, + ) + return 0 + + field_map = await get_raptor_chunk_field_map(doc_id, tenant_id, kb_id) + chunk_ids = collect_raptor_chunk_ids(field_map, exclude_methods={keep_method}) + if not chunk_ids: + logging.debug( + "delete_raptor_chunks: no stale RAPTOR chunks to remove (doc=%s tenant=%s kb=%s keep=%s)", + doc_id, tenant_id, kb_id, keep_method, + ) + return 0 + + logging.info( + "delete_raptor_chunks: removing %d stale RAPTOR chunks (doc=%s tenant=%s kb=%s keep=%s)", + len(chunk_ids), doc_id, tenant_id, kb_id, keep_method, + ) + await thread_pool_exec( + settings.docStoreConn.delete, + {"id": list(chunk_ids)}, + nlp_search.index_name(tenant_id), + kb_id, + ) + return len(chunk_ids) @timeout(3600) async def run_raptor_for_kb(row, kb_parser_config, chat_mdl, embd_mdl, vector_size, callback=None, doc_ids=[]): + """Generate RAPTOR summaries for selected documents in a knowledge base.""" fake_doc_id = GRAPH_RAPTOR_FAKE_DOC_ID raptor_config = kb_parser_config.get("raptor", {}) + raptor_ext_config = raptor_config.get("ext") or {} + tree_builder = get_raptor_tree_builder(raptor_config) + clustering_method = get_raptor_clustering_method(raptor_config) vctr_nm = "q_%d_vec" % vector_size res = [] tk_count = 0 + cleanup_raptor_chunks = [] max_errors = int(os.environ.get("RAPTOR_MAX_ERRORS", 3)) - doc_name_by_id = {} + doc_info_by_id = {} for doc_id in set(doc_ids): ok, source_doc = DocumentService.get_by_id(doc_id) if not ok or not source_doc: continue - source_name = getattr(source_doc, "name", "") - if source_name: - doc_name_by_id[doc_id] = source_name + doc_info_by_id[doc_id] = { + "name": getattr(source_doc, "name", ""), + "type": getattr(source_doc, "type", ""), + "parser_id": getattr(source_doc, "parser_id", ""), + "parser_config": getattr(source_doc, "parser_config", {}) or {}, + } + + def schedule_raptor_cleanup(doc_id: str, keep_method: str | None = None): + """Queue stale RAPTOR summaries for deletion after successful insert.""" + cleanup_plan = (doc_id, keep_method) + if cleanup_plan not in cleanup_raptor_chunks: + cleanup_raptor_chunks.append(cleanup_plan) + + def skip_raptor_doc(doc_id: str) -> bool: + """Return whether RAPTOR should be skipped for this source document.""" + doc_info = doc_info_by_id.get(doc_id, {}) + file_type = doc_info.get("type") or row.get("type", "") + parser_id = doc_info.get("parser_id") or row.get("parser_id", "") + parser_config = doc_info.get("parser_config") or row.get("parser_config", {}) + if should_skip_raptor(file_type, parser_id, parser_config, raptor_config): + skip_reason = get_skip_reason(file_type, parser_id, parser_config) + doc_name = doc_info.get("name") or doc_id + logging.info("Skipping Raptor for document %s: %s", doc_name, skip_reason) + callback(msg=f"[RAPTOR] doc:{doc_id} skipped: {skip_reason}") + return True + return False async def generate(chunks, did): + """Run RAPTOR and append generated summary chunks for one doc id.""" nonlocal tk_count, res + logging.info("RAPTOR: using tree_builder=%s clustering_method=%s for doc %s", tree_builder, clustering_method, did) + from rag.raptor import RecursiveAbstractiveProcessing4TreeOrganizedRetrieval as Raptor # Lazy load, save around 8s raptor = Raptor( raptor_config.get("max_cluster", 64), chat_mdl, @@ -865,16 +994,21 @@ async def generate(chunks, did): raptor_config["max_token"], raptor_config["threshold"], max_errors=max_errors, + tree_builder=tree_builder, + clustering_method=clustering_method, + psi_exact_max_leaves=raptor_ext_config.get("psi_exact_max_leaves", 4096), + psi_bucket_size=raptor_ext_config.get("psi_bucket_size", 1024), ) original_length = len(chunks) chunks, layers = await raptor(chunks, kb_parser_config["raptor"]["random_seed"], callback, row["id"]) - effective_doc_name = row["name"] if did == fake_doc_id else doc_name_by_id.get(did, row["name"]) + effective_doc_name = row["name"] if did == fake_doc_id else doc_info_by_id.get(did, {}).get("name") or row["name"] doc = { "doc_id": did, "kb_id": [str(row["kb_id"])], "docnm_kwd": effective_doc_name, "title_tks": rag_tokenizer.tokenize(effective_doc_name), - "raptor_kwd": "raptor" + "raptor_kwd": "raptor", + "extra": {"raptor_method": tree_builder}, } if row["pagerank"]: doc[PAGERANK_FLD] = int(row["pagerank"]) @@ -891,7 +1025,7 @@ async def generate(chunks, did): for idx, (content, vctr) in enumerate(chunks[original_length:], start=original_length): d = copy.deepcopy(doc) - d["id"] = xxhash.xxh64((content + str(fake_doc_id)).encode("utf-8")).hexdigest() + d["id"] = make_raptor_summary_chunk_id(content, did) d["create_time"] = str(datetime.now()).replace("T", " ")[:19] d["create_timestamp_flt"] = datetime.now().timestamp() d[vctr_nm] = vctr.tolist() @@ -903,12 +1037,28 @@ async def generate(chunks, did): tk_count += num_tokens_from_string(content) if raptor_config.get("scope", "file") == "file": + dataset_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"]) + remove_dataset_summaries = bool(dataset_methods) + has_file_level_target = False + if dataset_methods: + callback(msg="[RAPTOR] will remove dataset-level summaries after file-level summaries are available.") + for x, doc_id in enumerate(doc_ids): + if skip_raptor_doc(doc_id): + callback(prog=(x + 1.) / len(doc_ids)) + continue # CHECKPOINT: skip docs that already have RAPTOR chunks in the doc store - if await has_raptor_chunks(doc_id, row["tenant_id"], row["kb_id"]): - callback(msg=f"[RAPTOR] doc:{doc_id} already has RAPTOR chunks, skipping.") + existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"]) + if tree_builder in existing_methods: + has_file_level_target = True + if existing_methods != {tree_builder}: + schedule_raptor_cleanup(doc_id, tree_builder) + callback(msg=f"[RAPTOR] doc:{doc_id} will remove old RAPTOR summaries after insert.") + callback(msg=f"[RAPTOR] doc:{doc_id} already has {tree_builder} RAPTOR chunks, skipping.") callback(prog=(x + 1.) / len(doc_ids)) continue + if existing_methods: + callback(msg=f"[RAPTOR] doc:{doc_id} will migrate RAPTOR summaries to {tree_builder} after insert.") chunks = [] skipped_chunks = 0 @@ -930,12 +1080,52 @@ async def generate(chunks, did): callback(msg=f"[WARN] No valid chunks with vectors found for doc {doc_id}, skipping") continue + before_generate = len(res) await generate(chunks, doc_id) + if len(res) > before_generate: + has_file_level_target = True + if existing_methods: + schedule_raptor_cleanup(doc_id, tree_builder) callback(prog=(x + 1.) / len(doc_ids)) + + if remove_dataset_summaries: + if has_file_level_target: + schedule_raptor_cleanup(fake_doc_id) + else: + callback(msg="[RAPTOR] kept dataset-level summaries because no file-level summaries were built.") else: + migrated_file_docs = 0 + file_cleanup_doc_ids = [] + skipped_doc_ids = set() + for doc_id in set(doc_ids): + if skip_raptor_doc(doc_id): + skipped_doc_ids.add(doc_id) + continue + existing_methods = await get_raptor_chunk_methods(doc_id, row["tenant_id"], row["kb_id"]) + if existing_methods: + file_cleanup_doc_ids.append(doc_id) + migrated_file_docs += 1 + if migrated_file_docs: + callback(msg=f"[RAPTOR] will remove file-level summaries for {migrated_file_docs} docs after dataset-level build succeeds.") + + existing_methods = await get_raptor_chunk_methods(fake_doc_id, row["tenant_id"], row["kb_id"]) + if tree_builder in existing_methods: + if existing_methods != {tree_builder}: + schedule_raptor_cleanup(fake_doc_id, tree_builder) + callback(msg="[RAPTOR] will remove old dataset-level RAPTOR summaries after insert.") + for doc_id in file_cleanup_doc_ids: + schedule_raptor_cleanup(doc_id) + callback(msg=f"[RAPTOR] dataset-level {tree_builder} summaries already exist, skipping.") + return res, tk_count, cleanup_raptor_chunks + migrate_dataset_summaries = bool(existing_methods) + if migrate_dataset_summaries: + callback(msg=f"[RAPTOR] will migrate dataset-level RAPTOR summaries to {tree_builder} after insert.") + chunks = [] skipped_chunks = 0 for doc_id in doc_ids: + if doc_id in skipped_doc_ids: + continue for d in settings.retriever.chunk_list(doc_id, row["tenant_id"], [str(row["kb_id"])], fields=["content_with_weight", vctr_nm], sort_by_position=True): @@ -950,13 +1140,22 @@ async def generate(chunks, did): callback(msg=f"[WARN] Skipped {skipped_chunks} chunks without vector field '{vctr_nm}'. Consider re-parsing documents with the current embedding model.") if not chunks: + if skipped_doc_ids and len(skipped_doc_ids) == len(set(doc_ids)): + callback(msg="[RAPTOR] all documents were skipped by RAPTOR auto-disable rules.") + return res, tk_count, cleanup_raptor_chunks logging.error(f"RAPTOR: No valid chunks with vectors found in any document for kb {row['kb_id']}") callback(msg=f"[ERROR] No valid chunks with vectors found. Please ensure documents are parsed with the current embedding model (vector size: {vector_size}).") - return res, tk_count + return res, tk_count, cleanup_raptor_chunks + before_generate = len(res) await generate(chunks, fake_doc_id) + if len(res) > before_generate: + for doc_id in file_cleanup_doc_ids: + schedule_raptor_cleanup(doc_id) + if migrate_dataset_summaries: + schedule_raptor_cleanup(fake_doc_id, tree_builder) - return res, tk_count + return res, tk_count, cleanup_raptor_chunks async def delete_image(kb_id, chunk_id): @@ -1014,6 +1213,29 @@ async def insert_chunks(task_id, task_tenant_id, task_dataset_id, chunks, progre search.index_name(task_tenant_id), task_dataset_id, ) task_canceled = has_canceled(task_id) if task_canceled: + # Roll back partial RAPTOR summary inserts so the next run is not + # mistaken for a completed checkpoint by get_raptor_chunk_methods. + raptor_ids_to_rollback = [ + c["id"] for c in chunks[:b + settings.DOC_BULK_SIZE] + if c.get("raptor_kwd") == "raptor" + ] + if raptor_ids_to_rollback: + try: + await thread_pool_exec( + settings.docStoreConn.delete, + {"id": raptor_ids_to_rollback}, + search.index_name(task_tenant_id), + task_dataset_id, + ) + logging.info( + "insert_chunks: rolled back %d partial RAPTOR chunks after cancellation (task=%s)", + len(raptor_ids_to_rollback), task_id, + ) + except Exception: + logging.exception( + "insert_chunks: failed to roll back partial RAPTOR chunks after cancellation (task=%s)", + task_id, + ) progress_callback(-1, msg="Task has been canceled.") return False if b % 128 == 0: @@ -1073,7 +1295,7 @@ async def do_handle_task(task): task_parser_config = task["parser_config"] task_start_ts = timer() toc_thread = None - executor = concurrent.futures.ThreadPoolExecutor() + raptor_cleanup_chunks = [] # prepare the progress callback function progress_callback = partial(set_progress, task_id, task_from_page, task_to_page) @@ -1121,7 +1343,9 @@ async def do_handle_task(task): "threshold": 0.1, "max_cluster": 64, "random_seed": 0, - "scope": "file" + "scope": "file", + "clustering_method": "gmm", + "tree_builder": "raptor", }, } ) @@ -1129,23 +1353,12 @@ async def do_handle_task(task): progress_callback(prog=-1.0, msg="Internal error: Invalid RAPTOR configuration") return - # Check if Raptor should be skipped for structured data - file_type = task.get("type", "") - parser_id = task.get("parser_id", "") - raptor_config = kb_parser_config.get("raptor", {}) - - if should_skip_raptor(file_type, parser_id, task_parser_config, raptor_config): - skip_reason = get_skip_reason(file_type, parser_id, task_parser_config) - logging.info(f"Skipping Raptor for document {task_document_name}: {skip_reason}") - progress_callback(prog=1.0, msg=f"Raptor skipped: {skip_reason}") - return - # bind LLM for raptor chat_model_config = get_model_config_by_type_and_name(task_tenant_id, LLMType.CHAT, kb_task_llm_id) chat_model = LLMBundle(task_tenant_id, chat_model_config, lang=task_language) # run RAPTOR async with kg_limiter: - chunks, token_count = await run_raptor_for_kb( + chunks, token_count, raptor_cleanup_chunks = await run_raptor_for_kb( row=task, kb_parser_config=kb_parser_config, chat_mdl=chat_model, @@ -1177,6 +1390,7 @@ async def do_handle_task(task): "category", ], "method": "light", + "batch_chunk_token_size": 4096, } } ) @@ -1192,6 +1406,7 @@ async def do_handle_task(task): with_community = graphrag_conf.get("community", False) async with kg_limiter: # await run_graphrag(task, task_language, with_resolution, with_community, chat_model, embedding_model, progress_callback) + from rag.graphrag.general.index import run_graphrag_for_kb # Lazy load, save around 2s result = await run_graphrag_for_kb( row=task, doc_ids=task.get("doc_ids", []), @@ -1235,7 +1450,7 @@ async def do_handle_task(task): logging.info(progress_message) progress_callback(msg=progress_message) if task["parser_id"].lower() == "naive" and task["parser_config"].get("toc_extraction", False): - toc_thread = executor.submit(build_TOC, task, chunks, progress_callback) + toc_thread = asyncio.create_task(asyncio.to_thread(build_TOC, task, chunks, progress_callback)) chunk_count = len(set([chunk["id"] for chunk in chunks])) start_ts = timer() @@ -1254,6 +1469,18 @@ async def _maybe_insert_chunks(_chunks): progress_callback(-1, msg="Task has been canceled.") return + if raptor_cleanup_chunks: + cleaned_chunks = 0 + for cleanup_doc_id, keep_method in raptor_cleanup_chunks: + cleaned_chunks += await delete_raptor_chunks( + cleanup_doc_id, + task_tenant_id, + task_dataset_id, + keep_method=keep_method, + ) + if cleaned_chunks: + progress_callback(msg=f"Cleaned up {cleaned_chunks} stale RAPTOR chunks.") + logging.info( "Indexing doc({}), page({}-{}), chunks({}), elapsed: {:.2f}".format( task_document_name, task_from_page, task_to_page, len(chunks), timer() - start_ts @@ -1262,10 +1489,47 @@ async def _maybe_insert_chunks(_chunks): DocumentService.increment_chunk_num(task_doc_id, task_dataset_id, token_count, chunk_count, 0) + # Table parser (manual): push metadata/both column values to document-level metadata for UI / chat filters + if task.get("parser_id", "").lower() == "table": + eff_pc = merge_table_parser_config_from_kb(task) + logging.debug( + f"[TABLE_META_DEBUG] table post-index: table_column_mode={eff_pc.get('table_column_mode')!r}" + ) + if eff_pc.get("table_column_mode") == "manual": + try: + agg = aggregate_table_manual_doc_metadata(chunks, task) + logging.debug(f"[TABLE_META_DEBUG] aggregated metadata: {agg}") + strip_keys = table_parser_strip_doc_metadata_keys(eff_pc) + existing = DocMetadataService.get_document_metadata(task_doc_id) + existing = existing if isinstance(existing, dict) else {} + preserved = {k: v for k, v in existing.items() if k not in strip_keys} + merged = update_metadata_to(dict(preserved), agg) + logging.debug( + f"[TABLE_META_DEBUG] calling update_document_metadata for doc_id={task_doc_id}, " + f"meta_fields keys={list(merged.keys())}, " + f"table_strip_key_count={len(strip_keys)}, agg_keys={list(agg.keys())}" + ) + try: + DocMetadataService.update_document_metadata(task_doc_id, merged) + logging.debug("[TABLE_META_DEBUG] update_document_metadata succeeded") + except Exception as ue: + logging.error( + "update_document_metadata failed (table parser, doc_id=%s): %s", + task_doc_id, + ue, + exc_info=True, + ) + except Exception as e: + logging.exception( + "Table parser document metadata aggregation failed (doc_id=%s): %s", + task_doc_id, + e, + ) + progress_callback(msg="Indexing done ({:.2f}s).".format(timer() - start_ts)) if toc_thread: - d = toc_thread.result() + d = await toc_thread if d: if not await _maybe_insert_chunks([d]): return @@ -1284,7 +1548,8 @@ async def _maybe_insert_chunks(_chunks): ) finally: - executor.shutdown(wait=False) + if toc_thread is not None and not toc_thread.done(): + toc_thread.cancel() if has_canceled(task_id): try: exists = await thread_pool_exec( diff --git a/rag/utils/es_conn.py b/rag/utils/es_conn.py index 51356befad1..eed2e67c27c 100644 --- a/rag/utils/es_conn.py +++ b/rag/utils/es_conn.py @@ -69,8 +69,7 @@ def _es_search_once(self, index_names: list[str], query: dict, track_total_hits: index=index_names, body=query, timeout="600s", - track_total_hits=track_total_hits, - _source=True, + track_total_hits=track_total_hits ) def _search_with_search_after(self, index_names: list[str], query: dict, offset: int, limit: int): @@ -324,7 +323,7 @@ def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = try: res = [] r = self.es.bulk(index=index_name, operations=operations, - refresh=False, timeout="60s") + refresh="wait_for", timeout="60s") if re.search(r"False", str(r["errors"]), re.IGNORECASE): return res diff --git a/rag/utils/infinity_conn.py b/rag/utils/infinity_conn.py index 45290c520d6..7ffd9f13d42 100644 --- a/rag/utils/infinity_conn.py +++ b/rag/utils/infinity_conn.py @@ -318,7 +318,10 @@ def get(self, chunk_id: str, index_name: str, knowledgebase_ids: list[str]) -> d "authors_sm_tks"]: fields.add(field) res_fields = self.get_fields(res, list(fields)) - return res_fields.get(chunk_id, None) + chunk = res_fields.get(chunk_id, None) + if chunk is not None: + chunk["id"] = chunk_id + return chunk def insert(self, documents: list[dict], index_name: str, knowledgebase_id: str = None) -> list[str]: ''' diff --git a/rag/utils/ob_conn.py b/rag/utils/ob_conn.py index 22fbc9c7b1a..fde2138f0e5 100644 --- a/rag/utils/ob_conn.py +++ b/rag/utils/ob_conn.py @@ -46,6 +46,8 @@ column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval") column_mom_id = Column("mom_id", String(256), nullable=True, comment="parent chunk id") column_chunk_data = Column("chunk_data", JSON, nullable=True, comment="table parser row data") +column_raptor_kwd = Column("raptor_kwd", String(256), nullable=True, comment="RAPTOR summary marker") +column_raptor_layer_int = Column("raptor_layer_int", Integer, nullable=True, comment="RAPTOR summary layer") column_definitions: list[Column] = [ Column("id", String(256), primary_key=True, comment="chunk id"), @@ -86,6 +88,8 @@ Column("rank_flt", Double, nullable=True, comment="rank of this entity"), Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'", comment="whether it has been deleted"), + column_raptor_kwd, + column_raptor_layer_int, column_chunk_data, Column("metadata", JSON, nullable=True, comment="metadata for this chunk"), Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"), @@ -127,7 +131,14 @@ ] # Extra columns to add after table creation (for migration) -EXTRA_COLUMNS: list[Column] = [column_order_id, column_group_id, column_mom_id, column_chunk_data] +EXTRA_COLUMNS: list[Column] = [ + column_order_id, + column_group_id, + column_mom_id, + column_chunk_data, + column_raptor_kwd, + column_raptor_layer_int, +] class SearchResult(BaseModel): diff --git a/rag/utils/opensearch_conn.py b/rag/utils/opensearch_conn.py index cb8b70ac2d1..2239102ef31 100644 --- a/rag/utils/opensearch_conn.py +++ b/rag/utils/opensearch_conn.py @@ -126,6 +126,99 @@ def create_idx(self, indexName: str, knowledgebaseId: str, vectorSize: int, pars except Exception: logger.exception("OSConnection.createIndex error %s" % (indexName)) + def create_doc_meta_idx(self, index_name: str): + """ + Create a per-tenant document metadata index on OpenSearch. + + Mirrors ESConnectionBase.create_doc_meta_idx so that the + DocMetadataService dispatches uniformly across ES and OS backends. + Index name pattern: ragflow_doc_meta_{tenant_id} + """ + if self.index_exist(index_name, ""): + return True + try: + fp_mapping = os.path.join(get_project_base_directory(), "conf", "doc_meta_es_mapping.json") + if not os.path.exists(fp_mapping): + logger.error(f"Document metadata mapping file not found at {fp_mapping}") + return False + + with open(fp_mapping, "r") as f: + doc_meta_mapping = json.load(f) + + from opensearchpy.client import IndicesClient + body = { + "settings": doc_meta_mapping["settings"], + "mappings": doc_meta_mapping["mappings"], + } + return IndicesClient(self.os).create(index=index_name, body=body) + except Exception as e: + logger.exception(f"OSConnection.create_doc_meta_idx error creating {index_name}: {e}") + return False + + def refresh_idx(self, index_name: str) -> bool: + """ + Refresh an index so that recently inserted documents become searchable. + + DocMetadataService used to call ``settings.docStoreConn.es.indices.refresh`` + directly, which raised AttributeError on the OpenSearch backend because + OSConnection exposes ``self.os`` rather than ``self.es``. This wrapper + gives both backends a uniform abstract entry point. + """ + try: + self.os.indices.refresh(index=index_name) + return True + except NotFoundError: + return False + except Exception as e: + logger.warning(f"OSConnection.refresh_idx({index_name}) failed: {e}") + return False + + def count_idx(self, index_name: str) -> int: + """ + Return the document count for an index, or -1 if the call fails. + + Used by DocMetadataService._drop_empty_metadata_table to decide whether + a per-tenant metadata index is empty without paying a full search. + """ + try: + response = self.os.count(index=index_name) + return int(response.get("count", 0)) + except NotFoundError: + return 0 + except Exception as e: + logger.warning(f"OSConnection.count_idx({index_name}) failed: {e}") + return -1 + + def replace_meta_fields(self, index_name: str, doc_id: str, meta_fields: dict) -> bool: + """ + Replace the ``meta_fields`` object on a single document. + + ES.update with a ``doc`` body deep-merges object fields, which retains + old keys that should be removed. The fix in ESConnection is a script + that fully assigns the new meta_fields. We provide the same primitive + on OpenSearch so the service layer never reaches into ``self.es`` or + ``self.os`` directly. + """ + body = { + "script": { + "source": "ctx._source.meta_fields = params.meta_fields", + "params": {"meta_fields": meta_fields}, + } + } + for _ in range(ATTEMPT_TIME): + try: + self.os.update(index=index_name, id=doc_id, body=body, refresh=True) + return True + except NotFoundError: + return False + except Exception as e: + logger.warning(f"OSConnection.replace_meta_fields({index_name}, {doc_id}) failed: {e}") + if re.search(r"(timeout|connection)", str(e).lower()): + time.sleep(1) + continue + return False + return False + def delete_idx(self, indexName: str, knowledgebaseId: str): if len(knowledgebaseId) > 0: # The index need to be alive after any kb deletion since all kb under this tenant are in one index. @@ -327,7 +420,7 @@ def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = N try: res = [] r = self.os.bulk(index=(indexName), body=operations, - refresh=False, timeout=60) + refresh="wait_for", timeout=60) if re.search(r"False", str(r["errors"]), re.IGNORECASE): return res diff --git a/rag/utils/raptor_utils.py b/rag/utils/raptor_utils.py index dd6f75dd9a7..91d43cd9374 100644 --- a/rag/utils/raptor_utils.py +++ b/rag/utils/raptor_utils.py @@ -18,15 +18,111 @@ Utility functions for Raptor processing decisions. """ +import json import logging from typing import Optional +import xxhash + +RAPTOR_TREE_BUILDER = "raptor" +PSI_TREE_BUILDER = "psi" +SUPPORTED_TREE_BUILDERS = {RAPTOR_TREE_BUILDER, PSI_TREE_BUILDER} +GMM_CLUSTERING_METHOD = "gmm" +AHC_CLUSTERING_METHOD = "ahc" +SUPPORTED_CLUSTERING_METHODS = {GMM_CLUSTERING_METHOD, AHC_CLUSTERING_METHOD} + # File extensions for structured data types EXCEL_EXTENSIONS = {".xls", ".xlsx", ".xlsm", ".xlsb"} CSV_EXTENSIONS = {".csv", ".tsv"} STRUCTURED_EXTENSIONS = EXCEL_EXTENSIONS | CSV_EXTENSIONS +def get_raptor_tree_builder(raptor_config: dict | None) -> str: + """Return the configured RAPTOR tree builder with legacy ext fallback.""" + raptor_config = raptor_config or {} + ext = raptor_config.get("ext") or {} + tree_builder = ext.get("tree_builder") or raptor_config.get("tree_builder") or RAPTOR_TREE_BUILDER + if tree_builder not in SUPPORTED_TREE_BUILDERS: + raise ValueError(f"Unsupported RAPTOR tree builder: {tree_builder}") + return tree_builder + + +def get_raptor_clustering_method(raptor_config: dict | None) -> str: + """Return the configured RAPTOR clustering method with legacy ext fallback.""" + raptor_config = raptor_config or {} + ext = raptor_config.get("ext") or {} + clustering_method = ext.get("clustering_method") or raptor_config.get("clustering_method") or GMM_CLUSTERING_METHOD + if clustering_method not in SUPPORTED_CLUSTERING_METHODS: + raise ValueError(f"Unsupported RAPTOR clustering method: {clustering_method}") + return clustering_method + + +def _as_extra_dict(extra) -> dict: + """Normalize a chunk extra payload into a dictionary.""" + if isinstance(extra, dict): + return extra + if isinstance(extra, str) and extra: + try: + parsed = json.loads(extra) + except json.JSONDecodeError: + logging.warning( + "Ignoring malformed RAPTOR extra payload while collecting chunk metadata: %s", + extra[:200], + exc_info=True, + ) + return {} + return parsed if isinstance(parsed, dict) else {} + return {} + + +def _has_raptor_marker(marker) -> bool: + """Return whether a chunk marker identifies a RAPTOR summary chunk.""" + if isinstance(marker, list): + return any(str(item) == RAPTOR_TREE_BUILDER for item in marker) + return str(marker) == RAPTOR_TREE_BUILDER + + +def _raptor_methods_from_fields(fields: dict, extra: dict | None = None) -> set[str]: + """Read RAPTOR builder methods from stored chunk fields.""" + extra = extra if extra is not None else _as_extra_dict(fields.get("extra")) + method = extra.get("raptor_method") or RAPTOR_TREE_BUILDER + if isinstance(method, list): + return {str(item) for item in method if item} + return {str(method)} if method else set() + + +def collect_raptor_methods(field_map: dict) -> set[str]: + """Collect tree-builder methods from RAPTOR summary chunk fields.""" + methods = set() + for fields in field_map.values(): + extra = _as_extra_dict(fields.get("extra")) + marker = fields.get("raptor_kwd") or extra.get("raptor_kwd") + if not _has_raptor_marker(marker): + continue + + methods.update(_raptor_methods_from_fields(fields, extra)) + return methods + + +def collect_raptor_chunk_ids(field_map: dict, exclude_methods: set[str] | None = None) -> set[str]: + """Collect RAPTOR summary chunk IDs, optionally excluding some methods.""" + chunk_ids = set() + exclude_methods = exclude_methods or set() + for chunk_id, fields in field_map.items(): + extra = _as_extra_dict(fields.get("extra")) + marker = fields.get("raptor_kwd") or extra.get("raptor_kwd") + if _has_raptor_marker(marker): + if _raptor_methods_from_fields(fields, extra).issubset(exclude_methods): + continue + chunk_ids.add(chunk_id) + return chunk_ids + + +def make_raptor_summary_chunk_id(content: str, doc_id: str) -> str: + """Build the stable ID used for generated RAPTOR summary chunks.""" + return xxhash.xxh64((content + str(doc_id)).encode("utf-8")).hexdigest() + + def is_structured_file_type(file_type: Optional[str]) -> bool: """ Check if a file type is structured data (Excel, CSV, etc.) diff --git a/rag/utils/table_es_metadata.py b/rag/utils/table_es_metadata.py new file mode 100644 index 00000000000..18edfc4696d --- /dev/null +++ b/rag/utils/table_es_metadata.py @@ -0,0 +1,296 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Table manual-mode ES field resolution and document metadata aggregation (lightweight; used by task_executor).""" + +import logging + +from common import settings +from common.metadata_utils import dedupe_list + + +def _knowledgebase_service_cls(): + """Lazy import for KnowledgebaseService (used by aggregate; mockable in unit tests).""" + from api.db.services.knowledgebase_service import KnowledgebaseService + + return KnowledgebaseService + + +def merge_table_parser_config_from_kb(task: dict) -> dict: + """Merge dataset-level table parser keys into document parser_config (see build_chunks).""" + pc = task.get("parser_config") or {} + if task.get("parser_id", "").lower() != "table" or not task.get("kb_parser_config"): + return pc + out = dict(pc) + kb_pc = task["kb_parser_config"] + for _k in ("table_column_mode", "table_column_roles", "table_column_names"): + if _k in kb_pc: + out[_k] = kb_pc[_k] + return out + + +def table_parser_strip_doc_metadata_keys(eff_parser_config: dict) -> frozenset[str]: + """ + Table manual mode stores per-column values under document metadata keys equal to the + CSV column name. On reparse, strip these keys from existing metadata before merging + a fresh aggregate so columns switched to indexing-only (or removed) do not persist. + """ + names = eff_parser_config.get("table_column_names") + if names: + return frozenset(str(n).strip() for n in names if n is not None and str(n).strip()) + roles = eff_parser_config.get("table_column_roles") or {} + return frozenset(str(k).strip() for k in roles if k is not None and str(k).strip()) + + +def _field_map_typed_key_for_column(field_map: dict, col: str) -> str | None: + """Map CSV column name to ES typed field key (field_map: typed_key -> display name).""" + if not field_map or not col: + return None + col_s = str(col).strip() + col_norm = col_s.replace("_", " ").strip().lower() + for tk, disp in field_map.items(): + disp_s = str(disp).strip() + if disp_s.lower() == col_norm or disp_s.lower() == col_s.lower(): + return tk + return None + + +def _probe_es_typed_key_for_column(col: str, sample_chunk: dict) -> str | None: + """ + When field_map is missing/stale, try to infer the ES field key present on a chunk. + Table chunks use normalized/pinyin keys of the form , where suffix is + one of: _raw, _tks, _dt, _long, _flt, _kwd (see rag/app/table.py). + """ + if not col or not isinstance(sample_chunk, dict): + return None + base_raw = str(col).strip() + if not base_raw: + return None + base_norm = base_raw.replace("_", " ").strip().lower().replace(" ", "") + suffixes = ("_tks", "_raw", "_dt", "_long", "_flt", "_kwd") + for key in sample_chunk.keys(): + key_s = str(key) + if not key_s: + continue + key_norm = key_s.strip().lower() + if key_norm == base_raw.lower() or key_norm.replace("_", "").replace(" ", "") == base_norm: + return key_s + for key in sample_chunk.keys(): + key_s = str(key) + if not key_s: + continue + key_lower = key_s.lower() + for sfx in suffixes: + if key_lower.endswith(sfx): + core = key_lower[: -len(sfx)] + core_norm = core.replace("_", "").replace(" ", "") + if core_norm == base_norm: + return key_s + return None + + +def _resolve_es_chunk_field_key( + col: str, field_map: dict, sample_chunk: dict | None +) -> tuple[str | None, str]: + """Prefer field_map when key exists on chunk; else probe by suffix (matches table.py naming).""" + tk_fm = _field_map_typed_key_for_column(field_map, col) if field_map else None + if sample_chunk: + if tk_fm and tk_fm in sample_chunk: + return tk_fm, "field_map" + probed = _probe_es_typed_key_for_column(col, sample_chunk) + if probed: + return probed, "probe" if not tk_fm else "probe_field_map_mismatch" + if tk_fm: + return tk_fm, "field_map_absent_on_chunk" + if tk_fm: + return tk_fm, "field_map" + return None, "none" + + +def _value_to_meta_string(val) -> str | None: + """Normalize chunk field values for DocMetadataService (strings / list of strings only).""" + if val is None: + return None + if isinstance(val, bool): + return str(val).lower() + if isinstance(val, (int, float)): + return str(val) + if isinstance(val, str): + s = val.strip() + return s if s else None + return str(val) + + +def _es_raw_field_key_from_typed(tk: str | None) -> str | None: + """ES text columns use *_tks (tokenized); raw display value is stored as {same_base}_raw (see rag/app/table.py).""" + if not tk or not tk.endswith("_tks"): + return None + return tk[: -len("_tks")] + "_raw" + + +def _es_field_value_to_doc_metadata(val, *, from_tks_fallback: bool) -> str | None: + """Prefer raw strings; for legacy *_tks tokenized fields, normalize list/str to a single display string.""" + if val is None: + return None + if from_tks_fallback and isinstance(val, list): + parts = [str(x).strip() for x in val if x is not None and str(x).strip()] + if not parts: + return None + return " ".join(parts) + return _value_to_meta_string(val) + + +def aggregate_table_manual_doc_metadata(chunks: list, task: dict) -> dict: + """ + Collect unique values per metadata/both column across chunks for document-level metadata. + Used when table_column_mode == manual (parallel to LLM gen_metadata, no schema required). + """ + logging.debug( + f"[TABLE_META_DEBUG] aggregate_table_manual_doc_metadata called with {len(chunks)} chunks" + ) + eff = merge_table_parser_config_from_kb(task) + if eff.get("table_column_mode") != "manual": + logging.debug( + f"[TABLE_META_DEBUG] skip aggregate: table_column_mode={eff.get('table_column_mode')!r}" + ) + return {} + roles = eff.get("table_column_roles") or {} + table_column_names = eff.get("table_column_names") or [] + if table_column_names: + meta_cols = [ + col + for col in table_column_names + if roles.get(col, "both") in ("metadata", "both") + ] + else: + meta_cols = [c for c, r in roles.items() if r in ("metadata", "both")] + if not meta_cols: + logging.debug( + "[TABLE_META_DEBUG] skip aggregate: no metadata/both columns " + f"(table_column_names_present={bool(table_column_names)})" + ) + return {} + fm = (task.get("kb_parser_config") or {}).get("field_map") or {} + kb_id = task.get("kb_id") + if not fm and kb_id: + try: + KBS = _knowledgebase_service_cls() + ok, kb = KBS.get_by_id(kb_id) + if ok and kb: + fresh_pc = kb.parser_config or {} + reloaded = fresh_pc.get("field_map") or {} + if reloaded: + fm = reloaded + logging.debug( + f"[TABLE_META_DEBUG] reloaded field_map from DB: {len(fm)} entries" + ) + else: + logging.debug( + "[TABLE_META_DEBUG] KB reload: parser_config has no field_map yet; " + "will use ES key probe on chunk dicts if applicable" + ) + except Exception as e: + logging.debug( + "[TABLE_META_DEBUG] failed to reload field_map from DB: %s", + e, + exc_info=True, + ) + if not fm and not (settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE): + logging.debug( + "[TABLE_META_DEBUG] field_map empty on task snapshot — will use ES key probe on chunk dicts; " + f"kb_parser_config keys={list((task.get('kb_parser_config') or {}).keys())}" + ) + logging.debug( + f"[TABLE_META_DEBUG] meta_cols={meta_cols}, field_map entries={len(fm)}, " + f"infinity={settings.DOC_ENGINE_INFINITY}, oceanbase={settings.DOC_ENGINE_OCEANBASE}" + ) + sample_ck = next((c for c in chunks if isinstance(c, dict)), None) + if sample_ck: + sk = [ + k + for k in sample_ck.keys() + if not (str(k).startswith("q_") and str(k).endswith("_vec")) + ][:50] + logging.debug(f"[TABLE_META_DEBUG] first chunk non-vector keys (sample): {sk}") + + es_col_keys: dict[str, tuple[str | None, str]] = {} + if not (settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE): + for col in meta_cols: + tk, src = _resolve_es_chunk_field_key(col, fm, sample_ck) + es_col_keys[col] = (tk, src) + logging.debug( + f"[TABLE_META_DEBUG] column '{col}' -> ES key {tk!r} (source={src})" + ) + + acc: dict[str, list] = {c: [] for c in meta_cols} + + for i, ck in enumerate(chunks): + if not isinstance(ck, dict): + continue + if settings.DOC_ENGINE_INFINITY or settings.DOC_ENGINE_OCEANBASE: + cd = ck.get("chunk_data") + if not isinstance(cd, dict): + continue + for col in meta_cols: + if col not in cd: + continue + s = _value_to_meta_string(cd[col]) + if s is not None: + acc[col].append(s) + else: + for col in meta_cols: + tk, _src = es_col_keys.get(col, (None, "none")) + if not tk: + if i == 0: + logging.debug( + f"[TABLE_META_DEBUG] no resolved ES key for column '{col}'" + ) + continue + raw_k = _es_raw_field_key_from_typed(tk) + val = None + from_tks = False + if raw_k and raw_k in ck: + val = ck[raw_k] + elif tk in ck: + val = ck[tk] + from_tks = tk.endswith("_tks") + else: + if i == 0: + logging.debug( + f"[TABLE_META_DEBUG] chunk missing ES field {tk!r}" + f"{' and ' + raw_k + ' (raw)' if raw_k else ''} for column '{col}'" + ) + continue + s = _es_field_value_to_doc_metadata(val, from_tks_fallback=from_tks) + if s is not None: + acc[col].append(s) + + for col, vals in acc.items(): + logging.debug( + "[TABLE_META_DEBUG] Column '%s' values found (count=%d)", + col, + len(vals), + ) + + out = {} + for col, vals in acc.items(): + if vals: + out[col] = dedupe_list(vals) + logging.debug( + f"[TABLE_META_DEBUG] aggregated metadata dict keys={list(out.keys())}, " + f"sizes={[len(v) for v in out.values()]}" + ) + return out diff --git a/rag/utils/tts_cache.py b/rag/utils/tts_cache.py new file mode 100644 index 00000000000..a96f1925288 --- /dev/null +++ b/rag/utils/tts_cache.py @@ -0,0 +1,120 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import binascii +import hashlib +import logging +import os +from typing import Any, Optional + +from rag.utils.redis_conn import REDIS_CONN + +_DEFAULT_TTL_SECONDS = 7 * 24 * 60 * 60 +_KEY_PREFIX = "tts:cache:" + + +def _ttl_seconds() -> int: + raw = os.environ.get("RAGFLOW_TTS_CACHE_TTL_SECONDS") + if not raw: + return _DEFAULT_TTL_SECONDS + try: + v = int(raw) + return v if v > 0 else 0 + except ValueError: + logging.warning("Invalid RAGFLOW_TTS_CACHE_TTL_SECONDS=%r, using default", raw) + return _DEFAULT_TTL_SECONDS + + +def _model_id(tts_mdl: Any) -> Optional[str]: + cfg = getattr(tts_mdl, "model_config", None) + if isinstance(cfg, dict): + mid = cfg.get("id") + if mid is not None: + return str(mid) + name = cfg.get("llm_name") or cfg.get("model_name") + if name: + return str(name) + return None + + +def _build_key(tts_mdl: Any, text: str) -> Optional[str]: + mid = _model_id(tts_mdl) + if not mid: + return None + digest = hashlib.sha256(text.encode("utf-8", "ignore")).hexdigest() + return f"{_KEY_PREFIX}{mid}:{digest}" + + +def _to_hex_string(value: Any) -> Optional[str]: + if value is None: + return None + if isinstance(value, bytes): + try: + return value.decode("utf-8") + except Exception: + return None + if isinstance(value, str): + return value + return None + + +def synthesize_with_cache(tts_mdl: Any, cleaned_text: str) -> Optional[str]: + """ + Synthesize ``cleaned_text`` through ``tts_mdl`` and return a hex-encoded + audio blob, reusing a Redis-cached result when available. + + The cache key is derived from the TTS model identifier and a SHA-256 of the + text, so different models keep separate caches and the same text on the + same model resolves to the same key regardless of call site. Returns + ``None`` on synthesis failure; callers should treat that as a no-op the + same way they do today. + """ + if not tts_mdl or not cleaned_text: + return None + + key = _build_key(tts_mdl, cleaned_text) + + if key: + try: + cached = REDIS_CONN.get(key) + except Exception as e: + logging.warning("TTS cache lookup failed: %s", e) + cached = None + hex_cached = _to_hex_string(cached) + if hex_cached: + return hex_cached + + buf = b"" + try: + for chunk in tts_mdl.tts(cleaned_text): + if isinstance(chunk, (bytes, bytearray)): + buf += bytes(chunk) + except Exception as e: + logging.error("TTS failed: %s (text length=%d)", e, len(cleaned_text)) + return None + + if not buf: + return None + + hex_value = binascii.hexlify(buf).decode("utf-8") + + ttl = _ttl_seconds() + if key and ttl > 0: + try: + REDIS_CONN.set(key, hex_value, exp=ttl) + except Exception as e: + logging.warning("TTS cache store failed: %s", e) + + return hex_value diff --git a/sdk/python/pyproject.toml b/sdk/python/pyproject.toml index f28e734a9f5..970eb5a77ab 100644 --- a/sdk/python/pyproject.toml +++ b/sdk/python/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ragflow-sdk" -version = "0.25.2" +version = "0.25.5" description = "Python client sdk of [RAGFlow](https://github.com/infiniflow/ragflow). RAGFlow is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding." authors = [{ name = "Zhichang Yu", email = "yuzhichang@gmail.com" }] license = { text = "Apache License, Version 2.0" } diff --git a/sdk/python/ragflow_sdk/modules/session.py b/sdk/python/ragflow_sdk/modules/session.py index f9c4799fd7a..5152160f6a4 100644 --- a/sdk/python/ragflow_sdk/modules/session.py +++ b/sdk/python/ragflow_sdk/modules/session.py @@ -15,8 +15,12 @@ # import json +import logging + from .base import Base +logger = logging.getLogger(__name__) + class Session(Base): def __init__(self, rag, res_dict): @@ -33,11 +37,67 @@ def __init__(self, rag, res_dict): super().__init__(rag, res_dict) - def ask(self, question="", stream=False, **kwargs): + def ask( + self, + question="", + stream=False, + inputs=None, + release=None, + return_trace=None, + **kwargs, + ): """ - Ask a question to the session. If stream=True, yields Message objects as they arrive (SSE streaming). - If stream=False, returns a single Message object for the final answer. + Ask a question to the session. + + Parameters + ---------- + question : str + The user's question. May be empty when the agent is driven solely by + Begin component inputs. + stream : bool + If ``True``, yields ``Message`` objects as they arrive (SSE streaming). + If ``False``, yields a single ``Message`` with the final answer. + inputs : dict, optional + Values for variables declared on the agent's **Begin** component. Each + value must be a dict containing at least a ``"value"`` key, and may + include ``"type"``. Example:: + + session.ask( + "", + stream=False, + inputs={"key1": {"type": "line", "value": "hello"}}, + ) + + Only meaningful for agent sessions; ignored for chat sessions. + release : bool, optional + If ``True``, run against the latest published agent version instead of + the editable draft. Only meaningful for agent sessions. + return_trace : bool, optional + If ``True``, include execution trace information in the response. + Only meaningful for agent sessions. + **kwargs + Additional fields forwarded verbatim to the completion endpoint + (e.g. ``session_id``, ``files``, ``user_id``, ``custom_header``). + See the HTTP API reference for the full list. """ + if inputs is not None: + kwargs["inputs"] = inputs + if release is not None: + kwargs["release"] = release + if return_trace is not None: + kwargs["return_trace"] = return_trace + + if inputs is not None or release is not None or return_trace is not None: + logger.debug( + "Session.ask explicit-params session_type=%s session_id=%s " + "input_keys=%s release=%s return_trace=%s", + self.__session_type, + getattr(self, "id", None), + list(inputs.keys()) if isinstance(inputs, dict) else None, + release, + return_trace, + ) + if self.__session_type == "agent": res = self._ask_agent(question, stream, **kwargs) elif self.__session_type == "chat": diff --git a/sdk/python/ragflow_sdk/ragflow.py b/sdk/python/ragflow_sdk/ragflow.py index fe0a683719c..679f5ba5f30 100644 --- a/sdk/python/ragflow_sdk/ragflow.py +++ b/sdk/python/ragflow_sdk/ragflow.py @@ -334,6 +334,7 @@ def delete_memory(self, memory_id: str): raise Exception(res["message"]) def add_message(self, memory_id: list[str], agent_id: str, session_id: str, user_input: str, agent_response: str, user_id: str = "") -> str: + """Append messages to memories; ``user_id`` is forwarded only for API-key auth (external subject).""" payload = { "memory_id": memory_id, "agent_id": agent_id, diff --git a/sdk/python/uv.lock b/sdk/python/uv.lock index b625b4bc89a..dbfeee21e6b 100644 --- a/sdk/python/uv.lock +++ b/sdk/python/uv.lock @@ -369,7 +369,7 @@ wheels = [ [[package]] name = "ragflow-sdk" -version = "0.25.2" +version = "0.25.5" source = { virtual = "." } dependencies = [ { name = "beartype" }, diff --git a/test/README.md b/test/README.md index 15546136f50..81d33c4f489 100644 --- a/test/README.md +++ b/test/README.md @@ -33,7 +33,7 @@ uv pip install sdk/python ```env COMPOSE_PROFILES=${COMPOSE_PROFILES},tei-cpu TEI_MODEL=BAAI/bge-small-en-v1.5 -RAGFLOW_IMAGE=infiniflow/ragflow:v0.25.2 #Replace with the image you are using +RAGFLOW_IMAGE=infiniflow/ragflow:v0.25.5 #Replace with the image you are using ``` diff --git a/test/testcases/configs.py b/test/testcases/configs.py index 546cd378c9d..a4711bf1583 100644 --- a/test/testcases/configs.py +++ b/test/testcases/configs.py @@ -65,6 +65,7 @@ "category", ], "method": "light", + "batch_chunk_token_size": 4096, }, "parent_child": { "use_parent_child": False, diff --git a/test/testcases/restful_api/conftest.py b/test/testcases/restful_api/conftest.py new file mode 100644 index 00000000000..b24f0bda45e --- /dev/null +++ b/test/testcases/restful_api/conftest.py @@ -0,0 +1,163 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from libs.auth import RAGFlowHttpApiAuth +from test.testcases.restful_api.helpers.client import RestClient +from utils.file_utils import create_txt_file +from utils import wait_for + + +@pytest.fixture(scope="session") +def RestApiAuth(token): + return RAGFlowHttpApiAuth(token) + + +@pytest.fixture(scope="session") +def rest_client(token): + return RestClient(token=token) + + +@pytest.fixture(scope="session") +def rest_client_noauth(): + return RestClient(token=None) + + +@pytest.fixture +def clear_datasets(rest_client): + def _cleanup(): + res = rest_client.delete("/datasets", json={"ids": None, "delete_all": True}) + assert res.status_code == 200, res.text + payload = res.json() + assert payload["code"] in (0, 102), payload + + yield + _cleanup() + + +@pytest.fixture +def clear_chats(rest_client): + def _cleanup(): + res = rest_client.delete("/chats", json={"ids": None, "delete_all": True}) + assert res.status_code == 200, res.text + payload = res.json() + assert payload["code"] in (0, 102), payload + + yield + _cleanup() + + +@pytest.fixture +def create_dataset(rest_client, clear_datasets): + created_ids: list[str] = [] + + def _create(name: str = "restful_dataset") -> str: + res = rest_client.post("/datasets", json={"name": name}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + dataset_id = payload["data"]["id"] + created_ids.append(dataset_id) + return dataset_id + + yield _create + + if created_ids: + res = rest_client.delete("/datasets", json={"ids": created_ids}) + assert res.status_code == 200 + payload = res.json() + # Dataset may already be removed by test logic/cleanup. + assert payload["code"] in (0, 102), payload + + +@pytest.fixture +def create_chat(rest_client, clear_chats): + created_ids: list[str] = [] + + def _create(name: str = "restful_chat") -> str: + res = rest_client.post("/chats", json={"name": name, "dataset_ids": []}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + chat_id = payload["data"]["id"] + created_ids.append(chat_id) + return chat_id + + yield _create + + if created_ids: + res = rest_client.delete("/chats", json={"ids": created_ids}) + assert res.status_code == 200, res.text + payload = res.json() + assert payload["code"] in (0, 102), payload + + +@pytest.fixture +def create_document(rest_client, create_dataset, tmp_path): + created_docs: list[tuple[str, str]] = [] + + def _create(name: str = "restful_doc.txt") -> tuple[str, str]: + dataset_id = create_dataset("dataset_for_doc") + fp = create_txt_file(tmp_path / name) + with fp.open("rb") as file_obj: + files = [("file", (fp.name, file_obj))] + res = rest_client.post(f"/datasets/{dataset_id}/documents", files=files) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + document_id = payload["data"][0]["id"] + created_docs.append((dataset_id, document_id)) + return dataset_id, document_id + + yield _create + + for dataset_id, document_id in created_docs: + res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]}) + assert res.status_code == 200, res.text + payload = res.json() + assert payload["code"] in (0, 102), payload + + +@wait_for(60, 1, "Document parsing timeout in RESTful batch2 tests") +def _parsed(rest_client: RestClient, dataset_id: str, document_id: str): + res = rest_client.get(f"/datasets/{dataset_id}/documents", params={"id": document_id}) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + docs = payload["data"]["docs"] + if not docs: + return False + return docs[0].get("run") == "DONE" + + +@pytest.fixture +def ensure_parsed_document(rest_client, create_document): + def _ensure() -> tuple[str, str]: + dataset_id, document_id = create_document() + res = rest_client.post( + f"/datasets/{dataset_id}/documents/parse", + json={"document_ids": [document_id]}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + _parsed(rest_client, dataset_id, document_id) + return dataset_id, document_id + + return _ensure diff --git a/test/testcases/restful_api/helpers/__init__.py b/test/testcases/restful_api/helpers/__init__.py new file mode 100644 index 00000000000..117dea3cf0d --- /dev/null +++ b/test/testcases/restful_api/helpers/__init__.py @@ -0,0 +1 @@ +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. diff --git a/test/testcases/restful_api/helpers/client.py b/test/testcases/restful_api/helpers/client.py new file mode 100644 index 00000000000..8c0a198fc24 --- /dev/null +++ b/test/testcases/restful_api/helpers/client.py @@ -0,0 +1,85 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dataclasses import dataclass +from typing import Any + +import requests +from configs import HOST_ADDRESS, VERSION + + +@dataclass +class RestClient: + token: str | None = None + timeout: int = 30 + + @property + def api_root(self) -> str: + return f"{HOST_ADDRESS}/api/{VERSION}" + + def _headers(self, headers: dict[str, str] | None = None) -> dict[str, str]: + merged: dict[str, str] = {"Content-Type": "application/json"} + if headers: + merged.update(headers) + if self.token and "Authorization" not in merged: + merged["Authorization"] = f"Bearer {self.token}" + return merged + + def request( + self, + method: str, + path: str, + *, + headers: dict[str, str] | None = None, + params: dict[str, Any] | None = None, + json: dict[str, Any] | None = None, + data: Any = None, + files: Any = None, + **request_kwargs: Any, + ) -> requests.Response: + req_headers = self._headers(headers) + if files is not None: + # requests sets multipart boundary automatically. + req_headers.pop("Content-Type", None) + + timeout = request_kwargs.pop("timeout", self.timeout) + normalized_path = f"/{path.lstrip('/')}" if path else "/" + return requests.request( + method=method, + url=f"{self.api_root}{normalized_path}", + headers=req_headers, + params=params, + json=json, + data=data, + files=files, + timeout=timeout, + **request_kwargs, + ) + + def get(self, path: str, **kwargs) -> requests.Response: + return self.request("GET", path, **kwargs) + + def post(self, path: str, **kwargs) -> requests.Response: + return self.request("POST", path, **kwargs) + + def delete(self, path: str, **kwargs) -> requests.Response: + return self.request("DELETE", path, **kwargs) + + def put(self, path: str, **kwargs) -> requests.Response: + return self.request("PUT", path, **kwargs) + + def patch(self, path: str, **kwargs) -> requests.Response: + return self.request("PATCH", path, **kwargs) diff --git a/test/testcases/restful_api/test_agents.py b/test/testcases/restful_api/test_agents.py new file mode 100644 index 00000000000..d7f01e9a44b --- /dev/null +++ b/test/testcases/restful_api/test_agents.py @@ -0,0 +1,356 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json + +import pytest + + +MINIMAL_DSL = { + "components": { + "begin": { + "obj": {"component_name": "Begin", "params": {}}, + "downstream": ["message"], + "upstream": [], + }, + "message": { + "obj": {"component_name": "Message", "params": {"content": ["{sys.query}"]}}, + "downstream": [], + "upstream": ["begin"], + }, + }, + "history": [], + "retrieval": [], + "path": [], + "globals": { + "sys.query": "", + "sys.user_id": "", + "sys.conversation_turns": 0, + "sys.files": [], + }, + "variables": {}, +} + + +def _sse_events(response_text: str) -> list[str]: + return [line[5:] for line in response_text.splitlines() if line.startswith("data:")] + + +@pytest.fixture +def create_agent_resource(rest_client): + created_agent_ids: list[str] = [] + + def _create(title: str = "restful_agent") -> str: + res = rest_client.post("/agents", json={"title": title, "dsl": MINIMAL_DSL}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + agent_id = payload["data"]["id"] + created_agent_ids.append(agent_id) + return agent_id + + yield _create + + cleanup_errors = [] + for agent_id in created_agent_ids: + res = rest_client.delete(f"/agents/{agent_id}") + if res.status_code != 200: + cleanup_errors.append((agent_id, res.status_code, res.text)) + continue + payload = res.json() + if payload["code"] not in (0, 103): + cleanup_errors.append((agent_id, res.status_code, payload)) + assert not cleanup_errors, f"Agent cleanup failed: {cleanup_errors}" + + +@pytest.mark.p2 +def test_agents_crud_validation_contract(rest_client, create_agent_resource): + list_empty = rest_client.get("/agents", params={"title": "missing_restful_agent"}) + assert list_empty.status_code == 200 + list_empty_payload = list_empty.json() + assert list_empty_payload["code"] == 0, list_empty_payload + assert "canvas" in list_empty_payload["data"], list_empty_payload + assert "total" in list_empty_payload["data"], list_empty_payload + + paged_list = rest_client.get( + "/agents", + params={"title": "missing_restful_agent", "desc": "true", "page_size": 1}, + ) + assert paged_list.status_code == 200 + paged_list_payload = paged_list.json() + assert paged_list_payload["code"] == 0, paged_list_payload + + missing_dsl = rest_client.post("/agents", json={"title": "missing_dsl_agent"}) + assert missing_dsl.status_code == 200 + missing_dsl_payload = missing_dsl.json() + assert missing_dsl_payload["code"] == 101, missing_dsl_payload + assert "No DSL data in request" in missing_dsl_payload["message"], missing_dsl_payload + + missing_title = rest_client.post("/agents", json={"dsl": MINIMAL_DSL}) + assert missing_title.status_code == 200 + missing_title_payload = missing_title.json() + assert missing_title_payload["code"] == 101, missing_title_payload + assert "No title in request" in missing_title_payload["message"], missing_title_payload + + agent_id = create_agent_resource("restful_agent_crud") + + duplicate = rest_client.post("/agents", json={"title": "restful_agent_crud", "dsl": MINIMAL_DSL}) + assert duplicate.status_code == 200 + duplicate_payload = duplicate.json() + assert duplicate_payload["code"] == 102, duplicate_payload + assert "already exists" in duplicate_payload["message"], duplicate_payload + + invalid_update = rest_client.put("/agents/invalid-agent-id", json={"title": "updated", "dsl": MINIMAL_DSL}) + assert invalid_update.status_code == 200 + invalid_update_payload = invalid_update.json() + assert invalid_update_payload["code"] == 103, invalid_update_payload + assert "Make sure you have permission to access the agent." in invalid_update_payload["message"], invalid_update_payload + + get_res = rest_client.get(f"/agents/{agent_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == agent_id, get_payload + + update_res = rest_client.put(f"/agents/{agent_id}", json={"title": "restful_agent_crud_updated", "dsl": MINIMAL_DSL}) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + + list_after_update = rest_client.get("/agents", params={"title": "restful_agent_crud_updated"}) + assert list_after_update.status_code == 200 + list_after_update_payload = list_after_update.json() + assert list_after_update_payload["code"] == 0, list_after_update_payload + assert list_after_update_payload["data"]["total"] >= 1, list_after_update_payload + + invalid_delete = rest_client.delete("/agents/invalid-agent-id") + assert invalid_delete.status_code == 200 + invalid_delete_payload = invalid_delete.json() + assert invalid_delete_payload["code"] == 103, invalid_delete_payload + assert "Only the owner of the agent is authorized for this operation." in invalid_delete_payload["message"], invalid_delete_payload + + delete_res = rest_client.delete(f"/agents/{agent_id}") + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + assert delete_payload["data"] is True, delete_payload + + +@pytest.mark.p2 +def test_agent_sessions_crud(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_sessions") + + create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_session_1"}) + assert create_session.status_code == 200 + create_session_payload = create_session.json() + assert create_session_payload["code"] == 0, create_session_payload + session_id = create_session_payload["data"]["id"] + + list_sessions = rest_client.get(f"/agents/{agent_id}/sessions") + assert list_sessions.status_code == 200 + list_sessions_payload = list_sessions.json() + assert list_sessions_payload["code"] == 0, list_sessions_payload + assert isinstance(list_sessions_payload["data"], list), list_sessions_payload + assert any(item["id"] == session_id for item in list_sessions_payload["data"]), list_sessions_payload + + get_session = rest_client.get(f"/agents/{agent_id}/sessions/{session_id}") + assert get_session.status_code == 200 + get_session_payload = get_session.json() + assert get_session_payload["code"] == 0, get_session_payload + assert get_session_payload["data"]["id"] == session_id, get_session_payload + + delete_session = rest_client.delete(f"/agents/{agent_id}/sessions/{session_id}") + assert delete_session.status_code == 200 + delete_session_payload = delete_session.json() + assert delete_session_payload["code"] == 0, delete_session_payload + + +@pytest.mark.p2 +def test_agent_chat_completion_validation(rest_client): + missing_agent_id = rest_client.post( + "/agents/chat/completions", + json={"query": "hello", "stream": False}, + ) + assert missing_agent_id.status_code == 200 + missing_agent_id_payload = missing_agent_id.json() + assert missing_agent_id_payload["code"] == 101, missing_agent_id_payload + assert "`agent_id` is required." in missing_agent_id_payload["message"], missing_agent_id_payload + + +@pytest.mark.p2 +def test_agent_chat_completion_nonstream(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_nonstream") + create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_completion_session"}) + assert create_session.status_code == 200 + create_session_payload = create_session.json() + assert create_session_payload["code"] == 0, create_session_payload + session_id = create_session_payload["data"]["id"] + + res = rest_client.post( + "/agents/chat/completions", + json={"agent_id": agent_id, "query": "hello", "stream": False, "session_id": session_id}, + timeout=60, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert isinstance(payload["data"], dict), payload + assert payload["data"].get("session_id") == session_id, payload + assert isinstance(payload["data"].get("data"), dict), payload + content = payload["data"]["data"].get("content", "") + assert content, payload + assert "hello" in content, payload + + +@pytest.mark.p2 +def test_agent_chat_completion_stream_structure_and_done(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_stream") + create_session = rest_client.post(f"/agents/{agent_id}/sessions", json={"name": "agent_stream_session"}) + assert create_session.status_code == 200 + create_session_payload = create_session.json() + assert create_session_payload["code"] == 0, create_session_payload + session_id = create_session_payload["data"]["id"] + + res = rest_client.post( + "/agents/chat/completions", + json={ + "agent_id": agent_id, + "query": "hello", + "stream": True, + "session_id": session_id, + "return_trace": True, + }, + timeout=60, + ) + assert res.status_code == 200 + content_type = res.headers.get("Content-Type", "") + assert "text/event-stream" in content_type, content_type + + events = _sse_events(res.text) + assert events, res.text + assert events[-1].strip() == "[DONE]", events[-1] + + json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"] + assert json_events, events + assert any(isinstance(evt, dict) for evt in json_events), json_events + + +@pytest.mark.p2 +def test_agent_openai_compatible_mode(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_openai_compat") + + missing_messages = rest_client.post( + "/agents/chat/completions", + json={"agent_id": agent_id, "openai-compatible": True, "model": "model", "messages": []}, + ) + assert missing_messages.status_code == 200 + missing_messages_payload = missing_messages.json() + assert missing_messages_payload["code"] == 102, missing_messages_payload + assert "at least one message" in missing_messages_payload["message"], missing_messages_payload + + nonstream = rest_client.post( + "/agents/chat/completions", + json={ + "agent_id": agent_id, + "openai-compatible": True, + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + timeout=60, + ) + assert nonstream.status_code == 200 + nonstream_payload = nonstream.json() + assert isinstance(nonstream_payload, dict), nonstream_payload + assert "choices" in nonstream_payload, nonstream_payload + + stream = rest_client.post( + "/agents/chat/completions", + json={ + "agent_id": agent_id, + "openai-compatible": True, + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }, + timeout=60, + ) + assert stream.status_code == 200 + stream_content_type = stream.headers.get("Content-Type", "") + assert "text/event-stream" in stream_content_type, stream_content_type + + +@pytest.mark.p2 +def test_agent_support_routes_auth_and_contracts(rest_client, rest_client_noauth, create_agent_resource): + prompts_unauth = rest_client_noauth.get("/agents/prompts") + assert prompts_unauth.status_code == 401 + assert prompts_unauth.json()["code"] == 401 + + prompts = rest_client.get("/agents/prompts") + assert prompts.status_code == 200 + prompts_payload = prompts.json() + assert prompts_payload["code"] == 0, prompts_payload + assert "task_analysis" in prompts_payload["data"], prompts_payload + assert "citation_guidelines" in prompts_payload["data"], prompts_payload + + templates = rest_client.get("/agents/templates") + assert templates.status_code == 200 + templates_payload = templates.json() + assert templates_payload["code"] == 0, templates_payload + assert isinstance(templates_payload["data"], list), templates_payload + + agent_id = create_agent_resource("restful_agent_support") + versions = rest_client.get(f"/agents/{agent_id}/versions") + assert versions.status_code == 200 + versions_payload = versions.json() + assert versions_payload["code"] == 0, versions_payload + assert isinstance(versions_payload["data"], list), versions_payload + + logs = rest_client.get(f"/agents/{agent_id}/logs/missing_message") + assert logs.status_code == 200 + logs_payload = logs.json() + assert logs_payload["code"] == 0, logs_payload + assert isinstance(logs_payload["data"], dict), logs_payload + + +@pytest.mark.p2 +def test_agent_webhook_logs_empty_poll_contract(rest_client, create_agent_resource): + agent_id = create_agent_resource("restful_agent_webhook_logs") + res = rest_client.get(f"/agents/{agent_id}/webhook/logs", params={"since_ts": 0}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"]["events"] == [], payload + assert payload["data"]["finished"] is False, payload + assert "next_since_ts" in payload["data"], payload + + +@pytest.mark.p2 +def test_agent_db_connection_validates_required_fields(rest_client): + res = rest_client.post("/agents/test_db_connection", json={"db_type": "mysql"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "required argument are missing" in payload["message"], payload + + +@pytest.mark.p2 +def test_agent_rerun_requires_required_fields(rest_client): + res = rest_client.post("/agents/rerun", json={"id": "flow-1"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "required argument are missing" in payload["message"], payload diff --git a/test/testcases/restful_api/test_chats.py b/test/testcases/restful_api/test_chats.py new file mode 100644 index 00000000000..eaf94a13c9d --- /dev/null +++ b/test/testcases/restful_api/test_chats.py @@ -0,0 +1,2045 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import sys +from copy import deepcopy +from concurrent.futures import ThreadPoolExecutor +from enum import Enum +from functools import wraps +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + +from test.testcases.configs import CHAT_ASSISTANT_NAME_LIMIT, INVALID_API_TOKEN +from test.testcases.restful_api.helpers.client import RestClient +from test.testcases.utils import encode_avatar +from test.testcases.utils.file_utils import create_image_file + + +DEFAULT_CHAT_EMPTY_RESPONSE = "Sorry! No relevant content was found in the knowledge base!" +DEFAULT_CHAT_PROLOGUE = "Hi! I'm your assistant. What can I do for you?" +DEFAULT_CHAT_SYSTEM_PROMPT = ( + 'You are an intelligent assistant. Please summarize the content of the dataset to answer the question. ' + 'Please list the data in the dataset and answer in detail. When all dataset content is irrelevant to the ' + 'question, your answer must include the sentence "The answer you are looking for is not found in the dataset!" ' + "Answers need to consider chat history.\n" + " Here is the knowledge base:\n" + " {knowledge}\n" + " The above is the knowledge base." +) + + +def _get_nested(data, path): + current = data + for key in path: + current = current[key] + return current + + +def _chat_names(payload): + return [chat["name"] for chat in payload["data"]["chats"]] + + +def _reset_chat_batch(rest_client, prefix, count=5): + cleanup_res = rest_client.delete("/chats", json={"ids": None, "delete_all": True}) + assert cleanup_res.status_code == 200, cleanup_res.text + cleanup_payload = cleanup_res.json() + assert cleanup_payload["code"] in (0, 102), cleanup_payload + + ids = [] + for index in range(count): + res = rest_client.post("/chats", json={"name": f"{prefix}_{index}", "dataset_ids": []}) + assert res.status_code == 200, (prefix, index, res.text) + payload = res.json() + assert payload["code"] == 0, (prefix, index, payload) + ids.append(payload["data"]["id"]) + return ids + + + +@pytest.mark.p1 +class TestChatsAuthorization: + def test_create_requires_auth(self, rest_client_noauth): + res = rest_client_noauth.post("/chats", json={"name": "chat_auth", "dataset_ids": []}) + assert res.status_code == 401 + + +@pytest.mark.p1 +def test_chat_crud_cycle(rest_client, clear_chats): + create_res = rest_client.post("/chats", json={"name": "restful_chat_crud", "dataset_ids": []}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + chat_id = create_payload["data"]["id"] + + list_res = rest_client.get("/chats", params={"id": chat_id}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert len(list_payload["data"]["chats"]) == 1, list_payload + assert list_payload["data"]["chats"][0]["id"] == chat_id, list_payload + + get_res = rest_client.get(f"/chats/{chat_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == chat_id, get_payload + + update_res = rest_client.put(f"/chats/{chat_id}", json={"name": "restful_chat_crud_updated", "dataset_ids": []}) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + assert update_payload["data"]["name"] == "restful_chat_crud_updated", update_payload + + patch_res = rest_client.patch(f"/chats/{chat_id}", json={"name": "restful_chat_crud_patched"}) + assert patch_res.status_code == 200 + patch_payload = patch_res.json() + assert patch_payload["code"] == 0, patch_payload + assert patch_payload["data"]["name"] == "restful_chat_crud_patched", patch_payload + + delete_res = rest_client.delete("/chats", json={"ids": [chat_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + assert delete_payload["data"]["success_count"] == 1, delete_payload + + list_after_delete = rest_client.get("/chats", params={"id": chat_id}) + assert list_after_delete.status_code == 200 + list_after_delete_payload = list_after_delete.json() + assert list_after_delete_payload["code"] == 0, list_after_delete_payload + assert list_after_delete_payload["data"]["chats"] == [], list_after_delete_payload + + +@pytest.mark.p2 +@pytest.mark.parametrize( + "name, expected_fragment", + [ + ("", "`name` is required."), + (" ", "`name` is required."), + ], +) +def test_chat_create_name_validation(rest_client, clear_chats, name, expected_fragment): + res = rest_client.post("/chats", json={"name": name, "dataset_ids": []}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert expected_fragment in payload["message"], payload + + +@pytest.mark.p2 +def test_chat_duplicate_name_validation(rest_client, clear_chats): + first = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []}) + assert first.status_code == 200 + first_payload = first.json() + assert first_payload["code"] == 0, first_payload + + second = rest_client.post("/chats", json={"name": "duplicate_chat_name", "dataset_ids": []}) + assert second.status_code == 200 + second_payload = second.json() + assert second_payload["code"] == 102, second_payload + assert "Duplicated chat name" in second_payload["message"], second_payload + + +@pytest.mark.p2 +def test_chat_list_pagination(rest_client, clear_chats): + for i in range(3): + res = rest_client.post("/chats", json={"name": f"chat_page_{i}", "dataset_ids": []}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + + page_res = rest_client.get("/chats", params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"}) + assert page_res.status_code == 200 + page_payload = page_res.json() + assert page_payload["code"] == 0, page_payload + assert len(page_payload["data"]["chats"]) == 2, page_payload + assert page_payload["data"]["total"] >= 3, page_payload + + +@pytest.mark.p1 +def test_chat_delete_requires_auth(): + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.delete("/chats", json={"ids": []}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p2 +def test_chat_delete_basic_scenarios(rest_client, clear_chats): + existing_ids = _reset_chat_batch(rest_client, "delete_basic") + existing_res = rest_client.delete("/chats", json={"ids": existing_ids}) + assert existing_res.status_code == 200 + existing_payload = existing_res.json() + assert existing_payload["code"] == 0, existing_payload + assert existing_payload["data"]["success_count"] == len(existing_ids), existing_payload + + list_after_existing = rest_client.get("/chats").json() + assert list_after_existing["code"] == 0, list_after_existing + assert list_after_existing["data"]["chats"] == [], list_after_existing + + empty_res = rest_client.delete("/chats", json={"ids": []}) + assert empty_res.status_code == 200 + empty_payload = empty_res.json() + assert empty_payload["code"] == 0, empty_payload + assert empty_payload["message"] == "success", empty_payload + + delete_all_ids = _reset_chat_batch(rest_client, "delete_all") + delete_all_res = rest_client.delete("/chats", json={"ids": None, "delete_all": True}) + assert delete_all_res.status_code == 200 + delete_all_payload = delete_all_res.json() + assert delete_all_payload["code"] == 0, delete_all_payload + assert delete_all_payload["data"]["success_count"] == len(delete_all_ids), delete_all_payload + + list_after_delete_all = rest_client.get("/chats").json() + assert list_after_delete_all["code"] == 0, list_after_delete_all + assert list_after_delete_all["data"]["chats"] == [], list_after_delete_all + + +@pytest.mark.p2 +def test_chat_delete_error_and_repeat_contract(rest_client, clear_chats): + partial_cases = [ + ("partial invalid id", lambda ids: {"ids": ids + ["invalid_id"]}), + ("partial invalid punctuation id", lambda ids: {"ids": ids + ["!@#$%^&*()"]}), + ] + for scenario_name, payload in partial_cases: + ids = _reset_chat_batch(rest_client, f"delete_partial_{scenario_name.replace(' ', '_')}") + res = rest_client.delete("/chats", json=payload(ids)) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == 0, (scenario_name, body) + assert len(body["data"]["errors"]) == 1, (scenario_name, body) + assert body["data"]["success_count"] == 5, (scenario_name, body) + + list_payload = rest_client.get("/chats").json() + assert list_payload["code"] == 0, (scenario_name, list_payload) + assert list_payload["data"]["chats"] == [], (scenario_name, list_payload) + + duplicate_ids = _reset_chat_batch(rest_client, "delete_duplicate_all") + duplicate_all_res = rest_client.delete("/chats", json={"ids": duplicate_ids + duplicate_ids}) + assert duplicate_all_res.status_code == 200 + duplicate_all_payload = duplicate_all_res.json() + assert duplicate_all_payload["code"] == 0, duplicate_all_payload + assert duplicate_all_payload["data"]["success_count"] == 5, duplicate_all_payload + assert len(duplicate_all_payload["data"]["errors"]) == 5, duplicate_all_payload + assert all(error.startswith("Duplicate chat ids: ") for error in duplicate_all_payload["data"]["errors"]), duplicate_all_payload + + duplicate_one_ids = _reset_chat_batch(rest_client, "delete_duplicate_one") + duplicate_one_res = rest_client.delete("/chats", json={"ids": [duplicate_one_ids[0], duplicate_one_ids[0]]}) + assert duplicate_one_res.status_code == 200 + duplicate_one_payload = duplicate_one_res.json() + assert duplicate_one_payload["code"] == 0, duplicate_one_payload + assert duplicate_one_payload["data"]["success_count"] == 1, duplicate_one_payload + assert duplicate_one_payload["data"]["errors"] == [f"Duplicate chat ids: {duplicate_one_ids[0]}"], duplicate_one_payload + + all_missing_res = rest_client.delete("/chats", json={"ids": ["missing-1", "missing-2"]}) + assert all_missing_res.status_code == 200 + all_missing_payload = all_missing_res.json() + assert all_missing_payload["code"] == 102, all_missing_payload + assert "Chat(missing-1) not found." in all_missing_payload["message"], all_missing_payload + assert "Chat(missing-2) not found." in all_missing_payload["message"], all_missing_payload + + repeated_ids = _reset_chat_batch(rest_client, "delete_repeated") + first_res = rest_client.delete("/chats", json={"ids": repeated_ids}) + assert first_res.status_code == 200 + first_payload = first_res.json() + assert first_payload["code"] == 0, first_payload + assert first_payload["data"]["success_count"] == 5, first_payload + + second_res = rest_client.delete("/chats", json={"ids": repeated_ids}) + assert second_res.status_code == 200 + second_payload = second_res.json() + assert second_payload["code"] == 102, second_payload + for chat_id in repeated_ids: + assert f"Chat({chat_id}) not found." in second_payload["message"], second_payload + + +@pytest.mark.p2 +def test_chat_delete_concurrent_and_bulk_contract(rest_client, clear_chats): + concurrent_ids = _reset_chat_batch(rest_client, "delete_concurrent", count=20) + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(lambda chat_id: rest_client.delete("/chats", json={"ids": [chat_id]}).json(), concurrent_ids)) + assert len(results) == 20, results + assert all(result["code"] == 0 for result in results), results + assert all(result["data"]["success_count"] == 1 for result in results), results + + list_after_concurrent = rest_client.get("/chats").json() + assert list_after_concurrent["code"] == 0, list_after_concurrent + assert list_after_concurrent["data"]["chats"] == [], list_after_concurrent + + bulk_ids = _reset_chat_batch(rest_client, "delete_bulk", count=100) + bulk_res = rest_client.delete("/chats", json={"ids": bulk_ids}) + assert bulk_res.status_code == 200 + bulk_payload = bulk_res.json() + assert bulk_payload["code"] == 0, bulk_payload + assert bulk_payload["data"]["success_count"] == len(bulk_ids), bulk_payload + + +@pytest.mark.p1 +def test_chat_list_requires_auth(): + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.get("/chats") + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p1 +def test_chat_list_default_get_and_separate_lookup_contract(rest_client, clear_chats): + ids = _reset_chat_batch(rest_client, "list_default") + + default_res = rest_client.get("/chats") + assert default_res.status_code == 200 + default_payload = default_res.json() + assert default_payload["code"] == 0, default_payload + assert len(default_payload["data"]["chats"]) == 5, default_payload + assert default_payload["data"]["total"] == 5, default_payload + + valid_get_res = rest_client.get(f"/chats/{ids[0]}") + assert valid_get_res.status_code == 200 + valid_get_payload = valid_get_res.json() + assert valid_get_payload["code"] == 0, valid_get_payload + assert valid_get_payload["data"]["id"] == ids[0], valid_get_payload + + invalid_get_res = rest_client.get("/chats/unknown") + assert invalid_get_res.status_code == 200 + invalid_get_payload = invalid_get_res.json() + assert invalid_get_payload["code"] == 109, invalid_get_payload + assert invalid_get_payload["message"] == "No authorization.", invalid_get_payload + + for chat_id, keywords, expected_count in ((ids[0], "list_default_0", 1), (ids[0], "list_default_1", 1), (ids[0], "unknown", 0)): + get_res = rest_client.get(f"/chats/{chat_id}") + list_res = rest_client.get("/chats", params={"keywords": keywords}) + assert get_res.status_code == 200, (keywords, get_res.text) + assert list_res.status_code == 200, (keywords, list_res.text) + get_payload = get_res.json() + list_payload = list_res.json() + assert get_payload["code"] == 0, (keywords, get_payload) + assert list_payload["code"] == 0, (keywords, list_payload) + assert len(list_payload["data"]["chats"]) == expected_count, (keywords, list_payload) + + +@pytest.mark.p2 +def test_chat_list_keyword_and_invalid_param_contract(rest_client, clear_chats): + _reset_chat_batch(rest_client, "list_keyword") + cases = [ + ("keywords none", {"keywords": None}, 5, None), + ("keywords empty", {"keywords": ""}, 5, None), + ("keywords exact", {"keywords": "list_keyword_1"}, 1, "list_keyword_1"), + ("keywords unknown", {"keywords": "unknown"}, 0, None), + ("invalid params ignored", {"a": "b"}, 5, None), + ] + + for scenario_name, params, expected_count, expected_name in cases: + res = rest_client.get("/chats", params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 0, (scenario_name, payload) + assert len(payload["data"]["chats"]) == expected_count, (scenario_name, payload) + if expected_name is not None: + assert payload["data"]["chats"][0]["name"] == expected_name, (scenario_name, payload) + + +@pytest.mark.p2 +def test_chat_list_page_and_page_size_contract(rest_client, clear_chats): + cases = [ + ("page none", {"page": None, "page_size": 2}, 0, lambda total: total, ""), + ("page zero", {"page": 0, "page_size": 2}, 0, lambda total: total, ""), + ("page two", {"page": 2, "page_size": 2}, 0, lambda total: min(max(total - 2, 0), 2), ""), + ("page three", {"page": 3, "page_size": 2}, 0, lambda total: min(max(total - 4, 0), 2), ""), + ("page string", {"page": "3", "page_size": 2}, 0, lambda total: min(max(total - 4, 0), 2), ""), + ("page negative", {"page": -1, "page_size": 2}, 100, None, "ProgrammingError(1064"), + ("page alpha", {"page": "a", "page_size": 2}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"), + ("page_size none", {"page_size": None}, 0, lambda total: total, ""), + ("page_size zero", {"page_size": 0}, 0, lambda total: total, ""), + ("page_size one", {"page_size": 1}, 0, lambda total: total, ""), + ("page_size six", {"page_size": 6}, 0, lambda total: total, ""), + ("page_size string", {"page_size": "1"}, 0, lambda total: total, ""), + ("page_size negative", {"page_size": -1}, 0, lambda total: total, ""), + ("page_size alpha", {"page_size": "a"}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"), + ] + + for scenario_name, params, expected_code, expected_count_fn, expected_message in cases: + _reset_chat_batch(rest_client, f"list_page_{scenario_name.replace(' ', '_')}") + baseline_payload = rest_client.get("/chats").json() + assert baseline_payload["code"] == 0, (scenario_name, baseline_payload) + baseline_total = baseline_payload["data"]["total"] + + res = rest_client.get("/chats", params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + assert len(payload["data"]["chats"]) == expected_count_fn(baseline_total), (scenario_name, payload) + assert payload["data"]["total"] == baseline_total, (scenario_name, payload) + else: + assert expected_message in payload["message"], (scenario_name, payload) + + +@pytest.mark.p2 +def test_chat_list_sorting_contract(rest_client, clear_chats): + _reset_chat_batch(rest_client, "list_sort") + ascending_names = [f"list_sort_{i}" for i in range(5)] + descending_names = list(reversed(ascending_names)) + cases = [ + ("orderby none", {"orderby": None}, 0, descending_names, ""), + ("orderby create", {"orderby": "create_time"}, 0, descending_names, ""), + ("orderby update", {"orderby": "update_time"}, 0, descending_names, ""), + ("orderby name ascending", {"orderby": "name", "desc": "False"}, 0, ascending_names, ""), + ("orderby unknown", {"orderby": "unknown"}, 100, None, "AttributeError(\"type object 'Dialog' has no attribute 'unknown'\")"), + ("desc none", {"desc": None}, 0, descending_names, ""), + ("desc true", {"desc": "true"}, 0, descending_names, ""), + ("desc True", {"desc": "True"}, 0, descending_names, ""), + ("desc bool true", {"desc": True}, 0, descending_names, ""), + ("desc false", {"desc": "false"}, 0, ascending_names, ""), + ("desc False", {"desc": "False"}, 0, ascending_names, ""), + ("desc bool false", {"desc": False}, 0, ascending_names, ""), + ("desc False update_time", {"desc": "False", "orderby": "update_time"}, 0, ascending_names, ""), + ("desc unknown", {"desc": "unknown"}, 0, descending_names, ""), + ] + + for scenario_name, params, expected_code, expected_names, expected_message in cases: + res = rest_client.get("/chats", params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + assert _chat_names(payload) == expected_names, (scenario_name, payload) + else: + assert expected_message in payload["message"], (scenario_name, payload) + + +@pytest.mark.p2 +def test_chat_list_concurrent_and_dataset_delete_contract(rest_client, clear_chats, ensure_parsed_document): + _reset_chat_batch(rest_client, "list_concurrent") + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(lambda _idx: rest_client.get("/chats").json(), range(10))) + assert len(results) == 10, results + assert all(result["code"] == 0 for result in results), results + assert all(len(result["data"]["chats"]) == 5 for result in results), results + + dataset_id, _ = ensure_parsed_document() + create_res = rest_client.post("/chats", json={"name": "list_after_dataset_delete", "dataset_ids": [dataset_id]}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + + delete_dataset_res = rest_client.delete("/datasets", json={"ids": [dataset_id]}) + assert delete_dataset_res.status_code == 200 + delete_dataset_payload = delete_dataset_res.json() + assert delete_dataset_payload["code"] == 0, delete_dataset_payload + + list_res = rest_client.get("/chats", params={"keywords": "list_after_dataset_delete"}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert len(list_payload["data"]["chats"]) == 1, list_payload + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _DummyArgs(dict): + def get(self, key, default=None): + return super().get(key, default) + + def getlist(self, key): + value = self.get(key, []) + if value is None: + return [] + if isinstance(value, list): + return value + return [value] + + +class _StubHeaders: + def __init__(self): + self._items = [] + + def add_header(self, key, value): + self._items.append((key, value)) + + def get(self, key, default=None): + for existing_key, value in reversed(self._items): + if existing_key == key: + return value + return default + + +class _StubResponse: + def __init__(self, body=None, mimetype=None, content_type=None): + self.body = body + self.mimetype = mimetype + self.content_type = content_type + self.headers = _StubHeaders() + + +class _DummyUploadFile: + def __init__(self, filename): + self.filename = filename + self.saved_path = None + + async def save(self, path): + self.saved_path = path + + +def _passthrough_login_required(func): + @wraps(func) + async def _wrapper(*args, **kwargs): + return await func(*args, **kwargs) + + return _wrapper + + +class _DummyKB: + def __init__(self, kid="kb-1", embd_id="embd@factory", chunk_num=1, name="Dataset A", status="1"): + self.id = kid + self.embd_id = embd_id + self.chunk_num = chunk_num + self.name = name + self.status = status + + +class _DummyDialogRecord: + def __init__(self, data=None): + self._data = data or { + "id": "chat-1", + "name": "chat-name", + "description": "desc", + "icon": "icon.png", + "kb_ids": ["kb-1"], + "llm_id": "glm-4", + "llm_setting": {"temperature": 0.1}, + "prompt_config": { + "system": "Answer with {knowledge}", + "parameters": [{"key": "knowledge", "optional": False}], + "prologue": "hello", + "quote": True, + }, + "similarity_threshold": 0.2, + "vector_similarity_weight": 0.3, + "top_n": 6, + "top_k": 1024, + "rerank_id": "", + "meta_data_filter": {}, + "tenant_id": "tenant-1", + } + + def to_dict(self): + return deepcopy(self._data) + + +def _run(coro): + return asyncio.run(coro) + + +async def _collect_stream(body): + items = [] + if hasattr(body, "__aiter__"): + async for item in body: + if isinstance(item, bytes): + item = item.decode("utf-8") + items.append(item) + else: + for item in body: + if isinstance(item, bytes): + item = item.decode("utf-8") + items.append(item) + return items + + +def _load_chat_routes_unit_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + module_name = "test_chat_restful_routes_unit_module" + module_path = repo_root / "api" / "apps" / "restful_apis" / "chat_api.py" + + quart_mod = ModuleType("quart") + quart_mod.request = SimpleNamespace(args=_DummyArgs()) + quart_mod.Response = _StubResponse + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_pkg = ModuleType("api.apps") + apps_pkg.__path__ = [str(repo_root / "api" / "apps")] + apps_pkg.current_user = SimpleNamespace(id="tenant-1") + apps_pkg.login_required = _passthrough_login_required + monkeypatch.setitem(sys.modules, "api.apps", apps_pkg) + api_pkg.apps = apps_pkg + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + common_constants_mod = ModuleType("common.constants") + + class _StubLLMType(str, Enum): + CHAT = "chat" + IMAGE2TEXT = "image2text" + RERANK = "rerank" + SPEECH2TEXT = "speech2text" + TTS = "tts" + + class _StubRetCode(int, Enum): + SUCCESS = 0 + DATA_ERROR = 102 + OPERATING_ERROR = 103 + AUTHENTICATION_ERROR = 109 + + class _StubStatusEnum(str, Enum): + VALID = "1" + INVALID = "0" + + common_constants_mod.LLMType = _StubLLMType + common_constants_mod.RetCode = _StubRetCode + common_constants_mod.StatusEnum = _StubStatusEnum + from common.constants import MAXIMUM_PAGE_NUMBER as _MPN, MAXIMUM_TASK_PAGE_NUMBER as _MTPN + common_constants_mod.MAXIMUM_PAGE_NUMBER = _MPN + common_constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN + monkeypatch.setitem(sys.modules, "common.constants", common_constants_mod) + + misc_utils_mod = ModuleType("common.misc_utils") + misc_utils_mod.get_uuid = lambda: "generated-chat-id" + + async def _thread_pool_exec(func, *args, **kwargs): + return func(*args, **kwargs) + + misc_utils_mod.thread_pool_exec = _thread_pool_exec + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + settings_mod = ModuleType("common.settings") + settings_mod.STORAGE_IMPL = type("_StorageImpl", (), {"rm": staticmethod(lambda *_args, **_kwargs: None)})() + monkeypatch.setitem(sys.modules, "common.settings", settings_mod) + + dialog_service_mod = ModuleType("api.db.services.dialog_service") + + class _StubDialogService: + model = SimpleNamespace( + _meta=SimpleNamespace( + fields={ + "id": None, + "tenant_id": None, + "name": None, + "description": None, + "icon": None, + "kb_ids": None, + "llm_id": None, + "llm_setting": None, + "prompt_config": None, + "similarity_threshold": None, + "vector_similarity_weight": None, + "top_n": None, + "top_k": None, + "rerank_id": None, + "meta_data_filter": None, + "created_by": None, + "create_time": None, + "create_date": None, + "update_time": None, + "update_date": None, + "status": None, + } + ) + ) + + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def save(**_kwargs): + return True + + @staticmethod + def get_by_id(_chat_id): + return False, None + + @staticmethod + def update_by_id(_chat_id, _payload): + return True + + @staticmethod + def get_by_tenant_ids(*_args, **_kwargs): + return [], 0 + + dialog_service_mod.DialogService = _StubDialogService + dialog_service_mod.async_ask = lambda *_args, **_kwargs: None + dialog_service_mod.async_chat = lambda *_args, **_kwargs: None + dialog_service_mod.gen_mindmap = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod) + + conversation_service_mod = ModuleType("api.db.services.conversation_service") + + class _StubConversationService: + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def get_list(*_args, **_kwargs): + return [] + + @staticmethod + def get_by_id(_session_id): + return False, None + + @staticmethod + def update_by_id(_session_id, _payload): + return True + + @staticmethod + def delete_by_id(_session_id): + return True + + @staticmethod + def save(**_kwargs): + return True + + conversation_service_mod.ConversationService = _StubConversationService + conversation_service_mod.structure_answer = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "api.db.services.conversation_service", conversation_service_mod) + + kb_service_mod = ModuleType("api.db.services.knowledgebase_service") + + class _StubKnowledgebaseService: + @staticmethod + def accessible(**_kwargs): + return [] + + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def get_by_id(_kb_id): + return False, None + + kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod) + + tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") + + class _StubTenantLLMService: + @staticmethod + def split_model_name_and_factory(model_name): + if model_name and "@" in model_name: + return tuple(model_name.split("@", 1)) + return model_name, None + + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def get_api_key(*_args, **_kwargs): + return SimpleNamespace(id=1) + + tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService + monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + + llm_service_mod = ModuleType("api.db.services.llm_service") + llm_service_mod.LLMBundle = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) + + search_service_mod = ModuleType("api.db.services.search_service") + search_service_mod.SearchService = SimpleNamespace() + monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod) + + tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") + tenant_model_service_mod.get_model_config_by_type_and_name = lambda *_args, **_kwargs: {} + tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + + class _StubTenantService: + @staticmethod + def get_by_id(_tenant_id): + return True, SimpleNamespace(llm_id="glm-4") + + @staticmethod + def get_joined_tenants_by_user_id(_user_id): + return [{"tenant_id": "tenant-1"}, {"tenant_id": "team-tenant-2"}] + + class _StubUserTenantService: + @staticmethod + def query(**_kwargs): + return [] + + user_service_mod.UserService = type("UserService", (), {}) + user_service_mod.TenantService = _StubTenantService + user_service_mod.UserTenantService = _StubUserTenantService + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + chunk_feedback_service_mod = ModuleType("api.db.services.chunk_feedback_service") + chunk_feedback_service_mod.ChunkFeedbackService = type( + "ChunkFeedbackService", + (), + {"apply_feedback": staticmethod(lambda **_kwargs: {"success_count": 0, "fail_count": 0, "chunk_ids": []})}, + ) + monkeypatch.setitem(sys.modules, "api.db.services.chunk_feedback_service", chunk_feedback_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + + def _check_duplicate_ids(ids, label): + counts = {} + for item in ids or []: + counts[item] = counts.get(item, 0) + 1 + duplicate_messages = [f"Duplicate {label} ids: {item}" for item, count in counts.items() if count > 1] + return list(dict.fromkeys(ids or [])), duplicate_messages + + api_utils_mod.check_duplicate_ids = _check_duplicate_ids + api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "data": None, "message": message} + api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "data": data, "message": message} + api_utils_mod.get_request_json = lambda: _AwaitableValue({}) + api_utils_mod.server_error_response = lambda ex: {"code": 500, "data": None, "message": str(ex)} + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + tenant_utils_mod = ModuleType("api.utils.tenant_utils") + tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, req: req + monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + + rag_prompts_pkg = ModuleType("rag.prompts") + rag_prompts_pkg.__path__ = [str(repo_root / "rag" / "prompts")] + monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg) + + rag_prompts_generator_mod = ModuleType("rag.prompts.generator") + rag_prompts_generator_mod.chunks_format = lambda reference: reference.get("chunks", []) if isinstance(reference, dict) else [] + monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts_generator_mod) + + rag_prompts_template_mod = ModuleType("rag.prompts.template") + rag_prompts_template_mod.load_prompt = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "rag.prompts.template", rag_prompts_template_mod) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +def _set_route_unit_request_json(monkeypatch, module, payload): + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload))) + + +@pytest.mark.p2 +def test_chat_session_create_and_update_guard_matrix_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + + _set_route_unit_request_json(monkeypatch, module, {"name": "session"}) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + res = _run(module.create_session.__wrapped__("chat-1")) + assert res["message"] == "No authorization." + + dia = SimpleNamespace(prompt_config={"prologue": "hello"}) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [dia]) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, dia)) + monkeypatch.setattr(module.ConversationService, "save", lambda **_kwargs: None) + monkeypatch.setattr(module.ConversationService, "get_by_id", lambda _id: (False, None)) + res = _run(module.create_session.__wrapped__("chat-1")) + assert "Fail to create a session" in res["message"] + + _set_route_unit_request_json(monkeypatch, module, {}) + monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: []) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert res["message"] == "Session not found!" + + monkeypatch.setattr(module.ConversationService, "query", lambda **_kwargs: [SimpleNamespace(id="session-1")]) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert res["message"] == "No authorization." + + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + _set_route_unit_request_json(monkeypatch, module, {"message": []}) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert "`messages` cannot be changed." in res["message"] + + _set_route_unit_request_json(monkeypatch, module, {"reference": []}) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert "`reference` cannot be changed." in res["message"] + + _set_route_unit_request_json(monkeypatch, module, {"name": ""}) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert "`name` can not be empty." in res["message"] + + _set_route_unit_request_json(monkeypatch, module, {"name": "renamed"}) + monkeypatch.setattr(module.ConversationService, "update_by_id", lambda *_args, **_kwargs: False) + res = _run(module.update_session.__wrapped__("chat-1", "session-1")) + assert res["message"] == "Session not found!" + + +@pytest.mark.p2 +def test_chat_session_list_projection_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + "page": 1, + "page_size": 30, + "orderby": "create_time", + "desc": "true", + "id": None, + "name": None, + "user_id": None, + }.get(key, default) + ) + ), + ) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + monkeypatch.setattr( + module.ConversationService, + "get_list", + lambda *_args, **_kwargs: [ + { + "id": "session-1", + "dialog_id": "chat-1", + "message": [{"role": "assistant", "content": "hello"}], + "reference": [], + } + ], + ) + + res = _run(module.list_sessions.__wrapped__("chat-1")) + assert res["data"][0]["chat_id"] == "chat-1" + assert res["data"][0]["messages"][0]["content"] == "hello" + + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + "page": 1, + "page_size": 0, + "orderby": "create_time", + "desc": "true", + "id": None, + "name": None, + "user_id": None, + }.get(key, default) + ) + ), + ) + res = _run(module.list_sessions.__wrapped__("chat-1")) + assert res["data"] == [] + + +@pytest.mark.p2 +def test_chat_session_delete_routes_partial_duplicate_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + _set_route_unit_request_json(monkeypatch, module, {}) + res = _run(module.delete_sessions.__wrapped__("chat-1")) + assert res["code"] == 0 + + monkeypatch.setattr(module.ConversationService, "delete_by_id", lambda *_args, **_kwargs: True) + + def _conversation_query(**kwargs): + if "dialog_id" in kwargs and "id" not in kwargs: + return [SimpleNamespace(id="seed")] + if kwargs.get("id") == "ok": + return [SimpleNamespace(id="ok")] + return [] + + monkeypatch.setattr(module.ConversationService, "query", _conversation_query) + _set_route_unit_request_json(monkeypatch, module, {"ids": ["ok", "bad"]}) + monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, [])) + res = _run(module.delete_sessions.__wrapped__("chat-1")) + assert res["code"] == 0 + assert res["data"]["success_count"] == 1 + assert res["data"]["errors"] == ["The chat doesn't own the session bad"] + + _set_route_unit_request_json(monkeypatch, module, {"ids": ["bad"]}) + monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (ids, [])) + res = _run(module.delete_sessions.__wrapped__("chat-1")) + assert res["message"] == "The chat doesn't own the session bad" + + _set_route_unit_request_json(monkeypatch, module, {"ids": ["ok", "ok"]}) + monkeypatch.setattr(module, "check_duplicate_ids", lambda ids, _kind: (["ok"], ["Duplicate session ids: ok"])) + res = _run(module.delete_sessions.__wrapped__("chat-1")) + assert res["code"] == 0 + assert res["data"]["success_count"] == 1 + assert res["data"]["errors"] == ["Duplicate session ids: ok"] + + +@pytest.mark.p2 +def test_chat_audio_transcription_routes_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + monkeypatch.setattr(module, "Response", _StubResponse) + monkeypatch.setattr(module.tempfile, "mkstemp", lambda suffix: (11, f"/tmp/audio{suffix}")) + monkeypatch.setattr(module.os, "close", lambda _fd: None) + + def _set_request(form, files): + monkeypatch.setattr(module, "request", SimpleNamespace(form=_AwaitableValue(form), files=_AwaitableValue(files))) + + _set_request({"stream": "false"}, {}) + res = _run(module.transcription.__wrapped__()) + assert "Missing 'file' in multipart form-data" in res["message"] + + _set_request({"stream": "false"}, {"file": _DummyUploadFile("bad.txt")}) + res = _run(module.transcription.__wrapped__()) + assert "Unsupported audio format: .txt" in res["message"] + + _set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr( + module, + "get_tenant_default_model_by_type", + lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found!")), + ) + res = _run(module.transcription.__wrapped__()) + assert res["message"] == "Tenant not found!" + + _set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr( + module, + "get_tenant_default_model_by_type", + lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default ASR model is set")), + ) + res = _run(module.transcription.__wrapped__()) + assert res["message"] == "No default ASR model is set" + + class _SyncASR: + def transcription(self, _path): + return "transcribed text" + + def stream_transcription(self, _path): + return [] + + _set_request({"stream": "false"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "asr-x"}) + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _SyncASR()) + monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup fail"))) + res = _run(module.transcription.__wrapped__()) + assert res["code"] == 0 + assert res["data"]["text"] == "transcribed text" + + class _StreamASR: + def transcription(self, _path): + return "" + + def stream_transcription(self, _path): + yield {"event": "partial", "text": "hello"} + + _set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _StreamASR()) + monkeypatch.setattr(module.os, "remove", lambda _path: None) + resp = _run(module.transcription.__wrapped__()) + assert isinstance(resp, _StubResponse) + assert resp.content_type == "text/event-stream" + chunks = _run(_collect_stream(resp.body)) + assert any('"event": "partial"' in chunk for chunk in chunks) + + class _ErrorASR: + def transcription(self, _path): + return "" + + def stream_transcription(self, _path): + raise RuntimeError("stream asr boom") + + _set_request({"stream": "true"}, {"file": _DummyUploadFile("audio.wav")}) + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _ErrorASR()) + monkeypatch.setattr(module.os, "remove", lambda _path: (_ for _ in ()).throw(RuntimeError("cleanup boom"))) + resp = _run(module.transcription.__wrapped__()) + chunks = _run(_collect_stream(resp.body)) + assert any("stream asr boom" in chunk for chunk in chunks) + + +@pytest.mark.p2 +def test_chat_audio_speech_routes_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + monkeypatch.setattr(module, "Response", _StubResponse) + _set_route_unit_request_json(monkeypatch, module, {"text": "A。B"}) + + monkeypatch.setattr( + module, + "get_tenant_default_model_by_type", + lambda *_args, **_kwargs: (_ for _ in ()).throw(LookupError("Tenant not found!")), + ) + res = _run(module.tts.__wrapped__()) + assert res["message"] == "Tenant not found!" + + monkeypatch.setattr( + module, + "get_tenant_default_model_by_type", + lambda *_args, **_kwargs: (_ for _ in ()).throw(Exception("No default TTS model is set")), + ) + res = _run(module.tts.__wrapped__()) + assert res["message"] == "No default TTS model is set" + + class _TTSOk: + def tts(self, txt): + if not txt: + return [] + yield f"chunk-{txt}".encode("utf-8") + + monkeypatch.setattr(module, "get_tenant_default_model_by_type", lambda *_args, **_kwargs: {"llm_name": "tts-x"}) + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSOk()) + resp = _run(module.tts.__wrapped__()) + assert resp.mimetype == "audio/mpeg" + assert resp.headers.get("Cache-Control") == "no-cache" + assert resp.headers.get("Connection") == "keep-alive" + assert resp.headers.get("X-Accel-Buffering") == "no" + chunks = _run(_collect_stream(resp.body)) + assert any("chunk-A" in chunk for chunk in chunks) + assert any("chunk-B" in chunk for chunk in chunks) + + class _TTSErr: + def tts(self, _txt): + raise RuntimeError("tts boom") + + monkeypatch.setattr(module, "LLMBundle", lambda *_args, **_kwargs: _TTSErr()) + resp = _run(module.tts.__wrapped__()) + chunks = _run(_collect_stream(resp.body)) + assert any('"code": 500' in chunk and "**ERROR**: tts boom" in chunk for chunk in chunks) + + +@pytest.mark.p1 +def test_chat_create_accepts_provider_scoped_rerank_id_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + saved = {} + query_calls = [] + + _set_route_unit_request_json( + monkeypatch, + module, + { + "name": "chat-a", + "icon": "icon.png", + "dataset_ids": ["kb-1"], + "llm_id": "glm-4@ZHIPU-AI", + "llm_setting": {"temperature": 0.8}, + "prompt_config": { + "system": "Answer with {knowledge}", + "parameters": [{"key": "knowledge", "optional": False}], + "prologue": "Hi", + }, + "rerank_id": "custom-reranker@OpenAI", + "vector_similarity_weight": 0.25, + }, + ) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4@ZHIPU-AI"))) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) + monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) + + def _split_model_name_and_factory(model_name): + return { + "glm-4@ZHIPU-AI": ("glm-4", "ZHIPU-AI"), + "custom-reranker@OpenAI": ("custom-reranker", "OpenAI"), + }.get(model_name, (model_name, None)) + + def _query(**kwargs): + query_calls.append(kwargs) + if kwargs == { + "tenant_id": "tenant-1", + "llm_name": "glm-4", + "llm_factory": "ZHIPU-AI", + "model_type": "chat", + }: + return [SimpleNamespace(id="llm-1")] + if kwargs == { + "tenant_id": "tenant-1", + "llm_name": "custom-reranker", + "llm_factory": "OpenAI", + "model_type": "rerank", + }: + return [SimpleNamespace(id="rerank-1")] + return [] + + monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", _split_model_name_and_factory) + monkeypatch.setattr(module.TenantLLMService, "query", _query) + + def _save(**kwargs): + saved.update(kwargs) + return True + + monkeypatch.setattr(module.DialogService, "save", _save) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved))) + + res = _run(module.create.__wrapped__()) + assert res["code"] == 0 + assert saved["rerank_id"] == "custom-reranker@OpenAI" + assert { + "tenant_id": "tenant-1", + "llm_name": "custom-reranker", + "llm_factory": "OpenAI", + "model_type": "rerank", + } in query_calls + + +@pytest.mark.p1 +def test_chat_create_allows_default_knowledge_placeholder_without_sources_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + saved = {} + _set_route_unit_request_json(monkeypatch, module, {"name": "chat-a"}) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module.TenantLLMService, "get_api_key", lambda *_args, **_kwargs: SimpleNamespace(id=1)) + + def _save(**kwargs): + saved.update(kwargs) + return True + + monkeypatch.setattr(module.DialogService, "save", _save) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved))) + + res = _run(module.create.__wrapped__()) + assert res["code"] == 0 + assert saved["kb_ids"] == [] + assert saved["prompt_config"]["system"].find("{knowledge}") >= 0 + assert saved["prompt_config"]["parameters"] == [{"key": "knowledge", "optional": False}] + + +@pytest.mark.p2 +def test_chat_create_uses_direct_chat_fields_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + saved = {} + _set_route_unit_request_json( + monkeypatch, + module, + { + "name": "chat-a", + "icon": "icon.png", + "dataset_ids": ["kb-1"], + "llm_id": "glm-4", + "llm_setting": {"temperature": 0.8}, + "prompt_config": { + "system": "Answer with {knowledge}", + "parameters": [{"key": "knowledge", "optional": False}], + "prologue": "Hi", + }, + "vector_similarity_weight": 0.25, + }, + ) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) + monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) + monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) + monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: [SimpleNamespace(id="llm-1")]) + + def _save(**kwargs): + saved.update(kwargs) + return True + + monkeypatch.setattr(module.DialogService, "save", _save) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(saved))) + + res = _run(module.create.__wrapped__()) + assert res["code"] == 0 + assert saved["kb_ids"] == ["kb-1"] + assert saved["prompt_config"]["prologue"] == "Hi" + assert saved["llm_id"] == "glm-4" + assert saved["llm_setting"]["temperature"] == 0.8 + assert res["data"]["dataset_ids"] == ["kb-1"] + assert res["data"]["kb_names"] == ["Dataset A"] + assert "kb_ids" not in res["data"] + assert "prompt" not in res["data"] + assert "llm" not in res["data"] + assert "avatar" not in res["data"] + + +@pytest.mark.p2 +def test_list_chats_defaults_to_authorized_owner_ids_when_omitted_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + captured = {} + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + "keywords": "", + "page": "1", + "page_size": "10", + "orderby": "create_time", + "desc": "true", + "id": None, + "name": None, + }.get(key, default), + getlist=lambda _key: [], + ) + ), + ) + + def _get_by_tenant_ids(owner_ids, *_args, **_kwargs): + captured["owner_ids"] = owner_ids + return ([], 0) + + monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids) + res = _run(module.list_chats.__wrapped__()) + assert res["code"] == 0 + assert set(captured["owner_ids"]) == {"tenant-1", "team-tenant-2"} + + +@pytest.mark.p2 +def test_list_chats_rejects_unauthorized_owner_ids_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + "keywords": "", + "page": "0", + "page_size": "0", + "orderby": "create_time", + "desc": "true", + "id": None, + "name": None, + }.get(key, default), + getlist=lambda key: ["foreign-tenant-id"] if key == "owner_ids" else [], + ) + ), + ) + res = _run(module.list_chats.__wrapped__()) + assert res["code"] == module.RetCode.OPERATING_ERROR + assert "authorized owner_ids" in res["message"] + + +@pytest.mark.p2 +def test_list_chats_returns_old_business_fields_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + "keywords": "", + "page": 1, + "page_size": 20, + "orderby": "create_time", + "desc": "true", + }.get(key, default), + getlist=lambda _key: [], + ) + ), + ) + monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", lambda *_args, **_kwargs: ([_DummyDialogRecord().to_dict()], 1)) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) + + res = _run(module.list_chats.__wrapped__()) + assert res["code"] == 0 + chat = res["data"]["chats"][0] + assert chat["icon"] == "icon.png" + assert chat["dataset_ids"] == ["kb-1"] + assert chat["kb_names"] == ["Dataset A"] + assert "kb_ids" not in chat + assert chat["prompt_config"]["prologue"] == "hello" + assert "dataset_names" not in chat + assert "prompt" not in chat + assert "llm" not in chat + + +@pytest.mark.p2 +def test_patch_chat_drops_response_only_fields_before_update_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + updated = {} + existing = _DummyDialogRecord().to_dict() + payload = { + "name": "renamed-chat", + "description": existing["description"], + "icon": existing["icon"], + "dataset_ids": existing["kb_ids"], + "kb_names": ["Dataset A"], + "llm_id": existing["llm_id"], + "llm_setting": existing["llm_setting"], + "prompt_config": existing["prompt_config"], + "similarity_threshold": existing["similarity_threshold"], + "vector_similarity_weight": existing["vector_similarity_weight"], + "top_n": existing["top_n"], + "top_k": existing["top_k"], + "rerank_id": existing["rerank_id"], + } + + _set_route_unit_request_json(monkeypatch, module, payload) + monkeypatch.setattr(module.DialogService, "query", lambda **kwargs: [] if "name" in kwargs else [SimpleNamespace(id="chat-1")]) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing))) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: [SimpleNamespace(id="kb-1")]) + monkeypatch.setattr(module.KnowledgebaseService, "query", lambda **_kwargs: [_DummyKB()]) + monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) + monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: [SimpleNamespace(id="llm-1")]) + + def _update(_chat_id, req): + updated.update(req) + return True + + monkeypatch.setattr(module.DialogService, "update_by_id", _update) + res = _run(module.patch_chat.__wrapped__("chat-1")) + assert res["code"] == 0 + assert updated["name"] == "renamed-chat" + assert "kb_names" not in updated + + +@pytest.mark.p2 +def test_patch_chat_merges_prompt_and_llm_settings_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + updated = {} + existing = _DummyDialogRecord().to_dict() + _set_route_unit_request_json( + monkeypatch, + module, + {"prompt_config": {"prologue": "updated opener"}, "llm_setting": {"temperature": 0.9}}, + ) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing))) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) + + def _update(_chat_id, payload): + updated.update(payload) + return True + + monkeypatch.setattr(module.DialogService, "update_by_id", _update) + res = _run(module.patch_chat.__wrapped__("chat-1")) + assert res["code"] == 0 + assert updated["prompt_config"]["system"] == "Answer with {knowledge}" + assert updated["prompt_config"]["prologue"] == "updated opener" + assert updated["llm_setting"]["temperature"] == 0.9 + + +@pytest.mark.p2 +def test_update_chat_allows_knowledge_placeholder_without_sources_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + existing = _DummyDialogRecord().to_dict() + _set_route_unit_request_json( + monkeypatch, + module, + { + "name": "chat-name", + "description": "desc", + "icon": "icon.png", + "dataset_ids": [], + "llm_id": "glm-4", + "llm_setting": {"temperature": 0.1}, + "prompt_config": { + "system": "Answer with {knowledge}", + "parameters": [{"key": "knowledge", "optional": False}], + "prologue": "hello", + "quote": True, + }, + "similarity_threshold": 0.2, + "vector_similarity_weight": 0.3, + "top_n": 6, + "top_k": 1024, + "rerank_id": "", + }, + ) + monkeypatch.setattr(module.DialogService, "query", lambda **_kwargs: [SimpleNamespace(id="chat-1")]) + monkeypatch.setattr(module.DialogService, "get_by_id", lambda _id: (True, _DummyDialogRecord(existing))) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda _tid: (True, SimpleNamespace(llm_id="glm-4"))) + monkeypatch.setattr(module.TenantLLMService, "split_model_name_and_factory", lambda model: (model.split("@")[0], "factory")) + monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: [SimpleNamespace(id="llm-1")]) + updated = {} + + def _update(_chat_id, payload): + updated.update(payload) + return True + + monkeypatch.setattr(module.DialogService, "update_by_id", _update) + res = _run(module.update_chat.__wrapped__("chat-1")) + assert res["code"] == 0 + assert updated["prompt_config"]["system"] == "Answer with {knowledge}" + + +@pytest.mark.p1 +def test_chat_create_dataset_ids_contract(rest_client, clear_chats, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + cases = [ + ("empty dataset_ids", [], 0, "", []), + ("owned parsed dataset", [dataset_id], 0, "", [dataset_id]), + ("invalid dataset id", ["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id", None), + ("dataset_ids wrong type", "invalid_dataset_id", 102, "`dataset_ids` should be a list.", None), + ] + + for index, (scenario_name, dataset_ids, expected_code, expected_message, expected_dataset_ids) in enumerate(cases, start=1): + res = rest_client.post( + "/chats", + json={"name": f"restful_chat_dataset_ids_{index}", "dataset_ids": dataset_ids}, + ) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + assert payload["data"]["dataset_ids"] == expected_dataset_ids, (scenario_name, payload) + else: + assert payload["message"] == expected_message, (scenario_name, payload) + + +@pytest.mark.p2 +def test_chat_create_avatar_contract(rest_client, clear_chats, tmp_path): + image_path = create_image_file(tmp_path / "restful_chat_avatar.png") + encoded_avatar = encode_avatar(image_path) + + res = rest_client.post( + "/chats", + json={"name": "restful_chat_avatar", "dataset_ids": [], "icon": encoded_avatar}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"]["icon"] == encoded_avatar, payload + + +@pytest.mark.p2 +def test_chat_create_llm_contract(rest_client, clear_chats, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + cases = [ + ("default llm", {}, 0, "", "glm-4-flash@ZHIPU-AI", {}), + ("explicit llm_id", {"llm_id": "glm-4"}, 0, "", "glm-4", {}), + ("unknown llm_id", {"llm_id": "unknown"}, 102, "`llm_id` unknown doesn't exist", None, None), + ("temperature zero", {"llm_setting": {"temperature": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 0}), + ("temperature one", {"llm_setting": {"temperature": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 1}), + ("temperature negative one", {"llm_setting": {"temperature": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": -1}), + ("temperature ten", {"llm_setting": {"temperature": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 10}), + ("temperature string", {"llm_setting": {"temperature": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": "a"}), + ("top_p zero", {"llm_setting": {"top_p": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 0}), + ("top_p one", {"llm_setting": {"top_p": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 1}), + ("top_p negative one", {"llm_setting": {"top_p": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": -1}), + ("top_p ten", {"llm_setting": {"top_p": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 10}), + ("top_p string", {"llm_setting": {"top_p": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": "a"}), + ("presence_penalty zero", {"llm_setting": {"presence_penalty": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 0}), + ("presence_penalty one", {"llm_setting": {"presence_penalty": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 1}), + ("presence_penalty negative one", {"llm_setting": {"presence_penalty": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": -1}), + ("presence_penalty ten", {"llm_setting": {"presence_penalty": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 10}), + ("presence_penalty string", {"llm_setting": {"presence_penalty": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": "a"}), + ("frequency_penalty zero", {"llm_setting": {"frequency_penalty": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 0}), + ("frequency_penalty one", {"llm_setting": {"frequency_penalty": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 1}), + ("frequency_penalty negative one", {"llm_setting": {"frequency_penalty": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": -1}), + ("frequency_penalty ten", {"llm_setting": {"frequency_penalty": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 10}), + ("frequency_penalty string", {"llm_setting": {"frequency_penalty": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": "a"}), + ("max_token zero", {"llm_setting": {"max_token": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 0}), + ("max_token 1024", {"llm_setting": {"max_token": 1024}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 1024}), + ("max_token negative one", {"llm_setting": {"max_token": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": -1}), + ("max_token ten", {"llm_setting": {"max_token": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 10}), + ("max_token string", {"llm_setting": {"max_token": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": "a"}), + ("unknown llm setting key", {"llm_setting": {"unknown": "unknown"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"unknown": "unknown"}), + ] + + for index, (scenario_name, extra_payload, expected_code, expected_message, expected_llm_id, expected_llm_setting) in enumerate(cases, start=1): + payload = { + "name": f"restful_chat_llm_{index}", + "dataset_ids": [dataset_id], + } + payload.update(extra_payload) + res = rest_client.post("/chats", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code == 0: + assert body["data"]["llm_id"] == expected_llm_id, (scenario_name, body) + assert body["data"]["llm_setting"] == expected_llm_setting, (scenario_name, body) + else: + assert body["message"] == expected_message, (scenario_name, body) + + +@pytest.mark.p2 +def test_chat_create_prompt_contract(rest_client, clear_chats): + cases = [ + ( + "default prompt config", + {}, + { + ("similarity_threshold",): 0.1, + ("vector_similarity_weight",): 0.3, + ("top_n",): 6, + ("rerank_id",): "", + ("prompt_config", "parameters"): [{"key": "knowledge", "optional": False}], + ("prompt_config", "empty_response"): DEFAULT_CHAT_EMPTY_RESPONSE, + ("prompt_config", "prologue"): DEFAULT_CHAT_PROLOGUE, + ("prompt_config", "quote"): True, + ("prompt_config", "system"): DEFAULT_CHAT_SYSTEM_PROMPT, + }, + ), + ("similarity_threshold zero", {"similarity_threshold": 0}, {("similarity_threshold",): 0}), + ("similarity_threshold one", {"similarity_threshold": 1}, {("similarity_threshold",): 1}), + ("similarity_threshold negative one", {"similarity_threshold": -1}, {("similarity_threshold",): -1.0}), + ("similarity_threshold ten", {"similarity_threshold": 10}, {("similarity_threshold",): 10.0}), + ("similarity_threshold string", {"similarity_threshold": "a"}, {("similarity_threshold",): 0.0}), + ("vector_similarity_weight one", {"vector_similarity_weight": 1}, {("vector_similarity_weight",): 1}), + ("vector_similarity_weight zero", {"vector_similarity_weight": 0}, {("vector_similarity_weight",): 0}), + ("vector_similarity_weight two", {"vector_similarity_weight": 2}, {("vector_similarity_weight",): 2.0}), + ("vector_similarity_weight negative nine", {"vector_similarity_weight": -9}, {("vector_similarity_weight",): -9.0}), + ("vector_similarity_weight string", {"vector_similarity_weight": "a"}, {("vector_similarity_weight",): 0.0}), + ("empty prompt parameters", {"prompt_config": {"parameters": []}}, {("prompt_config", "parameters"): []}), + ("top_n zero", {"top_n": 0}, {("top_n",): 0}), + ("top_n one", {"top_n": 1}, {("top_n",): 1}), + ("top_n negative one", {"top_n": -1}, {("top_n",): -1}), + ("top_n ten", {"top_n": 10}, {("top_n",): 10}), + ("top_n string", {"top_n": "a"}, {("top_n",): 0}), + ("empty_response plain text", {"prompt_config": {"empty_response": "Hello World"}}, {("prompt_config", "empty_response"): "Hello World"}), + ("empty_response empty string", {"prompt_config": {"empty_response": ""}}, {("prompt_config", "empty_response"): ""}), + ("empty_response punctuation", {"prompt_config": {"empty_response": "!@#$%^&*()"}}, {("prompt_config", "empty_response"): "!@#$%^&*()"}), + ("empty_response chinese text", {"prompt_config": {"empty_response": "中文测试"}}, {("prompt_config", "empty_response"): "中文测试"}), + ("empty_response integer", {"prompt_config": {"empty_response": 123}}, {("prompt_config", "empty_response"): 123}), + ("empty_response boolean", {"prompt_config": {"empty_response": True}}, {("prompt_config", "empty_response"): True}), + ("empty_response space", {"prompt_config": {"empty_response": " "}}, {("prompt_config", "empty_response"): " "}), + ("prologue plain text", {"prompt_config": {"prologue": "Hello World"}}, {("prompt_config", "prologue"): "Hello World"}), + ("prologue empty string", {"prompt_config": {"prologue": ""}}, {("prompt_config", "prologue"): ""}), + ("prologue punctuation", {"prompt_config": {"prologue": "!@#$%^&*()"}}, {("prompt_config", "prologue"): "!@#$%^&*()"}), + ("prologue chinese text", {"prompt_config": {"prologue": "中文测试"}}, {("prompt_config", "prologue"): "中文测试"}), + ("prologue integer", {"prompt_config": {"prologue": 123}}, {("prompt_config", "prologue"): 123}), + ("prologue boolean", {"prompt_config": {"prologue": True}}, {("prompt_config", "prologue"): True}), + ("prologue space", {"prompt_config": {"prologue": " "}}, {("prompt_config", "prologue"): " "}), + ("quote true", {"prompt_config": {"quote": True}}, {("prompt_config", "quote"): True}), + ("quote false", {"prompt_config": {"quote": False}}, {("prompt_config", "quote"): False}), + ("system prompt with knowledge prefix", {"prompt_config": {"system": "Hello World {knowledge}"}}, {("prompt_config", "system"): "Hello World {knowledge}"}), + ("system prompt only knowledge", {"prompt_config": {"system": "{knowledge}"}}, {("prompt_config", "system"): "{knowledge}"}), + ("system prompt punctuation", {"prompt_config": {"system": "!@#$%^&*() {knowledge}"}}, {("prompt_config", "system"): "!@#$%^&*() {knowledge}"}), + ("system prompt chinese text", {"prompt_config": {"system": "中文测试 {knowledge}"}}, {("prompt_config", "system"): "中文测试 {knowledge}"}), + ("system prompt plain text", {"prompt_config": {"system": "Hello World"}}, {("prompt_config", "system"): "Hello World"}), + ( + "system prompt with explicit empty parameters", + {"prompt_config": {"system": "Hello World", "parameters": []}}, + {("prompt_config", "system"): "Hello World", ("prompt_config", "parameters"): []}, + ), + ("system prompt integer", {"prompt_config": {"system": 123}}, {("prompt_config", "system"): 123}), + ("system prompt boolean", {"prompt_config": {"system": True}}, {("prompt_config", "system"): True}), + ("unknown prompt_config key", {"prompt_config": {"unknown": "unknown"}}, {("prompt_config", "unknown"): "unknown"}), + ] + + for index, (scenario_name, extra_payload, expected_values) in enumerate(cases, start=1): + res = rest_client.post( + "/chats", + json={"name": f"restful_chat_prompt_{index}", "dataset_ids": [], **extra_payload}, + ) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 0, (scenario_name, payload) + for path, expected_value in expected_values.items(): + assert _get_nested(payload["data"], path) == expected_value, (scenario_name, path, payload) + + +@pytest.mark.p2 +def test_chat_create_additional_guards_contract(rest_client, clear_chats): + cases = [ + ("reject tenant_id override", {"tenant_id": "tenant-should-not-pass"}, "`tenant_id` must not be provided."), + ("reject unknown rerank_id", {"rerank_id": "unknown-rerank-model"}, "`rerank_id` unknown-rerank-model doesn't exist"), + ] + + for index, (scenario_name, extra_payload, expected_message) in enumerate(cases, start=1): + res = rest_client.post( + "/chats", + json={"name": f"restful_chat_guard_{index}", "dataset_ids": [], **extra_payload}, + ) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 102, (scenario_name, payload) + assert expected_message in payload["message"], (scenario_name, payload) + + +@pytest.mark.p2 +def test_chat_create_rejects_unparsed_document(rest_client, clear_chats, create_document): + dataset_id, _ = create_document() + res = rest_client.post( + "/chats", + json={"name": "restful_chat_unparsed_document", "dataset_ids": [dataset_id]}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "doesn't own parsed file" in payload["message"], payload + + +@pytest.mark.p2 +def test_chat_update_name_contract(rest_client, clear_chats): + duplicate_res = rest_client.post("/chats", json={"name": "restful_chat_update_duplicate", "dataset_ids": []}) + assert duplicate_res.status_code == 200 + duplicate_payload = duplicate_res.json() + assert duplicate_payload["code"] == 0, duplicate_payload + + target_res = rest_client.post("/chats", json={"name": "restful_chat_update_name_target", "dataset_ids": []}) + assert target_res.status_code == 200 + target_payload = target_res.json() + assert target_payload["code"] == 0, target_payload + chat_id = target_payload["data"]["id"] + + cases = [ + ("valid name", {"name": "valid_name"}, 0, "", "valid_name"), + ( + "name too long", + {"name": "a" * (CHAT_ASSISTANT_NAME_LIMIT + 1)}, + 102, + f"Chat name length is {CHAT_ASSISTANT_NAME_LIMIT + 1} which is larger than {CHAT_ASSISTANT_NAME_LIMIT}.", + None, + ), + ("name wrong type", {"name": 1}, 102, "Chat name must be a string.", None), + ("name empty", {"name": ""}, 102, "`name` cannot be empty.", None), + ("duplicate lowercase", {"name": "restful_chat_update_duplicate"}, 102, "Duplicated chat name.", None), + ("duplicate uppercase", {"name": "RESTFUL_CHAT_UPDATE_DUPLICATE"}, 102, "Duplicated chat name.", None), + ] + + for scenario_name, patch_payload, expected_code, expected_message, expected_name in cases: + res = rest_client.patch(f"/chats/{chat_id}", json=patch_payload) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + get_res = rest_client.get(f"/chats/{chat_id}") + assert get_res.status_code == 200, (scenario_name, get_res.text) + get_payload = get_res.json() + assert get_payload["code"] == 0, (scenario_name, get_payload) + assert get_payload["data"]["name"] == expected_name, (scenario_name, get_payload) + else: + assert payload["message"] == expected_message, (scenario_name, payload) + + +@pytest.mark.p2 +def test_chat_update_dataset_ids_contract(rest_client, clear_chats, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + target_res = rest_client.post("/chats", json={"name": "restful_chat_update_dataset_target", "dataset_ids": []}) + assert target_res.status_code == 200 + target_payload = target_res.json() + assert target_payload["code"] == 0, target_payload + chat_id = target_payload["data"]["id"] + + cases = [ + ("empty dataset_ids", [], 0, "", []), + ("owned parsed dataset", [dataset_id], 0, "", [dataset_id]), + ("invalid dataset id", ["invalid_dataset_id"], 102, "You don't own the dataset invalid_dataset_id", None), + ("dataset_ids wrong type", "invalid_dataset_id", 102, "`dataset_ids` should be a list.", None), + ] + + for scenario_name, dataset_ids, expected_code, expected_message, expected_dataset_ids in cases: + res = rest_client.put( + f"/chats/{chat_id}", + json={"name": "ragflow test", "dataset_ids": dataset_ids}, + ) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + get_res = rest_client.get(f"/chats/{chat_id}") + assert get_res.status_code == 200, (scenario_name, get_res.text) + get_payload = get_res.json() + assert get_payload["code"] == 0, (scenario_name, get_payload) + assert get_payload["data"]["name"] == "ragflow test", (scenario_name, get_payload) + assert get_payload["data"]["dataset_ids"] == expected_dataset_ids, (scenario_name, get_payload) + else: + assert payload["message"] == expected_message, (scenario_name, payload) + + +@pytest.mark.p2 +def test_chat_update_avatar_contract(rest_client, clear_chats, ensure_parsed_document, tmp_path): + dataset_id, _ = ensure_parsed_document() + create_res = rest_client.post("/chats", json={"name": "restful_chat_update_avatar_target", "dataset_ids": []}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + chat_id = create_payload["data"]["id"] + + image_path = create_image_file(tmp_path / "restful_chat_update_avatar.png") + encoded_avatar = encode_avatar(image_path) + + res = rest_client.put( + f"/chats/{chat_id}", + json={"name": "avatar_test", "icon": encoded_avatar, "dataset_ids": [dataset_id]}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + + get_res = rest_client.get(f"/chats/{chat_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["name"] == "avatar_test", get_payload + assert get_payload["data"]["icon"] == encoded_avatar, get_payload + assert get_payload["data"]["dataset_ids"] == [dataset_id], get_payload + + +@pytest.mark.p2 +def test_chat_update_llm_contract(rest_client, clear_chats, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + cases = [ + ("default llm", {}, 0, "", "glm-4-flash@ZHIPU-AI", {}), + ("explicit llm_id", {"llm_id": "glm-4"}, 0, "", "glm-4", {}), + ("unknown llm_id", {"llm_id": "unknown"}, 102, "`llm_id` unknown doesn't exist", None, None), + ("temperature zero", {"llm_setting": {"temperature": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 0}), + ("temperature one", {"llm_setting": {"temperature": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 1}), + ("temperature negative one", {"llm_setting": {"temperature": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": -1}), + ("temperature ten", {"llm_setting": {"temperature": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": 10}), + ("temperature string", {"llm_setting": {"temperature": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"temperature": "a"}), + ("top_p zero", {"llm_setting": {"top_p": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 0}), + ("top_p one", {"llm_setting": {"top_p": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 1}), + ("top_p negative one", {"llm_setting": {"top_p": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": -1}), + ("top_p ten", {"llm_setting": {"top_p": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": 10}), + ("top_p string", {"llm_setting": {"top_p": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"top_p": "a"}), + ("presence_penalty zero", {"llm_setting": {"presence_penalty": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 0}), + ("presence_penalty one", {"llm_setting": {"presence_penalty": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 1}), + ("presence_penalty negative one", {"llm_setting": {"presence_penalty": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": -1}), + ("presence_penalty ten", {"llm_setting": {"presence_penalty": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": 10}), + ("presence_penalty string", {"llm_setting": {"presence_penalty": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"presence_penalty": "a"}), + ("frequency_penalty zero", {"llm_setting": {"frequency_penalty": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 0}), + ("frequency_penalty one", {"llm_setting": {"frequency_penalty": 1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 1}), + ("frequency_penalty negative one", {"llm_setting": {"frequency_penalty": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": -1}), + ("frequency_penalty ten", {"llm_setting": {"frequency_penalty": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": 10}), + ("frequency_penalty string", {"llm_setting": {"frequency_penalty": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"frequency_penalty": "a"}), + ("max_token zero", {"llm_setting": {"max_token": 0}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 0}), + ("max_token 1024", {"llm_setting": {"max_token": 1024}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 1024}), + ("max_token negative one", {"llm_setting": {"max_token": -1}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": -1}), + ("max_token ten", {"llm_setting": {"max_token": 10}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": 10}), + ("max_token string", {"llm_setting": {"max_token": "a"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"max_token": "a"}), + ("unknown llm setting key", {"llm_setting": {"unknown": "unknown"}}, 0, "", "glm-4-flash@ZHIPU-AI", {"unknown": "unknown"}), + ] + + for index, (scenario_name, extra_payload, expected_code, expected_message, expected_llm_id, expected_llm_setting) in enumerate(cases, start=1): + create_res = rest_client.post( + "/chats", + json={"name": f"restful_chat_update_llm_target_{index}", "dataset_ids": [dataset_id]}, + ) + assert create_res.status_code == 200, (scenario_name, create_res.text) + create_payload = create_res.json() + assert create_payload["code"] == 0, (scenario_name, create_payload) + chat_id = create_payload["data"]["id"] + + updated_name = f"llm_test_{index}" + payload = {"name": updated_name, "dataset_ids": [dataset_id]} + payload.update(extra_payload) + res = rest_client.put(f"/chats/{chat_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code == 0: + get_res = rest_client.get(f"/chats/{chat_id}") + assert get_res.status_code == 200, (scenario_name, get_res.text) + get_payload = get_res.json() + assert get_payload["code"] == 0, (scenario_name, get_payload) + assert get_payload["data"]["name"] == updated_name, (scenario_name, get_payload) + assert get_payload["data"]["llm_id"] == expected_llm_id, (scenario_name, get_payload) + assert get_payload["data"]["llm_setting"] == expected_llm_setting, (scenario_name, get_payload) + else: + assert body["message"] == expected_message, (scenario_name, body) + + +@pytest.mark.p2 +def test_chat_update_prompt_contract(rest_client, clear_chats, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + cases = [ + ( + "default prompt config", + {}, + { + ("similarity_threshold",): 0.1, + ("vector_similarity_weight",): 0.3, + ("top_n",): 6, + ("prompt_config", "parameters"): [{"key": "knowledge", "optional": False}], + ("prompt_config", "empty_response"): DEFAULT_CHAT_EMPTY_RESPONSE, + ("prompt_config", "prologue"): DEFAULT_CHAT_PROLOGUE, + ("prompt_config", "quote"): True, + ("prompt_config", "system"): DEFAULT_CHAT_SYSTEM_PROMPT, + }, + ), + ("similarity_threshold zero", {"similarity_threshold": 0}, {("similarity_threshold",): 0}), + ("similarity_threshold one", {"similarity_threshold": 1}, {("similarity_threshold",): 1}), + ("similarity_threshold negative one", {"similarity_threshold": -1}, {("similarity_threshold",): -1.0}), + ("similarity_threshold ten", {"similarity_threshold": 10}, {("similarity_threshold",): 10.0}), + ("similarity_threshold string", {"similarity_threshold": "a"}, {("similarity_threshold",): 0.0}), + ("vector_similarity_weight zero", {"vector_similarity_weight": 0}, {("vector_similarity_weight",): 0}), + ("vector_similarity_weight one", {"vector_similarity_weight": 1}, {("vector_similarity_weight",): 1}), + ("vector_similarity_weight negative one", {"vector_similarity_weight": -1}, {("vector_similarity_weight",): -1.0}), + ("vector_similarity_weight ten", {"vector_similarity_weight": 10}, {("vector_similarity_weight",): 10.0}), + ("vector_similarity_weight string", {"vector_similarity_weight": "a"}, {("vector_similarity_weight",): 0.0}), + ("empty prompt parameters", {"prompt_config": {"parameters": []}}, {("prompt_config", "parameters"): []}), + ("top_n zero", {"top_n": 0}, {("top_n",): 0}), + ("top_n one", {"top_n": 1}, {("top_n",): 1}), + ("top_n negative one", {"top_n": -1}, {("top_n",): -1}), + ("top_n ten", {"top_n": 10}, {("top_n",): 10}), + ("top_n string", {"top_n": "a"}, {("top_n",): 0}), + ("empty_response plain text", {"prompt_config": {"empty_response": "Hello World"}}, {("prompt_config", "empty_response"): "Hello World"}), + ("empty_response empty string", {"prompt_config": {"empty_response": ""}}, {("prompt_config", "empty_response"): ""}), + ("empty_response punctuation", {"prompt_config": {"empty_response": "!@#$%^&*()"}}, {("prompt_config", "empty_response"): "!@#$%^&*()"}), + ("empty_response chinese text", {"prompt_config": {"empty_response": "中文测试"}}, {("prompt_config", "empty_response"): "中文测试"}), + ("empty_response integer", {"prompt_config": {"empty_response": 123}}, {("prompt_config", "empty_response"): 123}), + ("empty_response boolean", {"prompt_config": {"empty_response": True}}, {("prompt_config", "empty_response"): True}), + ("empty_response space", {"prompt_config": {"empty_response": " "}}, {("prompt_config", "empty_response"): " "}), + ("prologue plain text", {"prompt_config": {"prologue": "Hello World"}}, {("prompt_config", "prologue"): "Hello World"}), + ("prologue empty string", {"prompt_config": {"prologue": ""}}, {("prompt_config", "prologue"): ""}), + ("prologue punctuation", {"prompt_config": {"prologue": "!@#$%^&*()"}}, {("prompt_config", "prologue"): "!@#$%^&*()"}), + ("prologue chinese text", {"prompt_config": {"prologue": "中文测试"}}, {("prompt_config", "prologue"): "中文测试"}), + ("prologue integer", {"prompt_config": {"prologue": 123}}, {("prompt_config", "prologue"): 123}), + ("prologue boolean", {"prompt_config": {"prologue": True}}, {("prompt_config", "prologue"): True}), + ("prologue space", {"prompt_config": {"prologue": " "}}, {("prompt_config", "prologue"): " "}), + ("quote true", {"prompt_config": {"quote": True}}, {("prompt_config", "quote"): True}), + ("quote false", {"prompt_config": {"quote": False}}, {("prompt_config", "quote"): False}), + ("system prompt with knowledge prefix", {"prompt_config": {"system": "Hello World {knowledge}"}}, {("prompt_config", "system"): "Hello World {knowledge}"}), + ("system prompt only knowledge", {"prompt_config": {"system": "{knowledge}"}}, {("prompt_config", "system"): "{knowledge}"}), + ("system prompt punctuation", {"prompt_config": {"system": "!@#$%^&*() {knowledge}"}}, {("prompt_config", "system"): "!@#$%^&*() {knowledge}"}), + ("system prompt chinese text", {"prompt_config": {"system": "中文测试 {knowledge}"}}, {("prompt_config", "system"): "中文测试 {knowledge}"}), + ("system prompt plain text", {"prompt_config": {"system": "Hello World"}}, {("prompt_config", "system"): "Hello World"}), + ( + "system prompt with explicit empty parameters", + {"prompt_config": {"system": "Hello World", "parameters": []}}, + {("prompt_config", "system"): "Hello World", ("prompt_config", "parameters"): []}, + ), + ("system prompt integer", {"prompt_config": {"system": 123}}, {("prompt_config", "system"): 123}), + ("system prompt boolean", {"prompt_config": {"system": True}}, {("prompt_config", "system"): True}), + ("unknown prompt key", {"unknown": "unknown"}, {}), + ] + + for index, (scenario_name, extra_payload, expected_values) in enumerate(cases, start=1): + create_res = rest_client.post( + "/chats", + json={"name": f"restful_chat_update_prompt_target_{index}", "dataset_ids": [dataset_id]}, + ) + assert create_res.status_code == 200, (scenario_name, create_res.text) + create_payload = create_res.json() + assert create_payload["code"] == 0, (scenario_name, create_payload) + chat_id = create_payload["data"]["id"] + + updated_name = f"prompt_test_{index}" + res = rest_client.put( + f"/chats/{chat_id}", + json={"name": updated_name, "dataset_ids": [dataset_id], **extra_payload}, + ) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 0, (scenario_name, payload) + + get_res = rest_client.get(f"/chats/{chat_id}") + assert get_res.status_code == 200, (scenario_name, get_res.text) + get_payload = get_res.json() + assert get_payload["code"] == 0, (scenario_name, get_payload) + assert get_payload["data"]["name"] == updated_name, (scenario_name, get_payload) + assert get_payload["data"]["dataset_ids"] == [dataset_id], (scenario_name, get_payload) + for path, expected_value in expected_values.items(): + assert _get_nested(get_payload["data"], path) == expected_value, (scenario_name, path, get_payload) + + +@pytest.mark.p2 +def test_chat_update_mapping_and_validation_branches_p2(rest_client, clear_chats): + duplicate_res = rest_client.post("/chats", json={"name": "restful_chat_update_mapping_duplicate", "dataset_ids": []}) + assert duplicate_res.status_code == 200 + duplicate_payload = duplicate_res.json() + assert duplicate_payload["code"] == 0, duplicate_payload + + target_res = rest_client.post("/chats", json={"name": "restful_chat_update_mapping_target", "dataset_ids": []}) + assert target_res.status_code == 200 + target_payload = target_res.json() + assert target_payload["code"] == 0, target_payload + chat_id = target_payload["data"]["id"] + + unauthorized = rest_client.patch("/chats/invalid-chat-id", json={"name": "anything"}) + assert unauthorized.status_code == 200 + unauthorized_payload = unauthorized.json() + assert unauthorized_payload["code"] == 109, unauthorized_payload + assert unauthorized_payload["message"] == "No authorization.", unauthorized_payload + + quote_res = rest_client.patch(f"/chats/{chat_id}", json={"prompt_config": {"quote": False}}) + assert quote_res.status_code == 200 + quote_payload = quote_res.json() + assert quote_payload["code"] == 0, quote_payload + assert quote_payload["data"]["prompt_config"]["quote"] is False, quote_payload + + invalid_llm_res = rest_client.patch( + f"/chats/{chat_id}", + json={"llm_id": "unknown-llm-model", "llm_setting": {"model_type": "chat"}}, + ) + assert invalid_llm_res.status_code == 200 + invalid_llm_payload = invalid_llm_res.json() + assert invalid_llm_payload["code"] == 102, invalid_llm_payload + assert "`llm_id` unknown-llm-model doesn't exist" in invalid_llm_payload["message"], invalid_llm_payload + + invalid_rerank_res = rest_client.patch(f"/chats/{chat_id}", json={"rerank_id": "unknown-rerank-model"}) + assert invalid_rerank_res.status_code == 200 + invalid_rerank_payload = invalid_rerank_res.json() + assert invalid_rerank_payload["code"] == 102, invalid_rerank_payload + assert "`rerank_id` unknown-rerank-model doesn't exist" in invalid_rerank_payload["message"], invalid_rerank_payload + + empty_name_res = rest_client.patch(f"/chats/{chat_id}", json={"name": ""}) + assert empty_name_res.status_code == 200 + empty_name_payload = empty_name_res.json() + assert empty_name_payload["code"] == 102, empty_name_payload + assert empty_name_payload["message"] == "`name` cannot be empty.", empty_name_payload + + duplicate_name_res = rest_client.patch(f"/chats/{chat_id}", json={"name": "restful_chat_update_mapping_duplicate"}) + assert duplicate_name_res.status_code == 200 + duplicate_name_payload = duplicate_name_res.json() + assert duplicate_name_payload["code"] == 102, duplicate_name_payload + assert duplicate_name_payload["message"] == "Duplicated chat name.", duplicate_name_payload + + prompt_without_placeholder_res = rest_client.patch( + f"/chats/{chat_id}", + json={"prompt_config": {"system": "No required placeholder", "parameters": [{"key": "knowledge", "optional": False}]}}, + ) + assert prompt_without_placeholder_res.status_code == 200 + prompt_without_placeholder_payload = prompt_without_placeholder_res.json() + assert prompt_without_placeholder_payload["code"] == 0, prompt_without_placeholder_payload + + icon_res = rest_client.patch(f"/chats/{chat_id}", json={"icon": "raw-avatar-value"}) + assert icon_res.status_code == 200 + icon_payload = icon_res.json() + assert icon_payload["code"] == 0, icon_payload + + get_res = rest_client.get(f"/chats/{chat_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["prompt_config"]["system"] == "No required placeholder", get_payload + assert get_payload["data"]["prompt_config"]["parameters"] == [{"key": "knowledge", "optional": False}], get_payload + assert get_payload["data"]["icon"] == "raw-avatar-value", get_payload + + +@pytest.mark.p2 +def test_chat_update_rejects_unparsed_document(rest_client, clear_chats, create_document): + dataset_id, _ = create_document() + create_res = rest_client.post("/chats", json={"name": "restful_chat_update_unparsed_target", "dataset_ids": []}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + chat_id = create_payload["data"]["id"] + + res = rest_client.patch(f"/chats/{chat_id}", json={"dataset_ids": [dataset_id]}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "doesn't own parsed file" in payload["message"], payload diff --git a/test/testcases/restful_api/test_chunks.py b/test/testcases/restful_api/test_chunks.py new file mode 100644 index 00000000000..e2ed7b48c87 --- /dev/null +++ b/test/testcases/restful_api/test_chunks.py @@ -0,0 +1,817 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from concurrent.futures import ThreadPoolExecutor +import os +import pytest +from test.testcases.configs import INVALID_API_TOKEN, INVALID_ID_32 +from test.testcases.restful_api.helpers.client import RestClient +from test.testcases.utils import wait_for + + +def _assert_created_chunk_id(payload): + chunk_id = payload["data"]["chunk"].get("id") + assert chunk_id, payload + assert isinstance(chunk_id, str), payload + assert chunk_id.strip(), payload + return chunk_id + + +@wait_for(10, 1, "Chunk indexing timeout in RESTful batch 09 tests") +def _chunk_count(rest_client, base_path, expected_total): + res = rest_client.get(base_path) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + return payload["data"]["total"] == expected_total and len(payload["data"]["chunks"]) == min(expected_total, 30) + + +def _reset_chunk_batch(rest_client, base_path, count=4): + cleanup_res = rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True}) + assert cleanup_res.status_code == 200, cleanup_res.text + cleanup_payload = cleanup_res.json() + assert cleanup_payload["code"] == 0, cleanup_payload + + baseline_res = rest_client.post(base_path, json={"content": "ragflow test upload"}) + assert baseline_res.status_code == 200, baseline_res.text + baseline_payload = baseline_res.json() + assert baseline_payload["code"] == 0, baseline_payload + baseline_id = _assert_created_chunk_id(baseline_payload) + + chunk_ids = [] + for index in range(count): + res = rest_client.post(base_path, json={"content": f"chunk test {index}"}) + assert res.status_code == 200, (index, res.text) + payload = res.json() + assert payload["code"] == 0, (index, payload) + chunk_ids.append(_assert_created_chunk_id(payload)) + + _chunk_count(rest_client, base_path, count + 1) + return baseline_id, chunk_ids + + +@pytest.mark.p1 +def test_chunks_add_list_get_update_delete_cycle(rest_client, create_document): + dataset_id, document_id = create_document("chunk_cycle.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + + add_res = rest_client.post( + base_path, + json={"content": "batch2 chunk content", "important_keywords": ["batch2"], "questions": ["what is batch2?"]}, + ) + assert add_res.status_code == 200 + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_id = _assert_created_chunk_id(add_payload) + + list_res = rest_client.get(base_path, params={"id": chunk_id}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert list_payload["data"]["total"] == 1, list_payload + assert list_payload["data"]["chunks"][0]["id"] == chunk_id, list_payload + + get_res = rest_client.get(f"{base_path}/{chunk_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == chunk_id, get_payload + + update_res = rest_client.patch( + f"{base_path}/{chunk_id}", + json={"content": "batch2 chunk content updated"}, + ) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + + get_updated_res = rest_client.get(f"{base_path}/{chunk_id}") + assert get_updated_res.status_code == 200 + get_updated_payload = get_updated_res.json() + assert get_updated_payload["code"] == 0, get_updated_payload + assert get_updated_payload["data"]["content_with_weight"] == "batch2 chunk content updated", get_updated_payload + + delete_candidate_res = rest_client.post(base_path, json={"content": "batch2 chunk content to delete"}) + assert delete_candidate_res.status_code == 200 + delete_candidate_payload = delete_candidate_res.json() + assert delete_candidate_payload["code"] == 0, delete_candidate_payload + delete_candidate_id = _assert_created_chunk_id(delete_candidate_payload) + + delete_res = rest_client.delete(base_path, json={"chunk_ids": [delete_candidate_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + deleted_list_res = rest_client.get(base_path, params={"id": delete_candidate_id}) + assert deleted_list_res.status_code == 200 + deleted_list_payload = deleted_list_res.json() + assert deleted_list_payload["code"] != 0, deleted_list_payload + + deleted_get_res = rest_client.get(f"{base_path}/{delete_candidate_id}") + assert deleted_get_res.status_code == 200 + deleted_get_payload = deleted_get_res.json() + assert deleted_get_payload["code"] != 0, deleted_get_payload + + +@pytest.mark.p1 +def test_chunk_add_requires_auth(create_document): + dataset_id, document_id = create_document("chunk_add_auth.txt") + path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.post(path, json={"content": "chunk test"}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p1 +def test_chunk_delete_requires_auth(create_document): + dataset_id, document_id = create_document("chunk_delete_auth.txt") + path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.delete(path, json={"chunk_ids": []}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p1 +def test_chunk_list_requires_auth(create_document): + dataset_id, document_id = create_document("chunk_list_auth.txt") + path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.get(path) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p2 +def test_chunks_add_requires_content(rest_client, create_document): + dataset_id, document_id = create_document("chunk_requires_content.txt") + res = rest_client.post( + f"/datasets/{dataset_id}/documents/{document_id}/chunks", + json={"content": " "}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert payload["message"] == "`content` is required", payload + + +@pytest.mark.p2 +def test_chunk_add_keyword_question_and_tag_contract(rest_client, create_document): + add_cases = [ + ( + "important keywords", + [ + ({"content": "chunk test", "important_keywords": ["a", "b", "c"]}, 0, ""), + ({"content": "chunk test", "important_keywords": [""]}, 0, ""), + ({"content": "chunk test", "important_keywords": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"content": "chunk test", "important_keywords": ["a", "a"]}, 0, ""), + ({"content": "chunk test", "important_keywords": "abc"}, 102, "`important_keywords` is required to be a list"), + ({"content": "chunk test", "important_keywords": 123}, 102, "`important_keywords` is required to be a list"), + ], + ), + ( + "questions", + [ + ({"content": "chunk test", "questions": ["a", "b", "c"]}, 0, ""), + ({"content": "chunk test", "questions": [""]}, 0, ""), + ({"content": "chunk test", "questions": [1]}, 100, "TypeError('sequence item 0: expected str instance, int found')"), + ({"content": "chunk test", "questions": ["a", "a"]}, 0, ""), + ({"content": "chunk test", "questions": "abc"}, 102, "`questions` is required to be a list"), + ({"content": "chunk test", "questions": 123}, 102, "`questions` is required to be a list"), + ], + ), + ( + "tag_kwd", + [ + ({"content": "chunk test", "tag_kwd": ["tag1", "tag2"]}, 0, ""), + ({"content": "chunk test", "tag_kwd": [""]}, 0, ""), + ({"content": "chunk test", "tag_kwd": [1]}, 102, "`tag_kwd` must be a list of strings"), + ({"content": "chunk test", "tag_kwd": ["tag", "tag"]}, 0, ""), + ({"content": "chunk test", "tag_kwd": "abc"}, 102, "`tag_kwd` is required to be a list"), + ({"content": "chunk test", "tag_kwd": 123}, 102, "`tag_kwd` is required to be a list"), + ], + ), + ] + + for group_index, (group_name, cases) in enumerate(add_cases): + dataset_id, document_id = create_document(f"chunk_add_contracts_{group_index}.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + for scenario_index, (payload, expected_code, expected_message) in enumerate(cases): + scenario_name = f"{group_name}-{scenario_index}" + before_payload = rest_client.get(base_path).json() + assert before_payload["code"] == 0, (scenario_name, before_payload) + before_total = before_payload["data"]["doc"]["chunk_count"] + + res = rest_client.post(base_path, json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code == 0: + chunk = body["data"]["chunk"] + assert chunk["dataset_id"] == dataset_id, (scenario_name, body) + assert chunk["document_id"] == document_id, (scenario_name, body) + assert chunk["content"] == payload["content"], (scenario_name, body) + if "important_keywords" in payload: + assert chunk["important_keywords"] == payload["important_keywords"], (scenario_name, body) + if "questions" in payload: + assert chunk["questions"] == [str(q).strip() for q in payload["questions"] if str(q).strip()], (scenario_name, body) + if "tag_kwd" in payload: + assert chunk["tag_kwd"] == payload["tag_kwd"], (scenario_name, body) + after_payload = rest_client.get(base_path).json() + assert after_payload["code"] == 0, (scenario_name, after_payload) + assert after_payload["data"]["doc"]["chunk_count"] == before_total + 1, (scenario_name, after_payload) + else: + assert body["message"] == expected_message, (scenario_name, body) + + +@pytest.mark.p2 +def test_chunk_add_invalid_dataset_and_document_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_invalid_targets.txt") + + invalid_dataset_res = rest_client.post( + f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks", + json={"content": "chunk test"}, + ) + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload + assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload + + invalid_document_res = rest_client.post( + f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks", + json={"content": "chunk test"}, + ) + assert invalid_document_res.status_code == 200 + invalid_document_payload = invalid_document_res.json() + assert invalid_document_payload["code"] == 102, invalid_document_payload + assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload + + +@pytest.mark.p2 +def test_chunk_add_repeated_and_deleted_document_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_repeat_deleted.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + + first_payload = rest_client.get(base_path).json() + assert first_payload["code"] == 0, first_payload + initial_count = first_payload["data"]["doc"]["chunk_count"] + + first_add_res = rest_client.post(base_path, json={"content": "chunk test"}) + second_add_res = rest_client.post(base_path, json={"content": "chunk test"}) + first_add_payload = first_add_res.json() + second_add_payload = second_add_res.json() + assert first_add_payload["code"] == 0, first_add_payload + assert second_add_payload["code"] == 0, second_add_payload + assert first_add_payload["data"]["chunk"]["id"] == second_add_payload["data"]["chunk"]["id"], (first_add_payload, second_add_payload) + + repeated_list_payload = rest_client.get(base_path).json() + assert repeated_list_payload["code"] == 0, repeated_list_payload + assert repeated_list_payload["data"]["doc"]["chunk_count"] == initial_count + 2, repeated_list_payload + assert repeated_list_payload["data"]["total"] == 1, repeated_list_payload + + delete_document_res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]}) + assert delete_document_res.status_code == 200 + delete_document_payload = delete_document_res.json() + assert delete_document_payload["code"] == 0, delete_document_payload + + add_after_delete_res = rest_client.post(base_path, json={"content": "chunk test"}) + assert add_after_delete_res.status_code == 200 + add_after_delete_payload = add_after_delete_res.json() + assert add_after_delete_payload["code"] == 102, add_after_delete_payload + assert add_after_delete_payload["message"] == f"You don't own the document {document_id}.", add_after_delete_payload + + +@pytest.mark.p2 +@pytest.mark.parametrize("count", [20]) +def test_chunk_concurrent_add_contract(rest_client, create_document, count): + dataset_id, document_id = create_document("chunk_concurrent_add.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + baseline_payload = rest_client.get(base_path).json() + assert baseline_payload["code"] == 0, baseline_payload + initial_count = baseline_payload["data"]["doc"]["chunk_count"] + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list( + executor.map( + lambda index: rest_client.post(base_path, json={"content": f"chunk test {index}"}).json(), + range(count), + ) + ) + assert len(results) == count, results + assert all(result["code"] == 0 for result in results), results + + final_payload = rest_client.get(base_path).json() + assert final_payload["code"] == 0, final_payload + assert final_payload["data"]["doc"]["chunk_count"] == initial_count + count, final_payload + + +@pytest.mark.p2 +def test_chunks_list_empty_document(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_empty.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + list_res = rest_client.get(base_path) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert "chunks" in list_payload["data"], list_payload + assert "doc" in list_payload["data"], list_payload + + +@pytest.mark.p2 +def test_chunk_delete_basic_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_delete_basic.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + + cases = [ + ("none payload", None, 0, "", 5), + ("invalid only", {"chunk_ids": ["invalid_id"]}, 102, "rm_chunk deleted chunks 0, expect 1", 5), + ("delete first", lambda ids: {"chunk_ids": ids[:1]}, 0, "", 4), + ("delete generated", lambda ids: {"chunk_ids": ids}, 0, "", 1), + ("empty ids", {"chunk_ids": []}, 0, "", 5), + ] + + for scenario_name, payload, expected_code, expected_message, remaining in cases: + _reset_chunk_batch(rest_client, base_path) + request_body = payload + generated_ids = rest_client.get(base_path).json()["data"]["chunks"][1:] + generated_ids = [chunk["id"] for chunk in generated_ids] + if callable(payload): + request_body = payload(generated_ids) + res = rest_client.delete(base_path, json=request_body) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_message: + assert body.get("message", "") == expected_message, (scenario_name, body) + + list_payload = rest_client.get(base_path).json() + assert list_payload["code"] == 0, (scenario_name, list_payload) + assert len(list_payload["data"]["chunks"]) == remaining, (scenario_name, list_payload) + assert list_payload["data"]["total"] == remaining, (scenario_name, list_payload) + + +@pytest.mark.p2 +def test_chunk_delete_partial_duplicate_repeat_and_invalid_target_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_delete_detail.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + + for scenario_name, payload_builder in ( + ("invalid first", lambda ids: {"chunk_ids": ["invalid_id"] + ids}), + ("invalid middle", lambda ids: {"chunk_ids": ids[:1] + ["invalid_id"] + ids[1:]}), + ("invalid last", lambda ids: {"chunk_ids": ids + ["invalid_id"]}), + ): + _, generated_ids = _reset_chunk_batch(rest_client, base_path) + res = rest_client.delete(base_path, json=payload_builder(generated_ids)) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == 102, (scenario_name, body) + assert body["message"] == "rm_chunk deleted chunks 4, expect 5", (scenario_name, body) + + list_payload = rest_client.get(base_path).json() + assert list_payload["code"] == 0, (scenario_name, list_payload) + assert list_payload["data"]["total"] == 1, (scenario_name, list_payload) + + _, generated_ids = _reset_chunk_batch(rest_client, base_path) + duplicate_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids * 2}) + assert duplicate_res.status_code == 200 + duplicate_payload = duplicate_res.json() + assert duplicate_payload["code"] == 0, duplicate_payload + assert duplicate_payload["data"]["success_count"] == 4, duplicate_payload + assert len(duplicate_payload["data"]["errors"]) == 4, duplicate_payload + assert all(error.startswith("Duplicate chunk ids: ") for error in duplicate_payload["data"]["errors"]), duplicate_payload + duplicate_list_payload = rest_client.get(base_path).json() + assert duplicate_list_payload["code"] == 0, duplicate_list_payload + assert duplicate_list_payload["data"]["total"] == 1, duplicate_list_payload + + _, generated_ids = _reset_chunk_batch(rest_client, base_path) + first_delete_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids}) + assert first_delete_res.status_code == 200 + assert first_delete_res.json()["code"] == 0 + second_delete_res = rest_client.delete(base_path, json={"chunk_ids": generated_ids}) + assert second_delete_res.status_code == 200 + second_delete_payload = second_delete_res.json() + assert second_delete_payload["code"] == 102, second_delete_payload + assert second_delete_payload["message"] == "rm_chunk deleted chunks 0, expect 4", second_delete_payload + + invalid_dataset_res = rest_client.delete( + f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks", + json={"chunk_ids": ["chunk-id"]}, + ) + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload + assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload + + invalid_document_res = rest_client.delete( + f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks", + json={"chunk_ids": ["chunk-id"]}, + ) + assert invalid_document_res.status_code == 200 + invalid_document_payload = invalid_document_res.json() + assert invalid_document_payload["code"] == 102, invalid_document_payload + assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload + + +@pytest.mark.p2 +def test_chunk_delete_web_legacy_basic_variants(rest_client, create_document): + dataset_id, document_id = create_document("chunk_delete_web_legacy_again.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + cases = [ + ("web invalid id", {"chunk_ids": ["invalid_id"]}, 102, 5), + ("web delete first", lambda ids: {"chunk_ids": ids[:1]}, 0, 4), + ("web delete generated", lambda ids: {"chunk_ids": ids}, 0, 1), + ("web empty ids", {"chunk_ids": []}, 0, 5), + ] + for scenario_name, payload, expected_code, remaining in cases: + _, generated_ids = _reset_chunk_batch(rest_client, base_path) + request_body = payload(generated_ids) if callable(payload) else payload + res = rest_client.delete(base_path, json=request_body) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + list_payload = rest_client.get(base_path).json() + assert list_payload["code"] == 0, (scenario_name, list_payload) + assert list_payload["data"]["total"] == remaining, (scenario_name, list_payload) + + +@pytest.mark.p2 +def test_chunk_delete_concurrent_and_bulk_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_delete_bulk_contract.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + + rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True}) + for index in range(12): + payload = rest_client.post(base_path, json={"content": f"chunk test {index}"}).json() + assert payload["code"] == 0, payload + ids_payload = rest_client.get(base_path).json() + assert ids_payload["code"] == 0, ids_payload + chunk_ids = [chunk["id"] for chunk in ids_payload["data"]["chunks"]] + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(lambda chunk_id: rest_client.delete(base_path, json={"chunk_ids": [chunk_id]}).json(), chunk_ids)) + assert len(results) == len(chunk_ids), results + assert all(result["code"] == 0 for result in results), results + + final_payload = rest_client.get(base_path).json() + assert final_payload["code"] == 0, final_payload + assert final_payload["data"]["total"] == 0, final_payload + + rest_client.delete(base_path, json={"chunk_ids": None, "delete_all": True}) + for index in range(40): + payload = rest_client.post(base_path, json={"content": f"bulk chunk {index}"}).json() + assert payload["code"] == 0, payload + bulk_ids_payload = rest_client.get(base_path, params={"page_size": 200}).json() + assert bulk_ids_payload["code"] == 0, bulk_ids_payload + bulk_ids = [chunk["id"] for chunk in bulk_ids_payload["data"]["chunks"]] + bulk_res = rest_client.delete(base_path, json={"chunk_ids": bulk_ids}) + assert bulk_res.status_code == 200 + bulk_payload = bulk_res.json() + assert bulk_payload["code"] == 0, bulk_payload + + +@pytest.mark.p2 +def test_chunk_list_default_get_id_and_invalid_target_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_core.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + baseline_id, generated_ids = _reset_chunk_batch(rest_client, base_path) + + default_res = rest_client.get(base_path) + assert default_res.status_code == 200 + default_payload = default_res.json() + assert default_payload["code"] == 0, default_payload + assert default_payload["data"]["total"] == 5, default_payload + assert len(default_payload["data"]["chunks"]) == 5, default_payload + + get_res = rest_client.get(f"{base_path}/{generated_ids[0]}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == generated_ids[0], get_payload + assert get_payload["data"]["doc_id"] == document_id, get_payload + + invalid_get_res = rest_client.get(f"{base_path}/unknown") + assert invalid_get_res.status_code == 200 + invalid_get_payload = invalid_get_res.json() + assert invalid_get_payload["code"] == 102, invalid_get_payload + assert invalid_get_payload["message"] == "Chunk not found!", invalid_get_payload + + id_cases = [ + ("id none", {"id": None}, 0, 5, None), + ("id empty", {"id": ""}, 0, 5, None), + ("id valid", {"id": generated_ids[0]}, 0, 1, generated_ids[0]), + ("id invalid", {"id": "unknown"}, 102, None, None), + ] + for scenario_name, params, expected_code, expected_total, expected_id in id_cases: + res = rest_client.get(base_path, params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + assert payload["data"]["total"] == expected_total, (scenario_name, payload) + if expected_id is not None: + assert payload["data"]["chunks"][0]["id"] == expected_id, (scenario_name, payload) + else: + assert payload["message"] == f"Chunk not found: {dataset_id}/unknown", (scenario_name, payload) + + invalid_dataset_res = rest_client.get(f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks") + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload + assert invalid_dataset_payload["message"] == f"You don't own the dataset {INVALID_ID_32}.", invalid_dataset_payload + + invalid_document_res = rest_client.get(f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks") + assert invalid_document_res.status_code == 200 + invalid_document_payload = invalid_document_res.json() + assert invalid_document_payload["code"] == 102, invalid_document_payload + assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload + + +@pytest.mark.p2 +@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity") +def test_chunk_list_keyword_and_invalid_param_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_keywords.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + _reset_chunk_batch(rest_client, base_path) + + cases = [ + ("keywords none", {"keywords": None}, 5), + ("keywords empty", {"keywords": ""}, 5), + ("keywords exact one", {"keywords": "1"}, 1), + ("keywords chunk", {"keywords": "chunk"}, 4), + ("keywords ragflow", {"keywords": "ragflow"}, 1), + ("keywords unknown", {"keywords": "unknown"}, 0), + ("invalid params ignored", {"a": "b"}, 5), + ] + + for scenario_name, params, expected_total in cases: + res = rest_client.get(base_path, params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 0, (scenario_name, payload) + assert payload["data"]["total"] == expected_total, (scenario_name, payload) + assert len(payload["data"]["chunks"]) == expected_total, (scenario_name, payload) + + +@pytest.mark.p2 +@pytest.mark.skipif(os.getenv("DOC_ENGINE") == "infinity", reason="infinity") +def test_chunk_list_page_and_page_size_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_paging.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + _reset_chunk_batch(rest_client, base_path) + + cases = [ + ("page none", {"page": None, "page_size": 2}, 0, 2, ""), + ("page zero", {"page": 0, "page_size": 2}, 100, None, "ValueError('Search does not support negative slicing.')"), + ("page two", {"page": 2, "page_size": 2}, 0, 2, ""), + ("page three", {"page": 3, "page_size": 2}, 0, 1, ""), + ("page string", {"page": "3", "page_size": 2}, 0, 1, ""), + ("page negative", {"page": -1, "page_size": 2}, 100, None, "ValueError('Search does not support negative slicing.')"), + ("page alpha", {"page": "a", "page_size": 2}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"), + ("page_size none", {"page_size": None}, 0, 5, ""), + ("page_size zero", {"page_size": 0}, 0, 5, ""), + ("page_size one", {"page_size": 1}, 0, 1, ""), + ("page_size six", {"page_size": 6}, 0, 5, ""), + ("page_size string", {"page_size": "1"}, 0, 1, ""), + ("page_size negative", {"page_size": -1}, 0, 5, ""), + ("page_size alpha", {"page_size": "a"}, 100, None, "ValueError(\"invalid literal for int() with base 10: 'a'\")"), + ] + + for scenario_name, params, expected_code, expected_total, expected_message in cases: + res = rest_client.get(base_path, params=params) + assert res.status_code == 200, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + if expected_code == 0: + assert payload["data"]["total"] == 5, (scenario_name, payload) + assert len(payload["data"]["chunks"]) == expected_total, (scenario_name, payload) + else: + assert expected_message in payload["message"], (scenario_name, payload) + + +@pytest.mark.p2 +def test_chunk_list_concurrent_contract(rest_client, create_document): + dataset_id, document_id = create_document("chunk_list_concurrent.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + _reset_chunk_batch(rest_client, base_path) + + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(lambda _: rest_client.get(base_path).json(), range(20))) + assert len(results) == 20, results + assert all(result["code"] == 0 for result in results), results + assert all(result["data"]["total"] == 5 for result in results), results + + +def _create_chunk_for_update(rest_client, create_document, file_name): + dataset_id, document_id = create_document(file_name) + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + add_res = rest_client.post(base_path, json={"content": "chunk update test"}) + assert add_res.status_code == 200, add_res.text + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_id = add_payload["data"]["chunk"]["id"] + return dataset_id, document_id, chunk_id, base_path + + +@pytest.mark.p2 +def test_chunk_update_requires_auth(rest_client, create_document): + _, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_auth.txt") + for scenario_name, client in (("missing token", RestClient(token=None)), ("invalid token", RestClient(token=INVALID_API_TOKEN))): + res = client.patch(f"{base_path}/{chunk_id}", json={"content": "updated"}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == 401, (scenario_name, payload) + assert payload["message"] == "", (scenario_name, payload) + + +@pytest.mark.p2 +def test_chunk_update_content_and_available_contract(rest_client, create_document): + content_cases = [ + ("content none", {"content": None}, 0, ""), + ("content empty", {"content": ""}, 102, "`content` is required"), + ("content text", {"content": "update chunk"}, 0, ""), + ("content spaces", {"content": " "}, 102, "`content` is required"), + ("content punctuation", {"content": "\n!?。;!?\"'"}, 0, ""), + ] + for scenario_name, payload, expected_code, expected_message in content_cases: + _, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, f"{scenario_name}.txt") + res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert body["message"] == expected_message, (scenario_name, body) + + available_cases = [ + ("available true", {"available": True}, 0, ""), + ("available true str", {"available": "True"}, 100, "invalid literal for int()"), + ("available one", {"available": 1}, 0, ""), + ("available false", {"available": False}, 0, ""), + ("available false str", {"available": "False"}, 100, "invalid literal for int()"), + ("available zero", {"available": 0}, 0, ""), + ] + for scenario_name, payload, expected_code, expected_message in available_cases: + _, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, f"{scenario_name}.txt") + res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert expected_message in body["message"], (scenario_name, body) + + +@pytest.mark.p2 +def test_chunk_update_keywords_questions_and_tag_contract(rest_client, create_document): + _, _, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_fields.txt") + cases = [ + ("important keywords", {"important_keywords": ["a", "b", "c"]}, 0, ""), + ("important keywords empty", {"important_keywords": [""]}, 0, ""), + ("important keywords int", {"important_keywords": [1]}, 100, "TypeError"), + ("important keywords dup", {"important_keywords": ["a", "a"]}, 0, ""), + ("important keywords str", {"important_keywords": "abc"}, 102, "`important_keywords` should be a list"), + ("important keywords number", {"important_keywords": 123}, 102, "`important_keywords` should be a list"), + ("questions", {"questions": ["a", "b", "c"]}, 0, ""), + ("questions empty", {"questions": [""]}, 0, ""), + ("questions int", {"questions": [1]}, 100, "TypeError"), + ("questions dup", {"questions": ["a", "a"]}, 0, ""), + ("questions str", {"questions": "abc"}, 102, "`questions` should be a list"), + ("questions number", {"questions": 123}, 102, "`questions` should be a list"), + ("tag kwd", {"tag_kwd": ["tag1", "tag2"]}, 0, ""), + ("tag kwd empty", {"tag_kwd": [""]}, 0, ""), + ("tag kwd int in list", {"tag_kwd": [1]}, 102, "`tag_kwd` must be a list of strings"), + ("tag kwd dup", {"tag_kwd": ["tag", "tag"]}, 0, ""), + ("tag kwd str", {"tag_kwd": "tag"}, 102, "`tag_kwd` should be a list"), + ("tag kwd number", {"tag_kwd": 123}, 102, "`tag_kwd` should be a list"), + ] + for scenario_name, payload, expected_code, expected_message in cases: + res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert expected_message in body["message"], (scenario_name, body) + + +@pytest.mark.p2 +def test_chunk_update_invalid_target_and_param_contract(rest_client, create_document): + dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update(rest_client, create_document, "chunk_update_invalid_targets.txt") + + invalid_dataset_res = rest_client.patch( + f"/datasets/{INVALID_ID_32}/documents/{document_id}/chunks/{chunk_id}", + json={"content": "updated"}, + ) + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] == 102, invalid_dataset_payload + assert invalid_dataset_payload["message"] in { + f"You don't own the dataset {INVALID_ID_32}.", + f"Can't find this chunk {chunk_id}", + }, invalid_dataset_payload + + invalid_document_res = rest_client.patch( + f"/datasets/{dataset_id}/documents/{INVALID_ID_32}/chunks/{chunk_id}", + json={"content": "updated"}, + ) + assert invalid_document_res.status_code == 200 + invalid_document_payload = invalid_document_res.json() + assert invalid_document_payload["code"] == 102, invalid_document_payload + assert invalid_document_payload["message"] == f"You don't own the document {INVALID_ID_32}.", invalid_document_payload + + invalid_chunk_res = rest_client.patch( + f"{base_path}/{INVALID_ID_32}", + json={"content": "updated"}, + ) + assert invalid_chunk_res.status_code == 200 + invalid_chunk_payload = invalid_chunk_res.json() + assert invalid_chunk_payload["code"] == 102, invalid_chunk_payload + assert invalid_chunk_payload["message"] == f"Can't find this chunk {INVALID_ID_32}", invalid_chunk_payload + + for scenario_name, payload in ( + ("unknown key", {"unknown_key": "unknown_value"}), + ("empty payload", {}), + ): + res = rest_client.patch(f"{base_path}/{chunk_id}", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == 0, (scenario_name, body) + + +@pytest.mark.p2 +def test_chunk_update_repeated_concurrent_and_deleted_document_contract(rest_client, create_document): + dataset_id, document_id, chunk_id, base_path = _create_chunk_for_update( + rest_client, create_document, "chunk_update_repeated_concurrent_deleted.txt" + ) + + first_res = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "chunk test 1"}) + assert first_res.status_code == 200 + assert first_res.json()["code"] == 0 + + second_res = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "chunk test 2"}) + assert second_res.status_code == 200 + assert second_res.json()["code"] == 0 + + get_after_repeat = rest_client.get(f"{base_path}/{chunk_id}") + assert get_after_repeat.status_code == 200 + get_after_repeat_payload = get_after_repeat.json() + assert get_after_repeat_payload["code"] == 0, get_after_repeat_payload + assert get_after_repeat_payload["data"]["content_with_weight"] == "chunk test 2", get_after_repeat_payload + + chunk_ids = [chunk_id] + for index in range(3): + add_res = rest_client.post(base_path, json={"content": f"concurrent update {index}"}) + assert add_res.status_code == 200, add_res.text + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_ids.append(add_payload["data"]["chunk"]["id"]) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [] + for index in range(20): + target_id = chunk_ids[index % len(chunk_ids)] + futures.append( + executor.submit( + lambda cid, i: rest_client.patch( + f"{base_path}/{cid}", + json={"content": f"update chunk test {i}"}, + ).json(), + target_id, + index, + ) + ) + results = [future.result() for future in futures] + assert len(results) == 20, results + assert all(item["code"] == 0 for item in results), results + + delete_document_res = rest_client.delete(f"/datasets/{dataset_id}/documents", json={"ids": [document_id]}) + assert delete_document_res.status_code == 200 + assert delete_document_res.json()["code"] == 0 + + update_after_delete = rest_client.patch(f"{base_path}/{chunk_id}", json={"content": "after delete"}) + assert update_after_delete.status_code == 200 + update_after_delete_payload = update_after_delete.json() + assert update_after_delete_payload["code"] == 102, update_after_delete_payload + assert update_after_delete_payload["message"] in { + f"You don't own the document {document_id}.", + f"Can't find this chunk {chunk_id}", + }, update_after_delete_payload diff --git a/test/testcases/restful_api/test_connector_routes_unit.py b/test/testcases/restful_api/test_connector_routes_unit.py new file mode 100644 index 00000000000..80cd5662a6c --- /dev/null +++ b/test/testcases/restful_api/test_connector_routes_unit.py @@ -0,0 +1,755 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _Args(dict): + def get(self, key, default=None, type=None): + value = super().get(key, default) + if type is None: + return value + try: + return type(value) + except (TypeError, ValueError): + return default + + def to_dict(self, flat=True): + return dict(self) + + +class _FakeResponse: + def __init__(self, body, status_code): + self.body = body + self.status_code = status_code + self.headers = {} + + +class _FakeConnectorRecord: + def __init__(self, payload): + self._payload = payload + + def to_dict(self): + return dict(self._payload) + + +class _FakeCredentials: + def __init__(self, raw='{"refresh_token":"rt","access_token":"at"}'): + self._raw = raw + + def to_json(self): + return self._raw + + +class _FakeFlow: + def __init__(self, client_config, scopes): + self.client_config = client_config + self.scopes = scopes + self.redirect_uri = None + self.credentials = _FakeCredentials() + self.auth_kwargs = None + self.token_code = None + self.token_code_verifier = None + self.code_verifier = "fake-code-verifier" + + def authorization_url(self, **kwargs): + self.auth_kwargs = dict(kwargs) + return f"https://oauth.example/{kwargs['state']}", kwargs["state"] + + def fetch_token(self, code, code_verifier=None): + self.token_code = code + self.token_code_verifier = code_verifier + + +class _FakeBoxToken: + def __init__(self, access_token, refresh_token): + self.access_token = access_token + self.refresh_token = refresh_token + + +class _FakeBoxOAuth: + def __init__(self, config): + self.config = config + self.exchange_code = None + + def get_authorize_url(self, options): + return f"https://box.example/auth?state={options.state}&redirect={options.redirect_uri}" + + def get_tokens_authorization_code_grant(self, code): + self.exchange_code = code + + def retrieve_token(self): + return _FakeBoxToken("box-access", "box-refresh") + + +class _FakeRedis: + def __init__(self): + self.store = {} + self.set_calls = [] + self.deleted = [] + + def get(self, key): + return self.store.get(key) + + def set_obj(self, key, obj, ttl): + self.set_calls.append((key, obj, ttl)) + self.store[key] = json.dumps(obj) + + def delete(self, key): + self.deleted.append(key) + self.store.pop(key, None) + + +def _run(coro): + return asyncio.run(coro) + + +def _set_request(module, *, args=None, json_body=None): + module.request = SimpleNamespace( + args=_Args(args or {}), + json=_AwaitableValue({} if json_body is None else json_body), + ) + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +def _load_connector_app(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = SimpleNamespace(id="tenant-1") + apps_mod.login_required = lambda fn: fn + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + + db_mod = ModuleType("api.db") + db_mod.InputType = SimpleNamespace(POLL="POLL") + monkeypatch.setitem(sys.modules, "api.db", db_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + connector_service_mod = ModuleType("api.db.services.connector_service") + + class _StubConnectorService: + @staticmethod + def update_by_id(*_args, **_kwargs): + return True + + @staticmethod + def save(**_kwargs): + return True + + @staticmethod + def get_by_id(_connector_id): + return True, _FakeConnectorRecord({"id": _connector_id}) + + @staticmethod + def list(_tenant_id): + return [] + + @staticmethod + def accessible(*_args, **_kwargs): + return True + + @staticmethod + def cancel_tasks(*_args, **_kwargs): + return True + + @staticmethod + def rebuild(*_args, **_kwargs): + return None + + @staticmethod + def delete_by_id(*_args, **_kwargs): + return True + + class _StubSyncLogsService: + @staticmethod + def list_sync_tasks(*_args, **_kwargs): + return [], 0 + + connector_service_mod.ConnectorService = _StubConnectorService + connector_service_mod.SyncLogsService = _StubSyncLogsService + monkeypatch.setitem(sys.modules, "api.db.services.connector_service", connector_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + + async def _get_request_json(): + return {} + + api_utils_mod.get_request_json = _get_request_json + api_utils_mod.get_json_result = lambda data=None, message="", code=0: { + "code": code, + "message": message, + "data": data, + } + api_utils_mod.get_data_error_result = lambda message="", code=400, data=None: { + "code": code, + "message": message, + "data": data, + } + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + constants_mod = ModuleType("common.constants") + constants_mod.RetCode = SimpleNamespace( + ARGUMENT_ERROR=101, + SERVER_ERROR=500, + RUNNING=102, + PERMISSION_ERROR=403, + AUTHENTICATION_ERROR=109, + ) + constants_mod.TaskStatus = SimpleNamespace( + UNSTART="unstart", + SCHEDULE="schedule", + CANCEL="cancel", + ) + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + config_mod = ModuleType("common.data_source.config") + config_mod.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI = "https://example.com/drive" + config_mod.GMAIL_WEB_OAUTH_REDIRECT_URI = "https://example.com/gmail" + config_mod.BOX_WEB_OAUTH_REDIRECT_URI = "https://example.com/box" + config_mod.DocumentSource = SimpleNamespace(GMAIL="gmail", GOOGLE_DRIVE="google-drive") + monkeypatch.setitem(sys.modules, "common.data_source.config", config_mod) + + google_constants_mod = ModuleType("common.data_source.google_util.constant") + google_constants_mod.WEB_OAUTH_POPUP_TEMPLATE = ( + "{title}" + "

{heading}

{message}

" + ) + google_constants_mod.GOOGLE_SCOPES = { + config_mod.DocumentSource.GMAIL: ["scope-gmail"], + config_mod.DocumentSource.GOOGLE_DRIVE: ["scope-drive"], + } + monkeypatch.setitem(sys.modules, "common.data_source.google_util.constant", google_constants_mod) + + misc_mod = ModuleType("common.misc_utils") + misc_mod.get_uuid = lambda: "uuid-from-helper" + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + + rag_utils_pkg = ModuleType("rag.utils") + rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")] + monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg) + + redis_mod = ModuleType("rag.utils.redis_conn") + redis_mod.REDIS_CONN = _FakeRedis() + monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod) + + quart_mod = ModuleType("quart") + quart_mod.request = SimpleNamespace(args=_Args(), json=_AwaitableValue({})) + + async def _make_response(body, status_code): + return _FakeResponse(body, status_code) + + quart_mod.make_response = _make_response + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + google_pkg = ModuleType("google_auth_oauthlib") + google_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "google_auth_oauthlib", google_pkg) + + google_flow_mod = ModuleType("google_auth_oauthlib.flow") + + class _StubFlow: + @classmethod + def from_client_config(cls, client_config, scopes): + return _FakeFlow(client_config, scopes) + + google_flow_mod.Flow = _StubFlow + monkeypatch.setitem(sys.modules, "google_auth_oauthlib.flow", google_flow_mod) + + box_mod = ModuleType("box_sdk_gen") + + class _OAuthConfig: + def __init__(self, client_id, client_secret): + self.client_id = client_id + self.client_secret = client_secret + + class _GetAuthorizeUrlOptions: + def __init__(self, redirect_uri, state): + self.redirect_uri = redirect_uri + self.state = state + + box_mod.BoxOAuth = _FakeBoxOAuth + box_mod.OAuthConfig = _OAuthConfig + box_mod.GetAuthorizeUrlOptions = _GetAuthorizeUrlOptions + monkeypatch.setitem(sys.modules, "box_sdk_gen", box_mod) + + module_path = repo_root / "api" / "apps" / "restful_apis" / "connector_api.py" + spec = importlib.util.spec_from_file_location("test_connector_routes_unit", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_connector_basic_routes_and_task_controls(monkeypatch): + module = _load_connector_app(monkeypatch) + + async def _no_sleep(_secs): + return None + + monkeypatch.setattr(module.asyncio, "sleep", _no_sleep) + + records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})} + update_calls = [] + save_calls = [] + cancel_calls = [] + delete_calls = [] + + monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload))) + + def _save(**payload): + save_calls.append(payload) + records[payload["id"]] = _FakeConnectorRecord(payload) + + monkeypatch.setattr(module.ConnectorService, "save", _save) + monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid])) + monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}]) + monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9)) + monkeypatch.setattr(module.ConnectorService, "cancel_tasks", lambda cid: cancel_calls.append(cid)) + monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid)) + monkeypatch.setattr(module, "get_uuid", lambda: "generated-id") + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"id": "conn-1", "refresh_freq": 7, "config": {"x": 1}}), + ) + res = _run(module.update_connector("conn-1")) + assert update_calls == [("conn-1", {'id': 'conn-1', "refresh_freq": 7, "config": {"x": 1}})] + assert res["data"]["id"] == "conn-1" + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"name": "new", "source": "gmail", "config": {"y": 2}}), + ) + res = _run(module.create_connector()) + assert save_calls[-1]["id"] == "generated-id" + assert save_calls[-1]["tenant_id"] == "tenant-1" + assert save_calls[-1]["input_type"] == module.InputType.POLL + assert save_calls[-1]["status"] == module.TaskStatus.UNSTART + assert res["data"]["id"] == "generated-id" + + list_res = module.list_connector() + assert list_res["data"] == [{"id": "listed", "tenant": "tenant-1"}] + + monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda _cid: (False, None)) + missing_res = module.get_connector("missing") + assert missing_res["message"] == "Can't find this Connector!" + + monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, _FakeConnectorRecord({"id": cid}))) + found_res = module.get_connector("conn-2") + assert found_res["data"]["id"] == "conn-2" + + _set_request(module, args={"page": "2", "page_size": "7"}) + logs_res = module.list_logs("conn-log") + assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]} + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"})) + monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed") + failed_rebuild = _run(module.rebuild("conn-rb")) + assert failed_rebuild["code"] == module.RetCode.SERVER_ERROR + assert failed_rebuild["data"] is False + + monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: None) + ok_rebuild = _run(module.rebuild("conn-rb")) + assert ok_rebuild["data"] is True + + rm_res = module.rm_connector("conn-rm") + assert rm_res["data"] is True + assert cancel_calls == ["conn-rm"] + assert delete_calls == ["conn-rm"] + + +@pytest.mark.p2 +def test_connector_by_id_routes_reject_cross_tenant_access(monkeypatch): + """Verify per-id connector routes stop before body parsing or service access.""" + module = _load_connector_app(monkeypatch) + + touched = [] + monkeypatch.setattr(module.ConnectorService, "accessible", lambda cid, uid: False) + monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda *_args: touched.append("get_by_id")) + monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda *_args: touched.append("list_sync_tasks")) + monkeypatch.setattr(module.ConnectorService, "cancel_tasks", lambda *_args: touched.append("cancel_tasks")) + monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda *_args: touched.append("delete_by_id")) + monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda *_args: touched.append("update_by_id")) + monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: touched.append("rebuild")) + + def _get_request_json(): + touched.append("get_request_json") + return _AwaitableValue({"config": {"x": 1}}) + + monkeypatch.setattr(module, "get_request_json", _get_request_json) + + responses = [ + _run(module.update_connector("conn-victim")), + module.get_connector("conn-victim"), + module.list_logs("conn-victim"), + _run(module.rebuild("conn-victim")), + module.rm_connector("conn-victim"), + _run(module.test_connector("conn-victim")), + ] + + assert all(res["code"] == module.RetCode.AUTHENTICATION_ERROR for res in responses) + assert all(res["message"] == "No authorization." for res in responses) + assert all(res["data"] is False for res in responses) + assert touched == [] + + +@pytest.mark.p2 +def test_connector_oauth_helper_functions(monkeypatch): + module = _load_connector_app(monkeypatch) + + assert module._web_state_cache_key("flow-a", "gmail") == "gmail_web_flow_state:flow-a" + assert module._web_result_cache_key("flow-b", "google-drive") == "google-drive_web_flow_result:flow-b" + + creds_dict = {"web": {"client_id": "id"}} + assert module._load_credentials(creds_dict) == creds_dict + assert module._load_credentials(json.dumps(creds_dict)) == creds_dict + + with pytest.raises(ValueError, match="Invalid Google credentials JSON"): + module._load_credentials("{not-json") + + assert module._get_web_client_config(creds_dict) == {"web": {"client_id": "id"}} + with pytest.raises(ValueError, match="must include a 'web'"): + module._get_web_client_config({"installed": {"client_id": "id"}}) + + popup_ok = _run(module._render_web_oauth_popup("flow-1", True, "done", "gmail")) + assert popup_ok.status_code == 200 + assert popup_ok.headers["Content-Type"] == "text/html; charset=utf-8" + assert "Authorization complete" in popup_ok.body + assert "ragflow-gmail-oauth" in popup_ok.body + + popup_error = _run(module._render_web_oauth_popup("flow-2", False, "", "google-drive")) + assert popup_error.status_code == 200 + assert "Authorization failed" in popup_error.body + assert "<denied>" in popup_error.body + + +@pytest.mark.p2 +def test_start_google_web_oauth_matrix(monkeypatch): + module = _load_connector_app(monkeypatch) + + redis = _FakeRedis() + monkeypatch.setattr(module, "REDIS_CONN", redis) + monkeypatch.setattr(module.time, "time", lambda: 1700000000) + + flow_calls = [] + + def _from_client_config(client_config, scopes): + flow = _FakeFlow(client_config, scopes) + flow_calls.append(flow) + return flow + + monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config)) + + _set_request(module, args={"type": "invalid"}) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{}"})) + invalid_type = _run(module.start_google_web_oauth()) + assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR + + monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "") + _set_request(module, args={"type": "gmail"}) + missing_redirect = _run(module.start_google_web_oauth()) + assert missing_redirect["code"] == module.RetCode.SERVER_ERROR + + monkeypatch.setattr(module, "GMAIL_WEB_OAUTH_REDIRECT_URI", "https://example.com/gmail") + monkeypatch.setattr(module, "GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI", "https://example.com/drive") + + _set_request(module, args={"type": "google-drive"}) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": "{invalid-json"})) + invalid_credentials = _run(module.start_google_web_oauth()) + assert invalid_credentials["code"] == module.RetCode.ARGUMENT_ERROR + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id"}, "refresh_token": "rt"})}), + ) + has_refresh_token = _run(module.start_google_web_oauth()) + assert has_refresh_token["code"] == module.RetCode.ARGUMENT_ERROR + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"credentials": json.dumps({"installed": {"x": 1}})})) + missing_web = _run(module.start_google_web_oauth()) + assert missing_web["code"] == module.RetCode.ARGUMENT_ERROR + + ids = iter(["flow-gmail", "flow-drive"]) + monkeypatch.setattr(module.uuid, "uuid4", lambda: next(ids)) + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"credentials": json.dumps({"web": {"client_id": "id", "client_secret": "secret"}})}), + ) + + _set_request(module, args={"type": "gmail"}) + gmail_ok = _run(module.start_google_web_oauth()) + assert gmail_ok["code"] == 0 + assert gmail_ok["data"]["flow_id"] == "flow-gmail" + assert gmail_ok["data"]["authorization_url"].endswith("flow-gmail") + + _set_request(module, args={}) + drive_ok = _run(module.start_google_web_oauth()) + assert drive_ok["code"] == 0 + assert drive_ok["data"]["flow_id"] == "flow-drive" + assert drive_ok["data"]["authorization_url"].endswith("flow-drive") + + assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GMAIL] for call in flow_calls) + assert any(call.scopes == module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE] for call in flow_calls) + assert "gmail_web_flow_state:flow-gmail" in redis.store + assert "google-drive_web_flow_state:flow-drive" in redis.store + assert json.loads(redis.store["gmail_web_flow_state:flow-gmail"])["code_verifier"] == "fake-code-verifier" + assert json.loads(redis.store["google-drive_web_flow_state:flow-drive"])["code_verifier"] == "fake-code-verifier" + + +@pytest.mark.p2 +def test_google_web_oauth_callbacks_matrix(monkeypatch): + module = _load_connector_app(monkeypatch) + + flow_calls = [] + + def _from_client_config(client_config, scopes): + flow = _FakeFlow(client_config, scopes) + flow_calls.append(flow) + return flow + + monkeypatch.setattr(module.Flow, "from_client_config", staticmethod(_from_client_config)) + + callback_specs = [ + ( + module.google_gmail_web_oauth_callback, + "gmail", + module.GMAIL_WEB_OAUTH_REDIRECT_URI, + module.GOOGLE_SCOPES[module.DocumentSource.GMAIL], + ), + ( + module.google_drive_web_oauth_callback, + "google-drive", + module.GOOGLE_DRIVE_WEB_OAUTH_REDIRECT_URI, + module.GOOGLE_SCOPES[module.DocumentSource.GOOGLE_DRIVE], + ), + ] + + for callback, source, expected_redirect, expected_scopes in callback_specs: + redis = _FakeRedis() + monkeypatch.setattr(module, "REDIS_CONN", redis) + + _set_request(module, args={}) + missing_state = _run(callback()) + assert "Missing OAuth state parameter." in missing_state.body + + _set_request(module, args={"state": "sid"}) + expired_state = _run(callback()) + assert "Authorization session expired" in expired_state.body + + redis.store[module._web_state_cache_key("sid", source)] = json.dumps({"user_id": "tenant-1"}) + _set_request(module, args={"state": "sid"}) + invalid_state = _run(callback()) + assert "Authorization session was invalid" in invalid_state.body + assert module._web_state_cache_key("sid", source) in redis.deleted + + redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + }) + _set_request(module, args={"state": "sid", "error": "denied", "error_description": "permission denied"}) + oauth_error = _run(callback()) + assert "permission denied" in oauth_error.body + + redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + }) + _set_request(module, args={"state": "sid"}) + missing_code = _run(callback()) + assert "Missing authorization code" in missing_code.body + + redis.store[module._web_state_cache_key("sid", source)] = json.dumps({ + "user_id": "tenant-1", + "client_config": {"web": {"client_id": "cid"}}, + "code_verifier": "state-code-verifier", + }) + _set_request(module, args={"state": "sid", "code": "code-123"}) + success = _run(callback()) + assert "Authorization completed successfully." in success.body + + result_key = module._web_result_cache_key("sid", source) + assert result_key in redis.store + assert module._web_state_cache_key("sid", source) in redis.deleted + + assert flow_calls[-1].redirect_uri == expected_redirect + assert flow_calls[-1].scopes == expected_scopes + assert flow_calls[-1].token_code == "code-123" + assert flow_calls[-1].token_code_verifier == "state-code-verifier" + + +@pytest.mark.p2 +def test_poll_google_web_result_matrix(monkeypatch): + module = _load_connector_app(monkeypatch) + redis = _FakeRedis() + monkeypatch.setattr(module, "REDIS_CONN", redis) + + _set_request(module, args={"type": "invalid"}, json_body={"flow_id": "flow-1"}) + invalid_type = _run(module.poll_google_web_result()) + assert invalid_type["code"] == module.RetCode.ARGUMENT_ERROR + + _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) + pending = _run(module.poll_google_web_result()) + assert pending["code"] == module.RetCode.RUNNING + + redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps( + {"user_id": "another-user", "credentials": "token-x"} + ) + _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) + permission_error = _run(module.poll_google_web_result()) + assert permission_error["code"] == module.RetCode.PERMISSION_ERROR + + redis.store[module._web_result_cache_key("flow-1", "gmail")] = json.dumps( + {"user_id": "tenant-1", "credentials": "token-ok"} + ) + _set_request(module, args={"type": "gmail"}, json_body={"flow_id": "flow-1"}) + success = _run(module.poll_google_web_result()) + assert success["code"] == 0 + assert success["data"] == {"credentials": "token-ok"} + assert module._web_result_cache_key("flow-1", "gmail") in redis.deleted + + +@pytest.mark.p2 +def test_box_oauth_start_callback_and_poll_matrix(monkeypatch): + module = _load_connector_app(monkeypatch) + redis = _FakeRedis() + monkeypatch.setattr(module, "REDIS_CONN", redis) + + created_auth = [] + + class _TrackingBoxOAuth(_FakeBoxOAuth): + def __init__(self, config): + super().__init__(config) + created_auth.append(self) + + monkeypatch.setattr(module, "BoxOAuth", _TrackingBoxOAuth) + monkeypatch.setattr(module.uuid, "uuid4", lambda: "flow-box") + monkeypatch.setattr(module.time, "time", lambda: 1800000000) + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({})) + missing_params = _run(module.start_box_web_oauth()) + assert missing_params["code"] == module.RetCode.ARGUMENT_ERROR + + monkeypatch.setattr( + module, + "get_request_json", + lambda: _AwaitableValue({"client_id": "cid", "client_secret": "sec", "redirect_uri": "https://box.local/callback"}), + ) + start_ok = _run(module.start_box_web_oauth()) + assert start_ok["code"] == 0 + assert start_ok["data"]["flow_id"] == "flow-box" + assert "authorization_url" in start_ok["data"] + assert module._web_state_cache_key("flow-box", "box") in redis.store + + _set_request(module, args={}) + missing_state = _run(module.box_web_oauth_callback()) + assert "Missing OAuth parameters." in missing_state.body + + _set_request(module, args={"state": "flow-box"}) + missing_code = _run(module.box_web_oauth_callback()) + assert "Missing authorization code from Box." in missing_code.body + + redis.store[module._web_state_cache_key("flow-null", "box")] = "null" + _set_request(module, args={"state": "flow-null", "code": "abc"}) + invalid_session = _run(module.box_web_oauth_callback()) + assert invalid_session["code"] == module.RetCode.ARGUMENT_ERROR + + redis.store[module._web_state_cache_key("flow-box", "box")] = json.dumps( + {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"} + ) + _set_request(module, args={"state": "flow-box", "code": "abc", "error": "access_denied", "error_description": "denied"}) + callback_error = _run(module.box_web_oauth_callback()) + assert "denied" in callback_error.body + + redis.store[module._web_state_cache_key("flow-ok", "box")] = json.dumps( + {"user_id": "tenant-1", "client_id": "cid", "client_secret": "sec"} + ) + _set_request(module, args={"state": "flow-ok", "code": "code-ok"}) + callback_success = _run(module.box_web_oauth_callback()) + assert "Authorization completed successfully." in callback_success.body + assert created_auth[-1].exchange_code == "code-ok" + assert module._web_result_cache_key("flow-ok", "box") in redis.store + assert module._web_state_cache_key("flow-ok", "box") in redis.deleted + + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"flow_id": "flow-ok"})) + redis.store.pop(module._web_result_cache_key("flow-ok", "box"), None) + pending = _run(module.poll_box_web_result()) + assert pending["code"] == module.RetCode.RUNNING + + redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps({"user_id": "another-user"}) + permission_error = _run(module.poll_box_web_result()) + assert permission_error["code"] == module.RetCode.PERMISSION_ERROR + + redis.store[module._web_result_cache_key("flow-ok", "box")] = json.dumps( + {"user_id": "tenant-1", "access_token": "at", "refresh_token": "rt"} + ) + poll_success = _run(module.poll_box_web_result()) + assert poll_success["code"] == 0 + assert poll_success["data"]["credentials"]["access_token"] == "at" + assert module._web_result_cache_key("flow-ok", "box") in redis.deleted diff --git a/test/testcases/restful_api/test_datasets.py b/test/testcases/restful_api/test_datasets.py new file mode 100644 index 00000000000..a0f17261027 --- /dev/null +++ b/test/testcases/restful_api/test_datasets.py @@ -0,0 +1,335 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from configs import DATASET_NAME_LIMIT + + +@pytest.mark.p1 +class TestDatasetsAuthorization: + def test_create_requires_auth(self, rest_client_noauth): + res = rest_client_noauth.post("/datasets", json={"name": "auth_test"}) + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, payload + + +@pytest.mark.p1 +def test_dataset_crud_cycle(rest_client, clear_datasets): + create_res = rest_client.post("/datasets", json={"name": "restful_dataset_crud"}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + dataset_id = create_payload["data"]["id"] + + get_res = rest_client.get(f"/datasets/{dataset_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == dataset_id, get_payload + + update_res = rest_client.put( + f"/datasets/{dataset_id}", + json={"name": "restful_dataset_crud_updated"}, + ) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + assert update_payload["data"]["name"] == "restful_dataset_crud_updated", update_payload + + list_res = rest_client.get("/datasets", params={"id": dataset_id}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert len(list_payload["data"]) == 1, list_payload + assert list_payload["data"][0]["id"] == dataset_id, list_payload + assert list_payload.get("total_datasets", 0) >= 1, list_payload + + delete_res = rest_client.delete("/datasets", json={"ids": [dataset_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + list_after_delete = rest_client.get("/datasets") + assert list_after_delete.status_code == 200 + list_after_delete_payload = list_after_delete.json() + assert list_after_delete_payload["code"] == 0, list_after_delete_payload + assert all(dataset["id"] != dataset_id for dataset in list_after_delete_payload["data"]), list_after_delete_payload + + +@pytest.mark.p2 +@pytest.mark.parametrize( + "name, expected_fragment", + [ + ("", "String should have at least 1 character"), + (" ", "String should have at least 1 character"), + ("a" * (DATASET_NAME_LIMIT + 1), f"String should have at most {DATASET_NAME_LIMIT} characters"), + ], + ids=["empty", "spaces", "too_long"], +) +def test_dataset_create_name_validation(rest_client, clear_datasets, name, expected_fragment): + res = rest_client.post("/datasets", json={"name": name}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert expected_fragment in payload["message"], payload + + +@pytest.mark.p2 +def test_dataset_list_ordering_and_pagination(rest_client, clear_datasets): + for i in range(3): + res = rest_client.post("/datasets", json={"name": f"dataset_page_{i}"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + + list_res = rest_client.get( + "/datasets", + params={"page": 1, "page_size": 2, "orderby": "create_time", "desc": "true"}, + ) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert len(list_payload["data"]) == 2, list_payload + assert list_payload.get("total_datasets", 0) >= 3, list_payload + + +@pytest.mark.p2 +def test_dataset_search_endpoint(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + res = rest_client.post( + f"/datasets/{dataset_id}/search", + json={"question": "test TXT file", "page": 1, "size": 10}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_dataset_search_requires_question(rest_client, create_dataset): + dataset_id = create_dataset("dataset_search_missing_question") + res = rest_client.post(f"/datasets/{dataset_id}/search", json={}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "question" in payload["message"], payload + + +@pytest.mark.p2 +def test_dataset_tags_and_aggregation(rest_client, create_dataset): + dataset_id = create_dataset("dataset_tags") + second_dataset_id = create_dataset("dataset_tags_second") + + list_tags_res = rest_client.get(f"/datasets/{dataset_id}/tags") + assert list_tags_res.status_code == 200 + list_tags_payload = list_tags_res.json() + # Known env/runtime behavior: this route can return 102 when retriever tag + # backend is unavailable for an empty dataset. Keep route-contract coverage. + assert list_tags_payload["code"] in (0, 102), list_tags_payload + + aggregate_res = rest_client.get( + "/datasets/tags/aggregation", + params={"dataset_ids": f"{dataset_id},{second_dataset_id}"}, + ) + assert aggregate_res.status_code == 200 + aggregate_payload = aggregate_res.json() + assert aggregate_payload["code"] in (0, 102), aggregate_payload + + empty_aggregate_res = rest_client.get("/datasets/tags/aggregation") + assert empty_aggregate_res.status_code == 200 + empty_aggregate_payload = empty_aggregate_res.json() + assert empty_aggregate_payload["code"] != 0, empty_aggregate_payload + + +@pytest.mark.p2 +def test_dataset_tags_delete_and_rename_validation(rest_client, create_dataset): + dataset_id = create_dataset("dataset_tag_mutation") + + delete_missing_tags = rest_client.delete(f"/datasets/{dataset_id}/tags", json={}) + assert delete_missing_tags.status_code == 200 + delete_missing_tags_payload = delete_missing_tags.json() + assert delete_missing_tags_payload["code"] != 0, delete_missing_tags_payload + + delete_invalid_tags_type = rest_client.delete(f"/datasets/{dataset_id}/tags", json={"tags": "wrong"}) + assert delete_invalid_tags_type.status_code == 200 + delete_invalid_tags_type_payload = delete_invalid_tags_type.json() + assert delete_invalid_tags_type_payload["code"] != 0, delete_invalid_tags_type_payload + + rename_empty = rest_client.put( + f"/datasets/{dataset_id}/tags", + json={"from_tag": "", "to_tag": ""}, + ) + assert rename_empty.status_code == 200 + rename_empty_payload = rename_empty.json() + assert rename_empty_payload["code"] != 0, rename_empty_payload + + rename_invalid_dataset = rest_client.put( + "/datasets/invalid_id/tags", + json={"from_tag": "old", "to_tag": "new"}, + ) + assert rename_invalid_dataset.status_code == 200 + rename_invalid_dataset_payload = rename_invalid_dataset.json() + assert rename_invalid_dataset_payload["code"] != 0, rename_invalid_dataset_payload + + +@pytest.mark.p2 +def test_dataset_flattened_metadata(rest_client, create_dataset): + first_dataset_id = create_dataset("flattened_meta_1") + second_dataset_id = create_dataset("flattened_meta_2") + + flattened_res = rest_client.get( + "/datasets/metadata/flattened", + params={"dataset_ids": f"{first_dataset_id},{second_dataset_id}"}, + ) + assert flattened_res.status_code == 200 + flattened_payload = flattened_res.json() + assert flattened_payload["code"] == 0, flattened_payload + + empty_ids_res = rest_client.get("/datasets/metadata/flattened") + assert empty_ids_res.status_code == 200 + empty_ids_payload = empty_ids_res.json() + assert empty_ids_payload["code"] != 0, empty_ids_payload + + invalid_dataset_res = rest_client.get( + "/datasets/metadata/flattened", + params={"dataset_ids": "invalid_id"}, + ) + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload + + +@pytest.mark.p2 +def test_dataset_ingestion_summary_and_logs(rest_client, create_dataset): + dataset_id = create_dataset("dataset_ingestions") + + summary_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/summary") + assert summary_res.status_code == 200 + summary_payload = summary_res.json() + assert summary_payload["code"] == 0, summary_payload + assert "doc_num" in summary_payload["data"], summary_payload + assert "chunk_num" in summary_payload["data"], summary_payload + assert "token_num" in summary_payload["data"], summary_payload + assert "status" in summary_payload["data"], summary_payload + + logs_res = rest_client.get( + f"/datasets/{dataset_id}/ingestions", + params={"page": 1, "page_size": 10}, + ) + assert logs_res.status_code == 200 + logs_payload = logs_res.json() + assert logs_payload["code"] == 0, logs_payload + assert "total" in logs_payload["data"], logs_payload + assert "logs" in logs_payload["data"], logs_payload + + not_found_log_res = rest_client.get(f"/datasets/{dataset_id}/ingestions/nonexistent_log") + assert not_found_log_res.status_code == 200 + not_found_log_payload = not_found_log_res.json() + assert not_found_log_payload["code"] != 0, not_found_log_payload + + +@pytest.mark.p2 +def test_dataset_ingestion_invalid_dataset(rest_client): + summary_res = rest_client.get("/datasets/invalid_id/ingestions/summary") + assert summary_res.status_code == 200 + summary_payload = summary_res.json() + assert summary_payload["code"] != 0, summary_payload + + logs_res = rest_client.get("/datasets/invalid_id/ingestions") + assert logs_res.status_code == 200 + logs_payload = logs_res.json() + assert logs_payload["code"] != 0, logs_payload + + log_res = rest_client.get("/datasets/invalid_id/ingestions/some_log_id") + assert log_res.status_code == 200 + log_payload = log_res.json() + assert log_payload["code"] != 0, log_payload + + +@pytest.mark.p2 +def test_dataset_index_endpoints(rest_client, create_dataset): + dataset_id = create_dataset("dataset_index_endpoints") + + run_invalid_type = rest_client.post( + f"/datasets/{dataset_id}/index", + params={"type": "invalid_type"}, + ) + assert run_invalid_type.status_code == 200 + run_invalid_type_payload = run_invalid_type.json() + assert run_invalid_type_payload["code"] != 0, run_invalid_type_payload + + run_no_docs = rest_client.post( + f"/datasets/{dataset_id}/index", + params={"type": "graph"}, + ) + assert run_no_docs.status_code == 200 + run_no_docs_payload = run_no_docs.json() + assert run_no_docs_payload["code"] == 102, run_no_docs_payload + + trace_no_task = rest_client.get( + f"/datasets/{dataset_id}/index", + params={"type": "graph"}, + ) + assert trace_no_task.status_code == 200 + trace_no_task_payload = trace_no_task.json() + assert trace_no_task_payload["code"] == 0, trace_no_task_payload + assert trace_no_task_payload["data"] == {}, trace_no_task_payload + + delete_graph = rest_client.delete(f"/datasets/{dataset_id}/graph") + assert delete_graph.status_code == 200 + delete_graph_payload = delete_graph.json() + assert delete_graph_payload["code"] == 0, delete_graph_payload + + delete_invalid_type = rest_client.delete(f"/datasets/{dataset_id}/invalid_type") + assert delete_invalid_type.status_code == 200 + delete_invalid_type_payload = delete_invalid_type.json() + assert delete_invalid_type_payload["code"] != 0, delete_invalid_type_payload + + +@pytest.mark.p2 +@pytest.mark.parametrize("index_type", ["graph", "raptor", "mindmap"]) +def test_dataset_index_run_with_document_creates_task(rest_client, create_document, index_type): + dataset_id, _ = create_document("dataset_index_graph_source.txt") + run_graph = rest_client.post( + f"/datasets/{dataset_id}/index", + params={"type": index_type}, + ) + assert run_graph.status_code == 200 + run_graph_payload = run_graph.json() + assert run_graph_payload["code"] == 0, run_graph_payload + assert run_graph_payload["data"].get("task_id"), run_graph_payload + + +@pytest.mark.p2 +def test_dataset_embedding_endpoints(rest_client, create_dataset): + dataset_id = create_dataset("dataset_embedding_endpoints") + + run_no_docs_res = rest_client.post(f"/datasets/{dataset_id}/embedding") + assert run_no_docs_res.status_code == 200 + run_no_docs_payload = run_no_docs_res.json() + assert run_no_docs_payload["code"] == 102, run_no_docs_payload + + missing_embd_id_res = rest_client.post(f"/datasets/{dataset_id}/embedding/check", json={}) + assert missing_embd_id_res.status_code == 200 + missing_embd_id_payload = missing_embd_id_res.json() + assert missing_embd_id_payload["code"] != 0, missing_embd_id_payload + + invalid_dataset_res = rest_client.post("/datasets/invalid_id/embedding") + assert invalid_dataset_res.status_code == 200 + invalid_dataset_payload = invalid_dataset_res.json() + assert invalid_dataset_payload["code"] != 0, invalid_dataset_payload diff --git a/test/testcases/restful_api/test_dify_retrieval_routes_unit.py b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py new file mode 100644 index 00000000000..3187846a7e2 --- /dev/null +++ b/test/testcases/restful_api/test_dify_retrieval_routes_unit.py @@ -0,0 +1,425 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import inspect +import sys +from copy import deepcopy +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _DummyKB: + def __init__(self, tenant_id="tenant-1", embd_id="embd-1", tenant_embd_id=1): + self.tenant_id = tenant_id + self.embd_id = embd_id + self.tenant_embd_id = tenant_embd_id + + +class _DummyRetriever: + async def retrieval(self, *_args, **_kwargs): + return { + "chunks": [ + {"doc_id": "doc-1", "content_with_weight": "chunk-content", "similarity": 0.8, "docnm_kwd": "doc-title", "vector": [0.1]} + ] + } + + def retrieval_by_children(self, chunks, _tenant_ids): + return chunks + + +def _run(coro): + return asyncio.run(coro) + + +def _load_dify_retrieval_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + deepdoc_pkg = ModuleType("deepdoc") + deepdoc_parser_pkg = ModuleType("deepdoc.parser") + deepdoc_parser_pkg.__path__ = [] + + class _StubPdfParser: + pass + + class _StubExcelParser: + pass + + class _StubDocxParser: + pass + + deepdoc_parser_pkg.PdfParser = _StubPdfParser + deepdoc_parser_pkg.ExcelParser = _StubExcelParser + deepdoc_parser_pkg.DocxParser = _StubDocxParser + deepdoc_pkg.parser = deepdoc_parser_pkg + monkeypatch.setitem(sys.modules, "deepdoc", deepdoc_pkg) + monkeypatch.setitem(sys.modules, "deepdoc.parser", deepdoc_parser_pkg) + + deepdoc_excel_module = ModuleType("deepdoc.parser.excel_parser") + deepdoc_excel_module.RAGFlowExcelParser = _StubExcelParser + monkeypatch.setitem(sys.modules, "deepdoc.parser.excel_parser", deepdoc_excel_module) + + deepdoc_parser_utils = ModuleType("deepdoc.parser.utils") + deepdoc_parser_utils.get_text = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "deepdoc.parser.utils", deepdoc_parser_utils) + monkeypatch.setitem(sys.modules, "xgboost", ModuleType("xgboost")) + + tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") + + class _MockModelConfig: + def __init__(self, tenant_id, model_name): + self.tenant_id = tenant_id + self.llm_name = model_name + self.llm_factory = "Builtin" + self.api_key = "fake-api-key" + self.api_base = "https://api.example.com" + self.model_type = "chat" + self.max_tokens = 8192 + self.used_tokens = 0 + self.status = 1 + self.id = 1 + + def to_dict(self): + return { + "tenant_id": self.tenant_id, + "llm_name": self.llm_name, + "llm_factory": self.llm_factory, + "api_key": self.api_key, + "api_base": self.api_base, + "model_type": self.model_type, + "max_tokens": self.max_tokens, + "used_tokens": self.used_tokens, + "status": self.status, + "id": self.id, + } + + class _StubTenantService: + @staticmethod + def get_by_id(tenant_id): + return True, SimpleNamespace( + id=tenant_id, + llm_id="chat-model", + embd_id="embd-model", + asr_id="asr-model", + img2txt_id="img2txt-model", + rerank_id="rerank-model", + tts_id="tts-model", + ) + + class _StubTenantLLMService: + @staticmethod + def get_api_key(tenant_id, model_name): + return _MockModelConfig(tenant_id, model_name) + + @staticmethod + def split_model_name_and_factory(model_name): + if "@" in model_name: + parts = model_name.split("@") + return parts[0], parts[1] + return model_name, None + + tenant_llm_service_mod.TenantService = _StubTenantService + tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService + + class _StubLLMFactoriesService: + pass + + tenant_llm_service_mod.LLMFactoriesService = _StubLLMFactoriesService + monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + + llm_service_mod = ModuleType("api.db.services.llm_service") + + class _StubLLM: + def __init__(self, llm_name): + self.llm_name = llm_name + self.is_tools = False + + class _StubLLMBundle: + def __init__(self, tenant_id: str, model_config: dict, lang="Chinese", **kwargs): + self.tenant_id = tenant_id + self.model_config = model_config + self.lang = lang + + def encode(self, texts: list): + import numpy as np + + return [np.array([0.1, 0.2, 0.3]) for _ in texts], len(texts) * 10 + + llm_service_mod.LLMService = SimpleNamespace(query=lambda llm_name: [_StubLLM(llm_name)] if llm_name else []) + llm_service_mod.LLMBundle = _StubLLMBundle + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) + + tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") + + class _MockModelConfig2: + def __init__(self, tenant_id, model_name): + self.tenant_id = tenant_id + self.llm_name = model_name + self.llm_factory = "Builtin" + self.api_key = "fake-api-key" + self.api_base = "https://api.example.com" + self.model_type = "chat" + self.max_tokens = 8192 + self.used_tokens = 0 + self.status = 1 + self.id = 1 + + def to_dict(self): + return { + "tenant_id": self.tenant_id, + "llm_name": self.llm_name, + "llm_factory": self.llm_factory, + "api_key": self.api_key, + "api_base": self.api_base, + "model_type": self.model_type, + "max_tokens": self.max_tokens, + "used_tokens": self.used_tokens, + "status": self.status, + "id": self.id, + } + + def _get_model_config_by_id(tenant_model_id: int, allowed_tenant_ids=None, requester_tenant_id=None) -> dict: + mock_tenant_id = "tenant-1" + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") + return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() + + def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): + if not model_name: + raise Exception("Model Name is required") + return _MockModelConfig2(tenant_id, model_name).to_dict() + + def _get_tenant_default_model_by_type(tenant_id: str, model_type): + return _MockModelConfig2(tenant_id, "chat-model").to_dict() + + tenant_model_service_mod.get_model_config_by_id = _get_model_config_by_id + tenant_model_service_mod.get_model_config_by_type_and_name = _get_model_config_by_type_and_name + tenant_model_service_mod.get_tenant_default_model_by_type = _get_tenant_default_model_by_type + monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) + + module_name = "test_dify_retrieval_routes_unit_module" + module_path = repo_root / "api" / "apps" / "sdk" / "dify_retrieval.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +def _set_request_json(monkeypatch, module, payload): + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(deepcopy(payload))) + + +@pytest.mark.p2 +def test_retrieval_success_with_metadata_and_kg(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + _set_request_json( + monkeypatch, + module, + { + "knowledge_id": "kb-1", + "query": "hello", + "use_kg": True, + "retrieval_setting": {"score_threshold": 0.1, "top_k": 3}, + "metadata_condition": {"conditions": [{"name": "author", "comparison_operator": "is", "value": "alice"}], "logic": "and"}, + }, + ) + + monkeypatch.setattr(module, "jsonify", lambda payload: payload) + monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: [{"doc_id": "doc-1"}]) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module, "convert_conditions", lambda cond: cond.get("conditions", [])) + monkeypatch.setattr(module, "meta_filter", lambda *_args, **_kwargs: []) + + retriever = _DummyRetriever() + monkeypatch.setattr(module.settings, "retriever", retriever) + + class _DummyKgRetriever: + async def retrieval(self, *_args, **_kwargs): + return { + "doc_id": "doc-2", + "content_with_weight": "kg-content", + "similarity": 0.9, + "docnm_kwd": "kg-title", + } + + monkeypatch.setattr(module.settings, "kg_retriever", _DummyKgRetriever()) + monkeypatch.setattr(module.DocumentService, "get_by_id", lambda doc_id: (True, SimpleNamespace(meta_fields={"origin": f"meta-{doc_id}"}))) + monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) + + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert "records" in res, res + assert len(res["records"]) == 2, res + top = res["records"][0] + assert top["title"] == "kg-title", res + assert top["metadata"]["doc_id"] == "doc-2", res + assert "score" in top, res + + +@pytest.mark.p2 +def test_retrieval_kb_not_found(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-missing", "query": "hello"}) + monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (False, None)) + + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.NOT_FOUND, res + assert "Knowledgebase not found" in res["message"], res + + +@pytest.mark.p2 +def test_retrieval_not_found_exception_mapping(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"}) + monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) + + class _BrokenRetriever: + async def retrieval(self, *_args, **_kwargs): + raise RuntimeError("chunk_not_found_error") + + monkeypatch.setattr(module.settings, "retriever", _BrokenRetriever()) + + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.NOT_FOUND, res + assert "No chunk found" in res["message"], res + + +@pytest.mark.p2 +def test_retrieval_generic_exception_mapping(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1", "query": "hello"}) + monkeypatch.setattr(module.DocMetadataService, "get_flatted_meta_by_kbs", lambda _kbs: []) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, _DummyKB())) + monkeypatch.setattr(module, "label_question", lambda *_args, **_kwargs: []) + + class _BrokenRetriever: + async def retrieval(self, *_args, **_kwargs): + raise RuntimeError("boom") + + monkeypatch.setattr(module.settings, "retriever", _BrokenRetriever()) + + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.SERVER_ERROR, res + assert "boom" in res["message"], res + + +@pytest.mark.p2 +def test_read_retrieval_request_from_get_args(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + method="GET", + args={ + "knowledge_id": "kb-1", + "query": "hello", + "use_kg": "true", + "top_k": "12", + "score_threshold": "0.66", + }, + ), + ) + + req = _run(module._read_retrieval_request()) + assert req["knowledge_id"] == "kb-1", req + assert req["query"] == "hello", req + assert req["use_kg"] is True, req + assert req["retrieval_setting"]["top_k"] == 12, req + assert req["retrieval_setting"]["score_threshold"] == 0.66, req + + +@pytest.mark.p2 +def test_read_retrieval_request_from_post_json(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + payload = {"knowledge_id": "kb-1", "query": "hello"} + monkeypatch.setattr(module, "request", SimpleNamespace(method="POST", args={})) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload)) + + req = _run(module._read_retrieval_request()) + assert req == payload, req + + +@pytest.mark.p2 +def test_retrieval_argument_error_messages(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + + _set_request_json( + monkeypatch, + module, + { + "knowledge_id": "kb-1", + "query": "hello", + "retrieval_setting": {"top_k": "not-int", "score_threshold": "not-float"}, + }, + ) + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + assert "invalid or malformed arguments:" in res["message"], res + + _set_request_json(monkeypatch, module, {}) + res_missing = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_missing["code"] == module.RetCode.ARGUMENT_ERROR, res_missing + assert "required arguments are missing:" in res_missing["message"], res_missing + + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1"}) + res_missing_query = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_missing_query["code"] == module.RetCode.ARGUMENT_ERROR, res_missing_query + assert "query" in res_missing_query["message"], res_missing_query + + _set_request_json( + monkeypatch, + module, + {"knowledge_id": "kb-1", "query": "hello", "retrieval_setting": "bad-type"}, + ) + res_wrong_type = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_wrong_type["code"] == module.RetCode.ARGUMENT_ERROR, res_wrong_type + assert "retrieval_setting must be an object" in res_wrong_type["message"], res_wrong_type diff --git a/test/testcases/restful_api/test_document_raw_routes.py b/test/testcases/restful_api/test_document_raw_routes.py new file mode 100644 index 00000000000..07f65230dff --- /dev/null +++ b/test/testcases/restful_api/test_document_raw_routes.py @@ -0,0 +1,43 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p2 +def test_document_image_invalid_id_contract(rest_client_noauth): + res = rest_client_noauth.get("/documents/images/not-a-valid-image-id") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert payload["message"] == "Image not found.", payload + + +@pytest.mark.p2 +def test_document_artifact_requires_auth(rest_client_noauth): + res = rest_client_noauth.get("/documents/artifact/not-an-artifact.txt") + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, payload + + +@pytest.mark.p2 +def test_document_artifact_rejects_unsafe_filename(rest_client): + res = rest_client.get("/documents/artifact/not-an-artifact.exe") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert payload["message"] == "Invalid file type.", payload diff --git a/test/testcases/restful_api/test_documents.py b/test/testcases/restful_api/test_documents.py new file mode 100644 index 00000000000..59575fc0fca --- /dev/null +++ b/test/testcases/restful_api/test_documents.py @@ -0,0 +1,122 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from utils.file_utils import create_txt_file + + +@pytest.mark.p1 +def test_documents_upload_and_list(rest_client, create_dataset, tmp_path): + dataset_id = create_dataset("dataset_upload_list") + fp = create_txt_file(tmp_path / "upload_and_list.txt") + with fp.open("rb") as file_obj: + res = rest_client.post( + f"/datasets/{dataset_id}/documents", + files=[("file", (fp.name, file_obj))], + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"][0]["dataset_id"] == dataset_id, payload + + list_res = rest_client.get(f"/datasets/{dataset_id}/documents") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert list_payload["data"]["total"] >= 1, list_payload + assert any(doc["name"] == fp.name for doc in list_payload["data"]["docs"]), list_payload + + +@pytest.mark.p2 +def test_documents_upload_missing_file(rest_client, create_dataset): + dataset_id = create_dataset("dataset_upload_missing") + res = rest_client.post(f"/datasets/{dataset_id}/documents") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert payload["message"] == "No file part!", payload + + +@pytest.mark.p2 +def test_documents_update_patch_and_delete(rest_client, create_document): + dataset_id, document_id = create_document("update_target.txt") + + patch_res = rest_client.patch( + f"/datasets/{dataset_id}/documents/{document_id}", + json={"name": "updated_target.txt"}, + ) + assert patch_res.status_code == 200 + patch_payload = patch_res.json() + assert patch_payload["code"] == 0, patch_payload + assert patch_payload["data"]["name"] == "updated_target.txt", patch_payload + + delete_res = rest_client.delete( + f"/datasets/{dataset_id}/documents", + json={"ids": [document_id]}, + ) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + assert delete_payload["data"]["deleted"] == 1, delete_payload + + list_res = rest_client.get(f"/datasets/{dataset_id}/documents") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert all(doc["id"] != document_id for doc in list_payload["data"]["docs"]), list_payload + + +@pytest.mark.p2 +def test_documents_parse_and_stop(rest_client, create_document): + dataset_id, document_id = create_document("parse_target.txt") + + parse_res = rest_client.post( + f"/datasets/{dataset_id}/documents/parse", + json={"document_ids": [document_id]}, + ) + assert parse_res.status_code == 200 + parse_payload = parse_res.json() + assert parse_payload["code"] == 0, parse_payload + + stop_res = rest_client.post( + f"/datasets/{dataset_id}/documents/stop", + json={"document_ids": [document_id]}, + ) + assert stop_res.status_code == 200 + stop_payload = stop_res.json() + # Depending on timing this can be immediate stop success or "already completed". + assert stop_payload["code"] in (0, 102), stop_payload + if stop_payload["code"] == 102: + assert "already completed" in stop_payload["message"], stop_payload + + +@pytest.mark.p2 +def test_documents_metadata_update_path(rest_client, create_document): + dataset_id, document_id = create_document("metadata_target.txt") + + res = rest_client.patch( + f"/datasets/{dataset_id}/documents/metadatas", + json={ + "selector": {"document_ids": [document_id]}, + "updates": [{"key": "author", "value": "qa"}], + "deletes": [], + }, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"]["matched_docs"] == 1, payload + assert payload["data"]["updated"] >= 1, payload diff --git a/test/testcases/restful_api/test_file_routes_unit.py b/test/testcases/restful_api/test_file_routes_unit.py new file mode 100644 index 00000000000..39246e97a08 --- /dev/null +++ b/test/testcases/restful_api/test_file_routes_unit.py @@ -0,0 +1,632 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import sys +from enum import Enum +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _DummyFiles(dict): + def __init__(self, file_objs=None): + super().__init__() + self._file_objs = list(file_objs or []) + if file_objs is not None: + self["file"] = self._file_objs + + def getlist(self, key): + if key == "file": + return list(self._file_objs) + return [] + + +class _DummyUploadFile: + def __init__(self, filename, blob=b"blob"): + self.filename = filename + self._blob = blob + + def read(self): + return self._blob + + +class _DummyRequest: + def __init__(self, *, content_type="", form=None, files=None, args=None): + self.content_type = content_type + self.form = _AwaitableValue(form or {}) + self.files = _AwaitableValue(files if files is not None else _DummyFiles()) + self.args = args or {} + + +class _DummyResponse: + def __init__(self, data): + self.data = data + self.headers = {} + + +def _run(coro): + return asyncio.run(coro) + + +def _load_file_api_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + quart_mod = ModuleType("quart") + quart_mod.request = _DummyRequest() + + async def _make_response(data): + return _DummyResponse(data) + + quart_mod.make_response = _make_response + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_pkg = ModuleType("api.apps") + apps_pkg.__path__ = [str(repo_root / "api" / "apps")] + apps_pkg.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", apps_pkg) + api_pkg.apps = apps_pkg + + services_pkg = ModuleType("api.apps.services") + services_pkg.__path__ = [str(repo_root / "api" / "apps" / "services")] + monkeypatch.setitem(sys.modules, "api.apps.services", services_pkg) + apps_pkg.services = services_pkg + + file_api_service_mod = ModuleType("api.apps.services.file_api_service") + + async def _upload_file(_tenant_id, _pf_id, _file_objs): + return True, [{"id": "f1"}] + + async def _create_folder(_tenant_id, _name, _parent_id=None, _file_type=None): + return True, {"id": "folder1"} + + async def _delete_files(_tenant_id, _ids): + return True, True + + async def _move_files(_tenant_id, _src_file_ids, _dest_file_id=None, _new_name=None): + return True, True + + file_api_service_mod.upload_file = _upload_file + file_api_service_mod.create_folder = _create_folder + file_api_service_mod.list_files = lambda _tenant_id, _args: (True, {"files": [], "total": 0}) + file_api_service_mod.delete_files = _delete_files + file_api_service_mod.move_files = _move_files + file_api_service_mod.get_file_content = lambda _tenant_id, _file_id: ( + True, + SimpleNamespace(parent_id="bucket1", location="path1", name="doc.txt", type="doc"), + ) + file_api_service_mod.get_parent_folder = lambda _file_id, user_id=None: (True, {"parent_folder": {"id": "parent1"}}) + file_api_service_mod.get_all_parent_folders = lambda _file_id, user_id=None: (True, {"parent_folders": [{"id": "root"}]}) + monkeypatch.setitem(sys.modules, "api.apps.services.file_api_service", file_api_service_mod) + services_pkg.file_api_service = file_api_service_mod + + db_pkg = ModuleType("api.db") + db_pkg.__path__ = [] + + class _FileType(Enum): + DOC = "doc" + VISUAL = "visual" + + db_pkg.FileType = _FileType + monkeypatch.setitem(sys.modules, "api.db", db_pkg) + api_pkg.db = db_pkg + + file2doc_mod = ModuleType("api.db.services.file2document_service") + file2doc_mod.File2DocumentService = SimpleNamespace(get_storage_address=lambda **_kwargs: ("bucket2", "path2")) + monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2doc_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.add_tenant_id_to_kwargs = lambda func: func + api_utils_mod.get_error_argument_result = lambda message: {"code": 400, "data": None, "message": message} + api_utils_mod.get_error_data_result = lambda message: {"code": 500, "data": None, "message": message} + api_utils_mod.get_result = lambda data=None: {"code": 0, "data": data, "message": ""} + api_utils_mod.get_json_result = lambda code=0, message="success", data=None: {"code": code, "data": data, "message": message} + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + validation_mod = ModuleType("api.utils.validation_utils") + validation_mod.CreateFolderReq = object + validation_mod.DeleteFileReq = object + validation_mod.ListFileReq = object + validation_mod.MoveFileReq = object + + async def _validate_json_request(_request, _schema): + return {}, None + + validation_mod.validate_and_parse_json_request = _validate_json_request + validation_mod.validate_and_parse_request_args = lambda _request, _schema: ({}, None) + monkeypatch.setitem(sys.modules, "api.utils.validation_utils", validation_mod) + + web_utils_mod = ModuleType("api.utils.web_utils") + web_utils_mod.CONTENT_TYPE_MAP = {"txt": "text/plain"} + web_utils_mod.apply_safe_file_response_headers = lambda response, content_type, ext: response.headers.update({"content_type": content_type, "ext": ext}) + monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + common_pkg.settings = SimpleNamespace( + STORAGE_IMPL=SimpleNamespace( + get=lambda *_args, **_kwargs: b"blob", + ) + ) + monkeypatch.setitem(sys.modules, "common", common_pkg) + + misc_utils_mod = ModuleType("common.misc_utils") + + async def thread_pool_exec(func, *args, **kwargs): + return func(*args, **kwargs) + + misc_utils_mod.thread_pool_exec = thread_pool_exec + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + module_path = repo_root / "api" / "apps" / "restful_apis" / "file_api.py" + spec = importlib.util.spec_from_file_location("api.apps.restful_apis.file_api", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, "api.apps.restful_apis.file_api", module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_create_or_upload_multipart_requires_file(monkeypatch): + module = _load_file_api_module(monkeypatch) + monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={}, files=_DummyFiles())) + + res = _run(module.create_or_upload("tenant1")) + assert res["code"] == 400 + assert res["message"] == "No file part!" + + +@pytest.mark.p2 +def test_create_or_upload_uploads_via_new_service(monkeypatch): + module = _load_file_api_module(monkeypatch) + files = _DummyFiles([_DummyUploadFile("a.txt")]) + monkeypatch.setattr(module, "request", _DummyRequest(content_type="multipart/form-data", form={"parent_id": "pf1"}, files=files)) + + seen = {} + + async def _upload_file(tenant_id, pf_id, file_objs): + seen["args"] = (tenant_id, pf_id, [f.filename for f in file_objs]) + return True, [{"id": "f1"}] + + monkeypatch.setattr(module.file_api_service, "upload_file", _upload_file) + res = _run(module.create_or_upload("tenant1")) + + assert seen["args"] == ("tenant1", "pf1", ["a.txt"]) + assert res["code"] == 0 + assert res["data"] == [{"id": "f1"}] + + +@pytest.mark.p2 +def test_create_or_upload_creates_folder_from_json(monkeypatch): + module = _load_file_api_module(monkeypatch) + monkeypatch.setattr(module, "request", _DummyRequest(content_type="application/json")) + + async def _validate(_request, _schema): + return {"name": "folder-a", "parent_id": "pf1", "type": "folder"}, None + + async def _create_folder(tenant_id, name, parent_id=None, file_type=None): + return True, {"tenant_id": tenant_id, "name": name, "parent_id": parent_id, "type": file_type} + + monkeypatch.setattr(module, "validate_and_parse_json_request", _validate) + monkeypatch.setattr(module.file_api_service, "create_folder", _create_folder) + + res = _run(module.create_or_upload("tenant1")) + assert res["code"] == 0 + assert res["data"]["tenant_id"] == "tenant1" + assert res["data"]["name"] == "folder-a" + + +@pytest.mark.p2 +def test_list_files_validation_error(monkeypatch): + module = _load_file_api_module(monkeypatch) + monkeypatch.setattr(module, "validate_and_parse_request_args", lambda _request, _schema: (None, "bad args")) + + res = _run(module.list_files("tenant1")) + assert res["code"] == 400 + assert res["message"] == "bad args" + + +@pytest.mark.p2 +def test_move_uses_new_payload_shape(monkeypatch): + module = _load_file_api_module(monkeypatch) + + async def _validate(_request, _schema): + return {"src_file_ids": ["f1"], "dest_file_id": "pf2"}, None + + seen = {} + + async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None): + seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name) + return True, True + + monkeypatch.setattr(module, "validate_and_parse_json_request", _validate) + monkeypatch.setattr(module.file_api_service, "move_files", _move_files) + + res = _run(module.move("tenant1")) + assert seen["args"] == ("tenant1", ["f1"], "pf2", None) + assert res["code"] == 0 + assert res["data"] is True + + +@pytest.mark.p2 +def test_rename_via_move_route(monkeypatch): + module = _load_file_api_module(monkeypatch) + + async def _validate(_request, _schema): + return {"src_file_ids": ["file1"], "new_name": "renamed.txt"}, None + + seen = {} + + async def _move_files(tenant_id, src_file_ids, dest_file_id=None, new_name=None): + seen["args"] = (tenant_id, src_file_ids, dest_file_id, new_name) + return True, True + + monkeypatch.setattr(module, "validate_and_parse_json_request", _validate) + monkeypatch.setattr(module.file_api_service, "move_files", _move_files) + + res = _run(module.move("tenant1")) + assert seen["args"] == ("tenant1", ["file1"], None, "renamed.txt") + assert res["code"] == 0 + assert res["data"] is True + + +@pytest.mark.p2 +def test_download_falls_back_to_document_storage(monkeypatch): + module = _load_file_api_module(monkeypatch) + storage_calls = [] + + def _get(bucket, location): + storage_calls.append((bucket, location)) + return b"" if len(storage_calls) == 1 else b"fallback-blob" + + monkeypatch.setattr(module.settings, "STORAGE_IMPL", SimpleNamespace(get=_get)) + res = _run(module.download("tenant1", "file1")) + + assert storage_calls == [("bucket1", "path1"), ("bucket2", "path2")] + assert res.data == b"fallback-blob" + assert res.headers["content_type"] == "text/plain" + assert res.headers["ext"] == "txt" + + +@pytest.mark.p2 +def test_parent_and_ancestors_use_new_routes(monkeypatch): + module = _load_file_api_module(monkeypatch) + + parent_res = _run(module.parent_folder("tenant1", "file1")) + ancestors_res = _run(module.ancestors("tenant1", "file1")) + + assert parent_res["code"] == 0 + assert parent_res["data"]["parent_folder"]["id"] == "parent1" + assert ancestors_res["code"] == 0 + assert ancestors_res["data"]["parent_folders"][0]["id"] == "root" + +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import functools +from copy import deepcopy + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _DummyFile: + def __init__(self, file_id, file_type, *, name="file.txt", location="loc", size=1): + self.id = file_id + self.type = file_type + self.name = name + self.location = location + self.size = size + + +class _FalsyFile(_DummyFile): + def __bool__(self): + return False + + +def _run(coro): + return asyncio.run(coro) + + +def _set_request_json(monkeypatch, module, payload_state): + async def _req_json(): + return deepcopy(payload_state) + + monkeypatch.setattr(module, "get_request_json", _req_json) + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +def _load_file2document_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = SimpleNamespace(id="user-1") + apps_mod.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + api_pkg.apps = apps_mod + + db_pkg = ModuleType("api.db") + db_pkg.__path__ = [] + + class _FileType(Enum): + FOLDER = "folder" + DOC = "doc" + + db_pkg.FileType = _FileType + monkeypatch.setitem(sys.modules, "api.db", db_pkg) + api_pkg.db = db_pkg + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + common_pkg = ModuleType("api.common") + common_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.common", common_pkg) + + permission_mod = ModuleType("api.common.check_team_permission") + permission_mod.check_file_team_permission = lambda *_args, **_kwargs: True + permission_mod.check_kb_team_permission = lambda *_args, **_kwargs: True + monkeypatch.setitem(sys.modules, "api.common.check_team_permission", permission_mod) + common_pkg.check_team_permission = permission_mod + + file2document_mod = ModuleType("api.db.services.file2document_service") + + class _StubFile2DocumentService: + @staticmethod + def get_by_file_id(_file_id): + return [] + + @staticmethod + def delete_by_file_id(*_args, **_kwargs): + return None + + @staticmethod + def insert(_payload): + return SimpleNamespace(to_json=lambda: {}) + + file2document_mod.File2DocumentService = _StubFile2DocumentService + monkeypatch.setitem(sys.modules, "api.db.services.file2document_service", file2document_mod) + services_pkg.file2document_service = file2document_mod + + file_service_mod = ModuleType("api.db.services.file_service") + + class _StubFileService: + @staticmethod + def get_by_ids(_file_ids): + return [] + + @staticmethod + def get_all_innermost_file_ids(_file_id, _acc): + return [] + + @staticmethod + def get_by_id(_file_id): + return True, _DummyFile(_file_id, _FileType.DOC.value) + + @staticmethod + def get_parser(_file_type, _file_name, parser_id): + return parser_id + + file_service_mod.FileService = _StubFileService + monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod) + services_pkg.file_service = file_service_mod + + kb_service_mod = ModuleType("api.db.services.knowledgebase_service") + + class _StubKnowledgebaseService: + @staticmethod + def get_by_id(_kb_id): + return False, None + + kb_service_mod.KnowledgebaseService = _StubKnowledgebaseService + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod) + services_pkg.knowledgebase_service = kb_service_mod + + document_service_mod = ModuleType("api.db.services.document_service") + + class _StubDocumentService: + @staticmethod + def get_by_id(doc_id): + return True, SimpleNamespace(id=doc_id) + + @staticmethod + def get_tenant_id(_doc_id): + return "tenant-1" + + @staticmethod + def remove_document(*_args, **_kwargs): + return True + + @staticmethod + def insert(_payload): + return SimpleNamespace(id="doc-1") + + document_service_mod.DocumentService = _StubDocumentService + monkeypatch.setitem(sys.modules, "api.db.services.document_service", document_service_mod) + services_pkg.document_service = document_service_mod + + api_utils_mod = ModuleType("api.utils.api_utils") + + def get_json_result(data=None, message="", code=0): + return {"code": code, "data": data, "message": message} + + def get_data_error_result(message=""): + return {"code": 102, "data": None, "message": message} + + async def get_request_json(): + return {} + + def server_error_response(err): + return {"code": 500, "data": None, "message": str(err)} + + def validate_request(*_keys): + def _decorator(func): + @functools.wraps(func) + async def _wrapper(*args, **kwargs): + return await func(*args, **kwargs) + + return _wrapper + + return _decorator + + api_utils_mod.get_json_result = get_json_result + api_utils_mod.get_data_error_result = get_data_error_result + api_utils_mod.get_request_json = get_request_json + api_utils_mod.server_error_response = server_error_response + api_utils_mod.validate_request = validate_request + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + misc_utils_mod = ModuleType("common.misc_utils") + misc_utils_mod.get_uuid = lambda: "uuid" + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + constants_mod = ModuleType("common.constants") + + class _RetCode: + ARGUMENT_ERROR = 101 + + constants_mod.RetCode = _RetCode + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + module_name = "test_file2document_routes_unit_module" + module_path = repo_root / "api" / "apps" / "restful_apis" / "file2document_api.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_convert_branch_matrix_unit(monkeypatch): + module = _load_file2document_module(monkeypatch) + req_state = {"kb_ids": ["kb-1"], "file_ids": ["f1"]} + _set_request_json(monkeypatch, module, req_state) + + # Falsy file returns "File not found!" during synchronous validation. + monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_FalsyFile("f1", module.FileType.DOC.value)]) + res = _run(module.convert()) + assert res["code"] == 102 + assert res["message"] == "File not found!" + + # Valid file but invalid kb returns "Can't find this dataset!" during synchronous validation. + monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("f1", module.FileType.DOC.value)]) + res = _run(module.convert()) + assert res["code"] == 102 + assert res["message"] == "Can't find this dataset!" + + kb = SimpleNamespace(id="kb-1", parser_id="naive", pipeline_id="p1", parser_config={}) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _kb_id: (True, kb)) + + # Unauthorized file access is rejected before scheduling background work. + monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: False) + res = _run(module.convert()) + assert res["code"] == 102 + assert res["message"] == "No authorization." + + # Unauthorized dataset access is rejected before scheduling background work. + monkeypatch.setattr(module, "check_file_team_permission", lambda *_args, **_kwargs: True) + monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: False) + res = _run(module.convert()) + assert res["code"] == 102 + assert res["message"] == "No authorization." + + # Valid file and kb schedule background work and return data=True immediately. + monkeypatch.setattr(module, "check_kb_team_permission", lambda *_args, **_kwargs: True) + res = _run(module.convert()) + assert res["code"] == 0 + assert res["data"] is True + + # Folder expansion schedules background work and returns data=True immediately. + req_state["file_ids"] = ["folder-1"] + monkeypatch.setattr(module.FileService, "get_by_ids", lambda _ids: [_DummyFile("folder-1", module.FileType.FOLDER.value, name="folder")]) + monkeypatch.setattr(module.FileService, "get_all_innermost_file_ids", lambda _file_id, _acc: ["inner-1"]) + res = _run(module.convert()) + assert res["code"] == 0 + assert res["data"] is True + + # Exception in file lookup returns 500. + req_state["file_ids"] = ["f1"] + monkeypatch.setattr( + module.FileService, + "get_by_ids", + lambda _ids: (_ for _ in ()).throw(RuntimeError("convert boom")), + ) + res = _run(module.convert()) + assert res["code"] == 500 + assert "convert boom" in res["message"] diff --git a/test/testcases/restful_api/test_langfuse_routes.py b/test/testcases/restful_api/test_langfuse_routes.py new file mode 100644 index 00000000000..deda7fbe3e3 --- /dev/null +++ b/test/testcases/restful_api/test_langfuse_routes.py @@ -0,0 +1,37 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p2 +def test_langfuse_api_key_routes_require_auth(rest_client_noauth): + for method in ("get", "post", "put", "delete"): + requester = getattr(rest_client_noauth, method) + kwargs = {"json": {"secret_key": "s", "public_key": "p", "host": "http://example.com"}} if method in {"post", "put"} else {} + res = requester("/langfuse/api-key", **kwargs) + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, (method, payload) + + +@pytest.mark.p2 +def test_langfuse_api_key_missing_required_fields(rest_client): + res = rest_client.post("/langfuse/api-key", json={"secret_key": "", "public_key": "pub", "host": "http://host"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] in (101, 102), payload + assert "required" in payload["message"].lower() or "missing" in payload["message"].lower(), payload diff --git a/test/testcases/restful_api/test_mcp_routes_unit.py b/test/testcases/restful_api/test_mcp_routes_unit.py new file mode 100644 index 00000000000..ccd628f0fda --- /dev/null +++ b/test/testcases/restful_api/test_mcp_routes_unit.py @@ -0,0 +1,745 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import asyncio +import importlib.util +import inspect +import json +import sys +from functools import wraps +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _Args(dict): + def getlist(self, key): + value = self.get(key, []) + if isinstance(value, list): + return value + return [value] + + +class _Field: + def __init__(self, name): + self.name = name + + def __eq__(self, other): + return (self.name, other) + + +class _DummyMCPServer: + id = _Field("id") + tenant_id = _Field("tenant_id") + + def __init__(self, **kwargs): + self.id = kwargs.get("id", "") + self.name = kwargs.get("name", "") + self.url = kwargs.get("url", "") + self.server_type = kwargs.get("server_type", "sse") + self.tenant_id = kwargs.get("tenant_id", "tenant_1") + self.variables = kwargs.get("variables", {}) + self.headers = kwargs.get("headers", {}) + + def to_dict(self): + return { + "id": self.id, + "name": self.name, + "url": self.url, + "server_type": self.server_type, + "tenant_id": self.tenant_id, + "variables": self.variables, + "headers": self.headers, + } + + +class _DummyMCPServerService: + @staticmethod + def get_servers(*_args, **_kwargs): + return [] + + @staticmethod + def get_or_none(*_args, **_kwargs): + return None + + @staticmethod + def get_by_id(*_args, **_kwargs): + return False, None + + @staticmethod + def get_by_name_and_tenant(*_args, **_kwargs): + return False, None + + @staticmethod + def insert(**_kwargs): + return True + + @staticmethod + def filter_update(*_args, **_kwargs): + return True + + @staticmethod + def delete_by_ids(*_args, **_kwargs): + return True + + +class _DummyTenantService: + @staticmethod + def get_by_id(*_args, **_kwargs): + return True, SimpleNamespace(id="tenant_1") + + +class _DummyTool: + def __init__(self, name): + self._name = name + + def model_dump(self): + return {"name": self._name} + + +class _DummyMCPToolCallSession: + def __init__(self, _mcp_server, _variables): + self._tools = [_DummyTool("tool_a"), _DummyTool("tool_b")] + + def get_tools(self, _timeout): + return self._tools + + def tool_call(self, _name, _arguments, _timeout): + return "ok" + + +def _run(coro): + return asyncio.run(coro) + + +def _set_request_json(monkeypatch, module, payload): + async def _request_json(): + return payload + + monkeypatch.setattr(module, "get_request_json", _request_json) + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +def _load_mcp_api(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + quart_mod = ModuleType("quart") + quart_mod.Response = object + quart_mod.request = SimpleNamespace(args=_Args({})) + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + constants_mod = ModuleType("common.constants") + constants_mod.VALID_MCP_SERVER_TYPES = {"sse", "streamable-http"} + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + apps_mod = ModuleType("api.apps") + apps_mod.current_user = SimpleNamespace(id="tenant_1") + apps_mod.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.MCPServer = _DummyMCPServer + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + mcp_service_mod = ModuleType("api.db.services.mcp_server_service") + mcp_service_mod.MCPServerService = _DummyMCPServerService + monkeypatch.setitem(sys.modules, "api.db.services.mcp_server_service", mcp_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + user_service_mod.TenantService = _DummyTenantService + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + mcp_conn_mod = ModuleType("common.mcp_tool_call_conn") + mcp_conn_mod.MCPToolCallSession = _DummyMCPToolCallSession + mcp_conn_mod.close_multiple_mcp_toolcall_sessions = lambda _sessions: None + monkeypatch.setitem(sys.modules, "common.mcp_tool_call_conn", mcp_conn_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + + async def _default_request_json(): + return {} + + def _get_json_result(code=0, message="success", data=None): + return {"code": code, "message": message, "data": data} + + def _get_data_error_result(code=102, message="Sorry! Data missing!"): + return {"code": code, "message": message} + + def _server_error_response(error): + return {"code": 100, "message": repr(error)} + + async def _get_mcp_tools(*_args, **_kwargs): + return {} + + def _validate_request(*_args, **_kwargs): + def _decorator(func): + @wraps(func) + async def _wrapped(*func_args, **func_kwargs): + if inspect.iscoroutinefunction(func): + return await func(*func_args, **func_kwargs) + return func(*func_args, **func_kwargs) + + return _wrapped + + return _decorator + + api_utils_mod.get_request_json = _default_request_json + api_utils_mod.get_json_result = _get_json_result + api_utils_mod.get_data_error_result = _get_data_error_result + api_utils_mod.server_error_response = _server_error_response + api_utils_mod.validate_request = _validate_request + api_utils_mod.get_mcp_tools = _get_mcp_tools + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + web_utils_mod = ModuleType("api.utils.web_utils") + + def _get_float(data, key, default): + try: + return float(data.get(key, default)) + except (TypeError, ValueError): + return default + + def _safe_json_parse(value): + if isinstance(value, (dict, list)): + return value + if value in (None, ""): + return {} + try: + return json.loads(value) + except (TypeError, ValueError): + return {} + + web_utils_mod.get_float = _get_float + web_utils_mod.safe_json_parse = _safe_json_parse + monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) + + module_name = "test_mcp_api_unit_module" + module_path = repo_root / "api" / "apps" / "restful_apis" / "mcp_api.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_list_mcp_desc_pagination_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + + monkeypatch.setattr( + module, + "request", + SimpleNamespace(args=_Args({"keywords": "k", "page": "2", "page_size": "1", "orderby": "create_time", "desc": "false"})), + ) + _set_request_json(monkeypatch, module, {"mcp_ids": []}) + monkeypatch.setattr(module.MCPServerService, "get_servers", lambda *_args, **_kwargs: [{"id": "a"}, {"id": "b"}]) + + res = _run(module.list_mcp()) + assert res["code"] == 0 + assert res["data"]["total"] == 2 + assert res["data"]["mcp_servers"] == [{"id": "b"}] + + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({}))) + _set_request_json(monkeypatch, module, {"mcp_ids": []}) + + def _raise_list(*_args, **_kwargs): + raise RuntimeError("list explode") + + monkeypatch.setattr(module.MCPServerService, "get_servers", _raise_list) + res = _run(module.list_mcp()) + assert res["code"] == 100 + assert "list explode" in res["message"] + + +@pytest.mark.p2 +def test_detail_not_found_success_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({}))) + + monkeypatch.setattr(module.MCPServerService, "get_or_none", lambda **_kwargs: None) + res = module.detail("mcp-1") + assert res["code"] == 102 + assert "Cannot find MCP server mcp-1 for user tenant_1" in res["message"] + + monkeypatch.setattr( + module.MCPServerService, + "get_or_none", + lambda **_kwargs: _DummyMCPServer(id="mcp-1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1"), + ) + res = module.detail("mcp-1") + assert res["code"] == 0 + assert res["data"]["id"] == "mcp-1" + + def _raise_detail(**_kwargs): + raise RuntimeError("detail explode") + + monkeypatch.setattr(module.MCPServerService, "get_or_none", _raise_detail) + res = module.detail("mcp-1") + assert res["code"] == 100 + assert "detail explode" in res["message"] + + +@pytest.mark.p2 +def test_create_validation_guards(monkeypatch): + module = _load_mcp_api(monkeypatch) + + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None)) + + _set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "invalid"}) + res = _run(module.create.__wrapped__()) + assert "Unsupported MCP server type" in res["message"] + + _set_request_json(monkeypatch, module, {"name": "", "url": "http://a", "server_type": "sse"}) + res = _run(module.create.__wrapped__()) + assert "Invalid MCP name" in res["message"] + + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (True, object())) + _set_request_json(monkeypatch, module, {"name": "srv", "url": "http://a", "server_type": "sse"}) + res = _run(module.create.__wrapped__()) + assert "Duplicated MCP server name" in res["message"] + + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None)) + _set_request_json(monkeypatch, module, {"name": "srv", "url": "", "server_type": "sse"}) + res = _run(module.create.__wrapped__()) + assert "Invalid url" in res["message"] + + +@pytest.mark.p2 +def test_create_service_paths(monkeypatch): + module = _load_mcp_api(monkeypatch) + + base_payload = { + "name": "srv", + "url": "http://server", + "server_type": "sse", + "headers": '{"Authorization": "x"}', + "variables": '{"tools": {"old": 1}, "token": "abc"}', + "timeout": "2.5", + } + + monkeypatch.setattr(module, "get_uuid", lambda: "uuid-create") + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", lambda **_kwargs: (False, None)) + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (False, None)) + res = _run(module.create.__wrapped__()) + assert "Tenant not found" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.TenantService, "get_by_id", lambda *_args, **_kwargs: (True, object())) + + async def _thread_pool_tools_error(_func, _servers, _timeout): + return None, "tools error" + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error) + res = _run(module.create.__wrapped__()) + assert res["code"] == 102 + assert "tools error" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + + async def _thread_pool_ok(_func, servers, _timeout): + return {servers[0].name: [{"name": "tool_a"}, {"invalid": True}]}, None + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok) + monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: False) + res = _run(module.create.__wrapped__()) + assert res["code"] == 102 + assert "Failed to create MCP server" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.MCPServerService, "insert", lambda **_kwargs: True) + res = _run(module.create.__wrapped__()) + assert res["code"] == 0 + assert res["data"]["id"] == "uuid-create" + assert res["data"]["tenant_id"] == "tenant_1" + assert res["data"]["variables"]["tools"] == {"tool_a": {"name": "tool_a"}} + + _set_request_json(monkeypatch, module, dict(base_payload)) + + async def _thread_pool_raises(_func, _servers, _timeout): + raise RuntimeError("create explode") + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises) + res = _run(module.create.__wrapped__()) + assert res["code"] == 100 + assert "create explode" in res["message"] + + +@pytest.mark.p2 +def test_update_validation_guards(monkeypatch): + module = _load_mcp_api(monkeypatch) + + existing = _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="tenant_1", variables={}, headers={}) + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"}) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None)) + res = _run(module.update("mcp-1")) + assert "Cannot find MCP server" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1"}) + monkeypatch.setattr( + module.MCPServerService, + "get_by_id", + lambda _mcp_id: (True, _DummyMCPServer(id="mcp-1", name="srv", url="http://server", server_type="sse", tenant_id="other", variables={}, headers={})), + ) + res = _run(module.update("mcp-1")) + assert "Cannot find MCP server" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "server_type": "invalid"}) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing)) + res = _run(module.update("mcp-1")) + assert "Unsupported MCP server type" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "name": "a" * 256}) + res = _run(module.update("mcp-1")) + assert "Invalid MCP name" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_id": "mcp-1", "url": ""}) + res = _run(module.update("mcp-1")) + assert "Invalid url" in res["message"] + + +@pytest.mark.p2 +def test_update_service_paths(monkeypatch): + module = _load_mcp_api(monkeypatch) + + existing = _DummyMCPServer( + id="mcp-1", + name="srv", + url="http://server", + server_type="sse", + tenant_id="tenant_1", + variables={"tools": {"old": {"enabled": True}}, "token": "abc"}, + headers={"Authorization": "old"}, + ) + updated = _DummyMCPServer( + id="mcp-1", + name="srv-new", + url="http://server-new", + server_type="sse", + tenant_id="tenant_1", + variables={"tools": {"tool_a": {"name": "tool_a"}}}, + headers={"Authorization": "new"}, + ) + + base_payload = { + "mcp_id": "mcp-1", + "name": "srv-new", + "url": "http://server-new", + "server_type": "sse", + "headers": '{"Authorization": "new"}', + "variables": '{"tools": {"ignore": 1}, "token": "new"}', + "timeout": "3.0", + } + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing)) + + async def _thread_pool_tools_error(_func, _servers, _timeout): + return None, "update tools error" + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_tools_error) + res = _run(module.update("mcp-1")) + assert res["code"] == 102 + assert "update tools error" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + + async def _thread_pool_ok(_func, servers, _timeout): + return {servers[0].name: [{"name": "tool_a"}, {"bad": True}]}, None + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_ok) + monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: False) + res = _run(module.update("mcp-1")) + assert "Failed to updated MCP server" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.MCPServerService, "filter_update", lambda *_args, **_kwargs: True) + + def _get_by_id_fetch_fail(_mcp_id): + if _get_by_id_fetch_fail.calls == 0: + _get_by_id_fetch_fail.calls += 1 + return True, existing + return False, None + + _get_by_id_fetch_fail.calls = 0 + monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_fetch_fail) + res = _run(module.update("mcp-1")) + assert "Failed to fetch updated MCP server" in res["message"] + + _set_request_json(monkeypatch, module, dict(base_payload)) + + def _get_by_id_success(_mcp_id): + if _get_by_id_success.calls == 0: + _get_by_id_success.calls += 1 + return True, existing + return True, updated + + _get_by_id_success.calls = 0 + monkeypatch.setattr(module.MCPServerService, "get_by_id", _get_by_id_success) + res = _run(module.update("mcp-1")) + assert res["code"] == 0 + assert res["data"]["id"] == "mcp-1" + + _set_request_json(monkeypatch, module, dict(base_payload)) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, existing)) + + async def _thread_pool_raises(_func, _servers, _timeout): + raise RuntimeError("update explode") + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_raises) + res = _run(module.update("mcp-1")) + assert res["code"] == 100 + assert "update explode" in res["message"] + + +@pytest.mark.p2 +def test_rm_failure_success_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + server = _DummyMCPServer(id="id1", name="srv", url="http://a", server_type="sse", tenant_id="tenant_1", variables={}) + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (True, server)) + + _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]}) + monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: False) + res = _run(module.rm("id1")) + assert "Failed to delete MCP servers" in res["message"] + + _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]}) + monkeypatch.setattr(module.MCPServerService, "delete_by_ids", lambda _ids: True) + res = _run(module.rm("id1")) + assert res["code"] == 0 + assert res["data"] is True + + _set_request_json(monkeypatch, module, {"mcp_ids": ["a", "b"]}) + + def _raise_rm(_ids): + raise RuntimeError("rm explode") + + monkeypatch.setattr(module.MCPServerService, "delete_by_ids", _raise_rm) + res = _run(module.rm("id1")) + assert res["code"] == 100 + assert "rm explode" in res["message"] + + +@pytest.mark.p2 +def test_import_multiple_missing_servers_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + + _set_request_json(monkeypatch, module, {"mcpServers": {}}) + res = _run(module.import_multiple.__wrapped__()) + assert "No MCP servers provided" in res["message"] + + _set_request_json(monkeypatch, module, {"mcpServers": {"srv": {"type": "sse", "url": "http://x"}}, "timeout": "1"}) + + def _raise_import(**_kwargs): + raise RuntimeError("import explode") + + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _raise_import) + res = _run(module.import_multiple.__wrapped__()) + assert res["code"] == 100 + assert "import explode" in res["message"] + + +@pytest.mark.p2 +def test_import_multiple_mixed_results(monkeypatch): + module = _load_mcp_api(monkeypatch) + + payload = { + "mcpServers": { + "missing_fields": {"type": "sse"}, + "": {"type": "sse", "url": "http://empty"}, + "dup": {"type": "sse", "url": "http://dup", "authorization_token": "dup-token"}, + "tool_err": {"type": "sse", "url": "http://err"}, + "insert_fail": {"type": "sse", "url": "http://fail"}, + }, + "timeout": "3", + } + _set_request_json(monkeypatch, module, payload) + + monkeypatch.setattr(module, "get_uuid", lambda: "uuid-import") + + def _get_by_name_and_tenant(name, tenant_id): + if name == "dup" and not _get_by_name_and_tenant.first_dup_seen: + _get_by_name_and_tenant.first_dup_seen = True + return True, object() + return False, None + + _get_by_name_and_tenant.first_dup_seen = False + monkeypatch.setattr(module.MCPServerService, "get_by_name_and_tenant", _get_by_name_and_tenant) + + async def _thread_pool_exec(func, servers, _timeout): + mcp_server = servers[0] + if mcp_server.name == "tool_err": + return None, "tool call failed" + return {mcp_server.name: [{"name": "tool_a"}, {"invalid": True}]}, None + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec) + + def _insert(**kwargs): + return kwargs["name"] != "insert_fail" + + monkeypatch.setattr(module.MCPServerService, "insert", _insert) + + res = _run(module.import_multiple.__wrapped__()) + assert res["code"] == 0 + + results = {item["server"]: item for item in res["data"]["results"]} + assert results["missing_fields"]["success"] is False + assert "Missing required fields" in results["missing_fields"]["message"] + assert results[""]["success"] is False + assert "Invalid MCP name" in results[""]["message"] + assert results["tool_err"]["success"] is False + assert "tool call failed" in results["tool_err"]["message"] + assert results["insert_fail"]["success"] is False + assert "Failed to create MCP server" in results["insert_fail"]["message"] + assert results["dup"]["success"] is True + assert results["dup"]["new_name"] == "dup_0" + assert "Renamed from 'dup' to 'dup_0' avoid duplication" == results["dup"]["message"] + + +@pytest.mark.p2 +def test_detail_download_success_and_exception(monkeypatch): + module = _load_mcp_api(monkeypatch) + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({"mode": "download"}))) + + monkeypatch.setattr( + module.MCPServerService, + "get_by_id", + lambda _mcp_id: ( + True, + _DummyMCPServer( + id="id1", + name="srv-one", + url="http://one", + server_type="sse", + tenant_id="tenant_1", + variables={"authorization_token": "tok", "tools": {"tool_a": {"enabled": True}}}, + ), + ), + ) + res = module.detail("id1") + assert res["code"] == 0 + assert list(res["data"]["mcpServers"].keys()) == ["srv-one"] + + monkeypatch.setattr(module.MCPServerService, "get_by_id", lambda _mcp_id: (False, None)) + res = module.detail("missing") + assert res["code"] == 102 + assert "Cannot find MCP server missing for user tenant_1" in res["message"] + + monkeypatch.setattr( + module.MCPServerService, + "get_by_id", + lambda _mcp_id: ( + True, + _DummyMCPServer( + id="id2", + name="srv-two", + url="http://two", + server_type="sse", + tenant_id="other", + variables={}, + ), + ), + ) + res = module.detail("id2") + assert res["code"] == 102 + assert "Cannot find MCP server id2 for user tenant_1" in res["message"] + + def _raise_export(_mcp_id): + raise RuntimeError("export explode") + + monkeypatch.setattr(module.MCPServerService, "get_by_id", _raise_export) + res = module.detail("id1") + assert res["code"] == 100 + assert "export explode" in res["message"] + + +@pytest.mark.p2 +def test_test_mcp_route_matrix_unit(monkeypatch): + module = _load_mcp_api(monkeypatch) + + _set_request_json(monkeypatch, module, {"url": "", "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert "Invalid MCP url" in res["message"] + + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "invalid"}) + res = _run(module.test_mcp("mcp-1")) + assert "Unsupported MCP server type" in res["message"] + + close_calls = [] + + async def _thread_pool_exec_inner_error(func, *args): + if func is module.close_multiple_mcp_toolcall_sessions: + close_calls.append(args[0]) + return None + if getattr(func, "__name__", "") == "get_tools": + raise RuntimeError("get tools explode") + return func(*args) + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_inner_error) + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert res["code"] == 102 + assert "Test MCP error: get tools explode" in res["message"] + assert close_calls and len(close_calls[-1]) == 1 + + close_calls_success = [] + + async def _thread_pool_exec_success(func, *args): + if func is module.close_multiple_mcp_toolcall_sessions: + close_calls_success.append(args[0]) + return None + return func(*args) + + monkeypatch.setattr(module, "thread_pool_exec", _thread_pool_exec_success) + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert res["code"] == 0 + assert res["data"][0]["name"] == "tool_a" + assert all(tool["enabled"] is True for tool in res["data"]) + assert close_calls_success and len(close_calls_success[-1]) == 1 + + def _raise_session(*_args, **_kwargs): + raise RuntimeError("session explode") + + monkeypatch.setattr(module, "MCPToolCallSession", _raise_session) + _set_request_json(monkeypatch, module, {"url": "http://a", "server_type": "sse"}) + res = _run(module.test_mcp("mcp-1")) + assert res["code"] == 100 + assert "session explode" in res["message"] diff --git a/test/testcases/restful_api/test_memories_messages.py b/test/testcases/restful_api/test_memories_messages.py new file mode 100644 index 00000000000..12fcdc0df27 --- /dev/null +++ b/test/testcases/restful_api/test_memories_messages.py @@ -0,0 +1,210 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import time +import uuid + +import pytest + + +@pytest.fixture +def memory_cleanup(rest_client): + created_ids: list[str] = [] + + def _cleanup(): + cleanup_errors = [] + for memory_id in created_ids: + delete_res = rest_client.delete(f"/memories/{memory_id}") + if delete_res.status_code != 200: + cleanup_errors.append((memory_id, delete_res.status_code, delete_res.text)) + continue + delete_payload = delete_res.json() + if delete_payload["code"] not in (0, 404): + cleanup_errors.append((memory_id, delete_res.status_code, delete_payload)) + assert not cleanup_errors, f"Memory cleanup failed: {cleanup_errors}" + + yield created_ids + _cleanup() + + +@pytest.fixture +def create_memory_resource(rest_client, memory_cleanup): + def _create(name_prefix: str = "restful_memory") -> str: + payload = { + "name": f"{name_prefix}_{uuid.uuid4().hex[:8]}", + "memory_type": ["raw"], + "embd_id": "BAAI/bge-small-en-v1.5@Builtin", + "llm_id": "glm-4-flash@ZHIPU-AI", + } + res = rest_client.post("/memories", json=payload) + assert res.status_code == 200 + res_payload = res.json() + assert res_payload["code"] == 0, res_payload + memory_id = res_payload["data"]["id"] + memory_cleanup.append(memory_id) + return memory_id + + yield _create + + +def _add_message(rest_client, memory_id: str, user_input: str, agent_response: str) -> None: + add_res = rest_client.post( + "/messages", + json={ + "memory_id": [memory_id], + "agent_id": uuid.uuid4().hex, + "session_id": uuid.uuid4().hex, + "user_id": uuid.uuid4().hex, + "user_input": user_input, + "agent_response": agent_response, + }, + ) + assert add_res.status_code == 200 + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + + +def _wait_for_memory_messages(rest_client, memory_id: str, timeout: float = 10, interval: float = 0.2) -> list[dict]: + deadline = time.time() + timeout + last_payload = None + while time.time() < deadline: + res = rest_client.get(f"/memories/{memory_id}") + if res.status_code == 200: + payload = res.json() + last_payload = payload + if payload.get("code") == 0: + message_list = payload.get("data", {}).get("messages", {}).get("message_list", []) + if message_list: + return message_list + time.sleep(interval) + pytest.fail(f"Timed out waiting for memory messages: {last_payload}") + + +@pytest.mark.p1 +def test_memory_crud_cycle(rest_client, create_memory_resource): + memory_id = create_memory_resource("restful_memory_crud") + + list_res = rest_client.get("/memories") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload + + config_res = rest_client.get(f"/memories/{memory_id}/config") + assert config_res.status_code == 200 + config_payload = config_res.json() + assert config_payload["code"] == 0, config_payload + assert config_payload["data"]["id"] == memory_id, config_payload + + update_res = rest_client.put( + f"/memories/{memory_id}", + json={"name": f"updated_{uuid.uuid4().hex[:6]}", "permissions": "me"}, + ) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + + delete_res = rest_client.delete(f"/memories/{memory_id}") + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + +@pytest.mark.p2 +def test_memory_create_missing_required_fields(rest_client): + res = rest_client.post("/memories", json={"name": "missing_models", "memory_type": ["raw"]}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + + +@pytest.mark.p1 +def test_messages_add_list_recent_content_update_forget(rest_client, create_memory_resource): + memory_id = create_memory_resource("restful_message_memory") + _add_message( + rest_client, + memory_id, + user_input="what is coriander?", + agent_response="coriander can refer to leaves or seeds", + ) + + message_list = _wait_for_memory_messages(rest_client, memory_id) + + message_id = message_list[0]["message_id"] + + recent_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10}) + assert recent_res.status_code == 200 + recent_payload = recent_res.json() + assert recent_payload["code"] == 0, recent_payload + assert any(item["message_id"] == message_id for item in recent_payload["data"]), recent_payload + + content_res = rest_client.get(f"/messages/{memory_id}:{message_id}/content") + assert content_res.status_code == 200 + content_payload = content_res.json() + assert content_payload["code"] == 0, content_payload + assert content_payload["data"]["content"], content_payload + + update_res = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": False}) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + + forget_res = rest_client.delete(f"/messages/{memory_id}:{message_id}") + assert forget_res.status_code == 200 + forget_payload = forget_res.json() + assert forget_payload["code"] == 0, forget_payload + + +@pytest.mark.p2 +def test_message_status_validation_requires_boolean(rest_client, create_memory_resource): + memory_id = create_memory_resource("restful_message_status_validation") + _add_message(rest_client, memory_id, user_input="hello", agent_response="hello") + + message_id = _wait_for_memory_messages(rest_client, memory_id)[0]["message_id"] + + invalid_update = rest_client.put(f"/messages/{memory_id}:{message_id}", json={"status": "false"}) + assert invalid_update.status_code == 200 + invalid_payload = invalid_update.json() + assert invalid_payload["code"] == 101, invalid_payload + assert "Status must be a boolean." in invalid_payload["message"], invalid_payload + + +@pytest.mark.p2 +def test_messages_recent_requires_memory_ids(rest_client): + res = rest_client.get("/messages") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "memory_ids is required" in payload["message"], payload + + +@pytest.mark.p2 +def test_message_search_route_contract(rest_client, create_memory_resource): + memory_id = create_memory_resource("restful_message_search") + _add_message( + rest_client, + memory_id, + user_input="what is pineapple?", + agent_response="pineapple is a tropical fruit", + ) + + _wait_for_memory_messages(rest_client, memory_id) + + res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "pineapple", "top_n": 3}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert isinstance(payload["data"], list), payload diff --git a/test/testcases/restful_api/test_memory_messages.py b/test/testcases/restful_api/test_memory_messages.py new file mode 100644 index 00000000000..dcf5a3704f0 --- /dev/null +++ b/test/testcases/restful_api/test_memory_messages.py @@ -0,0 +1,165 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import uuid + +import pytest + + +def _memory_payload(name: str) -> dict: + return { + "name": name, + "memory_type": ["raw"], + "embd_id": "BAAI/bge-small-en-v1.5@Builtin", + "llm_id": "glm-4-flash@ZHIPU-AI", + } + + +def _create_memory(rest_client, name: str) -> dict: + res = rest_client.post("/memories", json=_memory_payload(name)) + assert res.status_code == 200 + payload = res.json() + if payload["code"] == 0: + return payload["data"] + + pytest.fail(f"Failed to create memory: {payload}") + + +@pytest.fixture +def memory_resource(rest_client): + memory = _create_memory(rest_client, f"restful_memory_{uuid.uuid4().hex[:8]}") + memory_id = memory["id"] + try: + yield memory + finally: + delete_res = rest_client.delete(f"/memories/{memory_id}") + assert delete_res.status_code == 200, delete_res.text + delete_payload = delete_res.json() + assert delete_payload["code"] in (0, 404), delete_payload + + +@pytest.mark.p2 +def test_memory_and_message_routes_require_auth(rest_client_noauth): + memory_res = rest_client_noauth.get("/memories") + assert memory_res.status_code == 401 + memory_payload = memory_res.json() + assert memory_payload["code"] == 401, memory_payload + + msg_list_res = rest_client_noauth.get("/messages") + assert msg_list_res.status_code == 401 + msg_list_payload = msg_list_res.json() + assert msg_list_payload["code"] == 401, msg_list_payload + + msg_search_res = rest_client_noauth.get("/messages/search") + assert msg_search_res.status_code == 401 + msg_search_payload = msg_search_res.json() + assert msg_search_payload["code"] == 401, msg_search_payload + + +@pytest.mark.p2 +def test_memory_crud_and_config(rest_client): + memory = _create_memory(rest_client, f"restful_memory_crud_{uuid.uuid4().hex[:8]}") + memory_id = memory["id"] + try: + config_res = rest_client.get(f"/memories/{memory_id}/config") + assert config_res.status_code == 200 + config_payload = config_res.json() + assert config_payload["code"] == 0, config_payload + assert config_payload["data"]["id"] == memory_id, config_payload + + list_res = rest_client.get("/memories", params={"keywords": memory["name"]}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert any(item["id"] == memory_id for item in list_payload["data"]["memory_list"]), list_payload + + update_res = rest_client.put(f"/memories/{memory_id}", json={"name": "restful_memory_updated"}) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + finally: + delete_res = rest_client.delete(f"/memories/{memory_id}") + assert delete_res.status_code == 200, delete_res.text + delete_payload = delete_res.json() + assert delete_payload["code"] in (0, 404), delete_payload + + +@pytest.mark.p2 +def test_memory_update_invalid_name(rest_client, memory_resource): + memory_id = memory_resource["id"] + res = rest_client.put(f"/memories/{memory_id}", json={"name": " "}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "cannot be empty" in payload["message"], payload + + +@pytest.mark.p2 +def test_messages_list_and_search_validation_contracts(rest_client, memory_resource): + memory_id = memory_resource["id"] + + list_res = rest_client.get("/messages", params={"memory_id": memory_id, "limit": 10}) + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert isinstance(list_payload["data"], list), list_payload + + missing_memory_res = rest_client.get("/messages") + assert missing_memory_res.status_code == 200 + missing_memory_payload = missing_memory_res.json() + assert missing_memory_payload["code"] == 101, missing_memory_payload + assert "memory_ids is required" in missing_memory_payload["message"], missing_memory_payload + + search_res = rest_client.get("/messages/search", params={"memory_id": memory_id, "query": "coriander"}) + assert search_res.status_code == 200 + search_payload = search_res.json() + assert search_payload["code"] == 0, search_payload + assert isinstance(search_payload["data"], list), search_payload + + search_no_memory = rest_client.get("/messages/search", params={"query": "coriander"}) + assert search_no_memory.status_code == 200 + search_no_memory_payload = search_no_memory.json() + assert search_no_memory_payload["code"] == 0, search_no_memory_payload + assert isinstance(search_no_memory_payload["data"], list), search_no_memory_payload + + +@pytest.mark.p2 +def test_message_update_forget_and_content_error_contracts(rest_client, memory_resource): + memory_id = memory_resource["id"] + + invalid_status_res = rest_client.put( + f"/messages/{memory_id}:1", + json={"status": "false"}, + ) + assert invalid_status_res.status_code == 200 + invalid_status_payload = invalid_status_res.json() + assert invalid_status_payload["code"] == 101, invalid_status_payload + assert "Status must be a boolean" in invalid_status_payload["message"], invalid_status_payload + + missing_content_res = rest_client.get(f"/messages/{memory_id}:999999/content") + assert missing_content_res.status_code == 200 + missing_content_payload = missing_content_res.json() + assert missing_content_payload["code"] == 404, missing_content_payload + + invalid_memory_forget = rest_client.delete("/messages/missing_memory_id:1") + assert invalid_memory_forget.status_code == 200 + invalid_memory_forget_payload = invalid_memory_forget.json() + assert invalid_memory_forget_payload["code"] == 404, invalid_memory_forget_payload + + invalid_memory_update = rest_client.put("/messages/missing_memory_id:1", json={"status": False}) + assert invalid_memory_update.status_code == 200 + invalid_memory_update_payload = invalid_memory_update.json() + assert invalid_memory_update_payload["code"] == 404, invalid_memory_update_payload diff --git a/test/testcases/restful_api/test_openai_compatible.py b/test/testcases/restful_api/test_openai_compatible.py new file mode 100644 index 00000000000..49e2c55ca59 --- /dev/null +++ b/test/testcases/restful_api/test_openai_compatible.py @@ -0,0 +1,214 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json + +import pytest + + +def _sse_events(response_text: str) -> list[str]: + return [line[5:] for line in response_text.splitlines() if line.startswith("data:")] + + +@pytest.mark.p2 +@pytest.mark.parametrize( + "payload, expected_message", + [ + ( + { + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "extra_body": "invalid_extra_body", + }, + "extra_body must be an object.", + ), + ( + { + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "extra_body": {"reference_metadata": "invalid_reference_metadata"}, + }, + "reference_metadata must be an object.", + ), + ( + { + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "extra_body": {"reference_metadata": {"fields": "author"}}, + }, + "reference_metadata.fields must be an array.", + ), + ( + { + "model": "model", + "messages": [], + }, + "You have to provide messages.", + ), + ( + { + "model": "model", + "messages": [{"role": "assistant", "content": "hello"}], + }, + "The last content of this conversation is not from user.", + ), + ], +) +def test_openai_compatible_validation_payloads(rest_client, create_chat, payload, expected_message): + chat_id = create_chat("restful_openai_validation_chat") + res = rest_client.post(f"/openai/{chat_id}/chat/completions", json=payload) + assert res.status_code == 200 + data = res.json() + assert data["code"] != 0, data + assert expected_message in data.get("message", ""), data + + +@pytest.mark.p2 +def test_openai_compatible_metadata_condition_requires_object(rest_client, create_chat): + chat_id = create_chat("restful_openai_metadata_condition_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "extra_body": {"metadata_condition": "invalid"}, + }, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "metadata_condition must be an object." in payload["message"], payload + + +@pytest.mark.p2 +def test_openai_compatible_invalid_chat(rest_client): + res = rest_client.post( + "/openai/invalid_chat_id/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] != 0, payload + assert "don't own the chat" in payload["message"], payload + + +@pytest.mark.p2 +def test_openai_compatible_nonstream_shape(rest_client, create_chat): + chat_id = create_chat("restful_openai_nonstream_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + timeout=60, + ) + assert res.status_code == 200 + payload = res.json() + + assert payload["object"] == "chat.completion", payload + assert isinstance(payload["choices"], list) and payload["choices"], payload + first_choice = payload["choices"][0] + assert first_choice.get("finish_reason") == "stop", payload + assert first_choice.get("message", {}).get("role") == "assistant", payload + assert "content" in first_choice.get("message", {}), payload + + usage = payload.get("usage", {}) + assert "prompt_tokens" in usage, usage + assert "completion_tokens" in usage, usage + assert "total_tokens" in usage, usage + assert usage["prompt_tokens"] > 0, usage + assert usage["completion_tokens"] > 0, usage + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"], usage + + +@pytest.mark.p2 +def test_openai_compatible_nonstream_with_reference_output_shape(rest_client, create_chat): + chat_id = create_chat("restful_openai_reference_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + "extra_body": { + "reference": True, + "reference_metadata": {"include": True, "fields": ["author"]}, + }, + }, + timeout=60, + ) + assert res.status_code == 200 + payload = res.json() + choice_msg = payload["choices"][0]["message"] + assert "reference" in choice_msg, payload + assert isinstance(choice_msg["reference"], list), payload + + +@pytest.mark.p2 +def test_openai_compatible_stream_shape_and_done_semantics(rest_client, create_chat): + chat_id = create_chat("restful_openai_stream_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + "extra_body": {"reference": True}, + }, + timeout=60, + ) + assert res.status_code == 200 + content_type = res.headers.get("Content-Type", "") + assert "text/event-stream" in content_type, content_type + + events = _sse_events(res.text) + assert events, res.text + assert events[-1].strip() == "[DONE]", events[-1] + + json_events = [json.loads(evt) for evt in events if evt.strip() != "[DONE]"] + assert json_events, events + assert any(evt.get("object") == "chat.completion.chunk" for evt in json_events), json_events + assert any(evt.get("choices", [{}])[0].get("finish_reason") == "stop" for evt in json_events), json_events + + +@pytest.mark.p2 +def test_openai_compatible_reference_metadata_fields_filter_accepts_array(rest_client, create_chat): + chat_id = create_chat("restful_openai_reference_fields_array_chat") + res = rest_client.post( + f"/openai/{chat_id}/chat/completions", + json={ + "model": "model", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + "extra_body": { + "reference": True, + "reference_metadata": {"include": True, "fields": ["author", "year"]}, + }, + }, + timeout=60, + ) + assert res.status_code == 200 + payload = res.json() + assert payload.get("choices"), payload + choice_msg = payload["choices"][0]["message"] + assert "reference" in choice_msg, payload + assert isinstance(choice_msg["reference"], list), payload diff --git a/test/testcases/restful_api/test_plugin_tools.py b/test/testcases/restful_api/test_plugin_tools.py new file mode 100644 index 00000000000..c151394c29e --- /dev/null +++ b/test/testcases/restful_api/test_plugin_tools.py @@ -0,0 +1,92 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType + +import pytest + + +@pytest.mark.p2 +def test_plugin_tools_requires_auth(rest_client_noauth): + res = rest_client_noauth.get("/plugin/tools") + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, payload + + +@pytest.mark.p2 +def test_plugin_tools_contract(rest_client): + res = rest_client.get("/plugin/tools") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert isinstance(payload["data"], list), payload + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +def _load_plugin_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + stub_apps = ModuleType("api.apps") + stub_apps.login_required = lambda func: func + monkeypatch.setitem(sys.modules, "api.apps", stub_apps) + + stub_plugin = ModuleType("agent.plugin") + + class _StubGlobalPluginManager: + @staticmethod + def get_llm_tools(): + return [] + + stub_plugin.GlobalPluginManager = _StubGlobalPluginManager + monkeypatch.setitem(sys.modules, "agent.plugin", stub_plugin) + + module_path = repo_root / "api" / "apps" / "restful_apis" / "plugin_api.py" + spec = importlib.util.spec_from_file_location("restful_plugin_api_unit", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_plugin_tools_metadata_shape_unit(monkeypatch): + module = _load_plugin_module(monkeypatch) + + class _DummyTool: + def get_metadata(self): + return {"name": "dummy", "description": "test"} + + monkeypatch.setattr(module.GlobalPluginManager, "get_llm_tools", staticmethod(lambda: [_DummyTool()])) + res = module.llm_tools() + assert res["code"] == 0 + assert isinstance(res["data"], list) + assert res["data"][0]["name"] == "dummy" + assert res["data"][0]["description"] == "test" diff --git a/test/testcases/restful_api/test_retrieval.py b/test/testcases/restful_api/test_retrieval.py new file mode 100644 index 00000000000..5f6531a8c3c --- /dev/null +++ b/test/testcases/restful_api/test_retrieval.py @@ -0,0 +1,375 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from concurrent.futures import ThreadPoolExecutor +import pytest +import requests +from test.testcases.configs import HOST_ADDRESS, INVALID_API_TOKEN, VERSION +from test.testcases.restful_api.helpers.client import RestClient +from test.testcases.utils import wait_for + + +@pytest.mark.p1 +def test_dataset_search_rest_endpoint(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + res = rest_client.post( + f"/datasets/{dataset_id}/search", + json={"question": "test TXT file", "top_k": 5}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_multi_dataset_search_rest_endpoint(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + res = rest_client.post( + "/datasets/search", + json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_multi_dataset_search_with_metadata_filter(rest_client, ensure_parsed_document): + dataset_id, document_id = ensure_parsed_document() + meta_res = rest_client.patch( + f"/datasets/{dataset_id}/documents/metadatas", + json={ + "selector": {"document_ids": [document_id]}, + "updates": [{"key": "author", "value": "qa_batch2"}], + "deletes": [], + }, + ) + assert meta_res.status_code == 200 + meta_payload = meta_res.json() + assert meta_payload["code"] == 0, meta_payload + + res = rest_client.post( + "/datasets/search", + json={ + "dataset_ids": [dataset_id], + "question": "test TXT file", + "meta_data_filter": { + "method": "manual", + "logic": "and", + "manual": [{"key": "author", "op": "=", "value": "qa_batch2"}], + }, + }, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_retrieval_compatibility_endpoint(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + # /api/v1/retrieval is SDK compatibility endpoint from api/apps/sdk/doc.py. + res = rest_client.post( + "/retrieval", + json={"dataset_ids": [dataset_id], "question": "test TXT file", "top_k": 5}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "chunks" in payload["data"], payload + + +@pytest.mark.p2 +def test_retrieval_compatibility_requires_dataset_ids(rest_client): + res = rest_client.post("/retrieval", json={"question": "test"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert payload["message"] == "`dataset_ids` is required.", payload + + +@pytest.mark.p2 +def test_retrieval_compatibility_requires_auth(rest_client_noauth): + res = rest_client_noauth.post("/retrieval", json={"question": "test", "dataset_ids": ["x"]}) + assert res.status_code == 401 + payload = res.json() + # token_required preserves legacy payload code/message while returning HTTP 401. + assert payload["code"] == 0, payload + assert payload["message"] == "`Authorization` can't be empty", payload + + +@wait_for(20, 1, "Retrieval indexing timeout in RESTful batch 10 tests") +def _retrieval_has_question(rest_client, dataset_id, question): + res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]}) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + return len(payload["data"]["chunks"]) > 0 + + +@wait_for(20, 1, "Retrieval indexing timeout waiting for chunk presence in RESTful batch 10 tests") +def _retrieval_has_chunks(rest_client, dataset_id, question, chunk_ids): + res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]}) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + retrieved_ids = {chunk["id"] for chunk in payload["data"]["chunks"]} + expected_ids = set(chunk_ids) + return expected_ids.issubset(retrieved_ids) + + +@wait_for(20, 1, "Retrieval indexing timeout waiting for chunk deletion in RESTful batch 10 tests") +def _retrieval_lacks_chunks(rest_client, dataset_id, question, chunk_ids): + res = rest_client.post("/retrieval", json={"question": question, "dataset_ids": [dataset_id]}) + if res.status_code != 200: + return False + payload = res.json() + if payload["code"] != 0: + return False + retrieved_ids = {chunk["id"] for chunk in payload["data"]["chunks"]} + expected_ids = set(chunk_ids) + return expected_ids.isdisjoint(retrieved_ids) + + +@pytest.mark.p2 +def test_retrieval_requires_auth_contract(ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + for scenario_name, token, expected_code, expected_message in ( + ("missing token", None, 0, "`Authorization` can't be empty"), + ("invalid token", INVALID_API_TOKEN, 109, "Authentication error: API key is invalid!"), + ): + client = RestClient(token=token) + res = client.post("/retrieval", json={"question": "chunk", "dataset_ids": [dataset_id]}) + assert res.status_code == 401, (scenario_name, res.text) + payload = res.json() + assert payload["code"] == expected_code, (scenario_name, payload) + assert payload["message"] == expected_message, (scenario_name, payload) + + +@pytest.mark.p2 +def test_retrieval_page_and_page_size_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + cases = [ + ("page none", {"question": "chunk", "dataset_ids": [dataset_id], "page": None, "page_size": 2}, 100, "TypeError"), + ("page zero", {"question": "chunk", "dataset_ids": [dataset_id], "page": 0, "page_size": 2}, 0, ""), + ("page two", {"question": "chunk", "dataset_ids": [dataset_id], "page": 2, "page_size": 2}, 0, ""), + ("page three", {"question": "chunk", "dataset_ids": [dataset_id], "page": 3, "page_size": 2}, 0, ""), + ("page str", {"question": "chunk", "dataset_ids": [dataset_id], "page": "3", "page_size": 2}, 0, ""), + ("page negative", {"question": "chunk", "dataset_ids": [dataset_id], "page": -1, "page_size": 2}, 0, ""), + ("page alpha", {"question": "chunk", "dataset_ids": [dataset_id], "page": "a", "page_size": 2}, 100, "invalid literal for int()"), + ("page_size none", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": None}, 100, "TypeError"), + ("page_size one", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": 1}, 0, ""), + ("page_size five", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": 5}, 0, ""), + ("page_size str", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": "1"}, 0, ""), + ("page_size alpha", {"question": "chunk", "dataset_ids": [dataset_id], "page_size": "a"}, 100, "invalid literal for int()"), + ] + for scenario_name, payload, expected_code, expected_message in cases: + res = rest_client.post("/retrieval", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert expected_message in body["message"], (scenario_name, body) + + +@pytest.mark.p2 +def test_retrieval_highlight_keyword_and_invalid_params_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + + highlight_cases = [ + ("highlight true", True, True), + ("highlight true str", "True", True), + ("highlight false", False, False), + ("highlight false str", "False", False), + ("highlight none", None, False), + ] + for scenario_name, highlight_value, expect_highlight in highlight_cases: + res = rest_client.post( + "/retrieval", + json={"question": "chunk", "dataset_ids": [dataset_id], "highlight": highlight_value}, + ) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == 0, (scenario_name, body) + for chunk in body["data"]["chunks"]: + if expect_highlight: + assert "highlight" in chunk, (scenario_name, body) + else: + assert "highlight" not in chunk, (scenario_name, body) + + invalid_highlight = rest_client.post( + "/retrieval", + json={"question": "chunk", "dataset_ids": [dataset_id], "highlight": "not_bool"}, + ) + assert invalid_highlight.status_code == 200 + invalid_highlight_payload = invalid_highlight.json() + assert invalid_highlight_payload["code"] == 102, invalid_highlight_payload + assert invalid_highlight_payload["message"] == "`highlight` should be a boolean", invalid_highlight_payload + + for scenario_name, keyword_value in ( + ("keyword true", True), + ("keyword true str", "True"), + ("keyword false", False), + ("keyword false str", "False"), + ("keyword none", None), + ): + keyword_res = rest_client.post( + "/retrieval", + json={"question": "chunk test", "dataset_ids": [dataset_id], "keyword": keyword_value}, + ) + assert keyword_res.status_code == 200, (scenario_name, keyword_res.text) + keyword_payload = keyword_res.json() + assert keyword_payload["code"] == 0, (scenario_name, keyword_payload) + assert isinstance(keyword_payload["data"]["chunks"], list), (scenario_name, keyword_payload) + + invalid_params_res = rest_client.post( + "/retrieval", + json={"question": "chunk", "dataset_ids": [dataset_id], "a": "b"}, + ) + assert invalid_params_res.status_code == 200 + invalid_params_payload = invalid_params_res.json() + assert invalid_params_payload["code"] == 0, invalid_params_payload + + +@pytest.mark.p2 +def test_retrieval_vector_similarity_and_top_k_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + cases = [ + ("vector 0", {"vector_similarity_weight": 0}, 0, ""), + ("vector 0.5", {"vector_similarity_weight": 0.5}, 0, ""), + ("vector 10", {"vector_similarity_weight": 10}, 0, ""), + ("vector alpha", {"vector_similarity_weight": "a"}, 100, "could not convert string to float"), + ("top_k 10", {"top_k": 10}, 0, ""), + ("top_k 1", {"top_k": 1}, 0, ""), + ("top_k -1", {"top_k": -1}, 102, "`top_k` must be greater than 0"), + ("top_k alpha", {"top_k": "a"}, 100, "invalid literal for int()"), + ] + for scenario_name, updates, expected_code, expected_message in cases: + payload = {"question": "chunk", "dataset_ids": [dataset_id]} + payload.update(updates) + res = rest_client.post("/retrieval", json=payload) + assert res.status_code == 200, (scenario_name, res.text) + body = res.json() + assert body["code"] == expected_code, (scenario_name, body) + if expected_code != 0: + assert expected_message in body["message"], (scenario_name, body) + + +@pytest.mark.p2 +def test_retrieval_rerank_unknown_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + res = rest_client.post( + "/retrieval", + json={"question": "chunk", "dataset_ids": [dataset_id], "rerank_id": "unknown"}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] != 0, payload + assert payload["message"], payload + + +@pytest.mark.p2 +def test_retrieval_concurrent_contract(rest_client, ensure_parsed_document): + dataset_id, _ = ensure_parsed_document() + payload = {"question": "chunk", "dataset_ids": [dataset_id]} + with ThreadPoolExecutor(max_workers=5) as executor: + results = list(executor.map(lambda _: rest_client.post("/retrieval", json=payload).json(), range(20))) + assert len(results) == 20, results + assert all(result["code"] == 0 for result in results), results + + +@pytest.mark.p2 +def test_deleted_chunk_not_in_retrieval_contract(rest_client, create_document): + dataset_id, document_id = create_document("retrieval_deleted_chunk.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + content = "UNIQUE_TEST_CONTENT_12520_REST" + + add_res = rest_client.post(base_path, json={"content": content}) + assert add_res.status_code == 200 + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_id = add_payload["data"]["chunk"]["id"] + + _retrieval_has_chunks(rest_client, dataset_id, content, [chunk_id]) + + delete_res = rest_client.delete(base_path, json={"chunk_ids": [chunk_id]}) + assert delete_res.status_code == 200 + assert delete_res.json()["code"] == 0 + _retrieval_lacks_chunks(rest_client, dataset_id, content, [chunk_id]) + + +@pytest.mark.p2 +def test_deleted_chunks_batch_not_in_retrieval_contract(rest_client, create_document): + dataset_id, document_id = create_document("retrieval_deleted_chunks_batch.txt") + base_path = f"/datasets/{dataset_id}/documents/{document_id}/chunks" + chunk_ids = [] + for index in range(3): + content = f"BATCH_DELETE_TEST_CHUNK_{index}_REST_12520" + add_res = rest_client.post(base_path, json={"content": content}) + assert add_res.status_code == 200 + add_payload = add_res.json() + assert add_payload["code"] == 0, add_payload + chunk_ids.append(add_payload["data"]["chunk"]["id"]) + _retrieval_has_chunks(rest_client, dataset_id, "BATCH_DELETE_TEST_CHUNK", chunk_ids) + + delete_res = rest_client.delete(base_path, json={"chunk_ids": chunk_ids}) + assert delete_res.status_code == 200 + assert delete_res.json()["code"] == 0 + _retrieval_lacks_chunks(rest_client, dataset_id, "BATCH_DELETE_TEST_CHUNK", chunk_ids) + + +@pytest.mark.p2 +def test_related_questions_contract(auth, rest_client, rest_client_noauth): + tokens_res = requests.get( + f"{HOST_ADDRESS}/api/{VERSION}/system/tokens", + headers={"Authorization": auth}, + timeout=30, + ) + assert tokens_res.status_code == 200, tokens_res.text + tokens_payload = tokens_res.json() + assert tokens_payload["code"] == 0, tokens_payload + assert tokens_payload["data"], tokens_payload + beta_token = tokens_payload["data"][0]["beta"] + + success_client = RestClient(token=beta_token) + success_res = success_client.post("/searchbots/related_questions", json={"question": "ragflow", "industry": "search"}) + assert success_res.status_code == 200 + success_payload = success_res.json() + assert success_payload["code"] == 0, success_payload + assert isinstance(success_payload["data"], list), success_payload + + missing_res = rest_client.post("/searchbots/related_questions", json={"industry": "search"}) + assert missing_res.status_code == 200 + missing_payload = missing_res.json() + assert missing_payload["code"] == 101, missing_payload + assert "question" in missing_payload["message"], missing_payload + + invalid_auth_res = rest_client_noauth.post( + "/searchbots/related_questions", + json={"question": "ragflow", "industry": "search"}, + headers={"Authorization": "invalid"}, + ) + assert invalid_auth_res.status_code == 200 + invalid_auth_payload = invalid_auth_res.json() + assert invalid_auth_payload["code"] == 102, invalid_auth_payload + assert "Authorization is not valid!" in invalid_auth_payload["message"], invalid_auth_payload diff --git a/test/testcases/restful_api/test_router_contracts.py b/test/testcases/restful_api/test_router_contracts.py new file mode 100644 index 00000000000..72683bad8f8 --- /dev/null +++ b/test/testcases/restful_api/test_router_contracts.py @@ -0,0 +1,28 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest +from configs import VERSION + + +@pytest.mark.p1 +def test_route_not_found_returns_json(rest_client_noauth): + res = rest_client_noauth.get("/__missing_route__") + assert res.status_code == 404 + payload = res.json() + assert payload["code"] == 404, payload + assert payload["error"] == "Not Found", payload + assert payload["message"] == f"Not Found: /api/{VERSION}/__missing_route__", payload diff --git a/test/testcases/restful_api/test_searches.py b/test/testcases/restful_api/test_searches.py new file mode 100644 index 00000000000..1a6923fb509 --- /dev/null +++ b/test/testcases/restful_api/test_searches.py @@ -0,0 +1,155 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json +import uuid + +import pytest + + +@pytest.fixture +def search_resource(rest_client): + name = f"restful_search_{uuid.uuid4().hex[:8]}" + create_res = rest_client.post("/searches", json={"name": name, "description": "restful search"}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + search_id = create_payload["data"]["search_id"] + + try: + yield search_id + finally: + delete_res = rest_client.delete(f"/searches/{search_id}") + assert delete_res.status_code == 200, delete_res.text + delete_payload = delete_res.json() + assert delete_payload["code"] in (0, 109), delete_payload + + +def _sse_events(response_text: str) -> list[str]: + return [line[5:] for line in response_text.splitlines() if line.startswith("data:")] + + +@pytest.mark.p2 +def test_search_routes_require_auth(rest_client_noauth): + create_res = rest_client_noauth.post("/searches", json={"name": "search_noauth"}) + assert create_res.status_code == 401 + create_payload = create_res.json() + assert create_payload["code"] == 401, create_payload + + list_res = rest_client_noauth.get("/searches") + assert list_res.status_code == 401 + list_payload = list_res.json() + assert list_payload["code"] == 401, list_payload + + +@pytest.mark.p2 +def test_search_crud_contract(rest_client, search_resource): + search_id = search_resource + + list_res = rest_client.get("/searches") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert any(item.get("id") == search_id for item in list_payload["data"]["search_apps"]), list_payload + + detail_res = rest_client.get(f"/searches/{search_id}") + assert detail_res.status_code == 200 + detail_payload = detail_res.json() + assert detail_payload["code"] == 0, detail_payload + assert detail_payload["data"]["id"] == search_id, detail_payload + + new_name = f"search_updated_{uuid.uuid4().hex[:6]}" + update_res = rest_client.put( + f"/searches/{search_id}", + json={"name": new_name, "search_config": {"top_k": 3}}, + ) + assert update_res.status_code == 200 + update_payload = update_res.json() + assert update_payload["code"] == 0, update_payload + assert update_payload["data"]["name"] == new_name, update_payload + + +@pytest.mark.p2 +def test_search_create_invalid_name(rest_client): + res = rest_client.post("/searches", json={"name": ""}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "empty" in payload["message"], payload + + +@pytest.mark.p2 +def test_search_update_invalid_search_id(rest_client): + res = rest_client.put( + "/searches/invalid_search_id", + json={"name": "invalid", "search_config": {}}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 109, payload + assert "No authorization" in payload["message"], payload + + +@pytest.mark.p2 +def test_search_completion_requires_question(rest_client, search_resource): + search_id = search_resource + + completion_res = rest_client.post(f"/searches/{search_id}/completion", json={}) + assert completion_res.status_code == 200 + completion_payload = completion_res.json() + assert completion_payload["code"] == 101, completion_payload + assert "required argument are missing: question" in completion_payload["message"], completion_payload + + completions_res = rest_client.post(f"/searches/{search_id}/completions", json={}) + assert completions_res.status_code == 200 + completions_payload = completions_res.json() + assert completions_payload["code"] == 101, completions_payload + assert "required argument are missing: question" in completions_payload["message"], completions_payload + + +@pytest.mark.p2 +def test_search_completion_requires_kb_ids(rest_client, search_resource): + search_id = search_resource + for path in ("completion", "completions"): + res = rest_client.post( + f"/searches/{search_id}/{path}", + json={"question": "what is coriander?"}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "`kb_ids` is required" in payload["message"], payload + + +@pytest.mark.p2 +def test_search_completion_sse_shape_when_kb_ids_provided(rest_client, search_resource): + search_id = search_resource + # Even with kb_ids provided, runtime may return an error event in-stream, but + # contract remains SSE with JSON data lines and terminal boolean event. + res = rest_client.post( + f"/searches/{search_id}/completion", + json={"question": "what is coriander?", "kb_ids": ["nonexistent_dataset"]}, + timeout=60, + ) + assert res.status_code == 200 + content_type = res.headers.get("Content-Type", "") + assert "text/event-stream" in content_type, content_type + + events = _sse_events(res.text) + assert events, res.text + parsed = [json.loads(evt) for evt in events] + assert isinstance(parsed[0], dict), parsed + assert parsed[-1].get("data") is True, parsed[-1] diff --git a/test/testcases/restful_api/test_sessions.py b/test/testcases/restful_api/test_sessions.py new file mode 100644 index 00000000000..ca1c8ea5c6e --- /dev/null +++ b/test/testcases/restful_api/test_sessions.py @@ -0,0 +1,270 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import json + +import pytest + + +def _sse_events(response_text: str) -> list[str]: + return [line[5:] for line in response_text.splitlines() if line.startswith("data:")] + + +@pytest.mark.p1 +def test_session_crud_cycle(rest_client, create_chat): + chat_id = create_chat("restful_session_crud_chat") + + create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_a"}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + session_id = create_payload["data"]["id"] + assert create_payload["data"]["chat_id"] == chat_id, create_payload + + list_res = rest_client.get(f"/chats/{chat_id}/sessions") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert any(item["id"] == session_id for item in list_payload["data"]), list_payload + + get_res = rest_client.get(f"/chats/{chat_id}/sessions/{session_id}") + assert get_res.status_code == 200 + get_payload = get_res.json() + assert get_payload["code"] == 0, get_payload + assert get_payload["data"]["id"] == session_id, get_payload + + patch_res = rest_client.patch( + f"/chats/{chat_id}/sessions/{session_id}", + json={"name": "session_a_updated"}, + ) + assert patch_res.status_code == 200 + patch_payload = patch_res.json() + assert patch_payload["code"] == 0, patch_payload + assert patch_payload["data"]["name"] == "session_a_updated", patch_payload + + delete_res = rest_client.delete(f"/chats/{chat_id}/sessions", json={"ids": [session_id]}) + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + + list_after_delete = rest_client.get(f"/chats/{chat_id}/sessions") + assert list_after_delete.status_code == 200 + list_after_delete_payload = list_after_delete.json() + assert list_after_delete_payload["code"] == 0, list_after_delete_payload + assert all(item["id"] != session_id for item in list_after_delete_payload["data"]), list_after_delete_payload + + +@pytest.mark.p2 +def test_session_create_name_validation(rest_client, create_chat): + chat_id = create_chat("restful_session_name_validation_chat") + + res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": " "}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "`name` can not be empty." in payload["message"], payload + + +@pytest.mark.p2 +def test_session_update_blocks_messages_and_reference(rest_client, create_chat): + chat_id = create_chat("restful_session_guard_chat") + create_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_guard"}) + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + session_id = create_payload["data"]["id"] + + msg_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"messages": []}) + assert msg_res.status_code == 200 + msg_payload = msg_res.json() + assert msg_payload["code"] == 102, msg_payload + assert "`messages` cannot be changed." in msg_payload["message"], msg_payload + + ref_res = rest_client.patch(f"/chats/{chat_id}/sessions/{session_id}", json={"reference": []}) + assert ref_res.status_code == 200 + ref_payload = ref_res.json() + assert ref_payload["code"] == 102, ref_payload + assert "`reference` cannot be changed." in ref_payload["message"], ref_payload + + +@pytest.mark.p2 +def test_chat_recommendation_requires_question(rest_client): + res = rest_client.post("/chat/recommendation", json={}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "required argument are missing: question" in payload["message"], payload + + +@pytest.mark.p2 +def test_related_questions_compatibility_requires_auth(rest_client_noauth): + # /api/v1/searchbots/related_questions is an SDK compatibility endpoint. + res = rest_client_noauth.post( + "/searchbots/related_questions", + json={"question": "ragflow"}, + headers={"Authorization": "invalid"}, + ) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 102, payload + assert "Authorization is not valid!" in payload["message"], payload + + +@pytest.mark.p2 +def test_chat_completion_nonstream_with_session(rest_client, create_chat): + chat_id = create_chat("restful_completion_nonstream_chat") + create_session_res = rest_client.post(f"/chats/{chat_id}/sessions", json={"name": "session_for_completion"}) + assert create_session_res.status_code == 200 + create_session_payload = create_session_res.json() + assert create_session_payload["code"] == 0, create_session_payload + session_id = create_session_payload["data"]["id"] + + completion_res = rest_client.post( + "/chat/completions", + json={ + "chat_id": chat_id, + "session_id": session_id, + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + timeout=60, + ) + assert completion_res.status_code == 200 + completion_payload = completion_res.json() + assert completion_payload["code"] == 0, completion_payload + assert isinstance(completion_payload["data"], dict), completion_payload + for key in ["answer", "reference", "audio_binary", "id", "session_id"]: + assert key in completion_payload["data"], completion_payload + assert completion_payload["data"]["session_id"] == session_id, completion_payload + + +@pytest.mark.p2 +def test_chat_completion_nonstream_with_chat_without_session(rest_client, create_chat): + chat_id = create_chat("restful_completion_nonstream_without_session_chat") + + completion_res = rest_client.post( + "/chat/completions", + json={ + "chat_id": chat_id, + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + timeout=60, + ) + assert completion_res.status_code == 200 + completion_payload = completion_res.json() + assert completion_payload["code"] == 0, completion_payload + assert isinstance(completion_payload["data"], dict), completion_payload + assert completion_payload["data"]["session_id"], completion_payload + + +@pytest.mark.p2 +def test_chat_completion_nonstream_without_chat(rest_client): + completion_res = rest_client.post( + "/chat/completions", + json={ + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + timeout=60, + ) + assert completion_res.status_code == 200 + completion_payload = completion_res.json() + assert completion_payload["code"] == 0, completion_payload + assert isinstance(completion_payload["data"], dict), completion_payload + assert "answer" in completion_payload["data"], completion_payload + + +@pytest.mark.p2 +def test_chat_completion_stream_events(rest_client, create_chat): + chat_id = create_chat("restful_completion_stream_chat") + stream_res = rest_client.post( + "/chat/completions", + json={ + "chat_id": chat_id, + "messages": [{"role": "user", "content": "hello"}], + "stream": True, + }, + timeout=60, + ) + assert stream_res.status_code == 200 + content_type = stream_res.headers.get("Content-Type", "") + assert "text/event-stream" in content_type, content_type + + events = _sse_events(stream_res.text) + assert events, stream_res.text + parsed_events = [] + for event in events: + parsed = json.loads(event) + parsed_events.append(parsed) + + assert any(evt.get("code") == 0 and isinstance(evt.get("data"), dict) for evt in parsed_events), parsed_events + assert parsed_events[-1].get("data") is True, parsed_events[-1] + + +@pytest.mark.p2 +def test_chat_completion_validation_errors(rest_client, create_chat): + chat_id = create_chat("restful_completion_validation_chat") + + missing_messages = rest_client.post( + "/chat/completions", + json={"chat_id": chat_id, "stream": False}, + ) + assert missing_messages.status_code == 200 + missing_messages_payload = missing_messages.json() + assert missing_messages_payload["code"] == 101, missing_messages_payload + assert "required argument are missing: messages" in missing_messages_payload["message"], missing_messages_payload + + missing_chat_for_session = rest_client.post( + "/chat/completions", + json={ + "session_id": "some_session", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + ) + assert missing_chat_for_session.status_code == 200 + missing_chat_for_session_payload = missing_chat_for_session.json() + assert missing_chat_for_session_payload["code"] == 102, missing_chat_for_session_payload + assert "`chat_id` is required when `session_id` is provided." in missing_chat_for_session_payload["message"], missing_chat_for_session_payload + + invalid_session = rest_client.post( + "/chat/completions", + json={ + "chat_id": chat_id, + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + "session_id": "invalid_session", + }, + ) + assert invalid_session.status_code == 200 + invalid_session_payload = invalid_session.json() + assert invalid_session_payload["code"] == 102, invalid_session_payload + assert "Session not found!" in invalid_session_payload["message"], invalid_session_payload + + invalid_chat = rest_client.post( + "/chat/completions", + json={ + "chat_id": "invalid_chat_id", + "session_id": "invalid_session", + "messages": [{"role": "user", "content": "hello"}], + "stream": False, + }, + ) + assert invalid_chat.status_code == 200 + invalid_chat_payload = invalid_chat.json() + assert invalid_chat_payload["code"] == 109, invalid_chat_payload + assert "No authorization." in invalid_chat_payload["message"], invalid_chat_payload diff --git a/test/testcases/restful_api/test_system.py b/test/testcases/restful_api/test_system.py new file mode 100644 index 00000000000..e5022f4861d --- /dev/null +++ b/test/testcases/restful_api/test_system.py @@ -0,0 +1,159 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p1 +def test_system_ping(rest_client): + res = rest_client.get("/system/ping") + assert res.status_code == 200 + assert res.text == "pong" + + +@pytest.mark.p1 +def test_system_version(rest_client): + res = rest_client.get("/system/version") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"], payload + + +@pytest.mark.p2 +def test_system_status_requires_auth(rest_client_noauth): + res = rest_client_noauth.get("/system/status") + assert res.status_code == 401 + payload = res.json() + assert payload["code"] == 401, payload + assert "Unauthorized" in payload["message"], payload + + +@pytest.mark.p2 +def test_system_status_contract(rest_client): + res = rest_client.get("/system/status") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + for key in ("doc_engine", "storage", "database", "redis"): + assert key in payload["data"], payload + + +@pytest.mark.p2 +def test_system_config_no_auth_required(rest_client_noauth): + res = rest_client_noauth.get("/system/config") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert "registerEnabled" in payload["data"], payload + assert "disablePasswordLogin" in payload["data"], payload + + +@pytest.mark.p2 +def test_system_healthz_contract(rest_client_noauth): + res = rest_client_noauth.get("/system/healthz") + assert res.status_code in (200, 500) + payload = res.json() + assert isinstance(payload, dict), payload + assert payload, payload + + +@pytest.mark.p2 +def test_system_tokens_auth_and_crud(rest_client, rest_client_noauth): + unauth_list = rest_client_noauth.get("/system/tokens") + assert unauth_list.status_code == 401 + unauth_list_payload = unauth_list.json() + assert unauth_list_payload["code"] == 401, unauth_list_payload + + create_res = rest_client.post("/system/tokens") + assert create_res.status_code == 200 + create_payload = create_res.json() + assert create_payload["code"] == 0, create_payload + token = create_payload["data"]["token"] + + list_res = rest_client.get("/system/tokens") + assert list_res.status_code == 200 + list_payload = list_res.json() + assert list_payload["code"] == 0, list_payload + assert isinstance(list_payload["data"], list), list_payload + assert any(item.get("token") == token for item in list_payload["data"]), list_payload + + delete_res = rest_client.delete(f"/system/tokens/{token}") + assert delete_res.status_code == 200 + delete_payload = delete_res.json() + assert delete_payload["code"] == 0, delete_payload + assert delete_payload["data"] is True, delete_payload + + delete_missing = rest_client.delete("/system/tokens/missing_token") + assert delete_missing.status_code == 200 + delete_missing_payload = delete_missing.json() + assert delete_missing_payload["code"] == 0, delete_missing_payload + assert delete_missing_payload["data"] is True, delete_missing_payload + + +@pytest.mark.p2 +def test_system_stats_auth_and_shape(rest_client, rest_client_noauth): + unauth_res = rest_client_noauth.get("/system/stats") + assert unauth_res.status_code == 401 + unauth_payload = unauth_res.json() + assert unauth_payload["code"] == 401, unauth_payload + + res = rest_client.get("/system/stats") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + data = payload["data"] + for key in ("pv", "uv", "speed", "tokens", "round", "thumb_up"): + assert key in data, payload + assert isinstance(data[key], list), payload + + +@pytest.mark.p2 +def test_system_oceanbase_status_auth_contract(rest_client, rest_client_noauth): + unauth = rest_client_noauth.get("/system/oceanbase/status") + assert unauth.status_code == 401 + assert unauth.json()["code"] == 401 + + res = rest_client.get("/system/oceanbase/status") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] in (0, 500), payload + assert "data" in payload, payload + + +@pytest.mark.p2 +def test_system_log_config_routes_auth_and_validation(rest_client, rest_client_noauth): + unauth = rest_client_noauth.get("/system/config/log") + assert unauth.status_code == 401 + assert unauth.json()["code"] == 401 + + levels = rest_client.get("/system/config/log") + assert levels.status_code == 200 + levels_payload = levels.json() + assert levels_payload["code"] == 0, levels_payload + assert isinstance(levels_payload["data"], dict), levels_payload + + missing_body = rest_client.put("/system/config/log", json={}) + assert missing_body.status_code == 200 + missing_payload = missing_body.json() + assert missing_payload["code"] == 102, missing_payload + assert "pkg_name and level are required" in missing_payload["message"], missing_payload + + invalid_level = rest_client.put("/system/config/log", json={"pkg_name": "rag", "level": "NOT_A_LEVEL"}) + assert invalid_level.status_code == 200 + invalid_payload = invalid_level.json() + assert invalid_payload["code"] == 102, invalid_payload + assert "Invalid log level" in invalid_payload["message"], invalid_payload diff --git a/test/testcases/restful_api/test_task_routes.py b/test/testcases/restful_api/test_task_routes.py new file mode 100644 index 00000000000..13a0fc8a9d9 --- /dev/null +++ b/test/testcases/restful_api/test_task_routes.py @@ -0,0 +1,48 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + + +@pytest.mark.p2 +def test_task_routes_require_auth(rest_client_noauth): + cancel_res = rest_client_noauth.post("/tasks/missing_task/cancel") + assert cancel_res.status_code == 401 + cancel_payload = cancel_res.json() + assert cancel_payload["code"] == 401, cancel_payload + + patch_res = rest_client_noauth.patch("/tasks/missing_task", json={"action": "stop"}) + assert patch_res.status_code == 401 + patch_payload = patch_res.json() + assert patch_payload["code"] == 401, patch_payload + + +@pytest.mark.p2 +def test_patch_task_rejects_unsupported_action(rest_client): + res = rest_client.patch("/tasks/missing_task", json={"action": "pause"}) + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 101, payload + assert "Only 'stop' is supported" in payload["message"], payload + + +@pytest.mark.p2 +def test_cancel_missing_task_sets_cancel_contract(rest_client): + res = rest_client.post("/tasks/missing_task/cancel") + assert res.status_code == 200 + payload = res.json() + assert payload["code"] == 0, payload + assert payload["data"] is True, payload diff --git a/test/testcases/restful_api/test_user_tenant_routes_unit.py b/test/testcases/restful_api/test_user_tenant_routes_unit.py new file mode 100644 index 00000000000..4d006f66821 --- /dev/null +++ b/test/testcases/restful_api/test_user_tenant_routes_unit.py @@ -0,0 +1,1628 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import base64 +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _DummyManager: + def route(self, *_args, **_kwargs): + def decorator(func): + return func + + return decorator + + +class _AwaitableValue: + def __init__(self, value): + self._value = value + + def __await__(self): + async def _co(): + return self._value + + return _co().__await__() + + +class _Field: + def __init__(self, name): + self.name = name + + def __eq__(self, other): + return (self.name, other) + + +class _Invitee: + def __init__(self, user_id="invitee-1", email="invitee@example.com"): + self.id = user_id + self.email = email + + def to_dict(self): + return { + "id": self.id, + "avatar": "avatar-url", + "email": self.email, + "nickname": "Invitee", + "password": "ignored", + } + + +def _run(coro): + return asyncio.run(coro) + + +def _passthrough_login_required(func): + async def _wrapper(*args, **kwargs): + return await func(*args, **kwargs) + + _wrapper.__wrapped__ = func + return _wrapper + + +def _set_request_json(monkeypatch, module, payload): + async def _request_json(): + return payload + + monkeypatch.setattr(module, "get_request_json", _request_json) + + +def _load_tenant_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = SimpleNamespace(id="tenant-1", email="owner@example.com") + apps_mod.login_required = lambda fn: fn + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + + db_mod = ModuleType("api.db") + db_mod.UserTenantRole = SimpleNamespace(NORMAL="normal", OWNER="owner", INVITE="invite") + monkeypatch.setitem(sys.modules, "api.db", db_mod) + + db_models_mod = ModuleType("api.db.db_models") + db_models_mod.UserTenant = type( + "UserTenant", + (), + { + "tenant_id": _Field("tenant_id"), + "user_id": _Field("user_id"), + }, + ) + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + user_service_mod = ModuleType("api.db.services.user_service") + + class _UserTenantService: + @staticmethod + def get_by_tenant_id(_tenant_id): + return [] + + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def save(**_kwargs): + return True + + @staticmethod + def filter_delete(_conditions): + return True + + @staticmethod + def get_tenants_by_user_id(_user_id): + return [] + + @staticmethod + def filter_update(_conditions, _payload): + return True + + class _UserService: + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def get_by_id(_user_id): + return False, None + + user_service_mod.UserTenantService = _UserTenantService + user_service_mod.UserService = _UserService + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.get_json_result = lambda data=None, message="", code=0: {"code": code, "message": message, "data": data} + api_utils_mod.get_data_error_result = lambda message="": {"code": 102, "message": message, "data": False} + api_utils_mod.server_error_response = lambda exc: {"code": 100, "message": repr(exc), "data": False} + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda fn: fn) + api_utils_mod.get_request_json = lambda: _AwaitableValue({}) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + web_utils_mod = ModuleType("api.utils.web_utils") + web_utils_mod.send_invite_email = lambda **_kwargs: {"ok": True} + monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + constants_mod = ModuleType("common.constants") + constants_mod.RetCode = SimpleNamespace(AUTHENTICATION_ERROR=401, SERVER_ERROR=500, DATA_ERROR=102) + constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value=1)) + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + misc_utils_mod = ModuleType("common.misc_utils") + misc_utils_mod.get_uuid = lambda: "uuid-1" + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + time_utils_mod = ModuleType("common.time_utils") + time_utils_mod.delta_seconds = lambda _value: 0 + monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod) + + settings_mod = ModuleType("common.settings") + settings_mod.MAIL_FRONTEND_URL = "https://frontend.example/invite" + monkeypatch.setitem(sys.modules, "common.settings", settings_mod) + common_pkg.settings = settings_mod + + sys.modules.pop("test_tenant_app_unit_module", None) + module_path = repo_root / "api" / "apps" / "restful_apis" / "tenant_api.py" + spec = importlib.util.spec_from_file_location("test_tenant_app_unit_module", module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, "test_tenant_app_unit_module", module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_user_list_auth_success_exception_matrix_unit(monkeypatch): + module = _load_tenant_module(monkeypatch) + + module.current_user.id = "other-user" + res = module.user_list("tenant-1") + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert res["message"] == "No authorization.", res + + module.current_user.id = "tenant-1" + monkeypatch.setattr( + module.UserTenantService, + "get_by_tenant_id", + lambda _tenant_id: [{"id": "u1", "update_date": "2024-01-01 00:00:00"}], + ) + monkeypatch.setattr(module, "delta_seconds", lambda _value: 42) + res = module.user_list("tenant-1") + assert res["code"] == 0, res + assert res["data"][0]["delta_seconds"] == 42, res + + monkeypatch.setattr(module.UserTenantService, "get_by_tenant_id", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("list boom"))) + res = module.user_list("tenant-1") + assert res["code"] == 100, res + assert "list boom" in res["message"], res + + +@pytest.mark.p2 +def test_create_invite_role_and_email_failure_matrix_unit(monkeypatch): + module = _load_tenant_module(monkeypatch) + + module.current_user.id = "other-user" + _set_request_json(monkeypatch, module, {"email": "invitee@example.com"}) + res = _run(module.create("tenant-1")) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert res["message"] == "No authorization.", res + + module.current_user.id = "tenant-1" + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + res = _run(module.create("tenant-1")) + assert res["message"] == "User not found.", res + + invitee = _Invitee() + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [invitee]) + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.NORMAL)]) + res = _run(module.create("tenant-1")) + assert "already in the team." in res["message"], res + + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role=module.UserTenantRole.OWNER)]) + res = _run(module.create("tenant-1")) + assert "owner of the team." in res["message"], res + + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(role="strange-role")]) + res = _run(module.create("tenant-1")) + assert "role: strange-role is invalid." in res["message"], res + + saved = [] + scheduled = [] + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module.UserTenantService, "save", lambda **kwargs: saved.append(kwargs) or True) + monkeypatch.setattr(module.UserService, "get_by_id", lambda _user_id: (True, SimpleNamespace(nickname="Inviter Nick"))) + monkeypatch.setattr(module, "send_invite_email", lambda **kwargs: kwargs) + monkeypatch.setattr(module.asyncio, "create_task", lambda payload: scheduled.append(payload) or SimpleNamespace()) + res = _run(module.create("tenant-1")) + assert res["code"] == 0, res + assert saved and saved[-1]["role"] == module.UserTenantRole.INVITE, saved + assert scheduled and scheduled[-1]["inviter"] == "Inviter Nick", scheduled + assert sorted(res["data"].keys()) == ["avatar", "email", "id", "nickname"], res + + monkeypatch.setattr(module.asyncio, "create_task", lambda _payload: (_ for _ in ()).throw(RuntimeError("send boom"))) + res = _run(module.create("tenant-1")) + assert res["code"] == module.RetCode.SERVER_ERROR, res + assert "Failed to send invite email." in res["message"], res + + +@pytest.mark.p2 +def test_rm_and_tenant_list_matrix_unit(monkeypatch): + module = _load_tenant_module(monkeypatch) + + module.current_user.id = "outsider" + _set_request_json(monkeypatch, module, {"user_id": "user-2"}) + res = _run(module.rm("tenant-1")) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert res["message"] == "No authorization.", res + + module.current_user.id = "tenant-1" + deleted = [] + monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda conditions: deleted.append(conditions) or True) + res = _run(module.rm("tenant-1")) + assert res["code"] == 0, res + assert res["data"] is True, res + assert deleted, "filter_delete should be called" + + monkeypatch.setattr(module.UserTenantService, "filter_delete", lambda _conditions: (_ for _ in ()).throw(RuntimeError("rm boom"))) + res = _run(module.rm("tenant-1")) + assert res["code"] == 100, res + assert "rm boom" in res["message"], res + + monkeypatch.setattr( + module.UserTenantService, + "get_tenants_by_user_id", + lambda _user_id: [{"id": "tenant-1", "update_date": "2024-01-01 00:00:00"}], + ) + monkeypatch.setattr(module, "delta_seconds", lambda _value: 9) + res = module.tenant_list() + assert res["code"] == 0, res + assert res["data"][0]["delta_seconds"] == 9, res + + monkeypatch.setattr(module.UserTenantService, "get_tenants_by_user_id", lambda _user_id: (_ for _ in ()).throw(RuntimeError("tenant boom"))) + res = module.tenant_list() + assert res["code"] == 100, res + assert "tenant boom" in res["message"], res + + +@pytest.mark.p2 +def test_agree_success_and_exception_unit(monkeypatch): + module = _load_tenant_module(monkeypatch) + + calls = [] + monkeypatch.setattr(module.UserTenantService, "filter_update", lambda conditions, payload: calls.append((conditions, payload)) or True) + res = module.agree("tenant-1") + assert res["code"] == 0, res + assert res["data"] is True, res + assert calls and calls[-1][1]["role"] == module.UserTenantRole.NORMAL + + monkeypatch.setattr(module.UserTenantService, "filter_update", lambda _conditions, _payload: (_ for _ in ()).throw(RuntimeError("agree boom"))) + res = module.agree("tenant-1") + assert res["code"] == 100, res + assert "agree boom" in res["message"], res + + +class _Args(dict): + def get(self, key, default=None, type=None): + value = super().get(key, default) + if type is None: + return value + try: + return type(value) + except (TypeError, ValueError): + return default + + +class _DummyResponse: + def __init__(self, data): + self.data = data + self.headers = {} + + +class _DummyHTTPResponse: + def __init__(self, payload): + self._payload = payload + + def json(self): + return self._payload + + +class _DummyRedis: + def __init__(self): + self.store = {} + + def get(self, key): + return self.store.get(key) + + def set(self, key, value, _ttl=None): + self.store[key] = value + + def delete(self, key): + self.store.pop(key, None) + + +class _DummyUser: + def __init__(self, user_id, email, *, password="stored-password", is_active="1", nickname="nick"): + self.id = user_id + self.email = email + self.password = password + self.is_active = is_active + self.nickname = nickname + self.access_token = "" + self.save_calls = 0 + + def save(self): + self.save_calls += 1 + + def get_id(self): + return self.id + + def to_json(self): + return {"id": self.id, "email": self.email, "nickname": self.nickname} + + def to_dict(self): + return {"id": self.id, "email": self.email} + + +def _set_request_args(monkeypatch, module, args=None): + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args(args or {}))) + + +@pytest.fixture(scope="session") +def auth(): + return "unit-auth" + + +@pytest.fixture(scope="session", autouse=True) +def set_tenant_info(): + return None + + +def _load_user_app(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + + quart_mod = ModuleType("quart") + quart_mod.session = {} + quart_mod.request = SimpleNamespace(args=_Args({})) + + async def _make_response(data): + return _DummyResponse(data) + + quart_mod.make_response = _make_response + quart_mod.redirect = lambda url: {"redirect": url} + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_mod = ModuleType("api.apps") + apps_mod.__path__ = [str(repo_root / "api" / "apps")] + apps_mod.current_user = _DummyUser("current-user", "current@example.com") + apps_mod.login_required = lambda fn: fn + apps_mod.login_user = lambda _user: True + apps_mod.logout_user = lambda: True + monkeypatch.setitem(sys.modules, "api.apps", apps_mod) + api_pkg.apps = apps_mod + + apps_auth_mod = ModuleType("api.apps.auth") + apps_auth_mod.get_auth_client = lambda _config: SimpleNamespace( + get_authorization_url=lambda state: f"https://oauth.example/{state}" + ) + monkeypatch.setitem(sys.modules, "api.apps.auth", apps_auth_mod) + + db_mod = ModuleType("api.db") + db_mod.FileType = SimpleNamespace(FOLDER=SimpleNamespace(value="folder")) + db_mod.UserTenantRole = SimpleNamespace(OWNER="owner") + monkeypatch.setitem(sys.modules, "api.db", db_mod) + api_pkg.db = db_mod + + db_models_mod = ModuleType("api.db.db_models") + + class _DummyTenantLLMModel: + tenant_id = _Field("tenant_id") + + @staticmethod + def delete(): + class _DeleteQuery: + def where(self, *_args, **_kwargs): + return self + + def execute(self): + return 1 + + return _DeleteQuery() + + db_models_mod.TenantLLM = _DummyTenantLLMModel + monkeypatch.setitem(sys.modules, "api.db.db_models", db_models_mod) + + services_pkg = ModuleType("api.db.services") + services_pkg.__path__ = [] + monkeypatch.setitem(sys.modules, "api.db.services", services_pkg) + + file_service_mod = ModuleType("api.db.services.file_service") + + class _StubFileService: + @staticmethod + def insert(_data): + return True + + file_service_mod.FileService = _StubFileService + monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod) + + llm_service_mod = ModuleType("api.db.services.llm_service") + llm_service_mod.get_init_tenant_llm = lambda _user_id: [] + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) + + tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") + + class _MockTableObject: + def __init__(self, **kwargs): + for key, value in kwargs.items(): + setattr(self, key, value) + + def to_dict(self): + return {k: v for k, v in self.__dict__.items()} + + class _StubTenantLLMService: + @staticmethod + def insert_many(_payload): + return True + + @staticmethod + def get_api_key(tenant_id, model_name, model_type=None): + return _MockTableObject( + id=1, + tenant_id=tenant_id, + llm_factory="", + model_type="chat", + llm_name=model_name, + api_key="fake-api-key", + api_base="https://api.example.com", + max_tokens=8192, + used_tokens=0, + status=1 + ) + + tenant_llm_service_mod.TenantLLMService = _StubTenantLLMService + monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + + class _StubTenantService: + @staticmethod + def insert(**_kwargs): + return True + + @staticmethod + def delete_by_id(_tenant_id): + return True + + @staticmethod + def get_by_id(_tenant_id): + return True, SimpleNamespace(id=_tenant_id) + + @staticmethod + def get_info_by(_user_id): + return [] + + @staticmethod + def update_by_id(_tenant_id, _payload): + return True + + class _StubUserService: + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def query_user(_email, _password): + return None + + @staticmethod + def query_user_by_email(**_kwargs): + return [] + + @staticmethod + def save(**_kwargs): + return True + + @staticmethod + def delete_by_id(_user_id): + return True + + @staticmethod + def update_by_id(_user_id, _payload): + return True + + @staticmethod + def update_user_password(_user_id, _new_password): + return True + + class _StubUserTenantService: + @staticmethod + def insert(**_kwargs): + return True + + @staticmethod + def query(**_kwargs): + return [] + + @staticmethod + def delete_by_id(_user_tenant_id): + return True + + user_service_mod.TenantService = _StubTenantService + user_service_mod.UserService = _StubUserService + user_service_mod.UserTenantService = _StubUserTenantService + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + + async def _default_request_json(): + return {} + + def _get_json_result(code=0, message="success", data=None): + return {"code": code, "message": message, "data": data} + + def _get_data_error_result(code=102, message="Sorry! Data missing!", data=None): + return {"code": code, "message": message, "data": data} + + def _server_error_response(error): + return {"code": 100, "message": repr(error)} + + def _validate_request(*_args, **_kwargs): + def _decorator(func): + return func + + return _decorator + + api_utils_mod.get_request_json = _default_request_json + api_utils_mod.get_json_result = _get_json_result + api_utils_mod.get_data_error_result = _get_data_error_result + api_utils_mod.server_error_response = _server_error_response + api_utils_mod.validate_request = _validate_request + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + tenant_utils_mod = ModuleType("api.utils.tenant_utils") + tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, params: params + monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) + + crypt_mod = ModuleType("api.utils.crypt") + crypt_mod.decrypt = lambda value: value + monkeypatch.setitem(sys.modules, "api.utils.crypt", crypt_mod) + + web_utils_mod = ModuleType("api.utils.web_utils") + web_utils_mod.send_email_html = lambda *_args, **_kwargs: _AwaitableValue(True) + web_utils_mod.OTP_LENGTH = 6 + web_utils_mod.OTP_TTL_SECONDS = 600 + web_utils_mod.ATTEMPT_LIMIT = 5 + web_utils_mod.ATTEMPT_LOCK_SECONDS = 600 + web_utils_mod.RESEND_COOLDOWN_SECONDS = 60 + web_utils_mod.otp_keys = lambda email: ( + f"otp:{email}:code", + f"otp:{email}:attempts", + f"otp:{email}:last", + f"otp:{email}:lock", + ) + web_utils_mod.hash_code = lambda code, _salt: f"hash:{code}" + web_utils_mod.captcha_key = lambda email: f"captcha:{email}" + monkeypatch.setitem(sys.modules, "api.utils.web_utils", web_utils_mod) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + settings_mod = ModuleType("common.settings") + settings_mod.OAUTH_CONFIG = { + "github": {"display_name": "GitHub", "icon": "gh"}, + "feishu": {"display_name": "Feishu", "icon": "fs"}, + } + settings_mod.GITHUB_OAUTH = {"url": "https://github.example/oauth", "client_id": "cid", "secret_key": "sk"} + settings_mod.FEISHU_OAUTH = { + "app_access_token_url": "https://feishu.example/app_token", + "user_access_token_url": "https://feishu.example/user_token", + "app_id": "app-id", + "app_secret": "app-secret", + "grant_type": "authorization_code", + } + settings_mod.CHAT_MDL = "chat-mdl" + settings_mod.EMBEDDING_MDL = "embd-mdl" + settings_mod.ASR_MDL = "asr-mdl" + settings_mod.PARSERS = [] + settings_mod.IMAGE2TEXT_MDL = "img-mdl" + settings_mod.RERANK_MDL = "rerank-mdl" + settings_mod.REGISTER_ENABLED = True + monkeypatch.setitem(sys.modules, "common.settings", settings_mod) + common_pkg.settings = settings_mod + + constants_mod = ModuleType("common.constants") + constants_mod.RetCode = SimpleNamespace( + AUTHENTICATION_ERROR=401, + SERVER_ERROR=500, + FORBIDDEN=403, + EXCEPTION_ERROR=100, + OPERATING_ERROR=300, + ARGUMENT_ERROR=101, + DATA_ERROR=102, + NOT_EFFECTIVE=103, + SUCCESS=0, + ) + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + connection_utils_mod = ModuleType("common.connection_utils") + + async def _construct_response(data=None, auth=None, message=""): + return {"code": 0, "message": message, "data": data, "auth": auth} + + connection_utils_mod.construct_response = _construct_response + monkeypatch.setitem(sys.modules, "common.connection_utils", connection_utils_mod) + + time_utils_mod = ModuleType("common.time_utils") + time_utils_mod.current_timestamp = lambda: 111 + time_utils_mod.datetime_format = lambda _dt: "2024-01-01 00:00:00" + time_utils_mod.get_format_time = lambda: "2024-01-01 00:00:00" + monkeypatch.setitem(sys.modules, "common.time_utils", time_utils_mod) + + misc_utils_mod = ModuleType("common.misc_utils") + misc_utils_mod.download_img = lambda _url: "avatar" + misc_utils_mod.get_uuid = lambda: "uuid-default" + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + http_client_mod = ModuleType("common.http_client") + + async def _async_request(_method, _url, **_kwargs): + return _DummyHTTPResponse({}) + + http_client_mod.async_request = _async_request + monkeypatch.setitem(sys.modules, "common.http_client", http_client_mod) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + + rag_utils_pkg = ModuleType("rag.utils") + rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")] + monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg) + + redis_mod = ModuleType("rag.utils.redis_conn") + redis_mod.REDIS_CONN = _DummyRedis() + monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", redis_mod) + + module_name = "test_user_app_unit_module" + module_path = repo_root / "api" / "apps" / "restful_apis" / "user_api.py" + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p2 +def test_login_route_branch_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + _set_request_json(monkeypatch, module, {}) + res = _run(module.login()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "Unauthorized" in res["message"] + + _set_request_json(monkeypatch, module, {"email": "unknown@example.com", "password": "enc"}) + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + res = _run(module.login()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "not registered" in res["message"] + + _set_request_json(monkeypatch, module, {"email": "known@example.com", "password": "enc"}) + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [SimpleNamespace(email="known@example.com")]) + + def _raise_decrypt(_value): + raise RuntimeError("decrypt explode") + + monkeypatch.setattr(module, "decrypt", _raise_decrypt) + res = _run(module.login()) + assert res["code"] == module.RetCode.SERVER_ERROR + assert "Fail to crypt password" in res["message"] + + user_inactive = _DummyUser("u-inactive", "known@example.com", is_active="0") + monkeypatch.setattr(module, "decrypt", lambda value: value) + monkeypatch.setattr(module.UserService, "query_user", lambda _email, _password: user_inactive) + res = _run(module.login()) + assert res["code"] == module.RetCode.FORBIDDEN + assert "disabled" in res["message"] + + monkeypatch.setattr(module.UserService, "query_user", lambda _email, _password: None) + res = _run(module.login()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "do not match" in res["message"] + + +@pytest.mark.p2 +def test_login_channels_and_oauth_login_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + module.settings.OAUTH_CONFIG = {"github": {"display_name": "GitHub", "icon": "gh"}} + res = _run(module.get_login_channels()) + assert res["code"] == 0 + assert res["data"][0]["channel"] == "github" + + class _BrokenOAuthConfig: + @staticmethod + def items(): + raise RuntimeError("broken oauth config") + + module.settings.OAUTH_CONFIG = _BrokenOAuthConfig() + res = _run(module.get_login_channels()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR + assert "Load channels failure" in res["message"] + + module.settings.OAUTH_CONFIG = {"github": {"display_name": "GitHub", "icon": "gh"}} + with pytest.raises(ValueError, match="Invalid channel name: missing"): + _run(module.oauth_login("missing")) + + module.session.clear() + monkeypatch.setattr(module, "get_uuid", lambda: "state-123") + + class _AuthClient: + @staticmethod + def get_authorization_url(state): + return f"https://oauth.example/{state}" + + monkeypatch.setattr(module, "get_auth_client", lambda _config: _AuthClient()) + res = _run(module.oauth_login("github")) + assert res["redirect"] == "https://oauth.example/state-123" + assert module.session["oauth_state"] == "state-123" + + +@pytest.mark.p2 +def test_oauth_callback_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + module.settings.OAUTH_CONFIG = {"github": {"display_name": "GitHub", "icon": "gh"}} + + class _SyncAuthClient: + def __init__(self, token_info, user_info): + self._token_info = token_info + self._user_info = user_info + + def exchange_code_for_token(self, _code): + return self._token_info + + def fetch_user_info(self, _token, id_token=None): + _ = id_token + return self._user_info + + class _AsyncAuthClient: + def __init__(self, token_info, user_info): + self._token_info = token_info + self._user_info = user_info + + async def async_exchange_code_for_token(self, _code): + return self._token_info + + async def async_fetch_user_info(self, _token, id_token=None): + _ = id_token + return self._user_info + + _set_request_args(monkeypatch, module, {"state": "x", "code": "c"}) + module.session.clear() + res = _run(module.oauth_callback("missing")) + assert "Invalid channel name: missing" in res["redirect"] + + sync_ok = _SyncAuthClient( + token_info={"access_token": "token-sync", "id_token": "id-sync"}, + user_info=SimpleNamespace(email="sync@example.com", avatar_url="http://img", nickname="sync"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: sync_ok) + + module.session.clear() + module.session["oauth_state"] = "expected" + _set_request_args(monkeypatch, module, {"state": "wrong", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=invalid_state" + + module.session.clear() + module.session["oauth_state"] = "ok-state" + _set_request_args(monkeypatch, module, {"state": "ok-state"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=missing_code" + + sync_missing_token = _SyncAuthClient( + token_info={"id_token": "id-only"}, + user_info=SimpleNamespace(email="sync@example.com", avatar_url="http://img", nickname="sync"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: sync_missing_token) + module.session.clear() + module.session["oauth_state"] = "token-state" + _set_request_args(monkeypatch, module, {"state": "token-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=token_failed" + + sync_missing_email = _SyncAuthClient( + token_info={"access_token": "token-sync", "id_token": "id-sync"}, + user_info=SimpleNamespace(email=None, avatar_url="http://img", nickname="sync"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: sync_missing_email) + module.session.clear() + module.session["oauth_state"] = "email-state" + _set_request_args(monkeypatch, module, {"state": "email-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=email_missing" + + async_new_user = _AsyncAuthClient( + token_info={"access_token": "token-async", "id_token": "id-async"}, + user_info=SimpleNamespace(email="new@example.com", avatar_url="http://img", nickname="new-user"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: async_new_user) + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + + def _raise_download(_url): + raise RuntimeError("download explode") + + monkeypatch.setattr(module, "download_img", _raise_download) + monkeypatch.setattr(module, "user_register", lambda _user_id, _user: None) + rollback_calls = [] + monkeypatch.setattr(module, "rollback_user_registration", lambda user_id: rollback_calls.append(user_id)) + monkeypatch.setattr(module, "get_uuid", lambda: "new-user-id") + module.session.clear() + module.session["oauth_state"] = "new-user-state" + _set_request_args(monkeypatch, module, {"state": "new-user-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert "Failed to register new@example.com" in res["redirect"] + assert rollback_calls == ["new-user-id"] + + monkeypatch.setattr(module, "download_img", lambda _url: "avatar") + monkeypatch.setattr( + module, + "user_register", + lambda _user_id, _user: [_DummyUser("dup-1", "new@example.com"), _DummyUser("dup-2", "new@example.com")], + ) + rollback_calls.clear() + module.session.clear() + module.session["oauth_state"] = "dup-user-state" + _set_request_args(monkeypatch, module, {"state": "dup-user-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert "Same email: new@example.com exists!" in res["redirect"] + assert rollback_calls == ["new-user-id"] + + new_user = _DummyUser("new-user", "new@example.com") + login_calls = [] + monkeypatch.setattr(module, "login_user", lambda user: login_calls.append(user)) + monkeypatch.setattr(module, "user_register", lambda _user_id, _user: [new_user]) + module.session.clear() + module.session["oauth_state"] = "create-user-state" + _set_request_args(monkeypatch, module, {"state": "create-user-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?auth=new-user" + assert login_calls and login_calls[-1] is new_user + + async_existing_inactive = _AsyncAuthClient( + token_info={"access_token": "token-existing", "id_token": "id-existing"}, + user_info=SimpleNamespace(email="existing@example.com", avatar_url="http://img", nickname="existing"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: async_existing_inactive) + inactive_user = _DummyUser("existing-user", "existing@example.com", is_active="0") + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [inactive_user]) + module.session.clear() + module.session["oauth_state"] = "inactive-state" + _set_request_args(monkeypatch, module, {"state": "inactive-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?error=user_inactive" + + async_existing_ok = _AsyncAuthClient( + token_info={"access_token": "token-existing", "id_token": "id-existing"}, + user_info=SimpleNamespace(email="existing@example.com", avatar_url="http://img", nickname="existing"), + ) + monkeypatch.setattr(module, "get_auth_client", lambda _config: async_existing_ok) + existing_user = _DummyUser("existing-user", "existing@example.com") + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [existing_user]) + login_calls.clear() + monkeypatch.setattr(module, "login_user", lambda user: login_calls.append(user)) + monkeypatch.setattr(module, "get_uuid", lambda: "existing-token") + module.session.clear() + module.session["oauth_state"] = "existing-state" + _set_request_args(monkeypatch, module, {"state": "existing-state", "code": "code"}) + res = _run(module.oauth_callback("github")) + assert res["redirect"] == "/?auth=existing-user" + assert existing_user.access_token == "existing-token" + assert existing_user.save_calls == 1 + assert login_calls and login_calls[-1] is existing_user + + +@pytest.mark.p2 +def test_logout_setting_profile_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + current_user = _DummyUser("current-user", "current@example.com", password="stored-password") + monkeypatch.setattr(module, "current_user", current_user) + monkeypatch.setattr(module.secrets, "token_hex", lambda _n: "abcdef") + logout_calls = [] + monkeypatch.setattr(module, "logout_user", lambda: logout_calls.append(True)) + + res = _run(module.log_out()) + assert res["code"] == 0 + assert current_user.access_token == "INVALID_abcdef" + assert current_user.save_calls == 1 + assert logout_calls == [True] + + _set_request_json(monkeypatch, module, {"password": "old-password", "new_password": "new-password"}) + monkeypatch.setattr(module, "decrypt", lambda value: value) + monkeypatch.setattr(module, "check_password_hash", lambda _hashed, _plain: False) + res = _run(module.setting_user()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR + assert "Password error" in res["message"] + + _set_request_json( + monkeypatch, + module, + { + "password": "old-password", + "new_password": "new-password", + "nickname": "neo", + "email": "blocked@example.com", + "status": "disabled", + "theme": "dark", + }, + ) + monkeypatch.setattr(module, "check_password_hash", lambda _hashed, _plain: True) + monkeypatch.setattr(module, "decrypt", lambda value: f"dec:{value}") + monkeypatch.setattr(module, "generate_password_hash", lambda value: f"hash:{value}") + update_calls = {} + + def _update_by_id(user_id, payload): + update_calls["user_id"] = user_id + update_calls["payload"] = payload + return True + + monkeypatch.setattr(module.UserService, "update_by_id", _update_by_id) + res = _run(module.setting_user()) + assert res["code"] == 0 + assert res["data"] is True + assert update_calls["user_id"] == "current-user" + assert update_calls["payload"]["password"] == "hash:dec:new-password" + assert update_calls["payload"]["nickname"] == "neo" + assert update_calls["payload"]["theme"] == "dark" + assert "email" not in update_calls["payload"] + assert "status" not in update_calls["payload"] + + _set_request_json(monkeypatch, module, {"nickname": "neo"}) + + def _raise_update(_user_id, _payload): + raise RuntimeError("update explode") + + monkeypatch.setattr(module.UserService, "update_by_id", _raise_update) + res = _run(module.setting_user()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR + assert "Update failure" in res["message"] + + res = _run(module.user_profile()) + assert res["code"] == 0 + assert res["data"] == current_user.to_dict() + + +@pytest.mark.p2 +def test_registration_helpers_and_register_route_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + deleted = {"user": 0, "tenant": 0, "user_tenant": 0, "tenant_llm": 0} + monkeypatch.setattr(module.UserService, "delete_by_id", lambda _user_id: deleted.__setitem__("user", deleted["user"] + 1)) + monkeypatch.setattr(module.TenantService, "delete_by_id", lambda _tenant_id: deleted.__setitem__("tenant", deleted["tenant"] + 1)) + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: [SimpleNamespace(id="ut-1")]) + monkeypatch.setattr(module.UserTenantService, "delete_by_id", lambda _ut_id: deleted.__setitem__("user_tenant", deleted["user_tenant"] + 1)) + + class _DeleteQuery: + def where(self, *_args, **_kwargs): + return self + + def execute(self): + deleted["tenant_llm"] += 1 + return 1 + + monkeypatch.setattr(module.TenantLLM, "delete", lambda: _DeleteQuery()) + module.rollback_user_registration("user-1") + assert deleted == {"user": 1, "tenant": 1, "user_tenant": 1, "tenant_llm": 1}, deleted + + monkeypatch.setattr(module.UserService, "delete_by_id", lambda _user_id: (_ for _ in ()).throw(RuntimeError("u boom"))) + monkeypatch.setattr(module.TenantService, "delete_by_id", lambda _tenant_id: (_ for _ in ()).throw(RuntimeError("t boom"))) + monkeypatch.setattr(module.UserTenantService, "query", lambda **_kwargs: (_ for _ in ()).throw(RuntimeError("ut boom"))) + + class _RaisingDeleteQuery: + def where(self, *_args, **_kwargs): + raise RuntimeError("llm boom") + + monkeypatch.setattr(module.TenantLLM, "delete", lambda: _RaisingDeleteQuery()) + module.rollback_user_registration("user-2") + + monkeypatch.setattr(module.UserService, "save", lambda **_kwargs: False) + res = module.user_register( + "new-user", + { + "nickname": "new", + "email": "new@example.com", + "password": "pw", + "access_token": "tk", + "login_channel": "password", + "last_login_time": "2024-01-01 00:00:00", + "is_superuser": False, + }, + ) + assert res is None + + monkeypatch.setattr(module.settings, "REGISTER_ENABLED", False) + _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "neo@example.com", "password": "enc"}) + res = _run(module.user_add()) + assert res["code"] == module.RetCode.OPERATING_ERROR, res + assert "disabled" in res["message"], res + + monkeypatch.setattr(module.settings, "REGISTER_ENABLED", True) + _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "bad-email", "password": "enc"}) + res = _run(module.user_add()) + assert res["code"] == module.RetCode.OPERATING_ERROR, res + assert "Invalid email address" in res["message"], res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + monkeypatch.setattr(module, "decrypt", lambda value: value) + monkeypatch.setattr(module, "get_uuid", lambda: "new-user-id") + rollback_calls = [] + monkeypatch.setattr(module, "rollback_user_registration", lambda user_id: rollback_calls.append(user_id)) + + _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "neo@example.com", "password": "enc"}) + monkeypatch.setattr(module, "user_register", lambda _user_id, _payload: None) + res = _run(module.user_add()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + assert "Fail to register neo@example.com." in res["message"], res + assert rollback_calls == ["new-user-id"], rollback_calls + + rollback_calls.clear() + monkeypatch.setattr( + module, + "user_register", + lambda _user_id, _payload: [_DummyUser("dup-1", "neo@example.com"), _DummyUser("dup-2", "neo@example.com")], + ) + _set_request_json(monkeypatch, module, {"nickname": "neo", "email": "neo@example.com", "password": "enc"}) + res = _run(module.user_add()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + assert "Same email: neo@example.com exists!" in res["message"], res + assert rollback_calls == ["new-user-id"], rollback_calls + + +@pytest.mark.p2 +def test_tenant_info_and_set_tenant_info_exception_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + monkeypatch.setattr(module.TenantService, "get_info_by", lambda _uid: []) + res = _run(module.tenant_info()) + assert res["code"] == module.RetCode.DATA_ERROR, res + assert "Tenant not found" in res["message"], res + + def _raise_tenant_info(_uid): + raise RuntimeError("tenant info boom") + + monkeypatch.setattr(module.TenantService, "get_info_by", _raise_tenant_info) + res = _run(module.tenant_info()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + assert "tenant info boom" in res["message"], res + + _set_request_json( + monkeypatch, + module, + {"tenant_id": "tenant-1", "llm_id": "l", "embd_id": "e", "asr_id": "a", "img2txt_id": "i"}, + ) + + def _raise_update(_tenant_id, _payload): + raise RuntimeError("tenant update boom") + + monkeypatch.setattr(module.TenantService, "update_by_id", _raise_update) + res = _run(module.set_tenant_info()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + assert "tenant update boom" in res["message"], res + + +@pytest.mark.p2 +def test_forget_captcha_and_send_otp_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + + class _Headers(dict): + def set(self, key, value): + self[key] = value + + async def _make_response(data): + return SimpleNamespace(data=data, headers=_Headers()) + + monkeypatch.setattr(module, "make_response", _make_response) + + captcha_pkg = ModuleType("captcha") + captcha_image_mod = ModuleType("captcha.image") + + class _ImageCaptcha: + def __init__(self, **_kwargs): + pass + + def generate(self, text): + return SimpleNamespace(read=lambda: f"img:{text}".encode()) + + captcha_image_mod.ImageCaptcha = _ImageCaptcha + monkeypatch.setitem(sys.modules, "captcha", captcha_pkg) + monkeypatch.setitem(sys.modules, "captcha.image", captcha_image_mod) + + _set_request_args(monkeypatch, module, {"email": ""}) + res = _run(module.forget_get_captcha()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + _set_request_args(monkeypatch, module, {"email": "nobody@example.com"}) + res = _run(module.forget_get_captcha()) + assert res["code"] == module.RetCode.DATA_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [_DummyUser("u1", "ok@example.com")]) + monkeypatch.setattr(module.secrets, "choice", lambda _allowed: "A") + _set_request_args(monkeypatch, module, {"email": "ok@example.com"}) + res = _run(module.forget_get_captcha()) + assert res.data.startswith(b"img:"), res + assert res.headers["Content-Type"] == "image/JPEG", res.headers + assert module.REDIS_CONN.get(module.captcha_key("ok@example.com")), module.REDIS_CONN.store + + _set_request_json(monkeypatch, module, {"email": "", "captcha": ""}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + _set_request_json(monkeypatch, module, {"email": "none@example.com", "captcha": "AAAA"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.DATA_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [_DummyUser("u1", "ok@example.com")]) + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "AAAA"}) + module.REDIS_CONN.store.pop(module.captcha_key("ok@example.com"), None) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.NOT_EFFECTIVE, res + + module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD" + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ZZZZ"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + + monkeypatch.setattr(module.time, "time", lambda: 1000) + k_code, k_attempts, k_last, k_lock = module.otp_keys("ok@example.com") + module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD" + module.REDIS_CONN.store[k_last] = "990" + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ABCD"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.NOT_EFFECTIVE, res + assert "wait" in res["message"], res + + module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD" + module.REDIS_CONN.store[k_last] = "bad-timestamp" + monkeypatch.setattr(module.secrets, "choice", lambda _allowed: "B") + monkeypatch.setattr(module.os, "urandom", lambda _n: b"\x00" * 16) + monkeypatch.setattr(module, "hash_code", lambda code, _salt: f"HASH_{code}") + + async def _raise_send_email(*_args, **_kwargs): + raise RuntimeError("send email boom") + + monkeypatch.setattr(module, "send_email_html", _raise_send_email) + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ABCD"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.SERVER_ERROR, res + assert "failed to send email" in res["message"], res + + async def _ok_send_email(*_args, **_kwargs): + return True + + module.REDIS_CONN.store[module.captcha_key("ok@example.com")] = "ABCD" + module.REDIS_CONN.store.pop(k_last, None) + monkeypatch.setattr(module, "send_email_html", _ok_send_email) + _set_request_json(monkeypatch, module, {"email": "ok@example.com", "captcha": "ABCD"}) + res = _run(module.forget_send_otp()) + assert res["code"] == module.RetCode.SUCCESS, res + assert res["data"] is True, res + assert module.REDIS_CONN.get(k_code), module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_attempts) == 0, module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_lock) is None, module.REDIS_CONN.store + + +@pytest.mark.p2 +def test_forget_verify_otp_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + email = "ok@example.com" + k_code, k_attempts, k_last, k_lock = module.otp_keys(email) + salt = b"\x01" * 16 + monkeypatch.setattr(module, "hash_code", lambda code, _salt: f"HASH_{code}") + + _set_request_json(monkeypatch, module, {}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: []) + _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.DATA_ERROR, res + + monkeypatch.setattr(module.UserService, "query", lambda **_kwargs: [_DummyUser("u1", email)]) + module.REDIS_CONN.store[k_lock] = "1" + _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.NOT_EFFECTIVE, res + module.REDIS_CONN.store.pop(k_lock, None) + + module.REDIS_CONN.store.pop(k_code, None) + _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.NOT_EFFECTIVE, res + + module.REDIS_CONN.store[k_code] = "broken" + _set_request_json(monkeypatch, module, {"email": email, "otp": "ABCDEF"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + + module.REDIS_CONN.store[k_code] = f"HASH_CORRECT:{salt.hex()}" + module.REDIS_CONN.store[k_attempts] = "bad-int" + _set_request_json(monkeypatch, module, {"email": email, "otp": "wrong"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert module.REDIS_CONN.get(k_attempts) == 1, module.REDIS_CONN.store + + module.REDIS_CONN.store[k_code] = f"HASH_CORRECT:{salt.hex()}" + module.REDIS_CONN.store[k_attempts] = str(module.ATTEMPT_LIMIT - 1) + _set_request_json(monkeypatch, module, {"email": email, "otp": "wrong"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + assert module.REDIS_CONN.get(k_lock) is not None, module.REDIS_CONN.store + module.REDIS_CONN.store.pop(k_lock, None) + + module.REDIS_CONN.store[k_code] = f"HASH_ABCDEF:{salt.hex()}" + module.REDIS_CONN.store[k_attempts] = "0" + module.REDIS_CONN.store[k_last] = "1000" + + def _set_with_verified_fail(key, value, _ttl=None): + if key == module._verified_key(email): + raise RuntimeError("verified set boom") + module.REDIS_CONN.store[key] = value + + monkeypatch.setattr(module.REDIS_CONN, "set", _set_with_verified_fail) + _set_request_json(monkeypatch, module, {"email": email, "otp": "abcdef"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.SERVER_ERROR, res + + monkeypatch.setattr(module.REDIS_CONN, "set", lambda key, value, _ttl=None: module.REDIS_CONN.store.__setitem__(key, value)) + module.REDIS_CONN.store[k_code] = f"HASH_ABCDEF:{salt.hex()}" + module.REDIS_CONN.store[k_attempts] = "0" + module.REDIS_CONN.store[k_last] = "1000" + _set_request_json(monkeypatch, module, {"email": email, "otp": "abcdef"}) + res = _run(module.forget_verify_otp()) + assert res["code"] == module.RetCode.SUCCESS, res + assert module.REDIS_CONN.get(k_code) is None, module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_attempts) is None, module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_last) is None, module.REDIS_CONN.store + assert module.REDIS_CONN.get(k_lock) is None, module.REDIS_CONN.store + assert module.REDIS_CONN.get(module._verified_key(email)) == "1", module.REDIS_CONN.store + + +@pytest.mark.p2 +def test_forget_reset_password_matrix_unit(monkeypatch): + module = _load_user_app(monkeypatch) + email = "reset@example.com" + v_key = module._verified_key(email) + user = _DummyUser("u-reset", email, nickname="reset-user") + pwd_a = base64.b64encode(b"new-password").decode() + pwd_b = base64.b64encode(b"confirm-password").decode() + pwd_same = base64.b64encode(b"same-password").decode() + monkeypatch.setattr(module, "decrypt", lambda value: value) + + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + module.REDIS_CONN.store.pop(v_key, None) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.AUTHENTICATION_ERROR, res + + module.REDIS_CONN.store[v_key] = "1" + monkeypatch.setattr(module, "decrypt", lambda _value: "") + _set_request_json(monkeypatch, module, {"email": email, "new_password": "", "confirm_new_password": ""}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + + monkeypatch.setattr(module, "decrypt", lambda value: value) + module.REDIS_CONN.store[v_key] = "1" + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_a, "confirm_new_password": pwd_b}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + assert "do not match" in res["message"], res + + module.REDIS_CONN.store[v_key] = "1" + monkeypatch.setattr(module.UserService, "query_user_by_email", lambda **_kwargs: []) + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.DATA_ERROR, res + + module.REDIS_CONN.store[v_key] = "1" + monkeypatch.setattr(module.UserService, "query_user_by_email", lambda **_kwargs: [user]) + + def _raise_update_password(_user_id, _new_pwd): + raise RuntimeError("reset boom") + + monkeypatch.setattr(module.UserService, "update_user_password", _raise_update_password) + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.EXCEPTION_ERROR, res + + module.REDIS_CONN.store[v_key] = "1" + monkeypatch.setattr(module.UserService, "update_user_password", lambda _user_id, _new_pwd: True) + monkeypatch.setattr(module.REDIS_CONN, "delete", lambda _key: (_ for _ in ()).throw(RuntimeError("delete boom"))) + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.SUCCESS, res + assert res["auth"] == user.get_id(), res + + monkeypatch.setattr(module.REDIS_CONN, "delete", lambda key: module.REDIS_CONN.store.pop(key, None)) + module.REDIS_CONN.store[v_key] = "1" + _set_request_json(monkeypatch, module, {"email": email, "new_password": pwd_same, "confirm_new_password": pwd_same}) + res = _run(module.forget_reset_password()) + assert res["code"] == module.RetCode.SUCCESS, res + assert res["auth"] == user.get_id(), res + assert module.REDIS_CONN.get(v_key) is None, module.REDIS_CONN.store + + +def _load_chat_routes_unit_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[3] + module_name = "test_chat_restful_routes_unit_module_for_tenant" + module_path = repo_root / "api" / "apps" / "restful_apis" / "chat_api.py" + + quart_mod = ModuleType("quart") + quart_mod.request = SimpleNamespace(args=SimpleNamespace(get=lambda _key, default=None: default, getlist=lambda _key: [])) + quart_mod.Response = type("_StubResponse", (), {}) + monkeypatch.setitem(sys.modules, "quart", quart_mod) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + + apps_pkg = ModuleType("api.apps") + apps_pkg.__path__ = [str(repo_root / "api" / "apps")] + apps_pkg.current_user = SimpleNamespace(id="tenant-1") + apps_pkg.login_required = _passthrough_login_required + monkeypatch.setitem(sys.modules, "api.apps", apps_pkg) + api_pkg.apps = apps_pkg + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + settings_mod = ModuleType("common.settings") + settings_mod.STORAGE_IMPL = type("_StorageImpl", (), {"rm": staticmethod(lambda *_args, **_kwargs: None)})() + monkeypatch.setitem(sys.modules, "common.settings", settings_mod) + + constants_mod = ModuleType("common.constants") + constants_mod.LLMType = SimpleNamespace(CHAT="chat", IMAGE2TEXT="image2text", RERANK="rerank", SPEECH2TEXT="speech2text", TTS="tts") + constants_mod.RetCode = SimpleNamespace(SUCCESS=0, DATA_ERROR=102, OPERATING_ERROR=103, AUTHENTICATION_ERROR=109) + constants_mod.StatusEnum = SimpleNamespace(VALID=SimpleNamespace(value="1"), INVALID=SimpleNamespace(value="0")) + from common.constants import MAXIMUM_PAGE_NUMBER as _MPN, MAXIMUM_TASK_PAGE_NUMBER as _MTPN + constants_mod.MAXIMUM_PAGE_NUMBER = _MPN + constants_mod.MAXIMUM_TASK_PAGE_NUMBER = _MTPN + monkeypatch.setitem(sys.modules, "common.constants", constants_mod) + + misc_utils_mod = ModuleType("common.misc_utils") + misc_utils_mod.get_uuid = lambda: "generated-chat-id" + async def _thread_pool_exec(func, *args, **kwargs): + return func(*args, **kwargs) + misc_utils_mod.thread_pool_exec = _thread_pool_exec + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) + + dialog_service_mod = ModuleType("api.db.services.dialog_service") + class _DialogService: + model = SimpleNamespace(_meta=SimpleNamespace(fields={ + "id": None, + "tenant_id": None, + "name": None, + "description": None, + "icon": None, + "kb_ids": None, + "llm_id": None, + "llm_setting": None, + "prompt_config": None, + "similarity_threshold": None, + "vector_similarity_weight": None, + "top_n": None, + "top_k": None, + "rerank_id": None, + "meta_data_filter": None, + "created_by": None, + "create_time": None, + "create_date": None, + "update_time": None, + "update_date": None, + "status": None, + })) + @staticmethod + def query(**_kwargs): + return [] + @staticmethod + def save(**_kwargs): + return True + @staticmethod + def get_by_id(_chat_id): + return False, None + @staticmethod + def get_by_tenant_ids(*_args, **_kwargs): + return [], 0 + dialog_service_mod.DialogService = _DialogService + dialog_service_mod.async_ask = lambda *_args, **_kwargs: None + dialog_service_mod.async_chat = lambda *_args, **_kwargs: None + dialog_service_mod.gen_mindmap = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.dialog_service", dialog_service_mod) + + conversation_service_mod = ModuleType("api.db.services.conversation_service") + conversation_service_mod.ConversationService = type("ConversationService", (), {}) + conversation_service_mod.structure_answer = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "api.db.services.conversation_service", conversation_service_mod) + + kb_service_mod = ModuleType("api.db.services.knowledgebase_service") + class _KB: + def __init__(self): + self.id = "kb-1" + self.embd_id = "embd@factory" + self.chunk_num = 1 + self.name = "Dataset A" + self.status = "1" + kb_service_mod.KnowledgebaseService = type('KnowledgebaseService', (), { + 'accessible': staticmethod(lambda **_kwargs: [SimpleNamespace(id='kb-1')]), + 'query': staticmethod(lambda **_kwargs: [_KB()]), + 'get_by_id': staticmethod(lambda _id: (True, _KB())), + }) + monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod) + + tenant_llm_service_mod = ModuleType("api.db.services.tenant_llm_service") + tenant_llm_service_mod.TenantLLMService = type('TenantLLMService', (), { + 'split_model_name_and_factory': staticmethod(lambda model: (model.split('@', 1)[0], model.split('@', 1)[1] if '@' in model else None)), + 'query': staticmethod(lambda **_kwargs: [SimpleNamespace(id='llm-1')]), + 'get_api_key': staticmethod(lambda *_args, **_kwargs: SimpleNamespace(id=1)), + }) + monkeypatch.setitem(sys.modules, "api.db.services.tenant_llm_service", tenant_llm_service_mod) + + llm_service_mod = ModuleType("api.db.services.llm_service") + llm_service_mod.LLMBundle = lambda *_args, **_kwargs: None + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service_mod) + + search_service_mod = ModuleType("api.db.services.search_service") + search_service_mod.SearchService = SimpleNamespace() + monkeypatch.setitem(sys.modules, "api.db.services.search_service", search_service_mod) + + tenant_model_service_mod = ModuleType("api.db.joint_services.tenant_model_service") + tenant_model_service_mod.get_model_config_by_type_and_name = lambda *_args, **_kwargs: {} + tenant_model_service_mod.get_tenant_default_model_by_type = lambda *_args, **_kwargs: {} + monkeypatch.setitem(sys.modules, "api.db.joint_services.tenant_model_service", tenant_model_service_mod) + + user_service_mod = ModuleType("api.db.services.user_service") + user_service_mod.UserService = type('UserService', (), {}) + user_service_mod.TenantService = type('TenantService', (), { + 'get_by_id': staticmethod(lambda _tenant_id: (True, SimpleNamespace(llm_id='glm-4'))), + 'get_joined_tenants_by_user_id': staticmethod(lambda _user_id: [{'tenant_id': 'tenant-1'}, {'tenant_id': 'team-tenant-2'}]), + }) + user_service_mod.UserTenantService = type('UserTenantService', (), {'query': staticmethod(lambda **_kwargs: [])}) + monkeypatch.setitem(sys.modules, "api.db.services.user_service", user_service_mod) + + chunk_feedback_service_mod = ModuleType("api.db.services.chunk_feedback_service") + chunk_feedback_service_mod.ChunkFeedbackService = type('ChunkFeedbackService', (), {'apply_feedback': staticmethod(lambda **_kwargs: {'success_count': 0, 'fail_count': 0, 'chunk_ids': []})}) + monkeypatch.setitem(sys.modules, "api.db.services.chunk_feedback_service", chunk_feedback_service_mod) + + api_utils_mod = ModuleType("api.utils.api_utils") + api_utils_mod.check_duplicate_ids = lambda ids, _label: (list(dict.fromkeys(ids or [])), []) + api_utils_mod.get_data_error_result = lambda message='': {'code': 102, 'data': None, 'message': message} + api_utils_mod.get_json_result = lambda data=None, message='', code=0: {'code': code, 'data': data, 'message': message} + api_utils_mod.server_error_response = lambda ex: {'code': 500, 'data': None, 'message': str(ex)} + api_utils_mod.validate_request = lambda *_args, **_kwargs: (lambda func: func) + api_utils_mod.get_request_json = lambda: _AwaitableValue({}) + monkeypatch.setitem(sys.modules, "api.utils.api_utils", api_utils_mod) + + tenant_utils_mod = ModuleType("api.utils.tenant_utils") + tenant_utils_mod.ensure_tenant_model_id_for_params = lambda _tenant_id, req: req + monkeypatch.setitem(sys.modules, "api.utils.tenant_utils", tenant_utils_mod) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / 'rag')] + monkeypatch.setitem(sys.modules, 'rag', rag_pkg) + rag_prompts_pkg = ModuleType('rag.prompts') + rag_prompts_pkg.__path__ = [str(repo_root / 'rag' / 'prompts')] + monkeypatch.setitem(sys.modules, 'rag.prompts', rag_prompts_pkg) + rag_prompts_generator_mod = ModuleType('rag.prompts.generator') + rag_prompts_generator_mod.chunks_format = lambda reference: reference.get('chunks', []) if isinstance(reference, dict) else [] + monkeypatch.setitem(sys.modules, 'rag.prompts.generator', rag_prompts_generator_mod) + rag_prompts_template_mod = ModuleType('rag.prompts.template') + rag_prompts_template_mod.load_prompt = lambda *_args, **_kwargs: '' + monkeypatch.setitem(sys.modules, 'rag.prompts.template', rag_prompts_template_mod) + + spec = importlib.util.spec_from_file_location(module_name, module_path) + module = importlib.util.module_from_spec(spec) + module.manager = _DummyManager() + monkeypatch.setitem(sys.modules, module_name, module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p1 +def test_create_chat_uses_tenant_default_llm_when_llm_id_is_null_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + saved = {} + + async def _request_json(): + return { + 'name': 'chat-a', + 'dataset_ids': ['kb-1'], + 'llm_id': None, + 'llm_setting': {'temperature': 0.8}, + 'prompt_config': {'system': 'Answer with {knowledge}', 'parameters': [{'key': 'knowledge', 'optional': False}]}, + } + + monkeypatch.setattr(module, 'get_request_json', _request_json) + monkeypatch.setattr(module.DialogService, 'query', lambda **_kwargs: []) + + def _save(**kwargs): + saved.update(kwargs) + return True + + monkeypatch.setattr(module.DialogService, 'save', _save) + monkeypatch.setattr(module.DialogService, 'get_by_id', lambda _id: (True, SimpleNamespace(to_dict=lambda: saved))) + + res = _run(module.create.__wrapped__()) + assert res['code'] == 0 + assert saved['llm_id'] == 'glm-4' + assert saved['llm_setting']['temperature'] == 0.8 + + +@pytest.mark.p2 +def test_list_chats_authorized_multi_tenant_unit(monkeypatch): + module = _load_chat_routes_unit_module(monkeypatch) + captured = {} + monkeypatch.setattr( + module, + 'request', + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + 'keywords': '', 'page': '1', 'page_size': '10', 'orderby': 'create_time', 'desc': 'true', 'id': None, 'name': None, + }.get(key, default), + getlist=lambda key: ['tenant-1', 'team-tenant-2'] if key == 'owner_ids' else [], + ) + ), + ) + + def _get_by_tenant_ids(owner_ids, user_id, *args, **kwargs): + captured['owner_ids'] = owner_ids + captured['user_id'] = user_id + return ([{'id': 'c1', 'tenant_id': 'tenant-1'}, {'id': 'c2', 'tenant_id': 'team-tenant-2'}], 2) + + monkeypatch.setattr(module.DialogService, 'get_by_tenant_ids', _get_by_tenant_ids) + monkeypatch.setattr(module.KnowledgebaseService, 'get_by_id', lambda _id: (False, None)) + res = _run(module.list_chats.__wrapped__()) + assert res['code'] == 0 + assert res['data']['total'] == 2 + assert {c['id'] for c in res['data']['chats']} == {'c1', 'c2'} + assert set(captured['owner_ids']) == {'tenant-1', 'team-tenant-2'} + assert captured['user_id'] == 'tenant-1' diff --git a/test/testcases/test_http_api/test_chat_assistant_management/conftest.py b/test/testcases/test_http_api/test_chat_assistant_management/conftest.py index 330732db6d1..60d5e432105 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/conftest.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/conftest.py @@ -18,7 +18,7 @@ from utils import wait_for -@wait_for(30, 1, "Document parsing timeout") +@wait_for(200, 1, "Document parsing timeout") def condition(_auth, _dataset_id): res = list_documents(_auth, _dataset_id) for doc in res["data"]["docs"]: diff --git a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py index a8d4f95cbaf..fa0894f1427 100644 --- a/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_chat_assistant_management/test_chat_sdk_routes_unit.py @@ -201,6 +201,7 @@ class _StubLLMType(str, Enum): class _StubRetCode(int, Enum): SUCCESS = 0 DATA_ERROR = 102 + OPERATING_ERROR = 103 AUTHENTICATION_ERROR = 109 class _StubStatusEnum(str, Enum): @@ -218,6 +219,11 @@ class _StubStatusEnum(str, Enum): misc_utils_mod = ModuleType("common.misc_utils") misc_utils_mod.get_uuid = lambda: "generated-chat-id" + + async def _thread_pool_exec(func, *args, **kwargs): + return func(*args, **kwargs) + + misc_utils_mod.thread_pool_exec = _thread_pool_exec monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_mod) dialog_service_mod = ModuleType("api.db.services.dialog_service") @@ -371,6 +377,10 @@ class _StubTenantService: def get_by_id(_tenant_id): return True, SimpleNamespace(llm_id="glm-4") + @staticmethod + def get_joined_tenants_by_user_id(_user_id): + return [{"tenant_id": "tenant-1"}, {"tenant_id": "team-tenant-2"}] + class _StubUserTenantService: @staticmethod def query(**_kwargs): @@ -808,7 +818,7 @@ def test_list_chats_returns_old_business_fields(monkeypatch): ) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) - res = module.list_chats.__wrapped__() + res = _run(module.list_chats.__wrapped__()) assert res["code"] == 0 chat = res["data"]["chats"][0] @@ -851,7 +861,7 @@ def _get_by_tenant_ids(_owner_ids, _user_id, page_number, items_per_page, *_args monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids) monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) - res = module.list_chats.__wrapped__() + res = _run(module.list_chats.__wrapped__()) assert res["code"] == 0 assert calls[-1] == (0, 0) @@ -874,13 +884,119 @@ def _get_by_tenant_ids(_owner_ids, _user_id, page_number, items_per_page, *_args ), ) - res = module.list_chats.__wrapped__() + res = _run(module.list_chats.__wrapped__()) assert res["code"] == 0 assert calls[-1] == (0, 2) assert len(res["data"]["chats"]) == 1 +@pytest.mark.p2 +def test_list_chats_rejects_unauthorized_owner_ids(monkeypatch): + module = _load_chat_module(monkeypatch) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + "keywords": "", + "page": "0", + "page_size": "0", + "orderby": "create_time", + "desc": "true", + "id": None, + "name": None, + }.get(key, default), + getlist=lambda key: ["foreign-tenant-id"] if key == "owner_ids" else [], + ) + ), + ) + res = _run(module.list_chats.__wrapped__()) + assert res["code"] == module.RetCode.OPERATING_ERROR + assert "authorized owner_ids" in res["message"] + + +@pytest.mark.p2 +def test_list_chats_authorized_multi_tenant(monkeypatch): + module = _load_chat_module(monkeypatch) + captured = {} + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + "keywords": "", + "page": "1", + "page_size": "10", + "orderby": "create_time", + "desc": "true", + "id": None, + "name": None, + }.get(key, default), + getlist=lambda key: ["tenant-1", "team-tenant-2"] if key == "owner_ids" else [], + ) + ), + ) + + def _get_by_tenant_ids(owner_ids, user_id, *args, **kwargs): + captured["owner_ids"] = owner_ids + captured["user_id"] = user_id + return ( + [ + {**_DummyDialogRecord().to_dict(), "tenant_id": "tenant-1", "id": "c1"}, + {**_DummyDialogRecord().to_dict(), "tenant_id": "team-tenant-2", "id": "c2"}, + ], + 2, + ) + + monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids) + monkeypatch.setattr(module.KnowledgebaseService, "get_by_id", lambda _id: (True, _DummyKB())) + + res = _run(module.list_chats.__wrapped__()) + assert res["code"] == 0 + assert res["data"]["total"] == 2 + assert {c["id"] for c in res["data"]["chats"]} == {"c1", "c2"} + assert set(captured["owner_ids"]) == {"tenant-1", "team-tenant-2"} + assert captured["user_id"] == "tenant-1" + + +@pytest.mark.p2 +def test_list_chats_defaults_to_authorized_owner_ids_when_omitted(monkeypatch): + module = _load_chat_module(monkeypatch) + captured = {} + + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=SimpleNamespace( + get=lambda key, default=None: { + "keywords": "", + "page": "1", + "page_size": "10", + "orderby": "create_time", + "desc": "true", + "id": None, + "name": None, + }.get(key, default), + getlist=lambda _key: [], + ) + ), + ) + + def _get_by_tenant_ids(owner_ids, *_args, **_kwargs): + captured["owner_ids"] = owner_ids + return ([], 0) + + monkeypatch.setattr(module.DialogService, "get_by_tenant_ids", _get_by_tenant_ids) + res = _run(module.list_chats.__wrapped__()) + + assert res["code"] == 0 + assert set(captured["owner_ids"]) == {"tenant-1", "team-tenant-2"} + + @pytest.mark.p2 def test_chat_session_create_and_update_guard_matrix_unit(monkeypatch): module = _load_chat_module(monkeypatch) @@ -962,7 +1078,7 @@ def test_chat_session_list_projection_unit(monkeypatch): ], ) - res = module.list_sessions.__wrapped__("chat-1") + res = _run(module.list_sessions.__wrapped__("chat-1")) assert res["data"][0]["chat_id"] == "chat-1" assert res["data"][0]["messages"][0]["content"] == "hello" @@ -983,7 +1099,7 @@ def test_chat_session_list_projection_unit(monkeypatch): ) ), ) - res = module.list_sessions.__wrapped__("chat-1") + res = _run(module.list_sessions.__wrapped__("chat-1")) assert res["data"] == [] diff --git a/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py b/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py index 5cada305fb9..46b6e8891c9 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py +++ b/test/testcases/test_http_api/test_dataset_management/test_create_dataset.py @@ -556,8 +556,8 @@ def test_parser_config(self, HttpApiAuth, name, parser_config): ("graphrag_type_invalid", {"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean"), ("graphrag_entity_types_not_list", {"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), ("graphrag_entity_types_not_str_in_list", {"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), - ("graphrag_method_unknown", {"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), - ("graphrag_method_none", {"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), + ("graphrag_method_unknown", {"graphrag": {"method": "unknown"}}, "Input should be 'light', 'general' or 'ner'"), + ("graphrag_method_none", {"graphrag": {"method": None}}, "Input should be 'light', 'general' or 'ner'"), ("graphrag_community_type_invalid", {"graphrag": {"community": "string"}}, "Input should be a valid boolean"), ("graphrag_resolution_type_invalid", {"graphrag": {"resolution": "string"}}, "Input should be a valid boolean"), ("raptor_type_invalid", {"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean"), diff --git a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py index ac98d9e1d33..6f4927b8d06 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py +++ b/test/testcases/test_http_api/test_dataset_management/test_dify_retrieval_routes_unit.py @@ -223,8 +223,20 @@ def to_dict(self): "id": self.id } - def _get_model_config_by_id(tenant_model_id: int) -> dict: - return _MockModelConfig2("tenant-1", "model-1").to_dict() + def _get_model_config_by_id( + tenant_model_id: int, + allowed_tenant_ids=None, + requester_tenant_id=None, + ) -> dict: + mock_tenant_id = "tenant-1" + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") + return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): if not model_name: @@ -352,3 +364,82 @@ async def retrieval(self, *_args, **_kwargs): res = _run(inspect.unwrap(module.retrieval)("tenant-1")) assert res["code"] == module.RetCode.SERVER_ERROR, res assert "boom" in res["message"], res + + +@pytest.mark.p2 +def test_read_retrieval_request_from_get_args(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + method="GET", + args={ + "knowledge_id": "kb-1", + "query": "hello", + "use_kg": "true", + "top_k": "12", + "score_threshold": "0.66", + }, + ), + ) + + req = _run(module._read_retrieval_request()) + assert req["knowledge_id"] == "kb-1", req + assert req["query"] == "hello", req + assert req["use_kg"] is True, req + assert req["retrieval_setting"]["top_k"] == 12, req + assert req["retrieval_setting"]["score_threshold"] == 0.66, req + + +@pytest.mark.p2 +def test_read_retrieval_request_from_post_json(monkeypatch): + module = _load_dify_retrieval_module(monkeypatch) + payload = {"knowledge_id": "kb-1", "query": "hello"} + monkeypatch.setattr(module, "request", SimpleNamespace(method="POST", args={})) + monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue(payload)) + + req = _run(module._read_retrieval_request()) + assert req == payload, req + + +@pytest.mark.p2 +def test_retrieval_argument_error_messages(monkeypatch): + """Guard: distinguish malformed vs missing argument errors.""" + module = _load_dify_retrieval_module(monkeypatch) + + # Case 1: malformed numeric options in retrieval_setting + _set_request_json( + monkeypatch, + module, + { + "knowledge_id": "kb-1", + "query": "hello", + "retrieval_setting": {"top_k": "not-int", "score_threshold": "not-float"}, + }, + ) + res = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res["code"] == module.RetCode.ARGUMENT_ERROR, res + assert "invalid or malformed arguments:" in res["message"], res + + # Case 2: missing required fields (knowledge_id, query) + _set_request_json(monkeypatch, module, {}) + res_missing = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_missing["code"] == module.RetCode.ARGUMENT_ERROR, res_missing + assert "required arguments are missing:" in res_missing["message"], res_missing + + # Case 3: partially missing required field (query) + _set_request_json(monkeypatch, module, {"knowledge_id": "kb-1"}) + res_missing_query = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_missing_query["code"] == module.RetCode.ARGUMENT_ERROR, res_missing_query + assert "query" in res_missing_query["message"], res_missing_query + + # Case 4: retrieval_setting wrong type + _set_request_json( + monkeypatch, + module, + {"knowledge_id": "kb-1", "query": "hello", "retrieval_setting": "bad-type"}, + ) + res_wrong_type = _run(inspect.unwrap(module.retrieval)("tenant-1")) + assert res_wrong_type["code"] == module.RetCode.ARGUMENT_ERROR, res_wrong_type + assert "retrieval_setting must be an object" in res_wrong_type["message"], res_wrong_type diff --git a/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py b/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py index 0398f772390..f2ada816a5e 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py +++ b/test/testcases/test_http_api/test_dataset_management/test_knowledge_graph.py @@ -50,4 +50,4 @@ def test_delete_knowledge_graph(self, HttpApiAuth, add_dataset_func): dataset_id = add_dataset_func res = delete_knowledge_graph(HttpApiAuth, dataset_id) assert res["code"] == 0, res - assert res["data"] is True, res + assert res["data"] is not None, res diff --git a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py index 0847a181c14..c3cd9ac3de0 100644 --- a/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py +++ b/test/testcases/test_http_api/test_dataset_management/test_update_dataset.py @@ -583,6 +583,10 @@ def test_pagerank_none(self, HttpApiAuth, add_dataset_func): {"raptor": {"max_cluster": 512}}, {"raptor": {"max_cluster": 1024}}, {"raptor": {"random_seed": 0}}, + {"raptor": {"clustering_method": "gmm"}}, + {"raptor": {"clustering_method": "ahc"}}, + {"raptor": {"tree_builder": "raptor"}}, + {"raptor": {"tree_builder": "psi"}}, ], ids=[ "auto_keywords_min", @@ -633,6 +637,10 @@ def test_pagerank_none(self, HttpApiAuth, add_dataset_func): "raptor_max_cluster_mid", "raptor_max_cluster_max", "raptor_random_seed_min", + "raptor_clustering_method_gmm", + "raptor_clustering_method_ahc", + "raptor_tree_builder_raptor", + "raptor_tree_builder_psi", ], ) def test_parser_config(self, HttpApiAuth, add_dataset_func, parser_config): @@ -686,8 +694,8 @@ def test_parser_config(self, HttpApiAuth, add_dataset_func, parser_config): ({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean"), ({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), ({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), - ({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), - ({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), + ({"graphrag": {"method": "unknown"}}, "Input should be 'light', 'general' or 'ner'"), + ({"graphrag": {"method": None}}, "Input should be 'light', 'general' or 'ner'"), ({"graphrag": {"community": "string"}}, "Input should be a valid boolean"), ({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean"), ({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean"), @@ -707,6 +715,10 @@ def test_parser_config(self, HttpApiAuth, add_dataset_func, parser_config): ({"raptor": {"random_seed": -1}}, "Input should be greater than or equal to 0"), ({"raptor": {"random_seed": 3.14}}, "Input should be a valid integer"), ({"raptor": {"random_seed": "string"}}, "Input should be a valid integer"), + ({"raptor": {"clustering_method": "unknown"}}, "Input should be 'gmm' or 'ahc'"), + ({"raptor": {"clustering_method": None}}, "Input should be 'gmm' or 'ahc'"), + ({"raptor": {"tree_builder": "ahc"}}, "Input should be 'raptor' or 'psi'"), + ({"raptor": {"tree_builder": None}}, "Input should be 'raptor' or 'psi'"), ({"delimiter": "a" * 65536}, "Parser config exceeds size limit (max 65,535 characters)"), ], ids=[ @@ -763,6 +775,10 @@ def test_parser_config(self, HttpApiAuth, add_dataset_func, parser_config): "raptor_random_seed_min_limit", "raptor_random_seed_float_not_allowed", "raptor_random_seed_type_invalid", + "raptor_clustering_method_invalid", + "raptor_clustering_method_none_invalid", + "raptor_tree_builder_invalid", + "raptor_tree_builder_none_invalid", "parser_config_type_invalid", ], ) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py index ca440d4ae0f..08055a57e66 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_doc_sdk_routes_unit.py @@ -417,8 +417,20 @@ def to_dict(self): "id": self.id } - def _get_model_config_by_id(tenant_model_id: int) -> dict: - return _MockModelConfig2("tenant-1", "model-1").to_dict() + def _get_model_config_by_id( + tenant_model_id: int, + allowed_tenant_ids=None, + requester_tenant_id=None, + ) -> dict: + mock_tenant_id = "tenant-1" + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") + return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): if not model_name: @@ -706,6 +718,36 @@ def test_list_chunks_branches(self, monkeypatch): assert res["data"]["total"] == 1 assert res["data"]["chunks"][0]["id"] == "chunk-1" + def test_list_chunks_uses_dataset_owner_index_for_team_dataset(self, monkeypatch): + module = _load_restful_chunk_module(monkeypatch) + seen = {} + monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: True) + monkeypatch.setattr( + module.KnowledgebaseService, + "get_by_id", + lambda _dataset_id: (True, SimpleNamespace(tenant_id="owner-tenant")), + ) + monkeypatch.setattr(module.DocumentService, "query", lambda **_kwargs: [_DummyDoc(kb_id="ds-1")]) + monkeypatch.setattr(module, "request", SimpleNamespace(args=_DummyArgs({}))) + + def _index_exist(index_name, dataset_id): + seen["index_exist"] = (index_name, dataset_id) + return True + + class _Retriever: + async def search(self, _query, index_name, dataset_ids, *_args, **_kwargs): + seen["search"] = (index_name, dataset_ids) + return SimpleNamespace(total=0, ids=[], field={}, highlight={}) + + _patch_docstore(monkeypatch, module, index_exist=_index_exist) + monkeypatch.setattr(module.settings, "retriever", _Retriever()) + + res = _run(_route_core(module.list_chunks)("member-tenant", "ds-1", "doc-1")) + + assert res["code"] == 0 + assert seen["index_exist"] == ("idx-owner-tenant", "ds-1") + assert seen["search"] == ("idx-owner-tenant", ["ds-1"]) + def test_add_chunk_access_guard(self, monkeypatch): module = _load_restful_chunk_module(monkeypatch) monkeypatch.setattr(module.KnowledgebaseService, "accessible", lambda **_kwargs: False) diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py index 5b9e5ad314a..4411cd43ccc 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_parse_documents.py @@ -16,7 +16,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed import pytest -from common import bulk_upload_documents, list_documents, parse_documents +from common import bulk_upload_documents, delete_documents, list_chunks, list_documents, parse_documents from configs import INVALID_API_TOKEN from libs.auth import RAGFlowHttpApiAuth from utils import wait_for @@ -165,6 +165,37 @@ def test_duplicate_parse(self, HttpApiAuth, add_documents_func): validate_document_details(HttpApiAuth, dataset_id, document_ids) + @pytest.mark.p2 + def test_chunks_retrievable_after_parse_status_done(self, HttpApiAuth, add_dataset_func, ragflow_tmp_dir): + @wait_for(30, 0.1, "Document parsing timeout") + def wait_until_done(ids): + r = list_documents(HttpApiAuth, dataset_id) + target_ids = set(ids) + for doc in r["data"]["docs"]: + if doc["id"] in target_ids and doc.get("run") != "DONE": + return False + return True + + dataset_id = add_dataset_func + + # if there is a bug it can be non-deterministic, so repeat 10 times + iterations = 10 + for i in range(1, iterations + 1): + document_ids = bulk_upload_documents(HttpApiAuth, dataset_id, 1, ragflow_tmp_dir) + + res = parse_documents(HttpApiAuth, dataset_id, {"document_ids": document_ids}) + assert res["code"] == 0, f"parse_documents failed: {res}" + + wait_until_done(document_ids) + + for document_id in document_ids: + res = list_chunks(HttpApiAuth, dataset_id, document_id) + assert res["code"] == 0, f"list_chunks failed: {res}" + assert res["data"]["doc"]["chunk_count"] > 0, f"Document {document_id} has run=DONE but chunk_count is 0" + assert len(res["data"]["chunks"]) > 0, f"Document {document_id} has run=DONE but no chunks returned" + + delete_documents(HttpApiAuth, dataset_id, {"ids": document_ids}) + @pytest.mark.p3 def test_parse_100_files(HttpApiAuth, add_dataset_func, tmp_path): diff --git a/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py b/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py index b24d9deeacf..de0b4189b96 100644 --- a/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py +++ b/test/testcases/test_http_api/test_file_management_within_dataset/test_update_document.py @@ -387,6 +387,7 @@ def test_update_doc_guards_and_error_paths(self, HttpApiAuth, add_documents, pay "category", ], "method": "light", + "batch_chunk_token_size": 4096, }, } diff --git a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py index 77ac86232b5..889de4ba1fa 100644 --- a/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py +++ b/test/testcases/test_http_api/test_session_management/test_session_sdk_routes_unit.py @@ -123,7 +123,7 @@ def _load_session_module(monkeypatch): # Mock common.constants module from enum import Enum - from strenum import StrEnum + from enum import StrEnum class _StubLLMType(StrEnum): CHAT = "chat" @@ -466,8 +466,20 @@ def to_dict(self): "id": self.id } - def _get_model_config_by_id(tenant_model_id: int) -> dict: - return _MockModelConfig2("tenant-1", "model-1").to_dict() + def _get_model_config_by_id( + tenant_model_id: int, + allowed_tenant_ids=None, + requester_tenant_id=None, + ) -> dict: + mock_tenant_id = "tenant-1" + if allowed_tenant_ids is not None: + if isinstance(allowed_tenant_ids, str): + allowed_tenant_ids = {allowed_tenant_ids} + else: + allowed_tenant_ids = {str(tenant_id) for tenant_id in allowed_tenant_ids if tenant_id} + if mock_tenant_id not in allowed_tenant_ids and str(requester_tenant_id) != mock_tenant_id: + raise LookupError(f"Tenant Model with id {tenant_model_id} not authorized") + return _MockModelConfig2(mock_tenant_id, "model-1").to_dict() def _get_model_config_by_type_and_name(tenant_id: str, model_type: str, model_name: str): if not model_name: @@ -758,7 +770,10 @@ def commit_after_run(**_kwargs): monkeypatch.setitem(sys.modules, "api.apps.services.canvas_replica_service", canvas_replica_mod) file_service_mod = ModuleType("api.db.services.file_service") - file_service_mod.FileService = SimpleNamespace(upload_info=lambda *_args, **_kwargs: {}) + file_service_mod.FileService = SimpleNamespace( + upload_info=lambda *_args, **_kwargs: {}, + get_blob=lambda *_args, **_kwargs: b"", + ) monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service_mod) api_service_mod = ModuleType("api.db.services.api_service") @@ -1201,7 +1216,119 @@ async def _agent_nonstream(*_args, **_kwargs): "c4": {}, } assert [item["component_id"] for item in res["data"]["data"]["trace"]] == ["c2", "c3", "c4"] - + + +class _FakeUploadFileField: + def __init__(self, filename: str): + self.filename = filename + + +class _FakeRequestFiles: + def __init__(self, filenames: list[str]): + self._filenames = filenames + + def get(self, key, default=None): + if key == "file" and self._filenames: + return _FakeUploadFileField(self._filenames[0]) + return default + + def getlist(self, key): + if key == "file": + return [_FakeUploadFileField(n) for n in self._filenames] + return [] + + +@pytest.mark.p2 +def test_agent_file_download_and_upload_unit(monkeypatch): + module = _load_agent_api_module(monkeypatch) + monkeypatch.setattr(module, "Response", _StubResponse) + + get_blob_calls = [] + + def _get_blob(tenant_id, file_id): + get_blob_calls.append((tenant_id, file_id)) + return b"file-bytes" + + monkeypatch.setattr(module.FileService, "get_blob", _get_blob) + monkeypatch.setattr(module, "request", SimpleNamespace(args=_Args({"id": "doc-99"}))) + + resp = _run(inspect.unwrap(module.download_agent_file)("tenant-1")) + assert isinstance(resp, _StubResponse) + assert resp.body == b"file-bytes" + assert get_blob_calls == [("tenant-1", "doc-99")] + + upload_calls = [] + + def _upload_info(tenant_id, file_obj, url=None): + upload_calls.append((tenant_id, getattr(file_obj, "filename", None), url)) + return {"id": tenant_id, "file": getattr(file_obj, "filename", None), "url": url} + + monkeypatch.setattr(module.FileService, "upload_info", _upload_info) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=_Args({"url": "https://example.com/a.png"}), + files=_AwaitableValue(_FakeRequestFiles(["one.png"])), + ), + ) + res = _run( + inspect.unwrap(module.upload_agent_file)( + agent_id="agent-1", + tenant_id="tenant-1", + ) + ) + assert res["code"] == 0 + assert res["data"]["file"] == "one.png" + assert upload_calls == [("tenant-1", "one.png", "https://example.com/a.png")] + + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=_Args({}), + files=_AwaitableValue(_FakeRequestFiles(["a.png", "b.png"])), + ), + ) + upload_calls.clear() + res = _run( + inspect.unwrap(module.upload_agent_file)( + agent_id="agent-1", + tenant_id="tenant-1", + ) + ) + assert res["code"] == 0 + assert len(res["data"]) == 2 + assert set(upload_calls) == { + ("tenant-1", "a.png", None), + ("tenant-1", "b.png", None), + } + + def _boom(*_a, **_k): + raise ValueError("upload failed") + + monkeypatch.setattr(module.FileService, "upload_info", _boom) + monkeypatch.setattr( + module, + "request", + SimpleNamespace( + args=_Args({}), + files=_AwaitableValue(_FakeRequestFiles(["bad.png"])), + ), + ) + res = _run( + inspect.unwrap(module.upload_agent_file)( + agent_id="agent-1", + tenant_id="tenant-1", + ) + ) + assert res["code"] != 0 + + monkeypatch.setattr(module.UserCanvasService, "accessible", lambda *_a, **_k: False) + res = _run(module.upload_agent_file(agent_id="agent-1")) + assert res["code"] == module.RetCode.OPERATING_ERROR + assert "permission" in res["message"].lower() + @pytest.mark.p2 def test_delete_routes_partial_duplicate_unit(monkeypatch): diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py index 8f8f9bfeb6f..92505aec5d5 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_create_dataset.py @@ -494,8 +494,8 @@ def test_parser_config(self, client, name, parser_config): ("graphrag_type_invalid", {"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean"), ("graphrag_entity_types_not_list", {"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), ("graphrag_entity_types_not_str_in_list", {"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), - ("graphrag_method_unknown", {"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), - ("graphrag_method_none", {"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), + ("graphrag_method_unknown", {"graphrag": {"method": "unknown"}}, "Input should be 'light', 'general' or 'ner'"), + ("graphrag_method_none", {"graphrag": {"method": None}}, "Input should be 'light', 'general' or 'ner'"), ("graphrag_community_type_invalid", {"graphrag": {"community": "string"}}, "Input should be a valid boolean"), ("graphrag_resolution_type_invalid", {"graphrag": {"resolution": "string"}}, "Input should be a valid boolean"), ("raptor_type_invalid", {"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean"), diff --git a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py index 6207e31db1f..d32d8fd9b3d 100644 --- a/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py +++ b/test/testcases/test_sdk_api/test_dataset_mangement/test_update_dataset.py @@ -550,8 +550,8 @@ def test_parser_config(self, client, add_dataset_func, parser_config): ({"graphrag": {"use_graphrag": "string"}}, "Input should be a valid boolean"), ({"graphrag": {"entity_types": "1,2"}}, "Input should be a valid list"), ({"graphrag": {"entity_types": [1, 2]}}, "nput should be a valid string"), - ({"graphrag": {"method": "unknown"}}, "Input should be 'light' or 'general'"), - ({"graphrag": {"method": None}}, "Input should be 'light' or 'general'"), + ({"graphrag": {"method": "unknown"}}, "Input should be 'light', 'general' or 'ner'"), + ({"graphrag": {"method": None}}, "Input should be 'light', 'general' or 'ner'"), ({"graphrag": {"community": "string"}}, "Input should be a valid boolean"), ({"graphrag": {"resolution": "string"}}, "Input should be a valid boolean"), ({"raptor": {"use_raptor": "string"}}, "Input should be a valid boolean"), diff --git a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py index f174f0e5462..2b02c0b19cc 100644 --- a/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py +++ b/test/testcases/test_sdk_api/test_file_management_within_dataset/test_update_document.py @@ -313,6 +313,7 @@ def test_immutable_fields_progress(self, add_documents, payload, expected_messag "category", ], "method": "light", + "batch_chunk_token_size": 4096, }, } diff --git a/test/testcases/test_web_api/test_agent_app/test_agents_webhook_unit.py b/test/testcases/test_web_api/test_agent_app/test_agents_webhook_unit.py index 1022a9b45a0..e93c48249ae 100644 --- a/test/testcases/test_web_api/test_agent_app/test_agents_webhook_unit.py +++ b/test/testcases/test_web_api/test_agent_app/test_agents_webhook_unit.py @@ -514,7 +514,7 @@ def test_agents_crud_unit_branches(monkeypatch): captured = {} - def fake_get_by_tenant_ids(owner_ids, tenant_id, page, page_size, orderby, desc, keywords, canvas_category): + def fake_get_by_tenant_ids(owner_ids, tenant_id, page, page_size, orderby, desc, keywords, canvas_category, tags): captured["owner_ids"] = owner_ids captured["tenant_id"] = tenant_id captured["page"] = page @@ -523,6 +523,7 @@ def fake_get_by_tenant_ids(owner_ids, tenant_id, page, page_size, orderby, desc, captured["desc"] = desc captured["keywords"] = keywords captured["canvas_category"] = canvas_category + captured["tags"] = tags return [{"id": "agent-1"}], 1 monkeypatch.setattr(module.UserCanvasService, "get_by_tenant_ids", fake_get_by_tenant_ids) diff --git a/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py b/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py index ff171c3b00e..19921054743 100644 --- a/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py +++ b/test/testcases/test_web_api/test_canvas_app/test_code_exec_contract_unit.py @@ -140,7 +140,7 @@ def test_select_business_output_ignores_system_outputs(): "actual_type": {"value": "", "type": "string"}, "_ERROR": {"value": "", "type": "string"}, "_ARTIFACTS": {"value": [], "type": "Array"}, - "_ATTACHMENT_CONTENT": {"value": "", "type": "string"}, + "attachments": {"value": [], "type": "Array"}, "raw_result": {"value": None, "type": "Any"}, "_created_time": {"value": 1.0, "type": "Number"}, "_elapsed_time": {"value": 2.0, "type": "Number"}, @@ -297,7 +297,7 @@ def test_legacy_multi_output_schema_is_rejected(): ) -@pytest.mark.parametrize("name", ["content", "actual_type", "_ERROR", "_ARTIFACTS", "_ATTACHMENT_CONTENT", "raw_result"]) +@pytest.mark.parametrize("name", ["content", "actual_type", "attachments", "_ERROR", "_ARTIFACTS", "raw_result"]) def test_reserved_business_output_names_are_rejected(name): module = _load_module() with pytest.raises(module.ContractError, match="reserved output name"): @@ -387,7 +387,6 @@ def test_process_execution_result_returns_early_for_stderr_only_without_artifact def test_process_execution_result_appends_artifact_content_to_canonical_content(): tool = _build_code_exec("Object") tool._upload_artifacts = lambda _artifacts: [{"name": "chart.png", "url": "/artifact/chart.png", "mime_type": "image/png", "size": 12}] - tool._build_attachment_content = lambda _artifacts, _artifact_urls: "attachment_count: 1\n\nattachment1 (image): chart.png\nparsed artifact" result = tool._process_execution_result( '{"foo": "bar"}', @@ -400,8 +399,7 @@ def test_process_execution_result_appends_artifact_content_to_canonical_content( assert result["content"] == '{\n "foo": "bar"\n}\n\nattachment_count: 1\n\nattachment1 (image): chart.png\nparsed artifact' assert result["_ARTIFACTS"] == [{"name": "chart.png", "url": "/artifact/chart.png", "mime_type": "image/png", "size": 12}] assert result["_ARTIFACTS"][0]["mime_type"] == "image/png" - assert result["_ATTACHMENT_CONTENT"] == "attachment_count: 1\n\nattachment1 (image): chart.png\nparsed artifact" - assert "attachment1 (image): chart.png" in result["_ATTACHMENT_CONTENT"] + assert result["attachments"] == ["![chart.png](/artifact/chart.png)"] def test_process_execution_result_without_artifacts_clears_stale_artifacts_output(): diff --git a/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py b/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py new file mode 100644 index 00000000000..e73139ec267 --- /dev/null +++ b/test/testcases/test_web_api/test_canvas_app/test_iteration_runtime_unit.py @@ -0,0 +1,391 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import asyncio +import importlib.util +import json +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +def _load_canvas_runtime(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + quart = ModuleType("quart") + quart.make_response = lambda *a, **kw: None + quart.jsonify = lambda *a, **kw: None + monkeypatch.setitem(sys.modules, "quart", quart) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + common_constants = ModuleType("common.constants") + common_constants.LLMType = SimpleNamespace(TTS="tts") + monkeypatch.setitem(sys.modules, "common.constants", common_constants) + + common_misc = ModuleType("common.misc_utils") + common_misc.get_uuid = lambda: "uuid" + common_misc.hash_str2int = lambda x: 1 + + async def _thread_pool_exec(fn, *args, **kwargs): + return fn(*args, **kwargs) + + common_misc.thread_pool_exec = _thread_pool_exec + monkeypatch.setitem(sys.modules, "common.misc_utils", common_misc) + + common_conn = ModuleType("common.connection_utils") + + def timeout(_seconds): + def decorator(fn): + return fn + + return decorator + + common_conn.timeout = timeout + monkeypatch.setitem(sys.modules, "common.connection_utils", common_conn) + + common_ex = ModuleType("common.exceptions") + + class TaskCanceledException(Exception): + pass + + common_ex.TaskCanceledException = TaskCanceledException + monkeypatch.setitem(sys.modules, "common.exceptions", common_ex) + + api_pkg = ModuleType("api") + api_pkg.__path__ = [str(repo_root / "api")] + monkeypatch.setitem(sys.modules, "api", api_pkg) + api_db_pkg = ModuleType("api.db") + api_db_pkg.__path__ = [str(repo_root / "api" / "db")] + monkeypatch.setitem(sys.modules, "api.db", api_db_pkg) + api_db_services_pkg = ModuleType("api.db.services") + api_db_services_pkg.__path__ = [str(repo_root / "api" / "db" / "services")] + monkeypatch.setitem(sys.modules, "api.db.services", api_db_services_pkg) + api_db_joint_pkg = ModuleType("api.db.joint_services") + api_db_joint_pkg.__path__ = [str(repo_root / "api" / "db" / "joint_services")] + monkeypatch.setitem(sys.modules, "api.db.joint_services", api_db_joint_pkg) + + file_service = ModuleType("api.db.services.file_service") + file_service.FileService = object + monkeypatch.setitem(sys.modules, "api.db.services.file_service", file_service) + + llm_service = ModuleType("api.db.services.llm_service") + llm_service.LLMBundle = object + monkeypatch.setitem(sys.modules, "api.db.services.llm_service", llm_service) + + task_service = ModuleType("api.db.services.task_service") + task_service.has_canceled = lambda _task_id: False + monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service) + + tenant_model_service = ModuleType("api.db.joint_services.tenant_model_service") + tenant_model_service.get_tenant_default_model_by_type = lambda *_a, **_kw: None + monkeypatch.setitem( + sys.modules, + "api.db.joint_services.tenant_model_service", + tenant_model_service, + ) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + rag_prompts_pkg = ModuleType("rag.prompts") + rag_prompts_pkg.__path__ = [str(repo_root / "rag" / "prompts")] + monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg) + rag_prompts = ModuleType("rag.prompts.generator") + rag_prompts.chunks_format = lambda *_a, **_kw: "" + monkeypatch.setitem(sys.modules, "rag.prompts.generator", rag_prompts) + + rag_utils_pkg = ModuleType("rag.utils") + rag_utils_pkg.__path__ = [str(repo_root / "rag" / "utils")] + monkeypatch.setitem(sys.modules, "rag.utils", rag_utils_pkg) + rag_redis = ModuleType("rag.utils.redis_conn") + rag_redis.REDIS_CONN = SimpleNamespace(delete=lambda *_a, **_kw: None, set=lambda *_a, **_kw: None) + monkeypatch.setitem(sys.modules, "rag.utils.redis_conn", rag_redis) + + agent_pkg = ModuleType("agent") + agent_pkg.__path__ = [str(repo_root / "agent")] + monkeypatch.setitem(sys.modules, "agent", agent_pkg) + + agent_settings = ModuleType("agent.settings") + agent_settings.FLOAT_ZERO = 1e-8 + agent_settings.PARAM_MAXDEPTH = 5 + monkeypatch.setitem(sys.modules, "agent.settings", agent_settings) + + dsl_migration = ModuleType("agent.dsl_migration") + dsl_migration.normalize_chunker_dsl = lambda dsl: dsl + monkeypatch.setitem(sys.modules, "agent.dsl_migration", dsl_migration) + + component_pkg = ModuleType("agent.component") + component_pkg.__path__ = [str(repo_root / "agent" / "component")] + monkeypatch.setitem(sys.modules, "agent.component", component_pkg) + + base_spec = importlib.util.spec_from_file_location( + "agent.component.base", repo_root / "agent" / "component" / "base.py" + ) + base_mod = importlib.util.module_from_spec(base_spec) + monkeypatch.setitem(sys.modules, "agent.component.base", base_mod) + base_spec.loader.exec_module(base_mod) + + iteration_spec = importlib.util.spec_from_file_location( + "agent.component.iteration", repo_root / "agent" / "component" / "iteration.py" + ) + iteration_mod = importlib.util.module_from_spec(iteration_spec) + monkeypatch.setitem(sys.modules, "agent.component.iteration", iteration_mod) + iteration_spec.loader.exec_module(iteration_mod) + + iterationitem_spec = importlib.util.spec_from_file_location( + "agent.component.iterationitem", + repo_root / "agent" / "component" / "iterationitem.py", + ) + iterationitem_mod = importlib.util.module_from_spec(iterationitem_spec) + monkeypatch.setitem(sys.modules, "agent.component.iterationitem", iterationitem_mod) + iterationitem_spec.loader.exec_module(iterationitem_mod) + + class BeginParam(base_mod.ComponentParamBase): + def check(self): + return True + + class Begin(base_mod.ComponentBase): + component_name = "Begin" + + def _invoke(self, **kwargs): + return + + def thoughts(self): + return "begin" + + class ProbeParam(base_mod.ComponentParamBase): + def __init__(self): + super().__init__() + self.query = "" + self.inputs = {"query": {"value": None}} + + def get_input_form(self): + return {"query": {"name": "Query", "type": "line"}} + + def check(self): + return True + + class Probe(base_mod.ComponentBase): + component_name = "Probe" + + def _invoke(self, **kwargs): + query_text = kwargs.get("query") + vars_map = self.get_input_elements_from_text(query_text) + query = self.string_format( + query_text, {key: value["value"] for key, value in vars_map.items()} + ) + calls = self._canvas.globals.setdefault("probe.calls", []) + calls.append(query) + self.set_output("result", query) + + def thoughts(self): + return "probe" + + class SinkParam(base_mod.ComponentParamBase): + def check(self): + return True + + class Sink(base_mod.ComponentBase): + component_name = "Sink" + + def _invoke(self, **kwargs): + self.set_output("done", True) + + def thoughts(self): + return "sink" + + class_map = { + "Begin": Begin, + "BeginParam": BeginParam, + "Iteration": iteration_mod.Iteration, + "IterationParam": iteration_mod.IterationParam, + "IterationItem": iterationitem_mod.IterationItem, + "IterationItemParam": iterationitem_mod.IterationItemParam, + "Probe": Probe, + "ProbeParam": ProbeParam, + "Sink": Sink, + "SinkParam": SinkParam, + } + + component_pkg.component_class = lambda name: class_map[name] + + canvas_spec = importlib.util.spec_from_file_location( + "agent.canvas", repo_root / "agent" / "canvas.py" + ) + canvas_mod = importlib.util.module_from_spec(canvas_spec) + monkeypatch.setitem(sys.modules, "agent.canvas", canvas_mod) + canvas_spec.loader.exec_module(canvas_mod) + + return canvas_mod + + +async def _collect_events(canvas): + events = [] + async for event in canvas.run(): + events.append(event) + return events + + +@pytest.mark.p2 +def test_iteration_runtime_processes_all_array_items(monkeypatch): + canvas_mod = _load_canvas_runtime(monkeypatch) + + dsl = { + "components": { + "begin": { + "obj": {"component_name": "Begin", "params": {}}, + "downstream": ["Iteration:1"], + "upstream": [], + }, + "Iteration:1": { + "obj": { + "component_name": "Iteration", + "params": {"items_ref": "env.items"}, + }, + "downstream": ["Sink:1"], + "upstream": ["begin"], + }, + "IterationItem:1": { + "obj": {"component_name": "IterationItem", "params": {}}, + "parent_id": "Iteration:1", + "downstream": ["Probe:1"], + "upstream": [], + }, + "Probe:1": { + "obj": { + "component_name": "Probe", + "params": {"query": "IterationItem:1@result"}, + }, + "parent_id": "Iteration:1", + "downstream": [], + "upstream": ["IterationItem:1"], + }, + "Sink:1": { + "obj": {"component_name": "Sink", "params": {}}, + "downstream": [], + "upstream": ["Iteration:1"], + }, + }, + "graph": { + "nodes": [ + {"id": "begin", "data": {"name": "Begin"}}, + {"id": "Iteration:1", "data": {"name": "Iteration"}}, + {"id": "IterationItem:1", "data": {"name": "IterationItem"}}, + {"id": "Probe:1", "data": {"name": "Probe"}}, + {"id": "Sink:1", "data": {"name": "Sink"}}, + ] + }, + "history": [], + "path": [], + "retrieval": [], + "globals": { + "sys.query": "", + "sys.user_id": "", + "sys.conversation_turns": 0, + "sys.files": [], + "sys.history": [], + "sys.date": "", + "env.items": ["a", "b", "c"], + }, + } + + canvas = canvas_mod.Canvas(json.dumps(dsl)) + events = asyncio.run(_collect_events(canvas)) + + assert canvas.globals["probe.calls"] == ["a", "b", "c"] + assert any(event["event"] == "workflow_finished" for event in events) + + +@pytest.mark.parametrize( + ("query", "expected_calls"), + [ + ("{item}", ["a", "b", "c"]), + ("{index}", ["0", "1", "2"]), + ("{result}", ["a", "b", "c"]), + ], +) +@pytest.mark.p2 +def test_iteration_runtime_supports_bare_iteration_aliases(monkeypatch, query, expected_calls): + canvas_mod = _load_canvas_runtime(monkeypatch) + + dsl = { + "components": { + "begin": { + "obj": {"component_name": "Begin", "params": {}}, + "downstream": ["Iteration:1"], + "upstream": [], + }, + "Iteration:1": { + "obj": { + "component_name": "Iteration", + "params": {"items_ref": "env.items"}, + }, + "downstream": ["Sink:1"], + "upstream": ["begin"], + }, + "IterationItem:1": { + "obj": {"component_name": "IterationItem", "params": {}}, + "parent_id": "Iteration:1", + "downstream": ["Probe:1"], + "upstream": [], + }, + "Probe:1": { + "obj": { + "component_name": "Probe", + "params": {"query": query}, + }, + "parent_id": "Iteration:1", + "downstream": [], + "upstream": ["IterationItem:1"], + }, + "Sink:1": { + "obj": {"component_name": "Sink", "params": {}}, + "downstream": [], + "upstream": ["Iteration:1"], + }, + }, + "graph": { + "nodes": [ + {"id": "begin", "data": {"name": "Begin"}}, + {"id": "Iteration:1", "data": {"name": "Iteration"}}, + {"id": "IterationItem:1", "data": {"name": "IterationItem"}}, + {"id": "Probe:1", "data": {"name": "Probe"}}, + {"id": "Sink:1", "data": {"name": "Sink"}}, + ] + }, + "history": [], + "path": [], + "retrieval": [], + "globals": { + "sys.query": "", + "sys.user_id": "", + "sys.conversation_turns": 0, + "sys.files": [], + "sys.history": [], + "sys.date": "", + "env.items": ["a", "b", "c"], + }, + } + + canvas = canvas_mod.Canvas(json.dumps(dsl)) + asyncio.run(_collect_events(canvas)) + + assert canvas.globals["probe.calls"] == expected_calls diff --git a/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py b/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py new file mode 100644 index 00000000000..1151bb60dc9 --- /dev/null +++ b/test/testcases/test_web_api/test_canvas_app/test_iterationitem_unit.py @@ -0,0 +1,148 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace +from unittest.mock import MagicMock + +import pytest + + +def _load_iterationitem_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + quart = ModuleType("quart") + quart.make_response = lambda *a, **kw: None + quart.jsonify = lambda *a, **kw: None + monkeypatch.setitem(sys.modules, "quart", quart) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + constants = ModuleType("common.constants") + + class _RetCode: + SUCCESS = 0 + EXCEPTION_ERROR = 100 + + constants.RetCode = _RetCode + monkeypatch.setitem(sys.modules, "common.constants", constants) + + conn_spec = importlib.util.spec_from_file_location( + "common.connection_utils", repo_root / "common" / "connection_utils.py" + ) + conn_mod = importlib.util.module_from_spec(conn_spec) + monkeypatch.setitem(sys.modules, "common.connection_utils", conn_mod) + conn_spec.loader.exec_module(conn_mod) + + misc_spec = importlib.util.spec_from_file_location( + "common.misc_utils", repo_root / "common" / "misc_utils.py" + ) + misc_mod = importlib.util.module_from_spec(misc_spec) + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_mod) + misc_spec.loader.exec_module(misc_mod) + + agent_pkg = ModuleType("agent") + agent_pkg.__path__ = [str(repo_root / "agent")] + monkeypatch.setitem(sys.modules, "agent", agent_pkg) + + agent_settings = ModuleType("agent.settings") + agent_settings.FLOAT_ZERO = 1e-8 + agent_settings.PARAM_MAXDEPTH = 5 + monkeypatch.setitem(sys.modules, "agent.settings", agent_settings) + + component_pkg = ModuleType("agent.component") + component_pkg.__path__ = [str(repo_root / "agent" / "component")] + monkeypatch.setitem(sys.modules, "agent.component", component_pkg) + + canvas_mod = ModuleType("agent.canvas") + + class Graph: + pass + + canvas_mod.Graph = Graph + monkeypatch.setitem(sys.modules, "agent.canvas", canvas_mod) + + base_spec = importlib.util.spec_from_file_location( + "agent.component.base", repo_root / "agent" / "component" / "base.py" + ) + base_mod = importlib.util.module_from_spec(base_spec) + monkeypatch.setitem(sys.modules, "agent.component.base", base_mod) + base_spec.loader.exec_module(base_mod) + + iterationitem_spec = importlib.util.spec_from_file_location( + "agent.component.iterationitem", + repo_root / "agent" / "component" / "iterationitem.py", + ) + iterationitem_mod = importlib.util.module_from_spec(iterationitem_spec) + monkeypatch.setitem( + sys.modules, "agent.component.iterationitem", iterationitem_mod + ) + iterationitem_spec.loader.exec_module(iterationitem_mod) + + return iterationitem_mod + + +def _make_iterationitem(module, values): + canvas = MagicMock() + canvas.is_canceled = MagicMock(return_value=False) + canvas.get_variable_value = MagicMock(return_value=values) + canvas.components = {} + + param = module.IterationItemParam() + param.outputs = {} + param.inputs = {} + + inst = module.IterationItem.__new__(module.IterationItem) + inst._canvas = canvas + inst._id = "IterationItem:test" + inst._param = param + inst._idx = 0 + inst.get_parent = MagicMock( + return_value=SimpleNamespace( + _id="Iteration:test", + _param=SimpleNamespace(items_ref="code:1@tempList"), + component_name="Iteration", + ) + ) + return inst + + +@pytest.mark.p2 +def test_iterationitem_exposes_result_alias_for_each_item(monkeypatch): + module = _load_iterationitem_module(monkeypatch) + item = _make_iterationitem(module, ["a", "b", "c"]) + + item._invoke() + assert item.output("item") == "a" + assert item.output("result") == "a" + assert item.output("index") == 0 + + item._invoke() + assert item.output("item") == "b" + assert item.output("result") == "b" + assert item.output("index") == 1 + + item._invoke() + assert item.output("item") == "c" + assert item.output("result") == "c" + assert item.output("index") == 2 + + item._invoke() + assert item.end() is True diff --git a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py index 339bd19bd0d..52c1ea5de66 100644 --- a/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py +++ b/test/testcases/test_web_api/test_chunk_app/test_chunk_routes_unit.py @@ -377,7 +377,7 @@ def accessible(**_kwargs): @staticmethod def get_by_id(_kb_id): - return True, SimpleNamespace(pagerank=0.6, tenant_embd_id=2, tenant_llm_id=1) + return True, SimpleNamespace(pagerank=0.6, tenant_id="tenant-1", tenant_embd_id=2, tenant_llm_id=1) kb_service_mod.KnowledgebaseService = _KnowledgebaseService monkeypatch.setitem(sys.modules, "api.db.services.knowledgebase_service", kb_service_mod) @@ -653,4 +653,3 @@ def test_restful_chunk_guard_branches_unit(monkeypatch): res = _run(_route_core(module.switch_chunks)("tenant-1", "kb-1", "doc-1")) assert res["message"] == "`available_int` or `available` is required.", res - diff --git a/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py b/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py index 9d9e1c9c14a..605ec415f15 100644 --- a/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py +++ b/test/testcases/test_web_api/test_connector_app/test_connector_routes_unit.py @@ -201,7 +201,11 @@ def list(_tenant_id): return [] @staticmethod - def resume(*_args, **_kwargs): + def accessible(*_args, **_kwargs): + return True + + @staticmethod + def cancel_tasks(*_args, **_kwargs): return True @staticmethod @@ -246,6 +250,7 @@ async def _get_request_json(): SERVER_ERROR=500, RUNNING=102, PERMISSION_ERROR=403, + AUTHENTICATION_ERROR=109, ) constants_mod.TaskStatus = SimpleNamespace(SCHEDULE="schedule", CANCEL="cancel") monkeypatch.setitem(sys.modules, "common.constants", constants_mod) @@ -344,7 +349,7 @@ async def _no_sleep(_secs): records = {"conn-1": _FakeConnectorRecord({"id": "conn-1", "source": "drive"})} update_calls = [] save_calls = [] - resume_calls = [] + cancel_calls = [] delete_calls = [] monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda cid, payload: update_calls.append((cid, payload))) @@ -357,7 +362,7 @@ def _save(**payload): monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda cid: (True, records[cid])) monkeypatch.setattr(module.ConnectorService, "list", lambda tenant_id: [{"id": "listed", "tenant": tenant_id}]) monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda cid, page, page_size: ([{"id": "log-1"}], 9)) - monkeypatch.setattr(module.ConnectorService, "resume", lambda cid, status: resume_calls.append((cid, status))) + monkeypatch.setattr(module.ConnectorService, "cancel_tasks", lambda cid: cancel_calls.append(cid)) monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda cid: delete_calls.append(cid)) monkeypatch.setattr(module, "get_uuid", lambda: "generated-id") @@ -396,14 +401,6 @@ def _save(**payload): logs_res = module.list_logs("conn-log") assert logs_res["data"] == {"total": 9, "logs": [{"id": "log-1"}]} - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": True})) - assert _run(module.resume("conn-r1"))["data"] is True - - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"resume": False})) - assert _run(module.resume("conn-r2"))["data"] is True - assert ("conn-r1", module.TaskStatus.SCHEDULE) in resume_calls - assert ("conn-r2", module.TaskStatus.CANCEL) in resume_calls - monkeypatch.setattr(module, "get_request_json", lambda: _AwaitableValue({"kb_id": "kb-1"})) monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: "rebuild-failed") failed_rebuild = _run(module.rebuild("conn-rb")) @@ -416,10 +413,45 @@ def _save(**payload): rm_res = module.rm_connector("conn-rm") assert rm_res["data"] is True - assert ("conn-rm", module.TaskStatus.CANCEL) in resume_calls + assert cancel_calls == ["conn-rm"] assert delete_calls == ["conn-rm"] +@pytest.mark.p2 +def test_connector_by_id_routes_reject_cross_tenant_access(monkeypatch): + """Verify per-id connector routes stop before body parsing or service access.""" + module = _load_connector_app(monkeypatch) + + touched = [] + monkeypatch.setattr(module.ConnectorService, "accessible", lambda cid, uid: False) + monkeypatch.setattr(module.ConnectorService, "get_by_id", lambda *_args: touched.append("get_by_id")) + monkeypatch.setattr(module.SyncLogsService, "list_sync_tasks", lambda *_args: touched.append("list_sync_tasks")) + monkeypatch.setattr(module.ConnectorService, "cancel_tasks", lambda *_args: touched.append("cancel_tasks")) + monkeypatch.setattr(module.ConnectorService, "delete_by_id", lambda *_args: touched.append("delete_by_id")) + monkeypatch.setattr(module.ConnectorService, "update_by_id", lambda *_args: touched.append("update_by_id")) + monkeypatch.setattr(module.ConnectorService, "rebuild", lambda *_args: touched.append("rebuild")) + + def _get_request_json(): + touched.append("get_request_json") + return _AwaitableValue({"config": {"x": 1}}) + + monkeypatch.setattr(module, "get_request_json", _get_request_json) + + responses = [ + _run(module.update_connector("conn-victim")), + module.get_connector("conn-victim"), + module.list_logs("conn-victim"), + _run(module.rebuild("conn-victim")), + module.rm_connector("conn-victim"), + _run(module.test_connector("conn-victim")), + ] + + assert all(res["code"] == module.RetCode.AUTHENTICATION_ERROR for res in responses) + assert all(res["message"] == "No authorization." for res in responses) + assert all(res["data"] is False for res in responses) + assert touched == [] + + @pytest.mark.p2 def test_connector_oauth_helper_functions(monkeypatch): module = _load_connector_app(monkeypatch) diff --git a/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py b/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py index 8bf9227a5d2..1b4dd47a6a8 100644 --- a/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py +++ b/test/testcases/test_web_api/test_llm_app/test_llm_list_unit.py @@ -252,6 +252,28 @@ async def _get_request_json(): return module +@pytest.mark.p2 +def test_openai_catalog_contains_latest_gpt_models_unit(): + repo_root = Path(__file__).resolve().parents[4] + + openai_provider_path = repo_root / "conf" / "llm_factories.json" + openai_model_path = repo_root / "conf" / "models" / "openai.json" + + with open(openai_provider_path, "r", encoding="utf-8") as f: + factories = json.load(f)["factory_llm_infos"] + + openai_factory = next(item for item in factories if item["name"] == "OpenAI") + factory_model_names = {item["llm_name"] for item in openai_factory["llm"]} + + with open(openai_model_path, "r", encoding="utf-8") as f: + openai_models = json.load(f)["models"] + model_file_names = {item["name"] for item in openai_models} + + for model_name in ["gpt-5.5", "gpt-5.4", "gpt-5.4-mini", "gpt-5.4-nano"]: + assert model_name in factory_model_names + assert model_name in model_file_names + + @pytest.mark.p2 def test_list_app_grouping_availability_and_merge(monkeypatch): module = _load_llm_app(monkeypatch) @@ -262,12 +284,16 @@ def test_list_app_grouping_availability_and_merge(monkeypatch): tenant_rows = [ _TenantLLMRow(id=1, llm_name="fast-emb", llm_factory="FastEmbed", model_type="embedding", api_key="k1", status="1"), _TenantLLMRow(id=2, llm_name="tenant-only", llm_factory="CustomFactory", model_type="chat", api_key="k2", status="1"), + _TenantLLMRow(id=3, llm_name="gpt-5.5", llm_factory="OpenAI", model_type="chat", api_key="k3", status="1"), + _TenantLLMRow(id=4, llm_name="gpt-5.4", llm_factory="OpenAI", model_type="chat", api_key="k4", status="1"), ] monkeypatch.setattr(module.TenantLLMService, "query", lambda **_kwargs: tenant_rows) all_llms = [ _LLMRow(llm_name="tei-embed", fid="Builtin", model_type="embedding", status="1"), _LLMRow(llm_name="fast-emb", fid="FastEmbed", model_type="embedding", status="1"), + _LLMRow(llm_name="gpt-5.5", fid="OpenAI", model_type="chat", status="1"), + _LLMRow(llm_name="gpt-5.4", fid="OpenAI", model_type="chat", status="1"), _LLMRow(llm_name="not-in-status", fid="Other", model_type="chat", status="1"), ] monkeypatch.setattr(module.LLMService, "get_all", lambda: all_llms) @@ -281,7 +307,7 @@ def test_list_app_grouping_availability_and_merge(monkeypatch): assert ensure_calls == ["tenant-1"] data = res["data"] - assert {"Builtin", "FastEmbed", "CustomFactory"}.issubset(set(data.keys())) + assert {"Builtin", "FastEmbed", "CustomFactory", "OpenAI"}.issubset(set(data.keys())) builtin = data["Builtin"][0] assert builtin["llm_name"] == "tei-embed" @@ -295,6 +321,10 @@ def test_list_app_grouping_availability_and_merge(monkeypatch): assert tenant_only["llm_name"] == "tenant-only" assert tenant_only["available"] is True + # Response-level assertion: /llm/list output includes latest OpenAI IDs. + openai_names = {item["llm_name"] for item in data["OpenAI"]} + assert {"gpt-5.5", "gpt-5.4"}.issubset(openai_names) + @pytest.mark.p2 def test_list_app_model_type_filter(monkeypatch): @@ -783,7 +813,7 @@ def _call(req): res = _call({"llm_factory": "FRKey", "llm_name": "m", "model_type": module.LLMType.RERANK.value, "verify": True}) assert res["code"] == 0 - assert "dose not support this model(FRKey/m)" in res["data"]["message"] + assert "does not support this model(FRKey/m)" in res["data"]["message"] res = _call({"llm_factory": "FRFail", "llm_name": "m", "model_type": module.LLMType.RERANK.value, "verify": True}) assert res["code"] == 0 diff --git a/test/testcases/test_web_api/test_memory_app/test_update_memory.py b/test/testcases/test_web_api/test_memory_app/test_update_memory.py index 1fa92b8e448..72ecfaa8ec3 100644 --- a/test/testcases/test_web_api/test_memory_app/test_update_memory.py +++ b/test/testcases/test_web_api/test_memory_app/test_update_memory.py @@ -106,6 +106,14 @@ def test_llm(self, WebApiAuth, add_memory_func): assert res["code"] == 0, res assert res["data"]["llm_id"] == llm_id, res + @pytest.mark.p2 + def test_reject_direct_tenant_model_ids(self, WebApiAuth, add_memory_func): + memory_ids = add_memory_func + payload = {"tenant_llm_id": 999999, "tenant_embd_id": 999998} + res = update_memory(WebApiAuth, memory_ids[0], payload) + assert res["code"] == 101, res + assert "Do not set tenant_llm_id or tenant_embd_id directly" in res["message"], res + @pytest.mark.p2 @pytest.mark.parametrize( "permission", diff --git a/test/testcases/test_web_api/test_search_app/test_search_routes_unit.py b/test/testcases/test_web_api/test_search_app/test_search_routes_unit.py index 3de9f3c1565..9ea8f0f3482 100644 --- a/test/testcases/test_web_api/test_search_app/test_search_routes_unit.py +++ b/test/testcases/test_web_api/test_search_app/test_search_routes_unit.py @@ -225,6 +225,10 @@ class _TenantService: def get_by_id(_tenant_id): return True, SimpleNamespace(id=_tenant_id) + @staticmethod + def get_joined_tenants_by_user_id(_user_id): + return [{"tenant_id": "tenant-1"}, {"tenant_id": "team-tenant-2"}] + class _UserTenantService: @staticmethod def query(**_kwargs): @@ -491,19 +495,30 @@ def test_list_and_delete_route_matrix_unit(monkeypatch): module, {"keywords": "k", "page": "1", "page_size": "1", "orderby": "create_time", "desc": "true", "owner_ids": ["tenant-1"]}, ) - monkeypatch.setattr( - module.SearchService, - "get_by_tenant_ids", - lambda _tenants, _uid, _page, _size, _orderby, _desc, _keywords: ( - [{"id": "x", "tenant_id": "tenant-1"}, {"id": "y", "tenant_id": "tenant-2"}], - 2, - ), - ) + + def _get_by_tenant_ids_filtered(tenants, _uid, page, size, _orderby, _desc, _keywords): + all_items = [{"id": "x", "tenant_id": "tenant-1"}, {"id": "y", "tenant_id": "tenant-1"}] + filtered = [item for item in all_items if item["tenant_id"] in set(tenants)] + total = len(filtered) + if page and size: + filtered = filtered[(page - 1) * size : page * size] + return filtered, total + + monkeypatch.setattr(module.SearchService, "get_by_tenant_ids", _get_by_tenant_ids_filtered) res = module.list_searches() assert res["code"] == 0 - assert res["data"]["total"] == 1 + assert res["data"]["total"] == 2 assert len(res["data"]["search_apps"]) == 1 - assert res["data"]["search_apps"][0]["tenant_id"] == "tenant-1" + + # list: unauthorized owner_ids + _set_request_args( + monkeypatch, + module, + {"keywords": "", "page": "0", "page_size": "10", "orderby": "create_time", "desc": "true", "owner_ids": ["other-tenant"]}, + ) + res = module.list_searches() + assert res["code"] == module.RetCode.OPERATING_ERROR + assert "authorized owner_ids" in res["message"] # list: exception def _raise_list(*_args, **_kwargs): @@ -542,3 +557,63 @@ def _raise_delete(_search_id): res = module.delete_search(search_id="search-1") assert res["code"] == module.RetCode.EXCEPTION_ERROR assert "rm boom" in res["message"] + + +@pytest.mark.p2 +def test_list_searches_authorized_multi_tenant(monkeypatch): + module = _load_search_api(monkeypatch) + captured = {} + + _set_request_args( + monkeypatch, + module, + { + "keywords": "", + "page": "1", + "page_size": "10", + "orderby": "create_time", + "desc": "true", + "owner_ids": ["tenant-1", "team-tenant-2"], + }, + ) + + def _get_by_tenant_ids(owner_ids, user_id, *args, **kwargs): + captured["owner_ids"] = owner_ids + captured["user_id"] = user_id + return ( + [ + {"id": "s1", "tenant_id": "tenant-1"}, + {"id": "s2", "tenant_id": "team-tenant-2"}, + ], + 2, + ) + + monkeypatch.setattr(module.SearchService, "get_by_tenant_ids", _get_by_tenant_ids) + res = module.list_searches() + assert res["code"] == 0 + assert res["data"]["total"] == 2 + assert {s["id"] for s in res["data"]["search_apps"]} == {"s1", "s2"} + assert set(captured["owner_ids"]) == {"tenant-1", "team-tenant-2"} + assert captured["user_id"] == "tenant-1" + + +@pytest.mark.p2 +def test_list_searches_defaults_to_authorized_owner_ids_when_omitted(monkeypatch): + module = _load_search_api(monkeypatch) + captured = {} + + _set_request_args( + monkeypatch, + module, + {"keywords": "", "page": "1", "page_size": "10", "orderby": "create_time", "desc": "true"}, + ) + + def _get_by_tenant_ids(owner_ids, *_args, **_kwargs): + captured["owner_ids"] = owner_ids + return ([], 0) + + monkeypatch.setattr(module.SearchService, "get_by_tenant_ids", _get_by_tenant_ids) + res = module.list_searches() + + assert res["code"] == 0 + assert set(captured["owner_ids"]) == {"tenant-1", "team-tenant-2"} diff --git a/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py b/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py index e183100cd3e..c7d951270ae 100644 --- a/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py +++ b/test/testcases/test_web_api/test_system_app/test_apps_init_unit.py @@ -175,6 +175,96 @@ def _raise_api_token(**_kwargs): assert "api token fallback failed" in caplog.text +@pytest.mark.p2 +def test_load_user_session_fallback(monkeypatch, caplog): + quart_app, apps_module = _load_apps_module(monkeypatch) + + valid_token = "a" * 32 + valid_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token=valid_token) + invalid_token_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token="INVALID_deadbeef") + short_token_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token="too-short") + + async def _case(): + # No Authorization header but a valid session: helper resolves the user. + async with quart_app.test_request_context("/"): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user]) + assert apps_module._load_user() is valid_user + + # Malformed bearer header still falls back to session. + async with quart_app.test_request_context("/", headers={"Authorization": "Bearer"}): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user]) + assert apps_module._load_user() is valid_user + + # Logout-revoked tokens (INVALID_ prefix) are rejected even with a session. + async with quart_app.test_request_context("/"): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [invalid_token_user]) + assert apps_module._load_user() is None + + # Short tokens are rejected (matches the JWT-path length floor). + async with quart_app.test_request_context("/"): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [short_token_user]) + assert apps_module._load_user() is None + + # No session and no header → still None. + async with quart_app.test_request_context("/"): + assert apps_module._load_user() is None + + # Database errors during the session lookup are swallowed and logged. + async with quart_app.test_request_context("/"): + from quart import session + + session["_user_id"] = "user-1" + + def _raise(**_kw): + raise RuntimeError("db down") + + monkeypatch.setattr(apps_module.UserService, "query", _raise) + with caplog.at_level(logging.ERROR): + assert apps_module._load_user() is None + + _run(_case()) + assert "load_user from session failed" in caplog.text + + +@pytest.mark.p2 +def test_load_user_session_fallback_after_token_paths_fail(monkeypatch): + """JWT-decode failures and API-token exhaustion must still fall through + to the session and return the user, not None.""" + quart_app, apps_module = _load_apps_module(monkeypatch) + + valid_token = "b" * 32 + valid_user = SimpleNamespace(id="user-1", email="oidc@example.com", access_token=valid_token) + + def _raise_decode(_self, _auth): + raise RuntimeError("jwt decode boom") + + monkeypatch.setattr(apps_module.Serializer, "loads", _raise_decode) + monkeypatch.setattr(apps_module.APIToken, "query", lambda **_kw: []) + + async def _case(): + # JWT decode fails AND API-token query returns nothing → session wins. + async with quart_app.test_request_context("/", headers={"Authorization": "Bearer junk"}): + from quart import session + + session["_user_id"] = "user-1" + monkeypatch.setattr(apps_module.UserService, "query", lambda **_kw: [valid_user]) + assert apps_module._load_user() is valid_user + + _run(_case()) + + @pytest.mark.p2 def test_login_required_timing_and_login_user_inactive(monkeypatch, caplog): quart_app, apps_module = _load_apps_module(monkeypatch) @@ -227,6 +317,7 @@ async def _case(): assert "Not Found:" in payload["message"] async with quart_app.test_request_context("/protected"): + @apps_module.login_required async def _protected(): return {"ok": True} diff --git a/test/unit_test/agent/sandbox/test_local_provider.py b/test/unit_test/agent/sandbox/test_local_provider.py index e3bcd14865f..25fbe7e03cd 100644 --- a/test/unit_test/agent/sandbox/test_local_provider.py +++ b/test/unit_test/agent/sandbox/test_local_provider.py @@ -3,12 +3,10 @@ import pytest -from agent.sandbox.providers.base import SandboxProviderConfigError from agent.sandbox.providers.local import LocalProvider -def _make_provider(monkeypatch, tmp_path, **overrides): - monkeypatch.setenv("SANDBOX_LOCAL_ENABLED", "true") +def _make_provider(tmp_path, **overrides): config = { "python_bin": sys.executable, "work_dir": str(tmp_path), @@ -24,16 +22,14 @@ def _make_provider(monkeypatch, tmp_path, **overrides): return provider -def test_local_provider_requires_explicit_env_enable(monkeypatch, tmp_path): - monkeypatch.delenv("SANDBOX_LOCAL_ENABLED", raising=False) +def test_local_provider_initializes_from_config(tmp_path): provider = LocalProvider() + provider.initialize({"python_bin": sys.executable, "work_dir": str(tmp_path)}) + assert provider.health_check() is True - with pytest.raises(SandboxProviderConfigError): - provider.initialize({"work_dir": str(tmp_path)}) - -def test_local_provider_executes_python_main(monkeypatch, tmp_path): - provider = _make_provider(monkeypatch, tmp_path) +def test_local_provider_executes_python_main(tmp_path): + provider = _make_provider(tmp_path) instance = provider.create_instance("python") try: @@ -53,8 +49,8 @@ def test_local_provider_executes_python_main(monkeypatch, tmp_path): assert result.metadata["result_value"] == {"message": "hello ragflow"} -def test_local_provider_collects_artifacts(monkeypatch, tmp_path): - provider = _make_provider(monkeypatch, tmp_path) +def test_local_provider_collects_artifacts(tmp_path): + provider = _make_provider(tmp_path) instance = provider.create_instance("python") try: @@ -82,8 +78,8 @@ def test_local_provider_collects_artifacts(monkeypatch, tmp_path): ] -def test_local_provider_times_out(monkeypatch, tmp_path): - provider = _make_provider(monkeypatch, tmp_path, timeout=1) +def test_local_provider_times_out(tmp_path): + provider = _make_provider(tmp_path, timeout=1) instance = provider.create_instance("python") try: diff --git a/test/unit_test/agent/sandbox/test_sandbox_client.py b/test/unit_test/agent/sandbox/test_sandbox_client.py new file mode 100644 index 00000000000..43e944ca62a --- /dev/null +++ b/test/unit_test/agent/sandbox/test_sandbox_client.py @@ -0,0 +1,43 @@ +import pytest +from agent.sandbox import client as sandbox_client +from agent.sandbox.providers.self_managed import SelfManagedProvider + +pytestmark = pytest.mark.p2 + + +def test_client_defaults_to_self_managed(monkeypatch): + class FakeSettingsService: + @staticmethod + def get_by_name(name): + return [] + + monkeypatch.setattr(sandbox_client, "SystemSettingsService", FakeSettingsService) + monkeypatch.setattr(SelfManagedProvider, "initialize", lambda self, config: True) + monkeypatch.setattr(sandbox_client, "_provider_manager", None) + + provider_manager = sandbox_client.get_provider_manager() + + assert provider_manager.get_provider_name() == "self_managed" + assert isinstance(provider_manager.get_provider(), SelfManagedProvider) + + +def test_self_managed_schema_uses_env_for_deployment_defaults(monkeypatch): + monkeypatch.setenv("SANDBOX_EXECUTOR_MANAGER_IMAGE", "custom-executor:latest") + monkeypatch.setenv("SANDBOX_EXECUTOR_MANAGER_POOL_SIZE", "7") + monkeypatch.setenv("SANDBOX_BASE_PYTHON_IMAGE", "custom-python:latest") + monkeypatch.setenv("SANDBOX_BASE_NODEJS_IMAGE", "custom-node:latest") + monkeypatch.setenv("SANDBOX_EXECUTOR_MANAGER_PORT", "19485") + monkeypatch.setenv("SANDBOX_ENABLE_SECCOMP", "true") + monkeypatch.setenv("SANDBOX_MAX_MEMORY", "512m") + monkeypatch.setenv("SANDBOX_TIMEOUT", "25s") + + schema = SelfManagedProvider.get_config_schema() + + assert schema["executor_manager_image"]["default"] == "custom-executor:latest" + assert schema["executor_manager_pool_size"]["default"] == 7 + assert schema["base_python_image"]["default"] == "custom-python:latest" + assert schema["base_nodejs_image"]["default"] == "custom-node:latest" + assert schema["executor_manager_port"]["default"] == 19485 + assert schema["enable_seccomp"]["default"] is True + assert schema["max_memory"]["default"] == "512m" + assert schema["sandbox_timeout"]["default"] == "25s" diff --git a/test/unit_test/agent/sandbox/test_ssh_provider.py b/test/unit_test/agent/sandbox/test_ssh_provider.py new file mode 100644 index 00000000000..74787313dd0 --- /dev/null +++ b/test/unit_test/agent/sandbox/test_ssh_provider.py @@ -0,0 +1,174 @@ +import base64 +from types import SimpleNamespace + +import pytest + +from agent.sandbox.providers.ssh import SSHProvider +from agent.sandbox.result_protocol import RESULT_MARKER_PREFIX + +pytestmark = pytest.mark.p3 + + +class _FakeWritableFile: + def __init__(self, sftp, path: str): + self._sftp = sftp + self._path = path + self._chunks: list[str] = [] + + def write(self, content: str): + self._chunks.append(content) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + self._sftp.files[self._path] = "".join(self._chunks).encode("utf-8") + return False + + +class _FakeReadableFile: + def __init__(self, payload: bytes): + self._payload = payload + + def read(self): + return self._payload + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + +class _FakeSFTP: + def __init__(self): + self.files: dict[str, bytes] = {} + self.closed = False + + def file(self, path: str, mode: str): + if "w" in mode: + return _FakeWritableFile(self, path) + return _FakeReadableFile(self.files[path]) + + def listdir_attr(self, path: str): + prefix = path.rstrip("/") + "/" + names = [] + for file_path, payload in self.files.items(): + if not file_path.startswith(prefix): + continue + relative = file_path[len(prefix):] + if "/" in relative: + continue + names.append( + SimpleNamespace( + filename=relative, + st_mode=0o100644, + st_size=len(payload), + ) + ) + return names + + def close(self): + self.closed = True + + +class _FakeClient: + def __init__(self, sftp: _FakeSFTP): + self._sftp = sftp + self.closed = False + + def open_sftp(self): + return self._sftp + + def close(self): + self.closed = True + + +def _build_provider(): + provider = SSHProvider() + provider.host = "example.com" + provider.port = 22 + provider.username = "ragflow" + provider.password = "secret" + provider.work_dir = "/tmp" + provider.command_template = "cd {workspace} && python3 {script_path}" + provider.timeout = 5 + provider.max_output_bytes = 1024 * 1024 + provider.max_artifacts = 20 + provider.max_artifact_bytes = 1024 * 1024 + provider._initialized = True + return provider + + +def test_ssh_provider_executes_python_main_and_collects_artifacts(monkeypatch): + provider = _build_provider() + fake_sftp = _FakeSFTP() + fake_client = _FakeClient(fake_sftp) + executed_commands: list[str] = [] + + monkeypatch.setattr(provider, "_create_ssh_client", lambda: fake_client) + monkeypatch.setattr(provider, "_create_remote_workspace", lambda client: "/tmp/ws-123") + + def _run_remote_command(client, command: str, timeout: int): + executed_commands.append(command) + if command.startswith("mkdir -p "): + return "", "", 0 + if command.startswith("cd /tmp/ws-123 && python3 /tmp/ws-123/main.py"): + fake_sftp.files["/tmp/ws-123/artifacts/chart.png"] = b"PNGDATA" + payload = base64.b64encode( + b'{"present":true,"value":{"message":"hello ssh"},"type":"json"}' + ).decode("ascii") + return f"debug line\n{RESULT_MARKER_PREFIX}{payload}\n", "", 0 + if command.startswith("rm -rf "): + return "", "", 0 + raise AssertionError(f"Unexpected command: {command}") + + monkeypatch.setattr(provider, "_run_remote_command", _run_remote_command) + + instance = provider.create_instance("python") + result = provider.execute_code( + instance.instance_id, + 'def main() -> dict:\n return {"message": "hello ssh"}\n', + "python", + timeout=5, + ) + provider.destroy_instance(instance.instance_id) + + assert result.exit_code == 0 + assert result.stdout == "debug line\n" + assert result.metadata["result_present"] is True + assert result.metadata["result_value"] == {"message": "hello ssh"} + assert result.metadata["artifacts"] == [ + { + "name": "chart.png", + "content_b64": base64.b64encode(b"PNGDATA").decode("ascii"), + "mime_type": "image/png", + "size": 7, + } + ] + assert "cd /tmp/ws-123 && python3 /tmp/ws-123/main.py" in executed_commands + assert fake_sftp.closed is True + assert fake_client.closed is True + + +def test_ssh_provider_propagates_timeouts(): + provider = _build_provider() + provider._instances["instance-1"] = { + "client": object(), + "sftp": _FakeSFTP(), + "remote_work_dir": "/tmp/ws-123", + "language": "python", + } + + def _timeout(*args, **kwargs): + raise TimeoutError("Execution timed out after 5 seconds") + + provider._run_remote_command = _timeout # type: ignore[method-assign] + + with pytest.raises(TimeoutError, match="Execution timed out"): + provider.execute_code( + "instance-1", + 'def main() -> dict:\n return {"ok": True}\n', + "python", + timeout=5, + ) diff --git a/test/unit_test/api/db/services/test_dialog_service_final_answer.py b/test/unit_test/api/db/services/test_dialog_service_final_answer.py index d38d157059f..30fb1e4c300 100644 --- a/test/unit_test/api/db/services/test_dialog_service_final_answer.py +++ b/test/unit_test/api/db/services/test_dialog_service_final_answer.py @@ -140,6 +140,41 @@ def insert_citations(self, answer, content_ltks, vectors, embd_mdl, **_kwargs): return answer, set() +class _FakeLangfuseObservation: + def __init__(self): + self.updates = [] + self.ended = False + + def update(self, **kwargs): + self.updates.append(kwargs) + + def end(self): + self.ended = True + + +class _FakeLangfuseClient: + instances = [] + fail_start_observation = False + + def __init__(self, **kwargs): + self.init_kwargs = kwargs + self.observation_kwargs = None + self.observation = _FakeLangfuseObservation() + self.instances.append(self) + + def auth_check(self): + return True + + def create_trace_id(self): + return "trace-id" + + def start_observation(self, **kwargs): + if self.fail_start_observation: + raise RuntimeError("langfuse unavailable") + self.observation_kwargs = kwargs + return self.observation + + def _collect(async_gen): async def _run(): return [ev async for ev in async_gen] @@ -356,3 +391,132 @@ def test_async_chat_final_event_carries_decorated_answer(monkeypatch): assert llm_answer in final["answer"], ( f"LLM answer text expected in final event, got: {final['answer']!r}" ) + + +@pytest.mark.p2 +def test_async_chat_langfuse_uses_start_observation(monkeypatch): + """ + Langfuse v4 exposes start_observation(as_type="generation"), not + start_generation(). Keep async_chat() on the migrated API. + """ + _FakeLangfuseClient.instances = [] + monkeypatch.setattr(_FakeLangfuseClient, "fail_start_observation", False) + llm_answer = "RAGFlow traces chat answers through Langfuse." + chat_mdl = _StreamingChatModel(llm_answer) + retriever = _StubRetriever() + + monkeypatch.setattr( + dialog_service.TenantLLMService, "llm_id2llm_type", lambda _llm_id: "chat" + ) + monkeypatch.setattr( + dialog_service.TenantLLMService, "get_model_config", + lambda _tid, _type, _llm_id: _LLM_CONFIG, + ) + monkeypatch.setattr( + dialog_service.TenantLangfuseService, "filter_by_tenant", + lambda tenant_id: SimpleNamespace( + public_key="public", + secret_key="secret", + host="http://langfuse.local", + ), + ) + monkeypatch.setattr(dialog_service, "Langfuse", _FakeLangfuseClient) + monkeypatch.setattr( + dialog_service, + "get_models", + lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None), + ) + monkeypatch.setattr( + dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} + ) + monkeypatch.setattr( + dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] + ) + monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) + monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "") + monkeypatch.setattr( + dialog_service, + "kb_prompt", + lambda _kbinfos, _max_tokens, **_kw: ["RAGFlow is a RAG engine."], + ) + + dialog = _make_dialog(chat_mdl) + messages = [{"role": "user", "content": "What is RAGFlow?"}] + + events = _collect(dialog_service.async_chat(dialog, messages, stream=True, quote=True)) + + assert any(e.get("final") is True for e in events) + assert len(_FakeLangfuseClient.instances) == 1 + langfuse = _FakeLangfuseClient.instances[0] + assert langfuse.observation_kwargs["as_type"] == "generation" + assert langfuse.observation_kwargs["trace_context"] == {"trace_id": "trace-id"} + assert langfuse.observation_kwargs["name"] == "chat" + assert langfuse.observation_kwargs["model"] == _LLM_CONFIG["llm_name"] + input_payload = langfuse.observation_kwargs["input"] + assert set(input_payload.keys()) == {"prompt", "prompt4citation", "messages"} + assert input_payload["prompt"] == "You are helpful. \n------\nRAGFlow is a RAG engine." + assert input_payload["prompt4citation"] == dialog_service.citation_prompt() + assert input_payload["messages"][0]["role"] == "system" + assert input_payload["messages"][0]["content"] == input_payload["prompt"] + assert input_payload["messages"][1] == {"role": "user", "content": "What is RAGFlow?"} + assert langfuse.observation.ended is True + + +@pytest.mark.p2 +def test_async_chat_continues_when_langfuse_observation_start_fails(monkeypatch): + """ + Langfuse tracing is best-effort; observation startup errors must not break + chat responses. + """ + _FakeLangfuseClient.instances = [] + monkeypatch.setattr(_FakeLangfuseClient, "fail_start_observation", True) + llm_answer = "RAGFlow still answers when tracing is unavailable." + chat_mdl = _StreamingChatModel(llm_answer) + retriever = _StubRetriever() + + monkeypatch.setattr( + dialog_service.TenantLLMService, "llm_id2llm_type", lambda _llm_id: "chat" + ) + monkeypatch.setattr( + dialog_service.TenantLLMService, "get_model_config", + lambda _tid, _type, _llm_id: _LLM_CONFIG, + ) + monkeypatch.setattr( + dialog_service.TenantLangfuseService, "filter_by_tenant", + lambda tenant_id: SimpleNamespace( + public_key="public", + secret_key="secret", + host="http://langfuse.local", + ), + ) + monkeypatch.setattr(dialog_service, "Langfuse", _FakeLangfuseClient) + monkeypatch.setattr( + dialog_service, + "get_models", + lambda _dialog: ([_KB], chat_mdl, None, chat_mdl, None), + ) + monkeypatch.setattr( + dialog_service.KnowledgebaseService, "get_field_map", lambda _kb_ids: {} + ) + monkeypatch.setattr( + dialog_service.KnowledgebaseService, "get_by_ids", lambda _ids: [_KB] + ) + monkeypatch.setattr(dialog_service.settings, "retriever", retriever, raising=False) + monkeypatch.setattr(dialog_service, "label_question", lambda _q, _kbs: "") + monkeypatch.setattr( + dialog_service, + "kb_prompt", + lambda _kbinfos, _max_tokens, **_kw: ["RAGFlow is a RAG engine."], + ) + + dialog = _make_dialog(chat_mdl) + messages = [{"role": "user", "content": "What is RAGFlow?"}] + + events = _collect(dialog_service.async_chat(dialog, messages, stream=True, quote=True)) + + final_events = [e for e in events if e.get("final") is True] + assert len(final_events) == 1 + assert llm_answer in final_events[0]["answer"] + assert len(_FakeLangfuseClient.instances) == 1 + assert _FakeLangfuseClient.instances[0].observation_kwargs is None + assert _FakeLangfuseClient.instances[0].observation.ended is False diff --git a/test/unit_test/api/utils/test_doc_validation.py b/test/unit_test/api/utils/test_doc_validation.py index 25e115c4292..b068e2b4999 100644 --- a/test/unit_test/api/utils/test_doc_validation.py +++ b/test/unit_test/api/utils/test_doc_validation.py @@ -18,14 +18,15 @@ from unittest.mock import Mock from api.utils.validation_utils import ( - validate_immutable_fields, + ParserConfig, + UpdateDocumentReq, + validate_chunk_method, validate_document_name, - validate_chunk_method + validate_immutable_fields, ) from api.constants import FILE_NAME_LEN_LIMIT from api.db import FileType from common.constants import RetCode -from api.utils.validation_utils import UpdateDocumentReq def test_validate_immutable_fields_no_changes(): @@ -299,4 +300,15 @@ def test_validate_chunk_method_other_extensions_still_valid(): error_msg, error_code = validate_chunk_method(doc) assert error_msg is None - assert error_code is None \ No newline at end of file + assert error_code is None + + +def test_parser_config_normalizes_legacy_vectorize_table_column_role(): + p = ParserConfig( + table_column_roles={"title": "vectorize", "country": "metadata", "x": "both"}, + ) + assert p.table_column_roles == { + "title": "indexing", + "country": "metadata", + "x": "both", + } \ No newline at end of file diff --git a/test/unit_test/common/test_blob_connector_fingerprint.py b/test/unit_test/common/test_blob_connector_fingerprint.py new file mode 100644 index 00000000000..ec133fd697b --- /dev/null +++ b/test/unit_test/common/test_blob_connector_fingerprint.py @@ -0,0 +1,347 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tests for the FingerprintConnector bypass path in BlobStorageConnector.""" + +import importlib.util +import sys +from datetime import datetime, timezone +from pathlib import Path +from types import ModuleType + +import pytest +import xxhash + + +def _load_blob_connector_module(): + repo_root = Path(__file__).resolve().parents[3] + package_name = "common.data_source" + saved_modules = {name: module for name, module in sys.modules.items() if name == package_name or name.startswith(f"{package_name}.")} + package_stub = ModuleType(package_name) + package_stub.__path__ = [str(repo_root / "common" / "data_source")] + sys.modules[package_name] = package_stub + + try: + spec = importlib.util.spec_from_file_location( + "_blob_connector_under_test", + repo_root / "common" / "data_source" / "blob_connector.py", + ) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + finally: + for name in list(sys.modules): + if name == package_name or name.startswith(f"{package_name}."): + if name in saved_modules: + sys.modules[name] = saved_modules[name] + else: + sys.modules.pop(name, None) + + +blob_connector = _load_blob_connector_module() +BlobStorageConnector = blob_connector.BlobStorageConnector +_normalize_etag = blob_connector._normalize_etag + + +# --------------------------------------------------------------------------- +# Fake S3 client wired through a paginator-style interface. +# --------------------------------------------------------------------------- + + +class _FakePaginator: + def __init__(self, pages: list[dict]) -> None: + self._pages = pages + + def paginate(self, **_kwargs): + for page in self._pages: + yield page + + +class _FakeS3Client: + """Captures every call on the connector's S3 client. + + Tests assert against `get_object_calls` to verify that the fingerprint + bypass actually skips downloads when ETags haven't changed. + """ + + def __init__(self, objects: list[dict]) -> None: + self._objects = objects + self.get_object_calls: list[tuple[str, str]] = [] + # Hand objects to the paginator unmodified so the connector exercises + # its own directory-placeholder filtering logic. + self._paginator = _FakePaginator([{"Contents": list(objects)}]) + + def get_paginator(self, name: str): + assert name == "list_objects_v2" + return self._paginator + + def list_objects_v2(self, **_kwargs): + return {"Contents": self._objects, "KeyCount": len(self._objects)} + + def get_object(self, Bucket: str, Key: str): # noqa: N803 (boto3 API) + self.get_object_calls.append((Bucket, Key)) + body_text = f"body-of-{Key}".encode() + return { + "Body": _FakeBody(body_text), + "ContentLength": len(body_text), + } + + +class _FakeBody: + """Minimal stand-in for botocore's StreamingBody. + + The real downloader (common.data_source.utils.download_object) consumes + the body via iter_chunks() and then calls close(); fake out both. + """ + + def __init__(self, payload: bytes) -> None: + self._payload = payload + + def read(self) -> bytes: + return self._payload + + def iter_chunks(self, chunk_size: int = 65536): + for i in range(0, len(self._payload), chunk_size): + yield self._payload[i : i + chunk_size] + + def close(self) -> None: + return None + + +def _make_connector(s3_client) -> BlobStorageConnector: + connector = BlobStorageConnector(bucket_type="s3", bucket_name="test-bucket") + connector.s3_client = s3_client + return connector + + +def _s3_object(key: str, etag: str, size: int = 12) -> dict: + return { + "Key": key, + "ETag": f'"{etag}"', + "LastModified": datetime(2026, 1, 1, 12, tzinfo=timezone.utc), + "Size": size, + } + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_normalize_etag_returns_32_char_hex_for_singlepart_etag(): + fp = _normalize_etag('"d41d8cd98f00b204e9800998ecf8427e"') + assert fp is not None + assert len(fp) == 32 + assert all(c in "0123456789abcdef" for c in fp) + + +def test_normalize_etag_returns_32_char_hex_for_multipart_etag(): + """Multipart ETags are 34+ chars; hashing normalizes them to 32.""" + fp = _normalize_etag('"d41d8cd98f00b204e9800998ecf8427e-7"') + assert fp is not None + assert len(fp) == 32 + + +def test_normalize_etag_is_deterministic(): + raw = '"abc123def456abc123def456abc123de"' + assert _normalize_etag(raw) == _normalize_etag(raw) + + +def test_normalize_etag_strips_quotes_so_quoted_and_unquoted_match(): + quoted = '"d41d8cd98f00b204e9800998ecf8427e"' + unquoted = "d41d8cd98f00b204e9800998ecf8427e" + assert _normalize_etag(quoted) == _normalize_etag(unquoted) + + +def test_normalize_etag_returns_none_for_empty_input(): + assert _normalize_etag("") is None + assert _normalize_etag(None) is None + + +def test_list_keys_yields_one_keyrecord_per_object_with_fingerprint(): + s3 = _FakeS3Client( + [ + _s3_object("foo.txt", "etag-foo"), + _s3_object("bar/baz.txt", "etag-baz"), + ] + ) + connector = _make_connector(s3) + + records = list(connector.list_keys()) + + assert len(records) == 2 + assert {r.key for r in records} == { + "BlobType.S3:test-bucket:foo.txt", + "BlobType.S3:test-bucket:bar/baz.txt", + } + for record in records: + assert record.fingerprint is not None + assert len(record.fingerprint) == 32 + assert record.deleted is False + + +def test_list_keys_does_not_call_get_object(): + """list_keys() must be cheap -- no body downloads during enumeration.""" + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + + list(connector.list_keys()) + + assert s3.get_object_calls == [] + + +def test_list_keys_skips_directory_placeholder_keys(): + """S3 'folders' are zero-byte keys ending in '/'; they shouldn't yield records.""" + s3 = _FakeS3Client( + [ + _s3_object("real-file.txt", "etag-real"), + _s3_object("folder/", "etag-folder"), + ] + ) + connector = _make_connector(s3) + + keys = [r.key for r in connector.list_keys()] + + assert keys == ["BlobType.S3:test-bucket:real-file.txt"] + + +def test_get_value_returns_document_with_fingerprint_set(): + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + [record] = list(connector.list_keys()) + + doc = connector.get_value(record.key) + + assert doc.id == "BlobType.S3:test-bucket:foo.txt" + assert doc.fingerprint == record.fingerprint + assert doc.fingerprint == xxhash.xxh128(b"etag-foo").hexdigest() + + +def test_get_value_calls_get_object_exactly_once_per_key(): + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + [record] = list(connector.list_keys()) + + connector.get_value(record.key) + + assert s3.get_object_calls == [("test-bucket", "foo.txt")] + + +def test_get_value_raises_keyerror_when_called_before_list_keys(): + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + + with pytest.raises(KeyError): + connector.get_value("BlobType.S3:test-bucket:foo.txt") + + +def test_singlepart_and_multipart_etags_yield_different_fingerprints(): + """Sanity: distinct ETags must produce distinct fingerprints.""" + s3 = _FakeS3Client( + [ + _s3_object("a.bin", "d41d8cd98f00b204e9800998ecf8427e"), + _s3_object("b.bin", "d41d8cd98f00b204e9800998ecf8427e-3"), + ] + ) + connector = _make_connector(s3) + + records = list(connector.list_keys()) + + assert records[0].fingerprint != records[1].fingerprint + + +def test_fingerprint_stable_across_repeated_listings(): + """Same ETag in two list_keys() calls yields the same fingerprint.""" + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-stable")]) + connector = _make_connector(s3) + + fp_first = next(connector.list_keys()).fingerprint + fp_second = next(connector.list_keys()).fingerprint + + assert fp_first == fp_second + + +# --------------------------------------------------------------------------- +# Bypass-logic test: simulates what the orchestrator does in +# _BlobLikeBase._fingerprint_filtered_generator. Verifies that a key whose +# fingerprint matches the persisted content_hash is NOT fetched. +# --------------------------------------------------------------------------- + + +def test_orchestrator_pattern_skips_get_object_when_fingerprint_matches(): + # Use distinct base names: "unchanged.txt".endswith("changed.txt") is True, + # which would silently break endswith-based lookups in the test setup. + s3 = _FakeS3Client( + [ + _s3_object("static.txt", "etag-static"), + _s3_object("modified.txt", "etag-modified"), + ] + ) + connector = _make_connector(s3) + + # Pre-compute the fingerprints the connector would emit, then pretend the + # DB already stores the one for static.txt but a stale value for + # modified.txt. This is the steady-state bypass scenario. + listed = list(connector.list_keys()) + static_record = next(r for r in listed if r.key.endswith(":static.txt")) + modified_record = next(r for r in listed if r.key.endswith(":modified.txt")) + persisted = { + static_record.key: static_record.fingerprint, + modified_record.key: "stale-fingerprint", + } + + # Reset the call log so we only count get_object during the bypass loop. + s3.get_object_calls = [] + + fetched = [] + for record in connector.list_keys(): + if record.fingerprint and persisted.get(record.key) == record.fingerprint: + continue + fetched.append(connector.get_value(record.key)) + + assert [doc.id for doc in fetched] == ["BlobType.S3:test-bucket:modified.txt"] + assert s3.get_object_calls == [("test-bucket", "modified.txt")] + + +def test_orchestrator_pattern_skips_deleted_records_without_calling_get_value(): + """KeyRecord(deleted=True) must short-circuit before get_value(). + + Reach KeyRecord through the already-loaded blob_connector module to avoid + triggering common.data_source.__init__'s circular imports. + """ + KeyRecord = blob_connector.KeyRecord + + s3 = _FakeS3Client([_s3_object("foo.txt", "etag-foo")]) + connector = _make_connector(s3) + + # Manually feed a deleted KeyRecord through the bypass logic to assert the + # short-circuit holds even when a connector emits one. (BlobStorageConnector + # itself doesn't yield deleted records yet -- that's PR-4 -- but the + # orchestrator must already be defensive.) + deleted_record = KeyRecord( + key="BlobType.S3:test-bucket:gone.txt", + fingerprint=None, + deleted=True, + ) + + # Mirror the orchestrator's loop body verbatim. + fetched = [] + for record in [deleted_record]: + if record.deleted: + continue + fetched.append(connector.get_value(record.key)) + + assert fetched == [] + assert s3.get_object_calls == [] diff --git a/test/unit_test/common/test_metadata_es_filter.py b/test/unit_test/common/test_metadata_es_filter.py deleted file mode 100644 index eb8217909e3..00000000000 --- a/test/unit_test/common/test_metadata_es_filter.py +++ /dev/null @@ -1,473 +0,0 @@ -"""Unit tests for the Elasticsearch push-down translator. - -These tests cover the public surface of ``common.metadata_es_filter`` without -touching the live ES cluster. They verify the shape of the produced query DSL -operator-by-operator and confirm that the parity rules with the in-memory -``meta_filter`` (lower-casing, list-membership coercion, date detection) hold. -""" - -import pytest - -from common.metadata_es_filter import ( - META_FIELDS_PREFIX, - MetaFilterPushdownPlan, - MetaFilterTranslator, - SUPPORTED_OPERATORS, - UnsupportedMetaFilter, - build_meta_filter_query, - extract_doc_ids, - is_pushdown_supported, - plan_pushdown, -) - - -# --------------------------------------------------------------------------- -# Fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def translator() -> MetaFilterTranslator: - return MetaFilterTranslator() - - -def _field(key: str) -> str: - return f"{META_FIELDS_PREFIX}.{key}" - - -# --------------------------------------------------------------------------- -# Translator: per-operator shape -# --------------------------------------------------------------------------- - - -def test_equal_translates_to_term_with_lowercased_value(translator): - """String equality runs against ``.keyword`` so multi-word phrases match. - - Querying the analyzed parent field with ``term`` only matches docs whose - inverted index contains the literal phrase token, which never happens for - multi-word values. The ``.keyword`` sub-field stores the unmodified string, - and ``case_insensitive: true`` keeps the lower-cased compare semantics from - the in-memory ``meta_filter``. - """ - clauses = translator.translate({"key": "tag", "op": "=", "value": "Alpha"}).to_clauses() - assert clauses == [ - {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} - ] - - -def test_equal_parses_numeric_literal(translator): - """Numeric values stay on the parent path — no ``.keyword`` sub-field exists for ``long``.""" - clauses = translator.translate({"key": "score", "op": "=", "value": "5"}).to_clauses() - assert clauses == [{"term": {_field("score"): 5}}] - - -def test_equal_multiword_uses_keyword_subfield(translator): - """Regression for qinling0210's report: multi-word string values must match. - - Before the keyword-routing fix this emitted - ``term: meta_fields.author = "alice wonderland"`` against an analyzed text - field, which never matched (inverted index only contained per-token - entries). Routing through ``.keyword`` preserves the full phrase. - """ - clauses = translator.translate( - {"key": "author", "op": "=", "value": "Alice Wonderland"} - ).to_clauses() - assert clauses == [ - { - "term": { - _field("author") + ".keyword": { - "value": "alice wonderland", - "case_insensitive": True, - } - } - } - ] - - -def test_not_equal_requires_field_to_exist(translator): - clauses = translator.translate({"key": "tag", "op": "≠", "value": "alpha"}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [{"exists": {"field": _field("tag")}}], - "must_not": [ - {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} - ], - } - } - ] - - -@pytest.mark.parametrize( - "op,es_key", - [(">", "gt"), ("<", "lt"), ("≥", "gte"), ("≤", "lte")], -) -def test_range_operator_translation(translator, op, es_key): - # Multi-clause positive filters wrap into a single bool so OR-logic - # parents can't match on just the ``exists`` half of the range. - clauses = translator.translate({"key": "score", "op": op, "value": "10"}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [ - {"exists": {"field": _field("score")}}, - {"range": {_field("score"): {es_key: 10}}}, - ] - } - } - ] - - -def test_range_passes_iso_date_through_unparsed(translator): - clauses = translator.translate({"key": "published", "op": "≥", "value": "2025-01-15"}).to_clauses() - range_clause = clauses[0]["bool"]["must"][1] - assert range_clause == {"range": {_field("published"): {"gte": "2025-01-15"}}} - - -def _string_terms_should(field_path: str, members): - """``in``/``not in`` over string members expands per-element so each ``term`` - can carry ``case_insensitive`` (``terms`` does not accept that flag).""" - return { - "bool": { - "should": [ - {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} - for m in members - ], - "minimum_should_match": 1, - } - } - - -def test_in_operator_csv_value_lowercased(translator): - clauses = translator.translate({"key": "status", "op": "in", "value": "Active,Pending"}).to_clauses() - assert clauses == [_string_terms_should(_field("status"), ["active", "pending"])] - - -def test_in_operator_python_list_literal(translator): - clauses = translator.translate({"key": "status", "op": "in", "value": "['Open', 'Closed']"}).to_clauses() - assert clauses == [_string_terms_should(_field("status"), ["open", "closed"])] - - -def test_in_operator_numeric_members_keep_terms(translator): - """All-numeric member lists keep the cheaper ``terms`` form on the parent path.""" - clauses = translator.translate({"key": "year", "op": "in", "value": "[2024, 2025]"}).to_clauses() - assert clauses == [{"terms": {_field("year"): [2024, 2025]}}] - - -def test_not_in_negates_with_existence_guard(translator): - clauses = translator.translate({"key": "status", "op": "not in", "value": "active,pending"}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [{"exists": {"field": _field("status")}}], - "must_not": [_string_terms_should(_field("status"), ["active", "pending"])], - } - } - ] - - -def test_contains_uses_case_insensitive_wildcard(translator): - clauses = translator.translate({"key": "version", "op": "contains", "value": "earth"}).to_clauses() - assert clauses == [ - { - "wildcard": { - _field("version") + ".keyword": { - "value": "*earth*", - "case_insensitive": True, - } - } - } - ] - - -def test_contains_escapes_user_wildcards(translator): - clauses = translator.translate({"key": "title", "op": "contains", "value": "a*b?c"}).to_clauses() - pattern = clauses[0]["wildcard"][_field("title") + ".keyword"]["value"] - assert pattern == "*a\\*b\\?c*" - - -def test_not_contains_negates_with_exists(translator): - clauses = translator.translate({"key": "version", "op": "not contains", "value": "earth"}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [{"exists": {"field": _field("version")}}], - "must_not": [ - { - "wildcard": { - _field("version") + ".keyword": { - "value": "*earth*", - "case_insensitive": True, - } - } - } - ], - } - } - ] - - -def test_start_with_uses_prefix(translator): - clauses = translator.translate({"key": "name", "op": "start with", "value": "pre"}).to_clauses() - assert clauses == [ - {"prefix": {_field("name") + ".keyword": {"value": "pre", "case_insensitive": True}}} - ] - - -def test_end_with_uses_trailing_wildcard(translator): - clauses = translator.translate({"key": "file", "op": "end with", "value": ".pdf"}).to_clauses() - pattern = clauses[0]["wildcard"][_field("file") + ".keyword"]["value"] - assert pattern == "*.pdf" - - -def test_empty_matches_missing_or_blank(translator): - clauses = translator.translate({"key": "notes", "op": "empty", "value": ""}).to_clauses() - assert clauses == [ - { - "bool": { - "should": [ - {"bool": {"must_not": [{"exists": {"field": _field("notes")}}]}}, - {"term": {_field("notes") + ".keyword": ""}}, - ], - "minimum_should_match": 1, - } - } - ] - - -def test_not_empty_requires_exists_and_excludes_blank(translator): - clauses = translator.translate({"key": "notes", "op": "not empty", "value": ""}).to_clauses() - assert clauses == [ - { - "bool": { - "must": [{"exists": {"field": _field("notes")}}], - "must_not": [{"term": {_field("notes") + ".keyword": ""}}], - } - } - ] - - -# --------------------------------------------------------------------------- -# Translator: validation paths -# --------------------------------------------------------------------------- - - -def test_unknown_operator_raises(translator): - with pytest.raises(UnsupportedMetaFilter) as exc: - translator.translate({"key": "tag", "op": "regex", "value": "^foo"}) - assert "regex" in exc.value.reason - - -def test_missing_key_raises(translator): - with pytest.raises(UnsupportedMetaFilter): - translator.translate({"op": "=", "value": "x"}) - - -def test_scalar_op_with_list_value_raises(translator): - with pytest.raises(UnsupportedMetaFilter): - translator.translate({"key": "tag", "op": "=", "value": ["a", "b"]}) - - -def test_string_op_with_empty_value_raises(translator): - with pytest.raises(UnsupportedMetaFilter): - translator.translate({"key": "tag", "op": "contains", "value": ""}) - - -def test_membership_with_empty_csv_raises(translator): - with pytest.raises(UnsupportedMetaFilter): - translator.translate({"key": "tag", "op": "in", "value": ""}) - - -def test_supported_operator_set_matches_documentation(): - expected = { - "=", - "≠", - ">", - "<", - "≥", - "≤", - "in", - "not in", - "contains", - "not contains", - "start with", - "end with", - "empty", - "not empty", - } - assert SUPPORTED_OPERATORS == expected - - -# --------------------------------------------------------------------------- -# Plan composition -# --------------------------------------------------------------------------- - - -def test_plan_emits_must_clauses_for_and_logic(): - plan = plan_pushdown( - [ - {"key": "tag", "op": "=", "value": "alpha"}, - {"key": "score", "op": ">", "value": "5"}, - ], - logic="and", - ) - assert isinstance(plan, MetaFilterPushdownPlan) - body = plan.to_query(["kb1"]) - bool_root = body["query"]["bool"] - assert bool_root["filter"][0] == {"terms": {"kb_id": ["kb1"]}} - inner = bool_root["filter"][1]["bool"] - assert "must" in inner - # Each translated filter contributes exactly one clause to the parent bool: - # ``=`` is a single ``term``; ``>`` is wrapped into one atomic ``bool``. - assert len(inner["must"]) == 2 - expected_tag_term = { - "term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}} - } - assert expected_tag_term in inner["must"] - range_wrap = { - "bool": { - "must": [ - {"exists": {"field": _field("score")}}, - {"range": {_field("score"): {"gt": 5}}}, - ] - } - } - assert range_wrap in inner["must"] - - -def test_range_filter_under_or_stays_atomic(): - """An OR'd range must not split into independent ``exists`` + ``range`` should branches.""" - body = build_meta_filter_query( - [ - {"key": "tag", "op": "=", "value": "alpha"}, - {"key": "score", "op": ">", "value": "5"}, - ], - logic="or", - kb_ids=["kb1"], - ) - should = body["query"]["bool"]["filter"][1]["bool"]["should"] - # Two filters → two should branches, not three or four. - assert len(should) == 2 - assert { - "term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}} - } in should - - -def test_plan_emits_should_clauses_for_or_logic(): - plan = plan_pushdown( - [ - {"key": "tag", "op": "=", "value": "alpha"}, - {"key": "tag", "op": "=", "value": "beta"}, - ], - logic="or", - ) - inner = plan.to_query(["kb1"])["query"]["bool"]["filter"][1]["bool"] - assert inner["minimum_should_match"] == 1 - assert len(inner["should"]) == 2 - - -def test_unknown_logic_rejected(): - with pytest.raises(UnsupportedMetaFilter): - plan_pushdown([{"key": "k", "op": "=", "value": "v"}], logic="xor") - - -def test_empty_filter_list_returns_kb_only_query(): - body = build_meta_filter_query([], "and", ["kb1", "kb2"]) - assert body == {"query": {"bool": {"filter": [{"terms": {"kb_id": ["kb1", "kb2"]}}]}}} - - -def test_negative_filter_in_or_logic_keeps_negation_scope(): - """Wrapping ``≠`` in an OR should not let the ``must_not`` swallow other branches. - - ``≠`` is rejected by :func:`is_pushdown_supported` for multi-value safety, so - this test exercises the translator directly to confirm the per-filter - wrapping invariant. The same shape protects ``not contains`` (which IS - pushed down) from leaking its ``must_not`` into a parent should. - """ - body = build_meta_filter_query( - [ - {"key": "tag", "op": "=", "value": "alpha"}, - {"key": "tag", "op": "≠", "value": "beta"}, - ], - logic="or", - kb_ids=["kb1"], - ) - inner = body["query"]["bool"]["filter"][1]["bool"] - should = inner["should"] - assert should[0] == { - "term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}} - } - # The ≠ branch is wrapped so its must_not does not bleed into the OR set. - assert "bool" in should[1] - assert "must_not" in should[1]["bool"] - - -# --------------------------------------------------------------------------- -# is_pushdown_supported pre-check -# --------------------------------------------------------------------------- - - -def test_pushdown_check_accepts_known_ops(): - assert is_pushdown_supported( - [ - {"key": "tag", "op": "=", "value": "v"}, - {"key": "tag", "op": "contains", "value": "x"}, - ] - ) - - -def test_pushdown_check_rejects_unknown_op(): - assert not is_pushdown_supported([{"key": "tag", "op": "regex", "value": "^v"}]) - - -def test_pushdown_check_rejects_missing_key(): - assert not is_pushdown_supported([{"op": "=", "value": "v"}]) - - -@pytest.mark.parametrize("op", ["≠", "not in"]) -def test_pushdown_check_rejects_multivalue_unsafe_negatives(op): - """Negatives that diverge on multi-valued fields force the in-memory fallback.""" - assert not is_pushdown_supported([{"key": "tag", "op": op, "value": "x"}]) - - -def test_pushdown_check_one_unsafe_op_rejects_whole_request(): - """Mixing one unsafe op with safe ones still falls back, preserving correctness.""" - assert not is_pushdown_supported( - [ - {"key": "tag", "op": "=", "value": "v"}, - {"key": "tag", "op": "≠", "value": "w"}, - ] - ) - - -def test_pushdown_check_accepts_not_contains(): - """``not contains`` stays in push-down; ``all(not contains)`` ≡ ``not any(contains)``.""" - assert is_pushdown_supported([{"key": "tag", "op": "not contains", "value": "x"}]) - - -# --------------------------------------------------------------------------- -# extract_doc_ids -# --------------------------------------------------------------------------- - - -def test_extract_doc_ids_from_dict_response(): - response = { - "hits": { - "hits": [ - {"_id": "doc1", "_source": {"id": "doc1"}}, - {"_id": "doc2", "_source": {"id": "doc2"}}, - ] - } - } - assert extract_doc_ids(response) == ["doc1", "doc2"] - - -def test_extract_doc_ids_falls_back_to_source_id(): - response = {"hits": {"hits": [{"_source": {"id": "src-id"}}]}} - assert extract_doc_ids(response) == ["src-id"] - - -def test_extract_doc_ids_empty_response(): - assert extract_doc_ids({}) == [] - assert extract_doc_ids({"hits": {}}) == [] - assert extract_doc_ids({"hits": {"hits": []}}) == [] diff --git a/test/unit_test/common/test_metadata_filter.py b/test/unit_test/common/test_metadata_filter.py new file mode 100644 index 00000000000..d48b30fb6cd --- /dev/null +++ b/test/unit_test/common/test_metadata_filter.py @@ -0,0 +1,659 @@ +"""Unit tests for the metadata filter push-down translators (ES and Infinity). + +Verifies the shape of the produced filter expressions for both ES DSL and +Infinity SQL, and confirms that coercion rules (lower-casing, list-membership, +date detection) are consistent between the two backends. +""" + +import pytest + +pytestmark = pytest.mark.p2 + +from common.metadata_es_filter import MetaFilterTranslator as ESMetaFilterTranslator +from common.metadata_infinity_filter import ( + MetaFilterTranslator as InfinityMetaFilterTranslator, + SUPPORTED_OPERATORS, + build_infinity_filter, + is_pushdown_supported, + plan_pushdown, + extract_doc_ids, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture +def es_translator() -> ESMetaFilterTranslator: + return ESMetaFilterTranslator() + + +@pytest.fixture +def infinity_translator() -> InfinityMetaFilterTranslator: + return InfinityMetaFilterTranslator() + + +# --------------------------------------------------------------------------- +# Shared: is_pushdown_supported pre-check (same logic for both backends) +# --------------------------------------------------------------------------- + + +def test_pushdown_check_accepts_known_ops(): + assert is_pushdown_supported( + [ + {"key": "tag", "op": "=", "value": "v"}, + {"key": "tag", "op": "contains", "value": "x"}, + ] + ) + + +def test_pushdown_check_rejects_unknown_op(): + assert not is_pushdown_supported([{"key": "tag", "op": "regex", "value": "^v"}]) + + +def test_pushdown_check_rejects_missing_key(): + assert not is_pushdown_supported([{"op": "=", "value": "v"}]) + + +def test_pushdown_check_accepts_not_contains(): + assert is_pushdown_supported([{"key": "tag", "op": "not contains", "value": "x"}]) + + +# --------------------------------------------------------------------------- +# Shared: plan_pushdown (same logic for both backends) +# --------------------------------------------------------------------------- + + +def test_plan_pushdown_and_logic(): + fragments = plan_pushdown( + [ + {"key": "tag", "op": "=", "value": "alpha"}, + {"key": "score", "op": ">", "value": "5"}, + ], + logic="and", + ) + assert len(fragments) == 2 + + +def test_plan_pushdown_or_logic(): + fragments = plan_pushdown( + [ + {"key": "tag", "op": "=", "value": "alpha"}, + {"key": "tag", "op": "=", "value": "beta"}, + ], + logic="or", + ) + assert len(fragments) == 2 + + +def test_unknown_logic_rejected(): + with pytest.raises(ValueError): + plan_pushdown([{"key": "k", "op": "=", "value": "v"}], logic="xor") + + +# --------------------------------------------------------------------------- +# Shared: extract_doc_ids (same implementation) +# --------------------------------------------------------------------------- + + +def test_extract_doc_ids_from_dataframe(): + import pandas as pd + + df = pd.DataFrame({"id": ["doc1", "doc2", "doc3"]}) + assert extract_doc_ids(df) == ["doc1", "doc2", "doc3"] + + +def test_extract_doc_ids_empty_dataframe(): + import pandas as pd + + df = pd.DataFrame({"id": []}) + assert extract_doc_ids(df) == [] + + +def test_extract_doc_ids_none_input(): + assert extract_doc_ids(None) == [] + + +def test_extract_doc_ids_non_dataframe(): + assert extract_doc_ids("not a dataframe") == [] + + +# --------------------------------------------------------------------------- +# Shared: SUPPORTED_OPERATORS +# --------------------------------------------------------------------------- + + +def test_supported_operator_set_matches_documentation(): + expected = { + "=", + "≠", + ">", + "<", + "≥", + "≤", + "in", + "not in", + "contains", + "not contains", + "start with", + "end with", + "empty", + "not empty", + } + assert SUPPORTED_OPERATORS == expected + + +# =========================================================================== +# ES-only tests +# =========================================================================== + + +def test_equal_translates_to_term_with_lowercased_value(es_translator): + """String equality runs against ``.keyword`` so multi-word phrases match.""" + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "tag", "op": "=", "value": "Alpha"}).to_clauses() + assert clauses == [ + {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} + ] + + +def test_equal_parses_numeric_literal(es_translator): + """Numeric values stay on the parent path — no ``.keyword`` sub-field exists for ``long``.""" + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "score", "op": "=", "value": "5"}).to_clauses() + assert clauses == [{"term": {_field("score"): 5}}] + + +def test_equal_multiword_uses_keyword_subfield(es_translator): + """Regression: multi-word string values must match via .keyword sub-field.""" + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate( + {"key": "author", "op": "=", "value": "Alice Wonderland"} + ).to_clauses() + assert clauses == [ + { + "term": { + _field("author") + ".keyword": { + "value": "alice wonderland", + "case_insensitive": True, + } + } + } + ] + + +def test_not_equal_requires_field_to_exist(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "tag", "op": "≠", "value": "alpha"}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [{"exists": {"field": _field("tag")}}], + "must_not": [ + {"term": {_field("tag") + ".keyword": {"value": "alpha", "case_insensitive": True}}} + ], + } + } + ] + + +@pytest.mark.parametrize( + "op,es_key", + [(">", "gt"), ("<", "lt"), ("≥", "gte"), ("≤", "lte")], +) +def test_range_operator_translation(es_translator, op, es_key): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "score", "op": op, "value": "10"}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [ + {"exists": {"field": _field("score")}}, + {"range": {_field("score"): {es_key: 10}}}, + ] + } + } + ] + + +def test_range_passes_iso_date_through_unparsed(es_translator): + clauses = es_translator.translate({"key": "published", "op": "≥", "value": "2025-01-15"}).to_clauses() + range_clause = clauses[0]["bool"]["must"][1] + assert range_clause == {"range": {"meta_fields.published": {"gte": "2025-01-15"}}} + + +def test_in_operator_csv_value_lowercased(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + def _string_terms_should(field_path: str, members): + return { + "bool": { + "should": [ + {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} + for m in members + ], + "minimum_should_match": 1, + } + } + + clauses = es_translator.translate({"key": "status", "op": "in", "value": "Active,Pending"}).to_clauses() + assert clauses == [_string_terms_should(_field("status"), ["active", "pending"])] + + +def test_in_operator_python_list_literal(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + def _string_terms_should(field_path: str, members): + return { + "bool": { + "should": [ + {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} + for m in members + ], + "minimum_should_match": 1, + } + } + + clauses = es_translator.translate({"key": "status", "op": "in", "value": "['Open', 'Closed']"}).to_clauses() + assert clauses == [_string_terms_should(_field("status"), ["open", "closed"])] + + +def test_in_operator_numeric_members_keep_terms(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "year", "op": "in", "value": "[2024, 2025]"}).to_clauses() + assert clauses == [{"terms": {_field("year"): [2024, 2025]}}] + + +def test_not_in_negates_with_existence_guard(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + def _string_terms_should(field_path: str, members): + return { + "bool": { + "should": [ + {"term": {field_path + ".keyword": {"value": m, "case_insensitive": True}}} + for m in members + ], + "minimum_should_match": 1, + } + } + + clauses = es_translator.translate({"key": "status", "op": "not in", "value": "active,pending"}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [{"exists": {"field": _field("status")}}], + "must_not": [_string_terms_should(_field("status"), ["active", "pending"])], + } + } + ] + + +def test_contains_uses_case_insensitive_wildcard(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "version", "op": "contains", "value": "earth"}).to_clauses() + assert clauses == [ + { + "wildcard": { + _field("version") + ".keyword": { + "value": "*earth*", + "case_insensitive": True, + } + } + } + ] + + +def test_contains_escapes_user_wildcards(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "title", "op": "contains", "value": "a*b?c"}).to_clauses() + pattern = clauses[0]["wildcard"][_field("title") + ".keyword"]["value"] + assert pattern == "*a\\*b\\?c*" + + +def test_not_contains_negates_with_exists(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "version", "op": "not contains", "value": "earth"}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [{"exists": {"field": _field("version")}}], + "must_not": [ + { + "wildcard": { + _field("version") + ".keyword": { + "value": "*earth*", + "case_insensitive": True, + } + } + } + ], + } + } + ] + + +def test_start_with_uses_prefix(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "name", "op": "start with", "value": "pre"}).to_clauses() + assert clauses == [ + {"prefix": {_field("name") + ".keyword": {"value": "pre", "case_insensitive": True}}} + ] + + +def test_end_with_uses_trailing_wildcard(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "file", "op": "end with", "value": ".pdf"}).to_clauses() + pattern = clauses[0]["wildcard"][_field("file") + ".keyword"]["value"] + assert pattern == "*.pdf" + + +def test_empty_matches_missing_or_blank(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "notes", "op": "empty", "value": ""}).to_clauses() + assert clauses == [ + { + "bool": { + "should": [ + {"bool": {"must_not": [{"exists": {"field": _field("notes")}}]}}, + {"term": {_field("notes") + ".keyword": ""}}, + ], + "minimum_should_match": 1, + } + } + ] + + +def test_not_empty_requires_exists_and_excludes_blank(es_translator): + from common.metadata_es_filter import META_FIELDS_PREFIX + + def _field(key: str) -> str: + return f"{META_FIELDS_PREFIX}.{key}" + + clauses = es_translator.translate({"key": "notes", "op": "not empty", "value": ""}).to_clauses() + assert clauses == [ + { + "bool": { + "must": [{"exists": {"field": _field("notes")}}], + "must_not": [{"term": {_field("notes") + ".keyword": ""}}], + } + } + ] + + +def test_unknown_operator_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter) as exc: + es_translator.translate({"key": "tag", "op": "regex", "value": "^foo"}) + assert "regex" in exc.value.reason + + +def test_missing_key_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter): + es_translator.translate({"op": "=", "value": "x"}) + + +def test_scalar_op_with_list_value_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter): + es_translator.translate({"key": "tag", "op": "=", "value": ["a", "b"]}) + + +def test_string_op_with_empty_value_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter): + es_translator.translate({"key": "tag", "op": "contains", "value": ""}) + + +def test_membership_with_empty_csv_raises(es_translator): + from common.metadata_es_filter import UnsupportedMetaFilter + + with pytest.raises(UnsupportedMetaFilter): + es_translator.translate({"key": "tag", "op": "in", "value": ""}) + + +# =========================================================================== +# Infinity-only tests +# =========================================================================== + + +def test_build_infinity_filter_and_logic(): + body = build_infinity_filter( + [ + {"key": "tag", "op": "=", "value": "alpha"}, + {"key": "score", "op": ">", "value": "5"}, + ], + logic="and", + ) + assert " AND " in body + assert "alpha" in body + + +def test_build_infinity_filter_or_logic(): + body = build_infinity_filter( + [ + {"key": "tag", "op": "=", "value": "alpha"}, + {"key": "tag", "op": "=", "value": "beta"}, + ], + logic="or", + ) + assert " OR " in body + assert "alpha" in body + assert "beta" in body + + +def test_empty_filter_list_returns_1eq1(): + body = build_infinity_filter([], "and") + assert body == "1=1" + + +def test_infinity_equal_string_uses_lowercase(infinity_translator): + cond = infinity_translator.translate({"key": "tag", "op": "=", "value": "Alpha"}) + assert cond == "JSON_CONTAINS(meta_fields, '$.tag', '\"Alpha\"')" + + +def test_infinity_equal_numeric_keeps_number(infinity_translator): + cond = infinity_translator.translate({"key": "score", "op": "=", "value": "5"}) + assert cond == "JSON_CONTAINS(meta_fields, '$.score', 5)" + + +def test_infinity_equal_date_passes_unparsed(infinity_translator): + cond = infinity_translator.translate({"key": "published", "op": "=", "value": "2025-01-15"}) + assert cond == "JSON_CONTAINS(meta_fields, '$.published', '\"2025-01-15\"')" + + +def test_infinity_not_equal_string(infinity_translator): + cond = infinity_translator.translate({"key": "tag", "op": "≠", "value": "alpha"}) + assert "JSON_CONTAINS" in cond + assert "alpha" in cond + assert "NOT" in cond + + +def test_infinity_not_equal_numeric(infinity_translator): + cond = infinity_translator.translate({"key": "score", "op": "≠", "value": "5"}) + assert "JSON_CONTAINS" in cond and "NOT" in cond and "5" in cond + + +@pytest.mark.parametrize("op,sql_op", [(">", ">"), ("<", "<"), ("≥", ">="), ("≤", "<=")]) +def test_infinity_range_operators(infinity_translator, op, sql_op): + cond = infinity_translator.translate({"key": "score", "op": op, "value": "10"}) + assert sql_op in cond + assert "JSON_EXTRACT_DOUBLE(meta_fields, '$.score')" in cond + + +def test_infinity_range_string_value(infinity_translator): + cond = infinity_translator.translate({"key": "published", "op": "≥", "value": "2025-01-15"}) + assert ">=" in cond + assert "2025-01-15" in cond + + +def test_infinity_in_csv_lowercased(infinity_translator): + cond = infinity_translator.translate({"key": "status", "op": "in", "value": "Active,Pending"}) + assert "JSON_CONTAINS" in cond + assert "active" in cond + assert "pending" in cond + + +def test_infinity_in_python_list(infinity_translator): + cond = infinity_translator.translate({"key": "status", "op": "in", "value": "['Open', 'Closed']"}) + assert "JSON_CONTAINS" in cond + assert "open" in cond + assert "closed" in cond + + +def test_infinity_in_numeric_members(infinity_translator): + cond = infinity_translator.translate({"key": "year", "op": "in", "value": "[2024, 2025]"}) + assert "JSON_CONTAINS" in cond + assert "2024" in cond + assert "2025" in cond + + +def test_infinity_not_in_csv(infinity_translator): + cond = infinity_translator.translate({"key": "status", "op": "not in", "value": "active,pending"}) + assert "NOT JSON_CONTAINS" in cond + + +def test_infinity_contains_uses_JSON_CONTAINS(infinity_translator): + """Infinity 'contains' uses JSON_CONTAINS for JSON array membership.""" + cond = infinity_translator.translate({"key": "version", "op": "contains", "value": "earth"}) + assert "JSON_CONTAINS" in cond + assert "earth" in cond + + +def test_infinity_contains_escapes_quotes(infinity_translator): + """Special characters in contains value are escaped for JSON_CONTAINS.""" + cond = infinity_translator.translate({"key": "title", "op": "contains", "value": "a%b_c"}) + assert "JSON_CONTAINS" in cond + assert "a%b_c" in cond + + +def test_infinity_not_contains_uses_JSON_CONTAINS(infinity_translator): + """Infinity 'not contains' uses JSON_CONTAINS with NOT.""" + cond = infinity_translator.translate({"key": "version", "op": "not contains", "value": "earth"}) + assert "JSON_CONTAINS" in cond + assert "NOT" in cond or "not" in cond.lower() + + +def test_infinity_start_with(infinity_translator): + cond = infinity_translator.translate({"key": "name", "op": "start with", "value": "pre"}) + assert "LIKE" in cond + assert "'pre%" in cond + + +def test_infinity_end_with(infinity_translator): + """Infinity 'end with' uses LIKE with trailing wildcard.""" + cond = infinity_translator.translate({"key": "file", "op": "end with", "value": ".pdf"}) + assert "LIKE" in cond + assert "%.pdf" in cond + + +def test_infinity_empty(infinity_translator): + cond = infinity_translator.translate({"key": "notes", "op": "empty", "value": ""}) + assert "JSON_EXTRACT_STRING" in cond + assert '""' in cond + + +def test_infinity_not_empty(infinity_translator): + cond = infinity_translator.translate({"key": "notes", "op": "not empty", "value": ""}) + assert "JSON_EXTRACT_STRING" in cond + assert "!=" in cond + + +def test_infinity_unknown_operator_raises(infinity_translator): + with pytest.raises(ValueError) as exc: + infinity_translator.translate({"key": "tag", "op": "regex", "value": "^foo"}) + assert "regex" in str(exc.value) + + +def test_infinity_missing_key_raises(infinity_translator): + with pytest.raises(ValueError): + infinity_translator.translate({"op": "=", "value": "x"}) + + +def test_infinity_invalid_key_format_raises(infinity_translator): + with pytest.raises(ValueError, match="invalid key format"): + infinity_translator.translate({"key": "a;b", "op": "=", "value": "x"}) + + +def test_infinity_key_with_brace_raises(infinity_translator): + with pytest.raises(ValueError, match="invalid key format"): + infinity_translator.translate({"key": "field$}", "op": "=", "value": "x"}) + + +def test_infinity_scalar_op_with_list_value_raises(infinity_translator): + with pytest.raises(ValueError): + infinity_translator.translate({"key": "tag", "op": "=", "value": ["a", "b"]}) + + +def test_infinity_string_op_with_empty_value_raises(infinity_translator): + with pytest.raises(ValueError): + infinity_translator.translate({"key": "tag", "op": "contains", "value": ""}) + + +def test_infinity_membership_with_empty_csv_raises(infinity_translator): + with pytest.raises(ValueError): + infinity_translator.translate({"key": "tag", "op": "in", "value": ""}) \ No newline at end of file diff --git a/test/unit_test/data_source/__init__.py b/test/unit_test/data_source/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/unit_test/data_source/conftest.py b/test/unit_test/data_source/conftest.py new file mode 100644 index 00000000000..0ba0f86071b --- /dev/null +++ b/test/unit_test/data_source/conftest.py @@ -0,0 +1,36 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Pre-register the ``common.data_source`` package namespace so that +importing individual sub-modules (config, exceptions, rest_api_connector, …) +does **not** trigger ``common/data_source/__init__.py``, which pulls in every +connector and their heavy transitive dependencies (numpy, xgboost, etc.). + +This file is executed by pytest before any test module in this directory is +collected, so the lightweight namespace is always in place. +""" + +import os +import sys +import types + +import common # lightweight top-level package + +if "common.data_source" not in sys.modules: + _pkg = types.ModuleType("common.data_source") + _pkg.__path__ = [os.path.join(p, "data_source") for p in common.__path__] + _pkg.__package__ = "common.data_source" + sys.modules["common.data_source"] = _pkg diff --git a/test/unit_test/data_source/test_rest_api_connector.py b/test/unit_test/data_source/test_rest_api_connector.py new file mode 100644 index 00000000000..1e7d737eaa7 --- /dev/null +++ b/test/unit_test/data_source/test_rest_api_connector.py @@ -0,0 +1,607 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from contextlib import contextmanager +from unittest.mock import MagicMock, patch + +import pytest +import requests + +from common.data_source import utils as _ds_utils +from common.data_source.exceptions import ( + ConnectorMissingCredentialError, + ConnectorValidationError, +) +from common.data_source.rest_api_connector import ( + AuthType, + PaginationType, + RestAPIConnector, + RestAPIConnectorConfig, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +VALID_URL = "https://api.example.com/v1/items" + +_MOCK_DNS_ADDRINFO = [(2, 1, 6, "", ("93.184.216.34", 0))] + + +@contextmanager +def _mocked_rest_api_requests_and_dns(): + """Block real DNS/TCP: mock SSRF getaddrinfo and HTTP at the class layer. + + `RestAPIConnector` calls `rl_requests.get` / `.post` on + `utils._RateLimitedRequest`. Replacing only module-level `rl_requests` is not + reliable everywhere (import/rebind quirks), so we patch the class methods + that wrap `requests.get` / `requests.post` and avoid retry backoff delays. + """ + mock_rl = MagicMock() + with patch( + "common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=_MOCK_DNS_ADDRINFO, + ), patch.object(_ds_utils._RateLimitedRequest, "get", mock_rl.get), patch.object( + _ds_utils._RateLimitedRequest, + "post", + mock_rl.post, + ): + yield mock_rl + + +def _make_paged_connector(**overrides) -> RestAPIConnector: + defaults = dict( + url=VALID_URL, + content_fields=["title"], + pagination_type=PaginationType.PAGE, + pagination_config={"page_param": "page"}, + max_pages=100, + request_delay=0, + ) + defaults.update(overrides) + return RestAPIConnector(**defaults) + + +def _make_connector(**overrides) -> RestAPIConnector: + """Build a RestAPIConnector with sensible defaults, applying *overrides*.""" + defaults = dict( + url=VALID_URL, + content_fields=["title", "body"], + ) + defaults.update(overrides) + return RestAPIConnector(**defaults) + + +def _mock_response(json_data, status_code=200): + """Return a ``requests.Response``-like mock.""" + resp = MagicMock(spec=requests.Response) + resp.status_code = status_code + resp.url = VALID_URL + resp.json.return_value = json_data + + if status_code >= 400: + http_error = requests.HTTPError(response=resp) + resp.raise_for_status.side_effect = http_error + resp.status_code = status_code + else: + resp.raise_for_status.return_value = None + + return resp + + +# ===================================================================== # +# 1. Config schema validation # +# ===================================================================== # + +class TestRestAPIConfig: + """Test Pydantic RestAPIConnectorConfig schema validation.""" + + def test_missing_url_raises_validation_error(self): + """Missing url should fail Pydantic validation.""" + with pytest.raises(Exception): + RestAPIConnectorConfig(content_fields=["title"]) + + def test_missing_content_fields_detected(self): + """An empty content_fields list should be caught by ensure_required_fields.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields=[]) + with pytest.raises(ConnectorValidationError): + cfg.ensure_required_fields() + + def test_valid_minimal_config(self): + """Minimal valid config: url + content_fields.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields=["title"]) + assert str(cfg.url).startswith("https://api.example.com") + assert cfg.content_fields == ["title"] + + def test_auth_type_defaults_to_none(self): + """auth_type should default to 'none'.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields=["t"]) + assert cfg.auth_type == AuthType.NONE + + def test_pagination_type_defaults_to_none(self): + """pagination_type should default to 'none'.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields=["t"]) + assert cfg.pagination_type == PaginationType.NONE + + def test_string_to_dict_coercion_for_headers(self): + """A key=value string should be coerced to a dict.""" + cfg = RestAPIConnectorConfig( + url=VALID_URL, content_fields=["t"], headers="X-Custom=hello" + ) + assert cfg.headers == {"X-Custom": "hello"} + + def test_string_to_list_coercion_for_content_fields(self): + """A comma-separated string should be coerced to a list.""" + cfg = RestAPIConnectorConfig(url=VALID_URL, content_fields="title,content") + assert cfg.content_fields == ["title", "content"] + + +# ===================================================================== # +# 2. SSRF URL validation # +# ===================================================================== # + +class TestSSRFValidation: + """Test that unsafe URLs are blocked before any HTTP request is made.""" + + def test_localhost_blocked(self): + """localhost should be rejected.""" + with pytest.raises(ConnectorValidationError, match="localhost"): + _make_connector(url="http://localhost/api") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_loopback_ip_blocked(self, mock_dns): + """127.0.0.1 should be rejected.""" + mock_dns.return_value = [(2, 1, 6, "", ("127.0.0.1", 0))] + with pytest.raises(ConnectorValidationError, match="disallowed"): + _make_connector(url="http://127.0.0.1/api") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_cloud_metadata_ip_blocked(self, mock_dns): + """169.254.169.254 (cloud metadata endpoint) should be rejected.""" + mock_dns.return_value = [(2, 1, 6, "", ("169.254.169.254", 0))] + with pytest.raises(ConnectorValidationError, match="disallowed"): + _make_connector(url="http://169.254.169.254/latest/meta-data/") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_private_ip_192_blocked(self, mock_dns): + """192.168.x.x should be rejected.""" + mock_dns.return_value = [(2, 1, 6, "", ("192.168.1.1", 0))] + with pytest.raises(ConnectorValidationError, match="disallowed"): + _make_connector(url="http://192.168.1.1/api") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_private_ip_10_blocked(self, mock_dns): + """10.x.x.x should be rejected.""" + mock_dns.return_value = [(2, 1, 6, "", ("10.0.0.1", 0))] + with pytest.raises(ConnectorValidationError, match="disallowed"): + _make_connector(url="http://10.0.0.1/api") + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo") + def test_public_url_passes(self, mock_dns): + """A public IP should pass validation.""" + mock_dns.return_value = [(2, 1, 6, "", ("93.184.216.34", 0))] + c = _make_connector(url="https://example.com/api") + assert c.url.startswith("https://") + + def test_ftp_scheme_blocked(self): + """ftp:// should be rejected.""" + with pytest.raises(ConnectorValidationError, match="scheme"): + _make_connector(url="ftp://example.com/file") + + def test_file_scheme_blocked(self): + """file:// should be rejected.""" + with pytest.raises(ConnectorValidationError, match="scheme"): + _make_connector(url="file:///etc/passwd") + + +# ===================================================================== # +# 3. Authentication setup # +# ===================================================================== # + +class TestAuthSetup: + """Test _build_auth produces the correct headers / auth objects.""" + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def test_auth_none(self, _dns): + """auth_type=none should produce no auth headers.""" + c = _make_connector(auth_type=AuthType.NONE) + c.load_credentials({}) + assert c._auth_headers == {} + assert c._basic_auth is None + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def test_api_key_header(self, _dns): + """api_key_header should set the specified header.""" + c = _make_connector( + auth_type=AuthType.API_KEY_HEADER, + auth_config={"header_name": "X-API-Key"}, + ) + c.load_credentials({"api_key": "secret123"}) + assert c._auth_headers == {"X-API-Key": "secret123"} + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def test_bearer_token(self, _dns): + """bearer should set Authorization: Bearer .""" + c = _make_connector(auth_type=AuthType.BEARER) + c.load_credentials({"token": "tok_abc"}) + assert c._auth_headers == {"Authorization": "Bearer tok_abc"} + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def test_basic_auth(self, _dns): + """basic should produce an HTTPBasicAuth object.""" + c = _make_connector(auth_type=AuthType.BASIC) + c.load_credentials({"username": "user", "password": "pass"}) + assert c._basic_auth is not None + assert c._basic_auth.username == "user" + assert c._basic_auth.password == "pass" + + +# ===================================================================== # +# 4. Field extraction # +# ===================================================================== # + +class TestFieldExtraction: + """Test _extract_field / _extract_field_values dot-notation paths.""" + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def setup_method(self, method, _dns=None): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + self.connector = _make_connector() + + def test_simple_field(self): + """Top-level field extraction.""" + assert self.connector._extract_field({"title": "Hello"}, "title") == "Hello" + + def test_dot_notation_nested(self): + """Dot-notation nested field.""" + item = {"country": {"name": "Kuwait"}} + assert self.connector._extract_field(item, "country.name") == "Kuwait" + + def test_array_wildcard(self): + """Wildcard [*] returns all array elements.""" + item = {"tags": [{"name": "A"}, {"name": "B"}]} + result = self.connector._extract_field(item, "tags[*].name") + assert result == ["A", "B"] + + def test_missing_field_returns_none(self): + """Missing field returns None.""" + assert self.connector._extract_field({"a": 1}, "nonexistent") is None + + def test_missing_field_with_default(self): + """Missing field returns configured default value.""" + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + c = _make_connector(field_default_values={"missing": "fallback"}) + result = c._get_typed_field_value("missing", {"other": 1}) + assert result == "fallback" + + def test_deeply_nested_path(self): + """Multi-level dot-notation path.""" + item = {"a": {"b": {"c": {"d": 42}}}} + assert self.connector._extract_field(item, "a.b.c.d") == 42 + + +# ===================================================================== # +# 5. Items array detection # +# ===================================================================== # + +class TestItemsArrayDetection: + """Test _extract_items auto-detection of the items array.""" + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def setup_method(self, method, _dns=None): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + self.connector = _make_connector() + + def test_items_key(self): + """Detect 'items' key.""" + resp = {"items": [{"id": 1}]} + assert self.connector._extract_items(resp) == [{"id": 1}] + + def test_results_key(self): + """Detect 'results' key.""" + resp = {"results": [{"id": 2}]} + assert self.connector._extract_items(resp) == [{"id": 2}] + + def test_data_key(self): + """Detect 'data' key.""" + resp = {"data": [{"id": 3}]} + assert self.connector._extract_items(resp) == [{"id": 3}] + + def test_records_key(self): + """Detect 'records' key.""" + resp = {"records": [{"id": 4}]} + assert self.connector._extract_items(resp) == [{"id": 4}] + + def test_custom_key_fallback(self): + """Fall back to the first list value in the dict.""" + resp = {"totalCount": 5, "stories": [{"id": 5}]} + assert self.connector._extract_items(resp) == [{"id": 5}] + + def test_response_is_list(self): + """Response that is directly a list.""" + resp = [{"id": 6}, {"id": 7}] + assert self.connector._extract_items(resp) == [{"id": 6}, {"id": 7}] + + def test_empty_response(self): + """Empty dict returns empty list.""" + assert self.connector._extract_items({}) == [] + + def test_no_list_in_response(self): + """Dict with no list values returns empty list.""" + assert self.connector._extract_items({"count": 0}) == [] + + +# ===================================================================== # +# 6. HTML stripping # +# ===================================================================== # + +class TestHTMLStripping: + """Test the _strip_html static method.""" + + def test_basic_tag_removal(self): + """Remove simple HTML tags.""" + assert RestAPIConnector._strip_html("

Hello

") == "Hello" + + def test_whitespace_collapsing(self): + """Multiple whitespace chars collapse to single space.""" + assert RestAPIConnector._strip_html("

Hello

World

") == "Hello World" + + def test_empty_string(self): + """Empty input returns empty output.""" + assert RestAPIConnector._strip_html("") == "" + + def test_plain_text_passthrough(self): + """Text without HTML passes through unchanged.""" + assert RestAPIConnector._strip_html("Hello World") == "Hello World" + + def test_nested_tags(self): + """Nested HTML tags are all stripped.""" + result = RestAPIConnector._strip_html("

Bold text

") + assert result == "Bold text" + + def test_html_with_attributes(self): + """Tags with attributes are stripped.""" + result = RestAPIConnector._strip_html('Link') + assert result == "Link" + + +# ===================================================================== # +# 7. Document creation # +# ===================================================================== # + +class TestDocumentCreation: + """Test _item_to_document mapping.""" + + @patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]) + def setup_method(self, method, _dns=None): + with patch("common.data_source.rest_api_connector.socket.getaddrinfo", + return_value=[(2, 1, 6, "", ("93.184.216.34", 0))]): + self.connector = _make_connector( + id_field="id", + content_fields=["title", "body"], + metadata_fields=["author"], + ) + + def test_document_id_from_configured_field(self): + """Document ID uses the configured id_field.""" + item = {"id": "abc", "title": "T", "body": "B", "author": "A"} + doc = self.connector._item_to_document(item) + assert doc.id is not None and len(doc.id) > 0 + + def test_semantic_identifier_from_first_content_field(self): + """semantic_identifier comes from the first content field.""" + item = {"id": "1", "title": "My Title", "body": "Body", "author": "A"} + doc = self.connector._item_to_document(item) + assert "My Title" in doc.semantic_identifier + + def test_content_blob_contains_all_fields(self): + """Blob should contain both content fields.""" + item = {"id": "1", "title": "Title", "body": "Body text", "author": "A"} + doc = self.connector._item_to_document(item) + content = doc.blob.decode("utf-8") + assert "Title" in content + assert "Body text" in content + + def test_metadata_populated(self): + """Metadata dict is populated from configured metadata_fields.""" + item = {"id": "1", "title": "T", "body": "B", "author": "Jane"} + doc = self.connector._item_to_document(item) + assert doc.metadata is not None + assert doc.metadata["author"] == "Jane" + + def test_html_stripped_from_content(self): + """HTML tags are removed from content fields.""" + item = {"id": "1", "title": "T", "body": "

Clean

", "author": "A"} + doc = self.connector._item_to_document(item) + content = doc.blob.decode("utf-8") + assert "

" not in content + assert "Clean" in content + + def test_extension_is_txt(self): + """Document extension should be .txt.""" + item = {"id": "1", "title": "T", "body": "B", "author": "A"} + doc = self.connector._item_to_document(item) + assert doc.extension == ".txt" + + def test_missing_content_fields_graceful(self): + """Missing content fields produce an empty blob gracefully.""" + item = {"id": "1", "author": "A"} + doc = self.connector._item_to_document(item) + assert doc.blob == b"" + + +# ===================================================================== # +# 8. Pagination behaviour # +# ===================================================================== # + +class TestPaginationBehavior: + """Test pagination iteration with mocked HTTP responses.""" + + def test_page_pagination_increments(self): + """Page-based pagination should increment the page param.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + page1 = _mock_response({"items": [{"title": "A"}, {"title": "B"}]}) + page2 = _mock_response({"items": []}) + mock_rl.get.side_effect = [page1, page2] + + c = _make_paged_connector() + items = list(c._iter_items()) + assert len(items) == 2 + assert mock_rl.get.call_count == 2 + + def test_offset_pagination_increments(self): + """Offset-based pagination should increment offset by limit.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + page1 = _mock_response({"items": [{"title": "A"}]}) + page2 = _mock_response({"items": []}) + mock_rl.get.side_effect = [page1, page2] + + c = _make_connector( + pagination_type=PaginationType.OFFSET, + pagination_config={ + "offset_param": "offset", + "limit_param": "limit", + "limit": 10, + }, + request_delay=0, + ) + items = list(c._iter_items()) + assert len(items) == 1 + + def test_stops_on_empty_results(self): + """Pagination stops when empty items are returned.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({"items": []}) + + c = _make_paged_connector() + items = list(c._iter_items()) + assert items == [] + assert mock_rl.get.call_count == 1 + + def test_stops_when_fewer_items_than_page_size(self): + """Pagination stops when fewer items than page_size are returned.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + page1 = _mock_response({"items": [{"title": "A"}]}) + mock_rl.get.return_value = page1 + + c = _make_paged_connector( + pagination_config={"page_param": "page", "page_size": 10}, + ) + items = list(c._iter_items()) + assert len(items) == 1 + assert mock_rl.get.call_count == 1 + + def test_max_pages_cap(self): + """Pagination respects the max_pages safety cap.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response( + {"items": [{"title": "A"}, {"title": "B"}]} + ) + + c = _make_paged_connector( + max_pages=3, + pagination_config={"page_param": "page", "page_size": 2}, + ) + list(c._iter_items()) + assert mock_rl.get.call_count == 3 + + def test_request_delay_applied(self): + """request_delay should cause a sleep between pages.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + with patch("common.data_source.rest_api_connector.time.sleep") as mock_sleep: + page1 = _mock_response({"items": [{"title": "A"}, {"title": "B"}]}) + page2 = _mock_response({"items": []}) + mock_rl.get.side_effect = [page1, page2] + + c = _make_paged_connector( + pagination_config={"page_param": "page", "page_size": 2}, + ) + c.request_delay = 1.5 + list(c._iter_items()) + mock_sleep.assert_called_once_with(1.5) + + +# ===================================================================== # +# 9. Non-retriable HTTP errors # +# ===================================================================== # + +class TestNonRetriableErrors: + """Test that HTTP errors are classified correctly in _fetch_page.""" + + def test_401_raises_credential_error(self): + """401 should raise ConnectorMissingCredentialError immediately.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=401) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(ConnectorMissingCredentialError): + c._fetch_page({}) + + def test_403_raises_credential_error(self): + """403 should raise ConnectorMissingCredentialError immediately.""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=403) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(ConnectorMissingCredentialError): + c._fetch_page({}) + + def test_404_raises_validation_error(self): + """404 should raise ConnectorValidationError (no retry).""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=404) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(ConnectorValidationError, match="non-retriable"): + c._fetch_page({}) + + def test_400_raises_validation_error(self): + """400 should raise ConnectorValidationError (no retry).""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=400) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(ConnectorValidationError, match="non-retriable"): + c._fetch_page({}) + + def test_500_triggers_retry(self): + """500 should raise HTTPError (which the retry decorator catches).""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=500) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(requests.HTTPError): + c._fetch_page({}) + + def test_429_triggers_retry(self): + """429 should raise HTTPError (retriable, not ConnectorValidationError).""" + with _mocked_rest_api_requests_and_dns() as mock_rl: + mock_rl.get.return_value = _mock_response({}, status_code=429) + c = _make_connector(request_delay=0) + c.load_credentials({}) + with pytest.raises(requests.HTTPError): + c._fetch_page({}) diff --git a/test/unit_test/memory/utils/test_ob_conn_aggregation.py b/test/unit_test/memory/utils/test_ob_conn_aggregation.py index cf136eb2087..a409a5c2556 100644 --- a/test/unit_test/memory/utils/test_ob_conn_aggregation.py +++ b/test/unit_test/memory/utils/test_ob_conn_aggregation.py @@ -20,6 +20,8 @@ without requiring a real OceanBase instance or heavy dependencies. """ +import pytest + from memory.utils.aggregation_utils import aggregate_by_field @@ -53,3 +55,24 @@ def test_pre_aggregated_value_count_rows(self): ] out = aggregate_by_field(messages, "message_type_kwd") assert set(out) == {("user", 2), ("assistant", 1)} + + @pytest.mark.p2 + def test_aggregates_list_values_and_trims_whitespace(self): + messages = [ + {"id": "m1", "tags_kwd": [" alpha ", "beta", ""]}, + {"id": "m2", "tags_kwd": ["alpha", " beta "]}, + {"id": "m3", "tags_kwd": ["gamma", None, 1]}, + ] + out = aggregate_by_field(messages, "tags_kwd") + assert set(out) == {("alpha", 2), ("beta", 2), ("gamma", 1)} + + @pytest.mark.p2 + def test_ignores_non_string_and_blank_scalar_values(self): + messages = [ + {"id": "m1", "message_type_kwd": " "}, + {"id": "m2", "message_type_kwd": None}, + {"id": "m3", "message_type_kwd": 1}, + {"id": "m4", "message_type_kwd": "assistant"}, + ] + out = aggregate_by_field(messages, "message_type_kwd") + assert out == [("assistant", 1)] diff --git a/test/unit_test/rag/app/__init__.py b/test/unit_test/rag/app/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test/unit_test/rag/app/test_table_chunk_column_roles.py b/test/unit_test/rag/app/test_table_chunk_column_roles.py new file mode 100644 index 00000000000..40eed2ae5b6 --- /dev/null +++ b/test/unit_test/rag/app/test_table_chunk_column_roles.py @@ -0,0 +1,235 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License on an "AS IS" BASIS, WITHOUT WARRANTIES +# OR CONDITIONS OF ANY KIND, either express or implied. See the License +# for the specific language governing permissions and limitations under +# the License. +# + +"""Integration-style tests for rag.app.table.chunk() column roles (mocked KB + tokenizer).""" + +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +# Mock heavy modules that trigger ONNX model loading at import time +# table.py -> deepdoc.parser.figure_parser -> rag.app.picture -> OCR() +for mod in [ + "deepdoc.vision.ocr", + "deepdoc.parser.figure_parser", + "rag.app.picture", +]: + if mod not in sys.modules: + sys.modules[mod] = MagicMock() + +import warnings + +# Importing rag.app.table pulls api -> rag.llm -> deepdoc -> xgboost; xgboost may warn on +# pkg_resources in a way that breaks its compat shim unless pkg_resources loads first. +warnings.filterwarnings("ignore", message=".*pkg_resources is deprecated.*", category=UserWarning) +import pkg_resources # noqa: F401 — stabilize xgboost import during collection + +import pytest + +import common.settings as settings +from rag.app.table import chunk + +# chunk() removes columns named id, _id, index, idx — use row_id instead of id. +TEST_CSV = b"""row_id,title,content,country,category +1,Earthquake hits Turkey,A 5.8 magnitude earthquake struck Konya,Turkey,Disaster +2,Oil prices surge,Brent crude jumped 4.2 percent,Global,Economy +3,AI regulation proposed,EU unveiled a draft regulation,EU,Technology +""" + +FILENAME = "test.csv" +KB_ID = "test_kb_id" + + +def _noop_callback(*_a, **_k): + pass + + +@pytest.fixture(autouse=True) +def _es_doc_engine(monkeypatch): + monkeypatch.setattr(settings, "DOC_ENGINE_INFINITY", False) + monkeypatch.setattr(settings, "DOC_ENGINE_OCEANBASE", False) + + +@pytest.fixture(autouse=True) +def _stub_rag_tokenizer(monkeypatch): + """Avoid NLTK / infinity tokenizer deps; keep string content inspectable.""" + + def fake_tokenize(line): + return str(line) + + monkeypatch.setattr("rag.nlp.rag_tokenizer.tokenize", fake_tokenize) + monkeypatch.setattr("rag.nlp.rag_tokenizer.fine_grained_tokenize", fake_tokenize) + + +@pytest.fixture +def mock_update_kb(): + with patch("rag.app.table.KnowledgebaseService.update_parser_config") as m: + yield m + + +def _run_chunk(parser_config: dict, mock_update_kb: MagicMock): + return chunk( + FILENAME, + binary=TEST_CSV, + callback=_noop_callback, + kb_id=KB_ID, + parser_config=parser_config, + lang="Chinese", + ) + + +def test_chunk_auto_mode_all_columns_in_text_and_stored(mock_update_kb: MagicMock): + parser_config: dict = {} + chunks = _run_chunk(parser_config, mock_update_kb) + assert len(chunks) == 3 + first = chunks[0] + cww = first["content_with_weight"] + assert "Earthquake hits Turkey" in cww + assert "Konya" in cww + assert "Turkey" in cww + assert "Disaster" in cww + assert "1" in cww or "row_id" in cww + # ES path: stored typed fields for text columns include *_tks and *_raw; row_id is int -> *_long + assert "row_id_long" in first + assert "title_raw" in first and "country_raw" in first + + +def test_chunk_manual_mode_indexing_only(mock_update_kb: MagicMock): + parser_config = { + "table_column_mode": "manual", + "table_column_roles": { + "title": "indexing", + "content": "indexing", + "row_id": "metadata", + "country": "metadata", + "category": "metadata", + }, + } + chunks = _run_chunk(parser_config, mock_update_kb) + first = chunks[0] + cww = first["content_with_weight"] + assert "- title:" in cww and "Earthquake" in cww + assert "- content:" in cww and "Konya" in cww + assert "- country:" not in cww + assert "- category:" not in cww + assert "- row_id:" not in cww + # Column title/content not stored as table fields + assert "title_raw" not in first + assert "content_raw" not in first + assert "country_raw" in first and "category_raw" in first + assert "row_id_long" in first + + +def test_chunk_manual_mode_legacy_vectorize_role(mock_update_kb: MagicMock): + """Stored configs may still use role *vectorize*; chunking treats it like *indexing*.""" + parser_config = { + "table_column_mode": "manual", + "table_column_roles": { + "title": "vectorize", + "content": "indexing", + "row_id": "metadata", + "country": "metadata", + "category": "metadata", + }, + } + chunks = _run_chunk(parser_config, mock_update_kb) + first = chunks[0] + cww = first["content_with_weight"] + assert "- title:" in cww and "Earthquake" in cww + assert "- content:" in cww and "Konya" in cww + assert "- country:" not in cww + + +def test_chunk_manual_mode_metadata_only(mock_update_kb: MagicMock): + parser_config = { + "table_column_mode": "manual", + "table_column_roles": { + "title": "metadata", + "content": "metadata", + "row_id": "metadata", + "country": "metadata", + "category": "metadata", + }, + } + chunks = _run_chunk(parser_config, mock_update_kb) + first = chunks[0] + assert (first.get("content_with_weight") or "").strip() == "" + assert "country_raw" in first and "title_raw" in first + + +def test_chunk_manual_mode_both(mock_update_kb: MagicMock): + parser_config = { + "table_column_mode": "manual", + "table_column_roles": {c: "both" for c in ["title", "content", "country", "category", "row_id"]}, + } + chunks = _run_chunk(parser_config, mock_update_kb) + first = chunks[0] + cww = first["content_with_weight"] + assert "Earthquake hits Turkey" in cww + assert "Turkey" in cww + assert "Disaster" in cww + assert "row_id_long" in first + assert "title_raw" in first and "country_raw" in first + + +def test_chunk_manual_mode_partial_roles_default_to_both(mock_update_kb: MagicMock): + parser_config = { + "table_column_mode": "manual", + "table_column_roles": { + "title": "indexing", + "country": "metadata", + }, + } + chunks = _run_chunk(parser_config, mock_update_kb) + first = chunks[0] + cww = first["content_with_weight"] + assert "- title:" in cww and "Earthquake" in cww + assert "- country:" not in cww + assert "- row_id:" in cww + assert "- content:" in cww + assert "- category:" in cww + assert "title_raw" not in first + assert "country_raw" in first and "country_tks" in first + assert "content_raw" in first and "category_raw" in first + + +def test_chunk_manual_mode_raw_fields_for_es(mock_update_kb: MagicMock): + parser_config = { + "table_column_mode": "manual", + "table_column_roles": {c: "both" for c in ["title", "content", "country", "category", "row_id"]}, + } + chunks = _run_chunk(parser_config, mock_update_kb) + first = chunks[0] + for col in ("title", "content", "country", "category"): + assert f"{col}_raw" in first + assert f"{col}_tks" in first + + +def test_chunk_updates_table_column_names(mock_update_kb: MagicMock): + _run_chunk({}, mock_update_kb) + mock_update_kb.assert_called_once() + args, kwargs = mock_update_kb.call_args + assert args[0] == KB_ID + payload = args[1] + names = payload["table_column_names"] + assert names == ["row_id", "title", "content", "country", "category"] + + +def test_chunk_count_matches_row_count(mock_update_kb: MagicMock): + chunks = _run_chunk({}, mock_update_kb) + assert len(chunks) == 3 diff --git a/test/unit_test/rag/conftest.py b/test/unit_test/rag/conftest.py new file mode 100644 index 00000000000..3ca5e289e9a --- /dev/null +++ b/test/unit_test/rag/conftest.py @@ -0,0 +1,58 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Restore the real ``common.data_source`` package before importing rag unit tests. + +``test/unit_test/data_source/conftest.py`` registers a lightweight +``sys.modules["common.data_source"]`` stub so submodule imports skip the heavy +package ``__init__.py``. Pytest collection order visits ``data_source/`` before +``rag/``, so without this hook ``rag.svr.sync_data_source`` fails on +``from common.data_source import BlobStorageConnector``. +""" + +from __future__ import annotations + +import importlib +import sys +import types + + +def _restore_common_data_source_package() -> None: + mod = sys.modules.get("common.data_source") + if mod is None: + return + # Stub is a bare types.ModuleType with __path__ and no __file__; real package has __init__.py. + if getattr(mod, "__file__", None) is not None: + return + if not isinstance(mod, types.ModuleType) or not getattr(mod, "__path__", None): + return + keys = [ + key + for key in sys.modules + if key == "common.data_source" or key.startswith("common.data_source.") + ] + for key in keys: + del sys.modules[key] + importlib.invalidate_caches() + try: + importlib.import_module("common.data_source") + except Exception as exc: # pragma: no cover + raise ImportError( + "conftest: failed to restore real common.data_source package" + ) from exc + + +_restore_common_data_source_package() diff --git a/test/unit_test/rag/prompts/test_generator_message_fit_in.py b/test/unit_test/rag/prompts/test_generator_message_fit_in.py new file mode 100644 index 00000000000..925c203e68a --- /dev/null +++ b/test/unit_test/rag/prompts/test_generator_message_fit_in.py @@ -0,0 +1,151 @@ +# +# Copyright 2024 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib.util +import sys +from pathlib import Path +from types import ModuleType, SimpleNamespace + +import pytest + + +class _CharEncoder: + @staticmethod + def encode(text): + return list(text) + + @staticmethod + def decode(tokens): + return "".join(tokens) + + +def _load_generator_module(monkeypatch): + repo_root = Path(__file__).resolve().parents[4] + + json_repair = ModuleType("json_repair") + json_repair.repair_json = lambda text, **_kwargs: text + monkeypatch.setitem(sys.modules, "json_repair", json_repair) + + common_pkg = ModuleType("common") + common_pkg.__path__ = [str(repo_root / "common")] + monkeypatch.setitem(sys.modules, "common", common_pkg) + + misc_utils = ModuleType("common.misc_utils") + misc_utils.hash_str2int = lambda value, _mod=500: 0 + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils) + + constants = ModuleType("common.constants") + constants.TAG_FLD = "tag" + monkeypatch.setitem(sys.modules, "common.constants", constants) + + token_utils = ModuleType("common.token_utils") + token_utils.encoder = _CharEncoder() + token_utils.num_tokens_from_string = lambda text: len(text) + monkeypatch.setitem(sys.modules, "common.token_utils", token_utils) + + rag_pkg = ModuleType("rag") + rag_pkg.__path__ = [str(repo_root / "rag")] + monkeypatch.setitem(sys.modules, "rag", rag_pkg) + + rag_nlp = ModuleType("rag.nlp") + rag_nlp.rag_tokenizer = SimpleNamespace(tokenize=lambda text: text.split()) + monkeypatch.setitem(sys.modules, "rag.nlp", rag_nlp) + + rag_prompts_pkg = ModuleType("rag.prompts") + rag_prompts_pkg.__path__ = [str(repo_root / "rag" / "prompts")] + monkeypatch.setitem(sys.modules, "rag.prompts", rag_prompts_pkg) + + template_mod = ModuleType("rag.prompts.template") + template_mod.load_prompt = lambda *_args, **_kwargs: "" + monkeypatch.setitem(sys.modules, "rag.prompts.template", template_mod) + + spec = importlib.util.spec_from_file_location( + "rag.prompts.generator", repo_root / "rag" / "prompts" / "generator.py" + ) + module = importlib.util.module_from_spec(spec) + monkeypatch.setitem(sys.modules, "rag.prompts.generator", module) + spec.loader.exec_module(module) + return module + + +@pytest.mark.p1 +def test_message_fit_in_truncates_user_message_by_system_token_budget(monkeypatch): + generator = _load_generator_module(monkeypatch) + monkeypatch.setattr(generator, "num_tokens_from_string", lambda text: len(text)) + monkeypatch.setattr(generator, "encoder", _CharEncoder()) + + messages = [ + {"role": "system", "content": "1234"}, + {"role": "user", "content": "abcdefghij"}, + ] + + used_tokens, trimmed = generator.message_fit_in(messages, max_length=8) + + assert used_tokens == 8 + assert trimmed[0]["content"] == "1234" + assert trimmed[-1]["content"] == "abcd" + + +@pytest.mark.p1 +def test_message_fit_in_handles_zero_token_messages(monkeypatch): + generator = _load_generator_module(monkeypatch) + monkeypatch.setattr(generator, "num_tokens_from_string", lambda _text: 0) + monkeypatch.setattr(generator, "encoder", _CharEncoder()) + + messages = [ + {"role": "system", "content": ""}, + {"role": "user", "content": ""}, + ] + + used_tokens, trimmed = generator.message_fit_in(messages, max_length=0) + + assert used_tokens == 0 + assert trimmed == messages + + +@pytest.mark.p1 +def test_message_fit_in_clamps_negative_slice_lengths(monkeypatch): + generator = _load_generator_module(monkeypatch) + monkeypatch.setattr(generator, "num_tokens_from_string", lambda text: len(text)) + monkeypatch.setattr(generator, "encoder", _CharEncoder()) + + messages = [ + {"role": "system", "content": "1234"}, + {"role": "user", "content": "abcdefghij"}, + ] + + used_tokens, trimmed = generator.message_fit_in(messages, max_length=2) + + assert used_tokens == 2 + assert trimmed[0]["content"] == "12" + assert trimmed[-1]["content"] == "" + + +@pytest.mark.p1 +def test_message_fit_in_clamps_dominant_last_message_to_budget(monkeypatch): + generator = _load_generator_module(monkeypatch) + monkeypatch.setattr(generator, "num_tokens_from_string", lambda text: len(text)) + monkeypatch.setattr(generator, "encoder", _CharEncoder()) + + messages = [ + {"role": "system", "content": "s" * 41}, + {"role": "user", "content": "abcdefghij"}, + ] + + used_tokens, trimmed = generator.message_fit_in(messages, max_length=8) + + assert used_tokens == 8 + assert trimmed[0]["content"] == "" + assert trimmed[-1]["content"] == "abcdefgh" diff --git a/test/unit_test/rag/svr/__init__.py b/test/unit_test/rag/svr/__init__.py new file mode 100644 index 00000000000..895bd9cee4c --- /dev/null +++ b/test/unit_test/rag/svr/__init__.py @@ -0,0 +1 @@ +# Unit tests for rag/svr diff --git a/test/unit_test/rag/svr/test_table_column_roles_helpers.py b/test/unit_test/rag/svr/test_table_column_roles_helpers.py new file mode 100644 index 00000000000..fe4eed27fe9 --- /dev/null +++ b/test/unit_test/rag/svr/test_table_column_roles_helpers.py @@ -0,0 +1,132 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for ES table metadata helpers (rag.utils.table_es_metadata).""" + +from rag.utils.table_es_metadata import ( + _es_field_value_to_doc_metadata, + _es_raw_field_key_from_typed, + _probe_es_typed_key_for_column, + _resolve_es_chunk_field_key, + merge_table_parser_config_from_kb, + table_parser_strip_doc_metadata_keys, +) + + +class TestProbeEsTypedKeyForColumn: + def test_probe_es_typed_key_tks(self): + chunk = {"country_tks": "tok", "other": 1} + assert _probe_es_typed_key_for_column("country", chunk) == "country_tks" + + def test_probe_es_typed_key_dt(self): + chunk = {"published_date_dt": "2024-01-01"} + assert _probe_es_typed_key_for_column("published_date", chunk) == "published_date_dt" + + def test_probe_es_typed_key_raw(self): + # Only raw field present (no _tks) — probe returns the raw key + chunk = {"country_raw": "Brazil"} + assert _probe_es_typed_key_for_column("country", chunk) == "country_raw" + + def test_probe_es_typed_key_no_match(self): + chunk = {"other_kwd": "x"} + assert _probe_es_typed_key_for_column("country", chunk) is None + + def test_probe_es_typed_key_empty_col(self): + assert _probe_es_typed_key_for_column("", {"a_tks": "x"}) is None + assert _probe_es_typed_key_for_column(None, {"a_tks": "x"}) is None + + +class TestResolveEsChunkFieldKey: + def test_resolve_es_field_empty_fieldmap_uses_probe(self): + sample = {"country_tks": ["tok"]} + tk, src = _resolve_es_chunk_field_key("country", {}, sample) + assert tk == "country_tks" + assert src == "probe" + + def test_resolve_es_field_fieldmap_priority(self): + fm = {"guojia_tks": "country"} + sample = {"guojia_tks": ["x"], "country_tks": ["y"]} + tk, src = _resolve_es_chunk_field_key("country", fm, sample) + assert tk == "guojia_tks" + assert src == "field_map" + + +class TestEsRawFieldKeyFromTyped: + def test_es_raw_field_key_from_tks(self): + assert _es_raw_field_key_from_typed("country_tks") == "country_raw" + + def test_es_raw_field_key_from_non_tks(self): + assert _es_raw_field_key_from_typed("country_dt") is None + + def test_es_raw_field_key_from_none(self): + assert _es_raw_field_key_from_typed(None) is None + + +class TestEsFieldValueToDocMetadata: + def test_es_field_value_string(self): + assert _es_field_value_to_doc_metadata("Brazil", from_tks_fallback=False) == "Brazil" + + def test_es_field_value_list_joined(self): + assert ( + _es_field_value_to_doc_metadata(["hello", "world"], from_tks_fallback=True) + == "hello world" + ) + + def test_es_field_value_empty(self): + assert _es_field_value_to_doc_metadata(None, from_tks_fallback=True) is None + assert _es_field_value_to_doc_metadata("", from_tks_fallback=True) is None + assert _es_field_value_to_doc_metadata([], from_tks_fallback=True) is None + + +class TestMergeTableParserConfigFromKb: + def test_merge_table_parser_config_from_kb(self): + task = { + "parser_id": "table", + "parser_config": {"llm_id": "x"}, + "kb_parser_config": { + "table_column_mode": "manual", + "table_column_roles": {"a": "metadata"}, + "table_column_names": ["a", "b"], + }, + } + merged = merge_table_parser_config_from_kb(task) + assert merged["table_column_mode"] == "manual" + assert merged["table_column_roles"] == {"a": "metadata"} + assert merged["table_column_names"] == ["a", "b"] + assert merged["llm_id"] == "x" + + def test_merge_table_parser_config_auto_default(self): + task = { + "parser_id": "table", + "parser_config": {"foo": 1}, + "kb_parser_config": {"llm_id": "abc"}, + } + merged = merge_table_parser_config_from_kb(task) + assert merged == {"foo": 1} # no table_* keys copied from kb without kb_parser_config keys + + +class TestTableParserStripDocMetadataKeys: + def test_uses_table_column_names_when_present(self): + eff = {"table_column_names": ["Region", " SKU "]} + assert table_parser_strip_doc_metadata_keys(eff) == frozenset({"Region", "SKU"}) + + def test_falls_back_to_role_keys_when_no_names(self): + eff = {"table_column_roles": {"x": "metadata", "y": "indexing"}} + assert table_parser_strip_doc_metadata_keys(eff) == frozenset({"x", "y"}) + + def test_empty_names_falls_back_to_roles(self): + eff = {"table_column_names": [], "table_column_roles": {"only": "both"}} + assert table_parser_strip_doc_metadata_keys(eff) == frozenset({"only"}) diff --git a/test/unit_test/rag/svr/test_table_metadata_aggregation.py b/test/unit_test/rag/svr/test_table_metadata_aggregation.py new file mode 100644 index 00000000000..59d2f7ee472 --- /dev/null +++ b/test/unit_test/rag/svr/test_table_metadata_aggregation.py @@ -0,0 +1,230 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Unit tests for aggregate_table_manual_doc_metadata.""" + +import pytest + +from rag.utils.table_es_metadata import aggregate_table_manual_doc_metadata, merge_table_parser_config_from_kb + + +@pytest.fixture +def es_engine(monkeypatch): + monkeypatch.setattr("rag.utils.table_es_metadata.settings.DOC_ENGINE_INFINITY", False) + monkeypatch.setattr("rag.utils.table_es_metadata.settings.DOC_ENGINE_OCEANBASE", False) + + +@pytest.fixture +def infinity_engine(monkeypatch): + monkeypatch.setattr("rag.utils.table_es_metadata.settings.DOC_ENGINE_INFINITY", True) + monkeypatch.setattr("rag.utils.table_es_metadata.settings.DOC_ENGINE_OCEANBASE", False) + + +def _table_task(**kb_extra): + return { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": { + "table_column_mode": "manual", + "table_column_roles": {"country": "metadata", "category": "metadata"}, + "table_column_names": ["country", "category"], + "field_map": { + "country_tks": "country", + "category_tks": "category", + }, + **kb_extra, + }, + } + + +class TestAggregateTableManualDocMetadata: + def test_aggregate_manual_mode_happy_path(self, es_engine): + task = _table_task() + chunks = [ + { + "country_raw": "Brazil", + "category_raw": "Economy", + "country_tks": "x", + "category_tks": "y", + }, + { + "country_raw": "Turkey", + "category_raw": "Disaster", + "country_tks": "x", + "category_tks": "y", + }, + { + "country_raw": "Brazil", + "category_raw": "Economy", + "country_tks": "x", + "category_tks": "y", + }, + ] + out = aggregate_table_manual_doc_metadata(chunks, task) + assert out["country"] == ["Brazil", "Turkey"] + assert out["category"] == ["Economy", "Disaster"] + + def test_aggregate_auto_mode_returns_empty(self, es_engine): + task = { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": { + "table_column_mode": "auto", + "table_column_roles": {"country": "metadata"}, + }, + } + assert aggregate_table_manual_doc_metadata([{"country_tks": "x"}], task) == {} + + def test_aggregate_no_mode_returns_empty(self, es_engine): + task = { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": { + "table_column_roles": {"country": "metadata"}, + }, + } + assert aggregate_table_manual_doc_metadata([{}], task) == {} + + def test_aggregate_no_metadata_columns(self, es_engine): + task = { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": { + "table_column_mode": "manual", + "table_column_roles": {"country": "indexing"}, + "table_column_names": ["country"], + }, + } + assert aggregate_table_manual_doc_metadata([{"country_tks": "x"}], task) == {} + + def test_aggregate_prefers_raw_over_tks(self, es_engine): + task = _table_task() + task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"} + task["kb_parser_config"]["table_column_names"] = ["country"] + chunks = [{"country_raw": "Brazil", "country_tks": ["brazil"]}] + out = aggregate_table_manual_doc_metadata(chunks, task) + assert out == {"country": ["Brazil"]} + + def test_aggregate_tks_fallback(self, es_engine): + task = _table_task() + task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"} + task["kb_parser_config"]["table_column_names"] = ["country"] + chunks = [{"country_tks": ["brazil"]}] + out = aggregate_table_manual_doc_metadata(chunks, task) + assert out == {"country": ["brazil"]} + + def test_aggregate_partial_roles_defaults_to_both(self, es_engine): + task = { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": { + "table_column_mode": "manual", + "table_column_roles": {"country": "indexing"}, + "table_column_names": ["country", "city"], + "field_map": {"city_tks": "city"}, + }, + } + chunks = [{"city_raw": "SP", "city_tks": "t", "country_tks": "x"}] + out = aggregate_table_manual_doc_metadata(chunks, task) + assert out == {"city": ["SP"]} + assert "country" not in out + + def test_aggregate_empty_roles_all_columns_both(self, es_engine): + task = { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": { + "table_column_mode": "manual", + "table_column_roles": {}, + "table_column_names": ["country", "city"], + "field_map": {"country_tks": "country", "city_tks": "city"}, + }, + } + chunks = [ + {"country_raw": "BR", "city_raw": "SP", "country_tks": "x", "city_tks": "y"}, + ] + out = aggregate_table_manual_doc_metadata(chunks, task) + assert "country" in out and "city" in out + + def test_aggregate_deduplicates_values(self, es_engine): + task = _table_task() + task["kb_parser_config"]["table_column_roles"] = {"country": "metadata"} + task["kb_parser_config"]["table_column_names"] = ["country"] + chunks = [ + {"country_raw": "US", "country_tks": "x"}, + {"country_raw": "UK", "country_tks": "y"}, + {"country_raw": "US", "country_tks": "x"}, + ] + out = aggregate_table_manual_doc_metadata(chunks, task) + assert out["country"] == ["US", "UK"] + + def test_aggregate_kb_reload_field_map(self, es_engine, monkeypatch): + from unittest.mock import MagicMock + + class MockKBS: + @staticmethod + def get_by_id(kid): + kb = MagicMock() + kb.parser_config = {"field_map": {"country_tks": "country"}} + return True, kb + + monkeypatch.setattr( + "rag.utils.table_es_metadata._knowledgebase_service_cls", + lambda: MockKBS, + ) + + task = { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": { + "table_column_mode": "manual", + "table_column_roles": {"country": "metadata"}, + "table_column_names": ["country"], + }, + "kb_id": "kb-1", + } + chunks = [{"country_raw": "X", "country_tks": "t"}] + out = aggregate_table_manual_doc_metadata(chunks, task) + assert out == {"country": ["X"]} + + def test_merge_infinity_chunk_data(self, infinity_engine): + task = { + "parser_id": "table", + "parser_config": {}, + "kb_parser_config": { + "table_column_mode": "manual", + "table_column_roles": {"country": "both"}, + "table_column_names": ["country"], + }, + } + chunks = [ + {"chunk_data": {"country": "US"}}, + {"chunk_data": {"country": "UK"}}, + ] + out = aggregate_table_manual_doc_metadata(chunks, task) + assert out == {"country": ["US", "UK"]} + + +class TestMergeTableParserConfigFromKbExtra: + """Merge tests also covered in helpers file; keep one explicit case for aggregation module.""" + + def test_merge_preserves_parser_config_when_parser_not_table(self): + task = { + "parser_id": "naive", + "parser_config": {"a": 1}, + "kb_parser_config": {"table_column_mode": "manual"}, + } + assert merge_table_parser_config_from_kb(task) == {"a": 1} diff --git a/test/unit_test/rag/test_raptor_psi_tree_builder.py b/test/unit_test/rag/test_raptor_psi_tree_builder.py new file mode 100644 index 00000000000..1d0af20d960 --- /dev/null +++ b/test/unit_test/rag/test_raptor_psi_tree_builder.py @@ -0,0 +1,375 @@ +# +# Copyright 2026 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import importlib +import sys +import types + +import pytest + +np = pytest.importorskip("numpy") + +from api.utils.validation_utils import RaptorConfig +from pydantic import ValidationError + + +@pytest.fixture() +def raptor_module(monkeypatch): + class TaskCanceledException(Exception): + pass + + class DummyLimiter: + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return False + + class DummyGaussianMixture: + def __init__(self, *args, **kwargs): + pass + + def fit(self, embeddings): + return self + + def bic(self, embeddings): + return 0 + + def predict_proba(self, embeddings): + return np.ones((len(embeddings), 1)) + + class DummyAgglomerativeClustering: + def __init__(self, n_clusters=None, distance_threshold=None, compute_distances=False, linkage="ward"): + self.n_clusters = n_clusters + self.distance_threshold = distance_threshold + self.compute_distances = compute_distances + self.linkage = linkage + self.distances_ = np.array([0.1, 0.2, 1.0]) + + def fit(self, embeddings): + self.labels_ = self.fit_predict(embeddings) + return self + + def fit_predict(self, embeddings): + if self.n_clusters is None: + return np.zeros(len(embeddings), dtype=int) + return np.array([idx % self.n_clusters for idx in range(len(embeddings))]) + + class DummyUMAP: + def __init__(self, *args, **kwargs): + pass + + def fit_transform(self, embeddings): + raise AssertionError("Psi tree builder must use original embeddings, not UMAP") + + sklearn_module = types.ModuleType("sklearn") + mixture_module = types.ModuleType("sklearn.mixture") + mixture_module.GaussianMixture = DummyGaussianMixture + cluster_module = types.ModuleType("sklearn.cluster") + cluster_module.AgglomerativeClustering = DummyAgglomerativeClustering + umap_module = types.ModuleType("umap") + umap_module.UMAP = DummyUMAP + task_service_module = types.ModuleType("api.db.services.task_service") + task_service_module.has_canceled = lambda task_id: False + connection_utils_module = types.ModuleType("common.connection_utils") + connection_utils_module.timeout = lambda seconds: lambda fn: fn + exceptions_module = types.ModuleType("common.exceptions") + exceptions_module.TaskCanceledException = TaskCanceledException + token_utils_module = types.ModuleType("common.token_utils") + token_utils_module.truncate = lambda text, max_len: text[:max_len] + graphrag_utils_module = types.ModuleType("rag.graphrag.utils") + graphrag_utils_module.chat_limiter = DummyLimiter() + graphrag_utils_module.get_embed_cache = lambda *args, **kwargs: None + graphrag_utils_module.get_llm_cache = lambda *args, **kwargs: None + graphrag_utils_module.set_embed_cache = lambda *args, **kwargs: None + graphrag_utils_module.set_llm_cache = lambda *args, **kwargs: None + + async def thread_pool_exec(fn, *args, **kwargs): + return fn(*args, **kwargs) + + misc_utils_module = types.ModuleType("common.misc_utils") + misc_utils_module.thread_pool_exec = thread_pool_exec + + monkeypatch.setitem(sys.modules, "sklearn", sklearn_module) + monkeypatch.setitem(sys.modules, "sklearn.mixture", mixture_module) + monkeypatch.setitem(sys.modules, "sklearn.cluster", cluster_module) + monkeypatch.setitem(sys.modules, "umap", umap_module) + monkeypatch.setitem(sys.modules, "api.db.services.task_service", task_service_module) + monkeypatch.setitem(sys.modules, "common.connection_utils", connection_utils_module) + monkeypatch.setitem(sys.modules, "common.exceptions", exceptions_module) + monkeypatch.setitem(sys.modules, "common.token_utils", token_utils_module) + monkeypatch.setitem(sys.modules, "rag.graphrag.utils", graphrag_utils_module) + monkeypatch.setitem(sys.modules, "common.misc_utils", misc_utils_module) + monkeypatch.delitem(sys.modules, "rag.raptor", raising=False) + module = importlib.import_module("rag.raptor") + yield module + monkeypatch.delitem(sys.modules, "rag.raptor", raising=False) + + +class FakeChatModel: + llm_name = "fake-chat" + max_length = 4096 + + def __init__(self): + self.calls = [] + + async def async_chat(self, system, history, gen_conf): + self.calls.append(history[0]["content"]) + return f"summary-{len(self.calls)}" + + +class FakeEmbeddingModel: + llm_name = "fake-embedding" + + def encode(self, texts): + embeddings = [] + for text in texts: + checksum = sum(ord(ch) for ch in text) + embeddings.append(np.array([len(text), checksum % 17 + 1], dtype=float)) + return embeddings, len(texts) + + +_DEFAULT_TREE_BUILDER = object() + + +def _make_raptor(raptor_module, max_cluster=64, tree_builder=_DEFAULT_TREE_BUILDER, **kwargs): + if tree_builder is _DEFAULT_TREE_BUILDER: + kwargs["tree_builder"] = raptor_module.PSI_TREE_BUILDER + else: + kwargs["tree_builder"] = tree_builder + return raptor_module.RecursiveAbstractiveProcessing4TreeOrganizedRetrieval( + max_cluster, + FakeChatModel(), + FakeEmbeddingModel(), + "{cluster_content}", + max_token=32, + threshold=0.1, + **kwargs, + ) + + +def _chunks(): + return [ + ("alpha first", np.array([1.0, 0.0])), + ("alpha second", np.array([0.99, 0.01])), + ("alpha third", np.array([0.98, 0.02])), + ] + + +def test_default_tree_builder_remains_original_raptor(raptor_module): + raptor = _make_raptor(raptor_module, tree_builder=None) + + assert raptor._tree_builder == raptor_module.RAPTOR_TREE_BUILDER + + +def test_unknown_tree_builder_is_rejected(raptor_module): + with pytest.raises(ValueError, match="Unsupported RAPTOR tree builder"): + _make_raptor(raptor_module, tree_builder="ahc") + + +def test_raptor_config_accepts_hidden_psi_tree_builder(): + assert RaptorConfig().tree_builder == "raptor" + assert RaptorConfig().clustering_method == "gmm" + assert RaptorConfig(clustering_method="ahc").clustering_method == "ahc" + assert RaptorConfig(tree_builder="psi").tree_builder == "psi" + + with pytest.raises(ValidationError): + RaptorConfig(tree_builder="ahc") + with pytest.raises(ValidationError): + RaptorConfig(clustering_method="psi") + + +def test_ahc_clustering_method_is_supported_in_original_tree_builder(raptor_module): + raptor = _make_raptor(raptor_module, tree_builder=raptor_module.RAPTOR_TREE_BUILDER, clustering_method="ahc") + + labels = raptor._get_clusters_ahc(np.array([[0.0, 0.0], [0.1, 0.0], [10.0, 10.0], [10.1, 10.0]])) + + assert raptor._tree_builder == raptor_module.RAPTOR_TREE_BUILDER + assert raptor._clustering_method == "ahc" + assert len(labels) == 4 + + +def test_unknown_clustering_method_is_rejected(raptor_module): + with pytest.raises(ValueError, match="Unsupported RAPTOR clustering method"): + _make_raptor(raptor_module, clustering_method="psi") + + +def test_psi_tree_builder_ranks_all_leaf_pairs_by_original_cosine_similarity(raptor_module): + raptor = _make_raptor(raptor_module) + leaves = [ + raptor_module._PsiTreeNode(index=0, embedding=np.array([1.0, 0.0])), + raptor_module._PsiTreeNode(index=1, embedding=np.array([0.0, 1.0])), + raptor_module._PsiTreeNode(index=2, embedding=np.array([0.99, 0.01])), + raptor_module._PsiTreeNode(index=3, embedding=np.array([-1.0, 0.0])), + ] + + ranked_pairs = raptor._rank_leaf_pairs(leaves) + + assert len(ranked_pairs) == 6 + assert tuple(ranked_pairs[0]) == (2, 0) + + +def test_psi_tree_builder_uses_cosine_similarity_not_vector_magnitude(raptor_module): + raptor = _make_raptor(raptor_module) + leaves = [ + raptor_module._PsiTreeNode(index=0, embedding=np.array([100.0, 0.0])), + raptor_module._PsiTreeNode(index=1, embedding=np.array([1.0, 1.0])), + raptor_module._PsiTreeNode(index=2, embedding=np.array([0.1, 0.0])), + ] + + ranked_pairs = raptor._rank_leaf_pairs(leaves) + + assert tuple(ranked_pairs[0]) == (2, 0) + + +def test_psi_tree_builder_handles_zero_vectors_in_cosine_ranking(raptor_module): + raptor = _make_raptor(raptor_module) + leaves = [ + raptor_module._PsiTreeNode(index=0, embedding=np.array([0.0, 0.0])), + raptor_module._PsiTreeNode(index=1, embedding=np.array([1.0, 0.0])), + raptor_module._PsiTreeNode(index=2, embedding=np.array([0.9, 0.1])), + ] + + ranked_pairs = raptor._rank_leaf_pairs(leaves) + + assert tuple(ranked_pairs[0]) == (2, 1) + + +def test_psi_tree_builder_collapses_leaf_into_ranked_pair_parent(raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=64) + + root, leaves = raptor._build_psi_structure(_chunks()) + + assert len(root.children) == 3 + assert {child.index for child in root.children} == {0, 1, 2} + assert all(leaf.parent is root for leaf in leaves) + + +def test_psi_tree_builder_collapses_leaf_at_matching_rank(monkeypatch, raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=64) + chunks = [ + ("node 0", np.array([1.0, 0.0])), + ("node 1", np.array([0.9, 0.1])), + ("node 2", np.array([-1.0, 0.0])), + ("node 3", np.array([-0.9, -0.1])), + ("node 4", np.array([0.8, 0.2])), + ] + monkeypatch.setattr( + raptor, + "_rank_leaf_pairs", + lambda _leaves: np.array([[0, 1], [2, 3], [0, 2], [4, 0]]), + ) + + root, leaves = raptor._build_psi_structure(chunks) + + assert leaves[4].parent is leaves[0].parent + assert leaves[4].parent is not root + assert len(root.children) == 2 + + +def test_psi_union_find_clamps_out_of_bounds_parent_rank(caplog, raptor_module): + union_find = raptor_module._PsiUnionFind(2) + union_find._node_ids[1] = [1] + union_find._rank[0] = 2 + + with caplog.at_level("WARNING"): + union_find._build(0, 1, insert_point=1) + + assert union_find.tree[0] == 1 + assert "rank index" in caplog.text + + +def test_psi_tree_builder_rebalances_nodes_over_max_children(raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=2) + + root, _ = raptor._build_psi_structure(_chunks()) + + assert all(len(node.children) <= 2 for node in raptor._iter_nodes(root)) + assert len(root.children) == 2 + assert any(child.children for child in root.children) + + +def test_psi_tree_builder_uses_bucketed_structure_for_large_inputs(monkeypatch, raptor_module): + chunks = [(f"node {idx}", np.array([float(idx), float(idx % 3 + 1)])) for idx in range(8)] + raptor = _make_raptor( + raptor_module, + max_cluster=3, + psi_exact_max_leaves=3, + psi_bucket_size=2, + ) + ranked_sizes = [] + original_rank = raptor._rank_leaf_pairs + + def track_rank(nodes): + ranked_sizes.append(len(nodes)) + return original_rank(nodes) + + monkeypatch.setattr(raptor, "_rank_leaf_pairs", track_rank) + + root, leaves = raptor._build_psi_structure(chunks) + + assert len(leaves) == len(chunks) + assert all(leaf.parent is not None for leaf in leaves) + assert all(len(node.children) <= 3 for node in raptor._iter_nodes(root)) + assert max(ranked_sizes) <= 3 + + +@pytest.mark.asyncio +async def test_psi_tree_builder_materializes_rebalanced_summary_layers_without_umap(monkeypatch, raptor_module): + def fail_umap(*args, **kwargs): + raise AssertionError("Psi tree builder must use original embeddings, not UMAP") + + monkeypatch.setattr(raptor_module.umap, "UMAP", fail_umap) + raptor = _make_raptor(raptor_module, max_cluster=2) + + chunks, layers = await raptor(_chunks(), random_state=0) + + assert len(chunks) == 5 + assert layers == [(0, 3), (3, 4), (4, 5)] + assert [chunk[0] for chunk in chunks[3:]] == ["summary-1", "summary-2"] + + +@pytest.mark.asyncio +async def test_psi_tree_builder_skips_failed_node_summary(monkeypatch, raptor_module): + raptor = _make_raptor(raptor_module, max_cluster=2) + + async def fail_summary(*args, **kwargs): + return None + + monkeypatch.setattr(raptor, "_summarize_texts", fail_summary) + + chunks, layers = await raptor(_chunks(), random_state=0) + + assert len(chunks) == 3 + assert [chunk[0] for chunk in chunks] == [chunk[0] for chunk in _chunks()] + assert layers == [(0, 3)] + + +@pytest.mark.asyncio +async def test_original_raptor_stops_when_transient_summary_fails(monkeypatch, raptor_module): + raptor = _make_raptor(raptor_module, tree_builder=raptor_module.RAPTOR_TREE_BUILDER) + + async def fail_summary(*args, **kwargs): + return None + + monkeypatch.setattr(raptor, "_summarize_texts", fail_summary) + + input_chunks = _chunks()[:2] + chunks, layers = await raptor(input_chunks, random_state=0) + + assert len(chunks) == 2 + assert [chunk[0] for chunk in chunks] == [chunk[0] for chunk in input_chunks] + assert layers == [(0, 2)] diff --git a/test/unit_test/rag/test_sync_data_source.py b/test/unit_test/rag/test_sync_data_source.py index be9d89372a3..8bb5e4cd437 100644 --- a/test/unit_test/rag/test_sync_data_source.py +++ b/test/unit_test/rag/test_sync_data_source.py @@ -133,7 +133,53 @@ def _patch_common_dependencies(monkeypatch): @pytest.mark.anyio @pytest.mark.p2 -async def test_run_task_logic_cleans_up_for_empty_snapshot(monkeypatch): +async def test_run_task_logic_skips_empty_sync_batches(monkeypatch): + _patch_common_dependencies(monkeypatch) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "increase_docs", + lambda *_args, **_kwargs: pytest.fail("increase_docs should not be called for empty batches"), + ) + monkeypatch.setattr( + sync_data_source.KnowledgebaseService, + "get_by_id", + lambda *_args, **_kwargs: pytest.fail("get_by_id should not be called for empty batches"), + ) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "duplicate_and_parse", + lambda *_args, **_kwargs: pytest.fail("duplicate_and_parse should not be called for empty batches"), + ) + + await _FakeSync(iter(([],)))._run_task_logic(_make_task()) + + +@pytest.mark.anyio +@pytest.mark.p2 +async def test_run_task_logic_skips_multiple_empty_sync_batches(monkeypatch): + _patch_common_dependencies(monkeypatch) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "increase_docs", + lambda *_args, **_kwargs: pytest.fail("increase_docs should not be called for empty batches"), + ) + monkeypatch.setattr( + sync_data_source.KnowledgebaseService, + "get_by_id", + lambda *_args, **_kwargs: pytest.fail("get_by_id should not be called for empty batches"), + ) + monkeypatch.setattr( + sync_data_source.SyncLogsService, + "duplicate_and_parse", + lambda *_args, **_kwargs: pytest.fail("duplicate_and_parse should not be called for empty batches"), + ) + + await _FakeSync(iter(([], [],)))._run_task_logic(_make_task()) + + +@pytest.mark.anyio +@pytest.mark.p2 +async def test_run_prune_task_logic_cleans_up_for_empty_snapshot(monkeypatch): cleanup_calls = [] _patch_common_dependencies(monkeypatch) @@ -148,7 +194,14 @@ def _fake_cleanup(*args, **kwargs): _fake_cleanup, ) - await _FakeSync((iter(()), []))._run_task_logic(_make_task()) + task = {**_make_task(), "task_type": sync_data_source.ConnectorTaskType.PRUNE} + sync = _FakeSync(iter(())) + sync.conf["sync_deleted_files"] = True + sync.connector = types.SimpleNamespace( + retrieve_all_slim_docs_perm_sync=lambda: iter(([],)) + ) + + await sync._run_task_logic(task) assert cleanup_calls == [ ( @@ -166,7 +219,7 @@ def _fake_cleanup(*args, **kwargs): @pytest.mark.anyio @pytest.mark.p2 -async def test_run_task_logic_cleans_up_for_non_empty_snapshot(monkeypatch): +async def test_run_prune_task_logic_cleans_up_for_non_empty_snapshot(monkeypatch): cleanup_calls = [] _patch_common_dependencies(monkeypatch) @@ -182,7 +235,14 @@ def _fake_cleanup(*args, **kwargs): ) file_list = [types.SimpleNamespace(id="doc-1")] - await _FakeSync((iter(()), file_list))._run_task_logic(_make_task()) + task = {**_make_task(), "task_type": sync_data_source.ConnectorTaskType.PRUNE} + sync = _FakeSync(iter(())) + sync.conf["sync_deleted_files"] = True + sync.connector = types.SimpleNamespace( + retrieve_all_slim_docs_perm_sync=lambda: iter((file_list,)) + ) + + await sync._run_task_logic(task) assert cleanup_calls == [ ( @@ -285,12 +345,13 @@ async def test_rdbms_generate_keeps_deleted_file_snapshot_without_timestamp_colu } ) - document_generator, file_list = await sync._generate(task) + document_generator = await sync._generate(task) connector = _FakeRDBMSConnector.instance assert connector is not None assert connector.load_from_state_called is True assert connector.load_from_cursor_range_called is False + file_list = sync._collect_prune_snapshot(task) assert connector.retrieve_all_slim_docs_perm_sync_called is True assert file_list is not None assert [doc.id for doc in file_list] == ["row-1"] @@ -447,14 +508,15 @@ async def test_dropbox_generate_returns_snapshot_when_sync_deleted_enabled(monke } ) - document_generator, file_list = await sync._generate(task) + document_generator = await sync._generate(task) connector = _FakeDropboxConnector.instance assert list(document_generator) == [["poll-sync"]] + file_list = sync._collect_prune_snapshot(task) assert [doc.id for doc in file_list] == ["dropbox:id-1", "dropbox:id-2"] assert connector.credentials == {"dropbox_access_token": "token-1"} assert connector.retrieve_all_slim_docs_perm_sync_called is True - assert connector.snapshot_called_before_poll is True + assert connector.snapshot_called_before_poll is False assert connector.poll_source_call[0] == poll_start.timestamp() assert connector.poll_source_call[1] >= poll_start.timestamp() @@ -477,11 +539,12 @@ async def test_dropbox_generate_skips_snapshot_for_full_reindex(monkeypatch): } ) - document_generator, file_list = await sync._generate(task) + document_generator = await sync._generate(task) connector = _FakeDropboxConnector.instance assert list(document_generator) == [["full-sync"]] - assert file_list is None assert connector.load_from_state_called is True - assert connector.retrieve_all_slim_docs_perm_sync_called is False + file_list = sync._collect_prune_snapshot(task) + assert [doc.id for doc in file_list] == ["dropbox:id-1", "dropbox:id-2"] + assert connector.retrieve_all_slim_docs_perm_sync_called is True assert connector.poll_source_called is False diff --git a/test/unit_test/rag/utils/test_opensearch_doc_meta.py b/test/unit_test/rag/utils/test_opensearch_doc_meta.py new file mode 100644 index 00000000000..ead97f6f8be --- /dev/null +++ b/test/unit_test/rag/utils/test_opensearch_doc_meta.py @@ -0,0 +1,288 @@ +# +# Copyright 2025 The InfiniFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +""" +Unit tests for the document-metadata helpers added to OSConnection. + +Covers issue #14570: PATCH /api/v1/datasets/{ds}/documents/{doc} with +{"meta_fields": {...}} previously raised +``'OSConnection' object has no attribute 'create_doc_meta_idx'`` when the +backend was OpenSearch. These tests pin the new dispatch surface so the same +regression cannot return: every helper that DocMetadataService dispatches to +on the ES path must exist on OSConnection too, with semantically equivalent +behaviour. + +The OpenSearch and Elasticsearch SDKs are imported at module load; mocking +the underlying client lets us exercise OSConnection methods in isolation +without a live cluster. +""" +from __future__ import annotations + +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + + +# Importing OSConnection touches opensearchpy at module load, so guard for +# environments where the package isn't installed. +opensearchpy = pytest.importorskip("opensearchpy") + + +def _install_module(name: str, **attrs) -> types.ModuleType: + mod = sys.modules.get(name) + if mod is None: + mod = types.ModuleType(name) + sys.modules[name] = mod + for key, value in attrs.items(): + if not hasattr(mod, key): + setattr(mod, key, value) + return mod + + +def _install_module_stubs() -> None: + """Bypass heavy optional backends for connection-only tests. + + ``rag.utils.opensearch_conn`` imports ``common.settings`` and ``rag.nlp`` + at module load. ``common.settings`` in turn pulls every storage backend + (Infinity, OceanBase, Azure, MinIO, GCS …), which is more surface than + these connection-only tests need. We replace just the modules opensearch_conn + captures so the real ``OSConnection`` class loads. + """ + _install_module( + "common.settings", + OS={"hosts": "stub", "username": "u", "password": "p"}, + ES={}, + DOC_ENGINE_INFINITY=False, + DOC_ENGINE_OCEANBASE=False, + DOC_ENGINE="opensearch", + docStoreConn=None, + ) + _install_module( + "rag.nlp", + is_english=lambda *_args, **_kwargs: False, + rag_tokenizer=MagicMock(), + ) + + +_install_module_stubs() + + +class _FakeFile: + """Minimal file-like stand-in supporting ``json.load``.""" + + def __init__(self, content: str) -> None: + self._content = content + + def read(self, *_args, **_kwargs) -> str: + return self._content + + +def _open_returning_payload(payload: dict): + """Build a context-manager mock for ``open`` that yields the JSON payload.""" + import json as _json + + fake_handle = MagicMock() + fake_handle.__enter__ = MagicMock(return_value=_FakeFile(_json.dumps(payload))) + fake_handle.__exit__ = MagicMock(return_value=False) + return MagicMock(return_value=fake_handle) + + +def _resolve_os_connection_class(): + """Return the real OSConnection class. + + ``@singleton`` from ``common.decorator`` wraps the class with a closure + that returns the cached instance on call. ``OSConnection`` at module + scope is therefore a function, not a type. We unwrap it to recover the + underlying class so we can call ``__new__`` directly without going through + ``__init__`` (which would attempt a real OpenSearch handshake). + """ + from rag.utils import opensearch_conn + + candidate = opensearch_conn.OSConnection + if isinstance(candidate, type): + return candidate + closure = getattr(candidate, "__closure__", None) or () + for cell in closure: + contents = cell.cell_contents + if isinstance(contents, type): + return contents + raise RuntimeError("Could not locate the OSConnection class in module scope") + + +def _make_os_connection(): + """Build an OSConnection without invoking its real network-dependent __init__.""" + cls = _resolve_os_connection_class() + instance = cls.__new__(cls) + instance.os = MagicMock() + instance.info = {"version": {"number": "2.18.0"}} + instance.mapping = {"settings": {}, "mappings": {}} + return instance + + +class TestOSConnectionMetaSurface: + """The OSConnection class must expose the dispatch surface + DocMetadataService relies on.""" + + def test_create_doc_meta_idx_exists(self): + cls = _resolve_os_connection_class() + assert callable(getattr(cls, "create_doc_meta_idx", None)), ( + "OSConnection.create_doc_meta_idx is required so the metadata " + "PATCH path does not raise AttributeError on OpenSearch backends " + "(issue #14570)." + ) + + def test_refresh_idx_exists(self): + cls = _resolve_os_connection_class() + assert callable(getattr(cls, "refresh_idx", None)) + + def test_count_idx_exists(self): + cls = _resolve_os_connection_class() + assert callable(getattr(cls, "count_idx", None)) + + def test_replace_meta_fields_exists(self): + cls = _resolve_os_connection_class() + assert callable(getattr(cls, "replace_meta_fields", None)) + + +class TestCreateDocMetaIdx: + """Behavioural tests for OSConnection.create_doc_meta_idx.""" + + def test_returns_true_when_index_already_exists(self): + conn = _make_os_connection() + with patch.object(_resolve_os_connection_class(), "index_exist", return_value=True) as exist: + assert conn.create_doc_meta_idx("ragflow_doc_meta_t1") is True + exist.assert_called_once_with("ragflow_doc_meta_t1", "") + + def test_creates_index_with_doc_meta_mapping(self): + conn = _make_os_connection() + fake_indices = MagicMock() + fake_indices.create.return_value = {"acknowledged": True} + cls = _resolve_os_connection_class() + + with patch.object(cls, "index_exist", return_value=False), \ + patch("rag.utils.opensearch_conn.os.path.exists", return_value=True), \ + patch( + "rag.utils.opensearch_conn.open", + new=_open_returning_payload({ + "settings": {"index": {"number_of_shards": 2}}, + "mappings": {"properties": {"meta_fields": {"type": "object"}}}, + }), + create=True, + ), \ + patch("opensearchpy.client.IndicesClient", return_value=fake_indices): + result = conn.create_doc_meta_idx("ragflow_doc_meta_t1") + + assert result == {"acknowledged": True} + fake_indices.create.assert_called_once() + kwargs = fake_indices.create.call_args.kwargs + assert kwargs["index"] == "ragflow_doc_meta_t1" + body = kwargs["body"] + assert "settings" in body and "mappings" in body + assert body["mappings"]["properties"]["meta_fields"]["type"] == "object" + + def test_returns_false_when_mapping_file_missing(self): + conn = _make_os_connection() + cls = _resolve_os_connection_class() + with patch.object(cls, "index_exist", return_value=False), \ + patch("rag.utils.opensearch_conn.os.path.exists", return_value=False): + assert conn.create_doc_meta_idx("ragflow_doc_meta_t1") is False + + def test_returns_false_when_create_call_explodes(self): + """If the underlying IndicesClient.create raises, the helper must + swallow the exception and return False so the service layer can fall + back gracefully (mirrors ESConnectionBase.create_doc_meta_idx).""" + conn = _make_os_connection() + cls = _resolve_os_connection_class() + fake_indices = MagicMock() + fake_indices.create.side_effect = RuntimeError("opensearch unreachable") + + with patch.object(cls, "index_exist", return_value=False), \ + patch("rag.utils.opensearch_conn.os.path.exists", return_value=True), \ + patch( + "rag.utils.opensearch_conn.open", + new=_open_returning_payload({"settings": {}, "mappings": {}}), + create=True, + ), \ + patch("opensearchpy.client.IndicesClient", return_value=fake_indices): + assert conn.create_doc_meta_idx("ragflow_doc_meta_t1") is False + + +class TestRefreshIdx: + def test_calls_indices_refresh(self): + conn = _make_os_connection() + assert conn.refresh_idx("ragflow_doc_meta_t1") is True + conn.os.indices.refresh.assert_called_once_with(index="ragflow_doc_meta_t1") + + def test_returns_false_on_not_found(self): + conn = _make_os_connection() + conn.os.indices.refresh.side_effect = opensearchpy.NotFoundError( + 404, "index_not_found_exception", {} + ) + assert conn.refresh_idx("missing_idx") is False + + def test_swallows_other_errors_and_returns_false(self): + conn = _make_os_connection() + conn.os.indices.refresh.side_effect = RuntimeError("transient") + assert conn.refresh_idx("ragflow_doc_meta_t1") is False + + +class TestCountIdx: + def test_returns_count_value(self): + conn = _make_os_connection() + conn.os.count.return_value = {"count": 42} + assert conn.count_idx("ragflow_doc_meta_t1") == 42 + conn.os.count.assert_called_once_with(index="ragflow_doc_meta_t1") + + def test_missing_index_reads_as_zero(self): + conn = _make_os_connection() + conn.os.count.side_effect = opensearchpy.NotFoundError( + 404, "index_not_found_exception", {} + ) + assert conn.count_idx("ragflow_doc_meta_t1") == 0 + + def test_other_failure_returns_negative_one(self): + conn = _make_os_connection() + conn.os.count.side_effect = RuntimeError("bad") + assert conn.count_idx("ragflow_doc_meta_t1") == -1 + + +class TestReplaceMetaFields: + def test_emits_full_assignment_script(self): + conn = _make_os_connection() + conn.os.update.return_value = {"_id": "doc-1", "result": "updated"} + meta = {"author": "alice", "year": 2026} + + ok = conn.replace_meta_fields("ragflow_doc_meta_t1", "doc-1", meta) + + assert ok is True + conn.os.update.assert_called_once() + kwargs = conn.os.update.call_args.kwargs + assert kwargs["index"] == "ragflow_doc_meta_t1" + assert kwargs["id"] == "doc-1" + assert kwargs["refresh"] is True + body = kwargs["body"] + # The script must fully assign meta_fields, otherwise removed keys + # would persist via deep merge. + assert body["script"]["source"] == "ctx._source.meta_fields = params.meta_fields" + assert body["script"]["params"]["meta_fields"] == meta + + def test_returns_false_when_doc_missing(self): + conn = _make_os_connection() + conn.os.update.side_effect = opensearchpy.NotFoundError( + 404, "document_missing_exception", {} + ) + assert conn.replace_meta_fields("ragflow_doc_meta_t1", "absent", {"a": 1}) is False diff --git a/test/unit_test/rag/utils/test_raptor_utils.py b/test/unit_test/rag/utils/test_raptor_utils.py index 5138ccda7aa..95abe21097b 100644 --- a/test/unit_test/rag/utils/test_raptor_utils.py +++ b/test/unit_test/rag/utils/test_raptor_utils.py @@ -18,15 +18,22 @@ Unit tests for Raptor utility functions. """ +import logging + import pytest from rag.utils.raptor_utils import ( + CSV_EXTENSIONS, + EXCEL_EXTENSIONS, + STRUCTURED_EXTENSIONS, + collect_raptor_chunk_ids, + collect_raptor_methods, + get_raptor_clustering_method, + get_raptor_tree_builder, + get_skip_reason, is_structured_file_type, is_tabular_pdf, + make_raptor_summary_chunk_id, should_skip_raptor, - get_skip_reason, - EXCEL_EXTENSIONS, - CSV_EXTENSIONS, - STRUCTURED_EXTENSIONS ) @@ -283,5 +290,117 @@ def test_override_for_special_excel(self): assert should_skip_raptor(file_type, raptor_config=raptor_config) is False +class TestRaptorTreeBuilderConfig: + """Test RAPTOR tree builder config resolution""" + + def test_defaults_to_original_raptor_builder(self): + assert get_raptor_tree_builder({}) == "raptor" + assert get_raptor_tree_builder(None) == "raptor" + + def test_reads_top_level_tree_builder(self): + assert get_raptor_tree_builder({"tree_builder": "psi"}) == "psi" + + def test_reads_legacy_ext_tree_builder(self): + assert get_raptor_tree_builder({"ext": {"tree_builder": "psi"}}) == "psi" + + def test_ext_tree_builder_overrides_stale_top_level_value(self): + assert get_raptor_tree_builder({"tree_builder": "psi", "ext": {"tree_builder": "raptor"}}) == "raptor" + + def test_rejects_unknown_tree_builder(self): + with pytest.raises(ValueError, match="Unsupported RAPTOR tree builder"): + get_raptor_tree_builder({"tree_builder": "ahc"}) + + +class TestRaptorClusteringMethodConfig: + """Test RAPTOR clustering method config resolution""" + + def test_defaults_to_gmm(self): + assert get_raptor_clustering_method({}) == "gmm" + assert get_raptor_clustering_method(None) == "gmm" + + def test_reads_top_level_clustering_method(self): + assert get_raptor_clustering_method({"clustering_method": "gmm"}) == "gmm" + assert get_raptor_clustering_method({"clustering_method": "ahc"}) == "ahc" + + def test_reads_legacy_ext_clustering_method(self): + assert get_raptor_clustering_method({"ext": {"clustering_method": "ahc"}}) == "ahc" + + def test_ext_clustering_method_overrides_stale_top_level_value(self): + assert get_raptor_clustering_method({"clustering_method": "gmm", "ext": {"clustering_method": "ahc"}}) == "ahc" + + def test_rejects_unknown_clustering_method(self): + with pytest.raises(ValueError, match="Unsupported RAPTOR clustering method"): + get_raptor_clustering_method({"clustering_method": "unknown"}) + + +class TestRaptorMethodCollection: + """Test RAPTOR summary method extraction from doc-store fields""" + + def test_legacy_summary_without_method_is_original_raptor(self): + field_map = {"chunk_1": {"raptor_kwd": "raptor"}} + + assert collect_raptor_methods(field_map) == {"raptor"} + assert collect_raptor_chunk_ids(field_map) == {"chunk_1"} + + def test_extra_method_is_preserved(self): + field_map = {"chunk_1": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}} + + assert collect_raptor_methods(field_map) == {"psi"} + assert collect_raptor_chunk_ids(field_map) == {"chunk_1"} + + def test_extra_field_supports_oceanbase_legacy_rows(self): + field_map = { + "chunk_1": { + "extra": { + "raptor_kwd": "raptor", + "raptor_method": "psi", + } + }, + "chunk_2": { + "extra": "{\"raptor_kwd\": \"raptor\"}", + }, + "chunk_3": { + "extra": {"raptor_kwd": ""}, + }, + } + + assert collect_raptor_methods(field_map) == {"psi", "raptor"} + assert collect_raptor_chunk_ids(field_map) == {"chunk_1", "chunk_2"} + + def test_non_raptor_rows_are_ignored(self): + field_map = { + "chunk_1": {"raptor_kwd": ""}, + "chunk_2": {"extra": {"raptor_kwd": "graph"}}, + "chunk_3": {}, + } + + assert collect_raptor_methods(field_map) == set() + assert collect_raptor_chunk_ids(field_map) == set() + + def test_malformed_extra_payload_is_logged_and_ignored(self, caplog): + field_map = {"chunk_1": {"extra": "{bad json"}} + + with caplog.at_level(logging.WARNING): + assert collect_raptor_methods(field_map) == set() + assert collect_raptor_chunk_ids(field_map) == set() + + assert "Ignoring malformed RAPTOR extra payload" in caplog.text + + def test_chunk_id_collection_can_preserve_current_method(self): + field_map = { + "legacy": {"raptor_kwd": "raptor"}, + "old": {"raptor_kwd": "raptor", "extra": {"raptor_method": "raptor"}}, + "current": {"raptor_kwd": "raptor", "extra": {"raptor_method": "psi"}}, + } + + assert collect_raptor_chunk_ids(field_map, exclude_methods={"psi"}) == {"legacy", "old"} + assert collect_raptor_chunk_ids(field_map, exclude_methods={"raptor"}) == {"current"} + + def test_summary_chunk_ids_include_real_document_id(self): + content = "same generated summary" + + assert make_raptor_summary_chunk_id(content, "doc-a") != make_raptor_summary_chunk_id(content, "doc-b") + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tools/scripts/README.md b/tools/scripts/README.md index fc05d12fbb5..53091730857 100644 --- a/tools/scripts/README.md +++ b/tools/scripts/README.md @@ -275,8 +275,8 @@ python db_schema_sync.py [OPTIONS] ### Version Format Version must be in format `vxx.xx.xx` where `xx` are digits: -- Valid: `v0.25.2`, `v1.0.0`, `v10.20.30` -- Invalid: `0.25.2`, `v0.25`, `v0.25.2.1` +- Valid: `v0.25.5`, `v1.0.0`, `v10.20.30` +- Invalid: `0.25.5`, `v0.25`, `v0.25.5.1` ### Migration File Location @@ -287,7 +287,7 @@ tools/migrate/{version_dir}/ Where `{version_dir}` is the version with `.` replaced by `_`. -Example: Version `v0.25.2` → Directory `tools/migrate/v0_25_2/` +Example: Version `v0.25.5` → Directory `tools/migrate/v0_25_5/` ### Examples @@ -295,32 +295,32 @@ Example: Version `v0.25.2` → Directory `tools/migrate/v0_25_2/` # List all migrations python db_schema_sync.py --list \ --host localhost --port 3306 --user root --password xxx --database rag_flow \ - --version v0.25.2 + --version v0.25.5 # Create a new auto-detected migration (new tables, new fields, type changes only) python db_schema_sync.py --create \ --host localhost --port 3306 --user root --password xxx --database rag_flow \ - --version v0.25.2 + --version v0.25.5 # Create a migration including dropped fields (destructive!) python db_schema_sync.py --create --drop \ --host localhost --port 3306 --user root --password xxx --database rag_flow \ - --version v0.25.2 + --version v0.25.5 # Create a named migration python db_schema_sync.py --create --name add_user_table \ --host localhost --port 3306 --user root --password xxx --database rag_flow \ - --version v0.25.2 + --version v0.25.5 # Run all pending migrations python db_schema_sync.py --migrate \ --host localhost --port 3306 --user root --password xxx --database rag_flow \ - --version v0.25.2 + --version v0.25.5 # Show schema differences (including removed fields) python db_schema_sync.py --diff \ --host localhost --port 3306 --user root --password xxx --database rag_flow \ - --version v0.25.2 + --version v0.25.5 ``` ## How It Works diff --git a/tools/scripts/db_schema_sync.py b/tools/scripts/db_schema_sync.py index 175fc9e61fc..1e85ffe44e0 100644 --- a/tools/scripts/db_schema_sync.py +++ b/tools/scripts/db_schema_sync.py @@ -55,7 +55,7 @@ def validate_version(version: str) -> bool: def version_to_dirname(version: str) -> str: - """Convert version string to valid directory name (e.g., 'v0.25.2' -> 'v0_25_2')""" + """Convert version string to valid directory name (e.g., 'v0.25.5' -> 'v0_25_5')""" return version.replace('.', '_') @@ -839,19 +839,19 @@ def main(): epilog=""" Examples: # List all migrations - python db_schema_sync.py --list --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.2 + python db_schema_sync.py --list --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.5 # Create migration from model changes - python db_schema_sync.py --create --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.2 + python db_schema_sync.py --create --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.5 # Create migration including dropped fields (destructive!) - python db_schema_sync.py --create --drop --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.2 + python db_schema_sync.py --create --drop --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.5 # Run all pending migrations - python db_schema_sync.py --migrate --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.2 + python db_schema_sync.py --migrate --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.5 # Show schema differences - python db_schema_sync.py --diff --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.2 + python db_schema_sync.py --diff --host localhost --port 3306 --user root --password xxx --database rag_flow --version v0.25.5 """ ) @@ -864,7 +864,7 @@ def main(): # Version option parser.add_argument('--version', '-v', type=str, required=True, - help='Version number in format vxx.xx.xx (e.g., v0.25.2)') + help='Version number in format vxx.xx.xx (e.g., v0.25.5)') # Action options parser.add_argument('--list', '-l', action='store_true', help='List all migrations') @@ -882,7 +882,7 @@ def main(): # Validate version format if not validate_version(args.version): - logger.error(f"Invalid version format: {args.version}. Expected format: vxx.xx.xx (e.g., v0.25.2)") + logger.error(f"Invalid version format: {args.version}. Expected format: vxx.xx.xx (e.g., v0.25.5)") sys.exit(1) # Validate at least one action is specified diff --git a/uv.lock b/uv.lock index abb33e17734..8faf560e497 100644 --- a/uv.lock +++ b/uv.lock @@ -1,29 +1,21 @@ version = 1 revision = 3 -requires-python = ">=3.12, <3.15" +requires-python = ">=3.13, <3.15" resolution-markers = [ "python_full_version >= '3.14' and sys_platform == 'darwin'", - "python_full_version == '3.13.*' and sys_platform == 'darwin'", - "python_full_version < '3.13' and sys_platform == 'darwin'", + "python_full_version < '3.14' and sys_platform == 'darwin'", "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", + "python_full_version < '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", "(python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'linux')", - "(python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'linux')", - "(python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", + "(python_full_version < '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.14' and sys_platform != 'darwin' and sys_platform != 'linux')", ] [manifest] -constraints = [{ name = "pyasn1", specifier = ">=0.6.3" }] - -[[package]] -name = "absl-py" -version = "2.4.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/64/c7/8de93764ad66968d19329a7e0c147a2bb3c7054c554d4a119111b8f9440f/absl_py-2.4.0.tar.gz", hash = "sha256:8c6af82722b35cf71e0f4d1d47dcaebfff286e27110a99fc359349b247dfb5d4" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/18/a6/907a406bb7d359e6a63f99c313846d9eec4f7e6f7437809e03aa00fa3074/absl_py-2.4.0-py3-none-any.whl", hash = "sha256:88476fd881ca8aab94ffa78b7b6c632a782ab3ba1cd19c9bd423abc4fb4cd28d" }, +constraints = [ + { name = "pyasn1", specifier = ">=0.6.3" }, + { name = "trio", specifier = ">=0.26.0", index = "https://pypi.org/simple" }, ] +overrides = [{ name = "attrs", specifier = ">=23.2.0" }] [[package]] name = "agentrun-mem0ai" @@ -103,23 +95,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/50/42/32cf8e7704ceb4481406eb87161349abb46a57fee3f008ba9cb610968646/aiohttp-3.13.3.tar.gz", hash = "sha256:a949eee43d3782f2daae4f4a2819b2cb9b0c5d3b7f7a927067cc84dafdbb9f88" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/a0/be/4fc11f202955a69e0db803a12a062b8379c970c7c84f4882b6da17337cc1/aiohttp-3.13.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:b903a4dfee7d347e2d87697d0713be59e0b87925be030c9178c5faa58ea58d5c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/97/2c/621d5b851f94fa0bb7430d6089b3aa970a9d9b75196bc93bb624b0db237a/aiohttp-3.13.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:a45530014d7a1e09f4a55f4f43097ba0fd155089372e105e4bff4ca76cb1b168" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5d/43/4be01406b78e1be8320bb8316dc9c42dbab553d281c40364e0f862d5661c/aiohttp-3.13.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:27234ef6d85c914f9efeb77ff616dbf4ad2380be0cda40b4db086ffc7ddd1b7d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8d/a8/5a35dc56a06a2c90d4742cbf35294396907027f80eea696637945a106f25/aiohttp-3.13.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d32764c6c9aafb7fb55366a224756387cd50bfa720f32b88e0e6fa45b27dcf29" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bf/62/4b9eeb331da56530bf2e198a297e5303e1c1ebdceeb00fe9b568a65c5a0c/aiohttp-3.13.3-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:b1a6102b4d3ebc07dad44fbf07b45bb600300f15b552ddf1851b5390202ea2e3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7c/f6/af16887b5d419e6a367095994c0b1332d154f647e7dc2bd50e61876e8e3d/aiohttp-3.13.3-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c014c7ea7fb775dd015b2d3137378b7be0249a448a1612268b5a90c2d81de04d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ce/83/397c634b1bcc24292fa1e0c7822800f9f6569e32934bdeef09dae7992dfb/aiohttp-3.13.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2b8d8ddba8f95ba17582226f80e2de99c7a7948e66490ef8d947e272a93e9463" }, - { url = "https://mirrors.aliyun.com/pypi/packages/86/f6/a62cbbf13f0ac80a70f71b1672feba90fdb21fd7abd8dbf25c0105fb6fa3/aiohttp-3.13.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:9ae8dd55c8e6c4257eae3a20fd2c8f41edaea5992ed67156642493b8daf3cecc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0a/87/20a35ad487efdd3fba93d5843efdfaa62d2f1479eaafa7453398a44faf13/aiohttp-3.13.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:01ad2529d4b5035578f5081606a465f3b814c542882804e2e8cda61adf5c71bf" }, - { url = "https://mirrors.aliyun.com/pypi/packages/de/95/8fd69a66682012f6716e1bc09ef8a1a2a91922c5725cb904689f112309c4/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bb4f7475e359992b580559e008c598091c45b5088f28614e855e42d39c2f1033" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e5/66/7b94b3b5ba70e955ff597672dad1691333080e37f50280178967aff68657/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:c19b90316ad3b24c69cd78d5c9b4f3aa4497643685901185b65166293d36a00f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/47/71/6f72f77f9f7d74719692ab65a2a0252584bf8d5f301e2ecb4c0da734530a/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:96d604498a7c782cb15a51c406acaea70d8c027ee6b90c569baa6e7b93073679" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fa/b4/75ec16cbbd5c01bdaf4a05b19e103e78d7ce1ef7c80867eb0ace42ff4488/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:084911a532763e9d3dd95adf78a78f4096cd5f58cdc18e6fdbc1b58417a45423" }, - { url = "https://mirrors.aliyun.com/pypi/packages/52/8f/bc518c0eea29f8406dcf7ed1f96c9b48e3bc3995a96159b3fc11f9e08321/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7a4a94eb787e606d0a09404b9c38c113d3b099d508021faa615d70a0131907ce" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9d/f2/a07a75173124f31f11ea6f863dc44e6f09afe2bca45dd4e64979490deab1/aiohttp-3.13.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:87797e645d9d8e222e04160ee32aa06bc5c163e8499f24db719e7852ec23093a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3c/4a/1a3fee7c21350cac78e5c5cef711bac1b94feca07399f3d406972e2d8fcd/aiohttp-3.13.3-cp312-cp312-win32.whl", hash = "sha256:b04be762396457bef43f3597c991e192ee7da460a4953d7e647ee4b1c28e7046" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d9/b7/76175c7cb4eb73d91ad63c34e29fc4f77c9386bba4a65b53ba8e05ee3c39/aiohttp-3.13.3-cp312-cp312-win_amd64.whl", hash = "sha256:e3531d63d3bdfa7e3ac5e9b27b2dd7ec9df3206a98e0b3445fa906f233264c57" }, { url = "https://mirrors.aliyun.com/pypi/packages/97/8a/12ca489246ca1faaf5432844adbfce7ff2cc4997733e0af120869345643a/aiohttp-3.13.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:5dff64413671b0d3e7d5918ea490bdccb97a4ad29b3f311ed423200b2203e01c" }, { url = "https://mirrors.aliyun.com/pypi/packages/32/08/de43984c74ed1fca5c014808963cc83cb00d7bb06af228f132d33862ca76/aiohttp-3.13.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:87b9aab6d6ed88235aa2970294f496ff1a1f9adcd724d800e9b952395a80ffd9" }, { url = "https://mirrors.aliyun.com/pypi/packages/17/f8/8dd2cf6112a5a76f81f81a5130c57ca829d101ad583ce57f889179accdda/aiohttp-3.13.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:425c126c0dc43861e22cb1c14ba4c8e45d09516d0a3ae0a3f7494b79f5f233a3" }, @@ -200,7 +175,6 @@ version = "1.4.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ { name = "frozenlist" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/61/62/06741b579156360248d1ec624842ad0edf697050bbaf7c3e46394e106ad1/aiosignal-1.4.0.tar.gz", hash = "sha256:f47eecd9468083c2029cc99945502cb7708b082c232f9aca65da147157b251c7" } wheels = [ @@ -523,7 +497,6 @@ version = "4.13.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ { name = "idna" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/19/14/2c5dd9f512b66549ae92767a9c7b330ae88e1932ca57876909410251fe13/anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc" } wheels = [ @@ -635,19 +608,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/96/3a/2baa6a2a3319bfcc0bc490a26c9057eba2412502eb6ab16e55533dd511a7/asana-5.2.3-py3-none-any.whl", hash = "sha256:543e928aadf1a0f05769bfab14e1d9dbb7c6183ce75c451aea0fd2196e392e7e" }, ] -[[package]] -name = "astunparse" -version = "1.6.3" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "six" }, - { name = "wheel" }, -] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/f3/af/4182184d3c338792894f34a62672919db7ca008c89abee9b564dd34d8029/astunparse-1.6.3.tar.gz", hash = "sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/2b/03/13dde6512ad7b4557eb792fbcf0c653af6076b81e5941d36ec61f7ce6028/astunparse-1.6.3-py2.py3-none-any.whl", hash = "sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8" }, -] - [[package]] name = "atlassian-python-api" version = "4.0.7" @@ -668,11 +628,67 @@ wheels = [ [[package]] name = "attrs" -version = "22.2.0" +version = "26.1.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/21/31/3f468da74c7de4fcf9b25591e682856389b3400b4b62f201e65f15ea3e07/attrs-22.2.0.tar.gz", hash = "sha256:c9227bfc2f01993c03f68db37d1d15c9690188323c067c641f1a35ca58185f99" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9a/8e/82a0fe20a541c03148528be8cac2408564a6c9a0cc7e9171802bc1d26985/attrs-26.1.0.tar.gz", hash = "sha256:d03ceb89cb322a8fd706d4fb91940737b6642aa36998fe130a9bc96c985eff32" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/fb/6e/6f83bf616d2becdf333a1640f1d463fef3150e2e926b7010cb0f81c95e88/attrs-22.2.0-py3-none-any.whl", hash = "sha256:29e95c7f6778868dbd49170f98f8818f78f3dc5e0e37c0b1f474e3561b240836" }, + { url = "https://mirrors.aliyun.com/pypi/packages/64/b4/17d4b0b2a2dc85a6df63d1157e028ed19f90d4cd97c36717afef2bc2f395/attrs-26.1.0-py3-none-any.whl", hash = "sha256:c647aa4a12dfbad9333ca4e71fe62ddc36f4e63b2d260a37a8b83d2f043ac309" }, +] + +[[package]] +name = "audioop-lts" +version = "0.2.2" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/38/53/946db57842a50b2da2e0c1e34bd37f36f5aadba1a929a3971c5d7841dbca/audioop_lts-0.2.2.tar.gz", hash = "sha256:64d0c62d88e67b98a1a5e71987b7aa7b5bcffc7dcee65b635823dbdd0a8dbbd0" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/de/d4/94d277ca941de5a507b07f0b592f199c22454eeaec8f008a286b3fbbacd6/audioop_lts-0.2.2-cp313-abi3-macosx_10_13_universal2.whl", hash = "sha256:fd3d4602dc64914d462924a08c1a9816435a2155d74f325853c1f1ac3b2d9800" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f8/5a/656d1c2da4b555920ce4177167bfeb8623d98765594af59702c8873f60ec/audioop_lts-0.2.2-cp313-abi3-macosx_10_13_x86_64.whl", hash = "sha256:550c114a8df0aafe9a05442a1162dfc8fec37e9af1d625ae6060fed6e756f303" }, + { url = "https://mirrors.aliyun.com/pypi/packages/1b/83/ea581e364ce7b0d41456fb79d6ee0ad482beda61faf0cab20cbd4c63a541/audioop_lts-0.2.2-cp313-abi3-macosx_11_0_arm64.whl", hash = "sha256:9a13dc409f2564de15dd68be65b462ba0dde01b19663720c68c1140c782d1d75" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b8/3b/e8964210b5e216e5041593b7d33e97ee65967f17c282e8510d19c666dab4/audioop_lts-0.2.2-cp313-abi3-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:51c916108c56aa6e426ce611946f901badac950ee2ddaf302b7ed35d9958970d" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c7/2e/0a1c52faf10d51def20531a59ce4c706cb7952323b11709e10de324d6493/audioop_lts-0.2.2-cp313-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:47eba38322370347b1c47024defbd36374a211e8dd5b0dcbce7b34fdb6f8847b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/75/e8/cd95eef479656cb75ab05dfece8c1f8c395d17a7c651d88f8e6e291a63ab/audioop_lts-0.2.2-cp313-abi3-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:ba7c3a7e5f23e215cb271516197030c32aef2e754252c4c70a50aaff7031a2c8" }, + { url = "https://mirrors.aliyun.com/pypi/packages/5c/1e/a0c42570b74f83efa5cca34905b3eef03f7ab09fe5637015df538a7f3345/audioop_lts-0.2.2-cp313-abi3-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:def246fe9e180626731b26e89816e79aae2276f825420a07b4a647abaa84becc" }, + { url = "https://mirrors.aliyun.com/pypi/packages/50/d5/8a0ae607ca07dbb34027bac8db805498ee7bfecc05fd2c148cc1ed7646e7/audioop_lts-0.2.2-cp313-abi3-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e160bf9df356d841bb6c180eeeea1834085464626dc1b68fa4e1d59070affdc3" }, + { url = "https://mirrors.aliyun.com/pypi/packages/12/17/0d28c46179e7910bfb0bb62760ccb33edb5de973052cb2230b662c14ca2e/audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:4b4cd51a57b698b2d06cb9993b7ac8dfe89a3b2878e96bc7948e9f19ff51dba6" }, + { url = "https://mirrors.aliyun.com/pypi/packages/84/ba/bd5d3806641564f2024e97ca98ea8f8811d4e01d9b9f9831474bc9e14f9e/audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_ppc64le.whl", hash = "sha256:4a53aa7c16a60a6857e6b0b165261436396ef7293f8b5c9c828a3a203147ed4a" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f9/5e/435ce8d5642f1f7679540d1e73c1c42d933331c0976eb397d1717d7f01a3/audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_riscv64.whl", hash = "sha256:3fc38008969796f0f689f1453722a0f463da1b8a6fbee11987830bfbb664f623" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ae/3b/b909e76b606cbfd53875693ec8c156e93e15a1366a012f0b7e4fb52d3c34/audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_s390x.whl", hash = "sha256:15ab25dd3e620790f40e9ead897f91e79c0d3ce65fe193c8ed6c26cffdd24be7" }, + { url = "https://mirrors.aliyun.com/pypi/packages/30/e7/8f1603b4572d79b775f2140d7952f200f5e6c62904585d08a01f0a70393a/audioop_lts-0.2.2-cp313-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:03f061a1915538fd96272bac9551841859dbb2e3bf73ebe4a23ef043766f5449" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b5/96/c37846df657ccdda62ba1ae2b6534fa90e2e1b1742ca8dcf8ebd38c53801/audioop_lts-0.2.2-cp313-abi3-win32.whl", hash = "sha256:3bcddaaf6cc5935a300a8387c99f7a7fbbe212a11568ec6cf6e4bc458c048636" }, + { url = "https://mirrors.aliyun.com/pypi/packages/34/a5/9d78fdb5b844a83da8a71226c7bdae7cc638861085fff7a1d707cb4823fa/audioop_lts-0.2.2-cp313-abi3-win_amd64.whl", hash = "sha256:a2c2a947fae7d1062ef08c4e369e0ba2086049a5e598fda41122535557012e9e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/34/25/20d8fde083123e90c61b51afb547bb0ea7e77bab50d98c0ab243d02a0e43/audioop_lts-0.2.2-cp313-abi3-win_arm64.whl", hash = "sha256:5f93a5db13927a37d2d09637ccca4b2b6b48c19cd9eda7b17a2e9f77edee6a6f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/58/a7/0a764f77b5c4ac58dc13c01a580f5d32ae8c74c92020b961556a43e26d02/audioop_lts-0.2.2-cp313-cp313t-macosx_10_13_universal2.whl", hash = "sha256:73f80bf4cd5d2ca7814da30a120de1f9408ee0619cc75da87d0641273d202a09" }, + { url = "https://mirrors.aliyun.com/pypi/packages/aa/ed/ebebedde1a18848b085ad0fa54b66ceb95f1f94a3fc04f1cd1b5ccb0ed42/audioop_lts-0.2.2-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:106753a83a25ee4d6f473f2be6b0966fc1c9af7e0017192f5531a3e7463dce58" }, + { url = "https://mirrors.aliyun.com/pypi/packages/cb/6e/11ca8c21af79f15dbb1c7f8017952ee8c810c438ce4e2b25638dfef2b02c/audioop_lts-0.2.2-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:fbdd522624141e40948ab3e8cdae6e04c748d78710e9f0f8d4dae2750831de19" }, + { url = "https://mirrors.aliyun.com/pypi/packages/84/52/0022f93d56d85eec5da6b9da6a958a1ef09e80c39f2cc0a590c6af81dcbb/audioop_lts-0.2.2-cp313-cp313t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:143fad0311e8209ece30a8dbddab3b65ab419cbe8c0dde6e8828da25999be911" }, + { url = "https://mirrors.aliyun.com/pypi/packages/87/1d/48a889855e67be8718adbc7a01f3c01d5743c325453a5e81cf3717664aad/audioop_lts-0.2.2-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:dfbbc74ec68a0fd08cfec1f4b5e8cca3d3cd7de5501b01c4b5d209995033cde9" }, + { url = "https://mirrors.aliyun.com/pypi/packages/98/a6/94b7213190e8077547ffae75e13ed05edc488653c85aa5c41472c297d295/audioop_lts-0.2.2-cp313-cp313t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:cfcac6aa6f42397471e4943e0feb2244549db5c5d01efcd02725b96af417f3fe" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e9/e9/78450d7cb921ede0cfc33426d3a8023a3bda755883c95c868ee36db8d48d/audioop_lts-0.2.2-cp313-cp313t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:752d76472d9804ac60f0078c79cdae8b956f293177acd2316cd1e15149aee132" }, + { url = "https://mirrors.aliyun.com/pypi/packages/4f/e2/cd5439aad4f3e34ae1ee852025dc6aa8f67a82b97641e390bf7bd9891d3e/audioop_lts-0.2.2-cp313-cp313t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:83c381767e2cc10e93e40281a04852facc4cd9334550e0f392f72d1c0a9c5753" }, + { url = "https://mirrors.aliyun.com/pypi/packages/68/4b/9d853e9076c43ebba0d411e8d2aa19061083349ac695a7d082540bad64d0/audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c0022283e9556e0f3643b7c3c03f05063ca72b3063291834cca43234f20c60bb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/58/26/4bae7f9d2f116ed5593989d0e521d679b0d583973d203384679323d8fa85/audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_ppc64le.whl", hash = "sha256:a2d4f1513d63c795e82948e1305f31a6d530626e5f9f2605408b300ae6095093" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b2/67/a9f4fb3e250dda9e9046f8866e9fa7d52664f8985e445c6b4ad6dfb55641/audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_riscv64.whl", hash = "sha256:c9c8e68d8b4a56fda8c025e538e639f8c5953f5073886b596c93ec9b620055e7" }, + { url = "https://mirrors.aliyun.com/pypi/packages/70/f7/3de86562db0121956148bcb0fe5b506615e3bcf6e63c4357a612b910765a/audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_s390x.whl", hash = "sha256:96f19de485a2925314f5020e85911fb447ff5fbef56e8c7c6927851b95533a1c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f1/32/fd772bf9078ae1001207d2df1eef3da05bea611a87dd0e8217989b2848fa/audioop_lts-0.2.2-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:e541c3ef484852ef36545f66209444c48b28661e864ccadb29daddb6a4b8e5f5" }, + { url = "https://mirrors.aliyun.com/pypi/packages/4f/41/affea7181592ab0ab560044632571a38edaf9130b84928177823fbf3176a/audioop_lts-0.2.2-cp313-cp313t-win32.whl", hash = "sha256:d5e73fa573e273e4f2e5ff96f9043858a5e9311e94ffefd88a3186a910c70917" }, + { url = "https://mirrors.aliyun.com/pypi/packages/28/2b/0372842877016641db8fc54d5c88596b542eec2f8f6c20a36fb6612bf9ee/audioop_lts-0.2.2-cp313-cp313t-win_amd64.whl", hash = "sha256:9191d68659eda01e448188f60364c7763a7ca6653ed3f87ebb165822153a8547" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ee/ca/baf2b9cc7e96c179bb4a54f30fcd83e6ecb340031bde68f486403f943768/audioop_lts-0.2.2-cp313-cp313t-win_arm64.whl", hash = "sha256:c174e322bb5783c099aaf87faeb240c8d210686b04bd61dfd05a8e5a83d88969" }, + { url = "https://mirrors.aliyun.com/pypi/packages/5c/73/413b5a2804091e2c7d5def1d618e4837f1cb82464e230f827226278556b7/audioop_lts-0.2.2-cp314-cp314t-macosx_10_13_universal2.whl", hash = "sha256:f9ee9b52f5f857fbaf9d605a360884f034c92c1c23021fb90b2e39b8e64bede6" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ae/8c/daa3308dc6593944410c2c68306a5e217f5c05b70a12e70228e7dd42dc5c/audioop_lts-0.2.2-cp314-cp314t-macosx_10_13_x86_64.whl", hash = "sha256:49ee1a41738a23e98d98b937a0638357a2477bc99e61b0f768a8f654f45d9b7a" }, + { url = "https://mirrors.aliyun.com/pypi/packages/4e/86/c2e0f627168fcf61781a8f72cab06b228fe1da4b9fa4ab39cfb791b5836b/audioop_lts-0.2.2-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5b00be98ccd0fc123dcfad31d50030d25fcf31488cde9e61692029cd7394733b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c7/bd/35dce665255434f54e5307de39e31912a6f902d4572da7c37582809de14f/audioop_lts-0.2.2-cp314-cp314t-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:a6d2e0f9f7a69403e388894d4ca5ada5c47230716a03f2847cfc7bd1ecb589d6" }, + { url = "https://mirrors.aliyun.com/pypi/packages/2d/d2/deeb9f51def1437b3afa35aeb729d577c04bcd89394cb56f9239a9f50b6f/audioop_lts-0.2.2-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f9b0b8a03ef474f56d1a842af1a2e01398b8f7654009823c6d9e0ecff4d5cfbf" }, + { url = "https://mirrors.aliyun.com/pypi/packages/76/3b/09f8b35b227cee28cc8231e296a82759ed80c1a08e349811d69773c48426/audioop_lts-0.2.2-cp314-cp314t-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2b267b70747d82125f1a021506565bdc5609a2b24bcb4773c16d79d2bb260bbd" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0b/15/05b48a935cf3b130c248bfdbdea71ce6437f5394ee8533e0edd7cfd93d5e/audioop_lts-0.2.2-cp314-cp314t-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0337d658f9b81f4cd0fdb1f47635070cc084871a3d4646d9de74fdf4e7c3d24a" }, + { url = "https://mirrors.aliyun.com/pypi/packages/83/80/186b7fce6d35b68d3d739f228dc31d60b3412105854edb975aa155a58339/audioop_lts-0.2.2-cp314-cp314t-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:167d3b62586faef8b6b2275c3218796b12621a60e43f7e9d5845d627b9c9b80e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/49/89/c78cc5ac6cb5828f17514fb12966e299c850bc885e80f8ad94e38d450886/audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:0d9385e96f9f6da847f4d571ce3cb15b5091140edf3db97276872647ce37efd7" }, + { url = "https://mirrors.aliyun.com/pypi/packages/4c/4b/6401888d0c010e586c2ca50fce4c903d70a6bb55928b16cfbdfd957a13da/audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_ppc64le.whl", hash = "sha256:48159d96962674eccdca9a3df280e864e8ac75e40a577cc97c5c42667ffabfc5" }, + { url = "https://mirrors.aliyun.com/pypi/packages/de/f8/c874ca9bb447dae0e2ef2e231f6c4c2b0c39e31ae684d2420b0f9e97ee68/audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_riscv64.whl", hash = "sha256:8fefe5868cd082db1186f2837d64cfbfa78b548ea0d0543e9b28935ccce81ce9" }, + { url = "https://mirrors.aliyun.com/pypi/packages/3e/c0/0323e66f3daebc13fd46b36b30c3be47e3fc4257eae44f1e77eb828c703f/audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_s390x.whl", hash = "sha256:58cf54380c3884fb49fdd37dfb7a772632b6701d28edd3e2904743c5e1773602" }, + { url = "https://mirrors.aliyun.com/pypi/packages/98/6b/acc7734ac02d95ab791c10c3f17ffa3584ccb9ac5c18fd771c638ed6d1f5/audioop_lts-0.2.2-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:088327f00488cdeed296edd9215ca159f3a5a5034741465789cad403fcf4bec0" }, + { url = "https://mirrors.aliyun.com/pypi/packages/13/c3/c3dc3f564ce6877ecd2a05f8d751b9b27a8c320c2533a98b0c86349778d0/audioop_lts-0.2.2-cp314-cp314t-win32.whl", hash = "sha256:068aa17a38b4e0e7de771c62c60bbca2455924b67a8814f3b0dee92b5820c0b3" }, + { url = "https://mirrors.aliyun.com/pypi/packages/72/bb/b4608537e9ffcb86449091939d52d24a055216a36a8bf66b936af8c3e7ac/audioop_lts-0.2.2-cp314-cp314t-win_amd64.whl", hash = "sha256:a5bf613e96f49712073de86f20dbdd4014ca18efd4d34ed18c75bd808337851b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f6/22/91616fe707a5c5510de2cac9b046a30defe7007ba8a0c04f9c08f27df312/audioop_lts-0.2.2-cp314-cp314t-win_arm64.whl", hash = "sha256:b492c3b040153e68b9fdaff5913305aaaba5bb433d8a7f73d5cf6a64ed3cc1dd" }, ] [[package]] @@ -779,6 +795,72 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/d4/a9/a58a63e2756e5d01901595af58c673f68de7621f28d71007479e00f45a6c/bce_python_sdk-0.9.67-py3-none-any.whl", hash = "sha256:3054879d098a92ceeb4b9ac1e64d2c658120a5a10e8e630f22410564b2170bf0" }, ] +[[package]] +name = "bcrypt" +version = "5.0.0" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d4/36/3329e2518d70ad8e2e5817d5a4cac6bba05a47767ec416c7d020a965f408/bcrypt-5.0.0.tar.gz", hash = "sha256:f748f7c2d6fd375cc93d3fba7ef4a9e3a092421b8dbf34d8d4dc06be9492dfdd" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/13/85/3e65e01985fddf25b64ca67275bb5bdb4040bd1a53b66d355c6c37c8a680/bcrypt-5.0.0-cp313-cp313t-macosx_10_12_universal2.whl", hash = "sha256:f3c08197f3039bec79cee59a606d62b96b16669cff3949f21e74796b6e3cd2be" }, + { url = "https://mirrors.aliyun.com/pypi/packages/44/dc/01eb79f12b177017a726cbf78330eb0eb442fae0e7b3dfd84ea2849552f3/bcrypt-5.0.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:200af71bc25f22006f4069060c88ed36f8aa4ff7f53e67ff04d2ab3f1e79a5b2" }, + { url = "https://mirrors.aliyun.com/pypi/packages/8c/cf/e82388ad5959c40d6afd94fb4743cc077129d45b952d46bdc3180310e2df/bcrypt-5.0.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:baade0a5657654c2984468efb7d6c110db87ea63ef5a4b54732e7e337253e44f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ec/86/7134b9dae7cf0efa85671651341f6afa695857fae172615e960fb6a466fa/bcrypt-5.0.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:c58b56cdfb03202b3bcc9fd8daee8e8e9b6d7e3163aa97c631dfcfcc24d36c86" }, + { url = "https://mirrors.aliyun.com/pypi/packages/cc/82/6296688ac1b9e503d034e7d0614d56e80c5d1a08402ff856a4549cb59207/bcrypt-5.0.0-cp313-cp313t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4bfd2a34de661f34d0bda43c3e4e79df586e4716ef401fe31ea39d69d581ef23" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d1/18/884a44aa47f2a3b88dd09bc05a1e40b57878ecd111d17e5bba6f09f8bb77/bcrypt-5.0.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:ed2e1365e31fc73f1825fa830f1c8f8917ca1b3ca6185773b349c20fd606cec2" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0e/8f/371a3ab33c6982070b674f1788e05b656cfbf5685894acbfef0c65483a59/bcrypt-5.0.0-cp313-cp313t-manylinux_2_34_aarch64.whl", hash = "sha256:83e787d7a84dbbfba6f250dd7a5efd689e935f03dd83b0f919d39349e1f23f83" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b1/34/7e4e6abb7a8778db6422e88b1f06eb07c47682313997ee8a8f9352e5a6f1/bcrypt-5.0.0-cp313-cp313t-manylinux_2_34_x86_64.whl", hash = "sha256:137c5156524328a24b9fac1cb5db0ba618bc97d11970b39184c1d87dc4bf1746" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c0/1b/54f416be2499bd72123c70d98d36c6cd61a4e33d9b89562c22481c81bb30/bcrypt-5.0.0-cp313-cp313t-musllinux_1_1_aarch64.whl", hash = "sha256:38cac74101777a6a7d3b3e3cfefa57089b5ada650dce2baf0cbdd9d65db22a9e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/13/62/062c24c7bcf9d2826a1a843d0d605c65a755bc98002923d01fd61270705a/bcrypt-5.0.0-cp313-cp313t-musllinux_1_1_x86_64.whl", hash = "sha256:d8d65b564ec849643d9f7ea05c6d9f0cd7ca23bdd4ac0c2dbef1104ab504543d" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d5/c8/1fdbfc8c0f20875b6b4020f3c7dc447b8de60aa0be5faaf009d24242aec9/bcrypt-5.0.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:741449132f64b3524e95cd30e5cd3343006ce146088f074f31ab26b94e6c75ba" }, + { url = "https://mirrors.aliyun.com/pypi/packages/a6/c1/8b84545382d75bef226fbc6588af0f7b7d095f7cd6a670b42a86243183cd/bcrypt-5.0.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:212139484ab3207b1f0c00633d3be92fef3c5f0af17cad155679d03ff2ee1e41" }, + { url = "https://mirrors.aliyun.com/pypi/packages/10/a6/ffb49d4254ed085e62e3e5dd05982b4393e32fe1e49bb1130186617c29cd/bcrypt-5.0.0-cp313-cp313t-win32.whl", hash = "sha256:9d52ed507c2488eddd6a95bccee4e808d3234fa78dd370e24bac65a21212b861" }, + { url = "https://mirrors.aliyun.com/pypi/packages/48/a9/259559edc85258b6d5fc5471a62a3299a6aa37a6611a169756bf4689323c/bcrypt-5.0.0-cp313-cp313t-win_amd64.whl", hash = "sha256:f6984a24db30548fd39a44360532898c33528b74aedf81c26cf29c51ee47057e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/2d/df/9714173403c7e8b245acf8e4be8876aac64a209d1b392af457c79e60492e/bcrypt-5.0.0-cp313-cp313t-win_arm64.whl", hash = "sha256:9fffdb387abe6aa775af36ef16f55e318dcda4194ddbf82007a6f21da29de8f5" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f8/14/c18006f91816606a4abe294ccc5d1e6f0e42304df5a33710e9e8e95416e1/bcrypt-5.0.0-cp314-cp314t-macosx_10_12_universal2.whl", hash = "sha256:4870a52610537037adb382444fefd3706d96d663ac44cbb2f37e3919dca3d7ef" }, + { url = "https://mirrors.aliyun.com/pypi/packages/67/49/dd074d831f00e589537e07a0725cf0e220d1f0d5d8e85ad5bbff251c45aa/bcrypt-5.0.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:48f753100931605686f74e27a7b49238122aa761a9aefe9373265b8b7aa43ea4" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f5/91/50ccba088b8c474545b034a1424d05195d9fcbaaf802ab8bfe2be5a4e0d7/bcrypt-5.0.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f70aadb7a809305226daedf75d90379c397b094755a710d7014b8b117df1ebbf" }, + { url = "https://mirrors.aliyun.com/pypi/packages/aa/e7/d7dba133e02abcda3b52087a7eea8c0d4f64d3e593b4fffc10c31b7061f3/bcrypt-5.0.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:744d3c6b164caa658adcb72cb8cc9ad9b4b75c7db507ab4bc2480474a51989da" }, + { url = "https://mirrors.aliyun.com/pypi/packages/33/fc/5b145673c4b8d01018307b5c2c1fc87a6f5a436f0ad56607aee389de8ee3/bcrypt-5.0.0-cp314-cp314t-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a28bc05039bdf3289d757f49d616ab3efe8cf40d8e8001ccdd621cd4f98f4fc9" }, + { url = "https://mirrors.aliyun.com/pypi/packages/27/d7/1ff22703ec6d4f90e62f1a5654b8867ef96bafb8e8102c2288333e1a6ca6/bcrypt-5.0.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:7f277a4b3390ab4bebe597800a90da0edae882c6196d3038a73adf446c4f969f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c8/88/815b6d558a1e4d40ece04a2f84865b0fef233513bd85fd0e40c294272d62/bcrypt-5.0.0-cp314-cp314t-manylinux_2_34_aarch64.whl", hash = "sha256:79cfa161eda8d2ddf29acad370356b47f02387153b11d46042e93a0a95127493" }, + { url = "https://mirrors.aliyun.com/pypi/packages/51/8c/e0db387c79ab4931fc89827d37608c31cc57b6edc08ccd2386139028dc0d/bcrypt-5.0.0-cp314-cp314t-manylinux_2_34_x86_64.whl", hash = "sha256:a5393eae5722bcef046a990b84dff02b954904c36a194f6cfc817d7dca6c6f0b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/06/83/1570edddd150f572dbe9fc00f6203a89fc7d4226821f67328a85c330f239/bcrypt-5.0.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:7f4c94dec1b5ab5d522750cb059bb9409ea8872d4494fd152b53cca99f1ddd8c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c9/f2/ea64e51a65e56ae7a8a4ec236c2bfbdd4b23008abd50ac33fbb2d1d15424/bcrypt-5.0.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:0cae4cb350934dfd74c020525eeae0a5f79257e8a201c0c176f4b84fdbf2a4b4" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d7/d4/1a388d21ee66876f27d1a1f41287897d0c0f1712ef97d395d708ba93004c/bcrypt-5.0.0-cp314-cp314t-win32.whl", hash = "sha256:b17366316c654e1ad0306a6858e189fc835eca39f7eb2cafd6aaca8ce0c40a2e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/3f/61/3291c2243ae0229e5bca5d19f4032cecad5dfb05a2557169d3a69dc0ba91/bcrypt-5.0.0-cp314-cp314t-win_amd64.whl", hash = "sha256:92864f54fb48b4c718fc92a32825d0e42265a627f956bc0361fe869f1adc3e7d" }, + { url = "https://mirrors.aliyun.com/pypi/packages/3e/89/4b01c52ae0c1a681d4021e5dd3e45b111a8fb47254a274fa9a378d8d834b/bcrypt-5.0.0-cp314-cp314t-win_arm64.whl", hash = "sha256:dd19cf5184a90c873009244586396a6a884d591a5323f0e8a5922560718d4993" }, + { url = "https://mirrors.aliyun.com/pypi/packages/84/29/6237f151fbfe295fe3e074ecc6d44228faa1e842a81f6d34a02937ee1736/bcrypt-5.0.0-cp38-abi3-macosx_10_12_universal2.whl", hash = "sha256:fc746432b951e92b58317af8e0ca746efe93e66555f1b40888865ef5bf56446b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/45/b6/4c1205dde5e464ea3bd88e8742e19f899c16fa8916fb8510a851fae985b5/bcrypt-5.0.0-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c2388ca94ffee269b6038d48747f4ce8df0ffbea43f31abfa18ac72f0218effb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/3b/71/427945e6ead72ccffe77894b2655b695ccf14ae1866cd977e185d606dd2f/bcrypt-5.0.0-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:560ddb6ec730386e7b3b26b8b4c88197aaed924430e7b74666a586ac997249ef" }, + { url = "https://mirrors.aliyun.com/pypi/packages/17/72/c344825e3b83c5389a369c8a8e58ffe1480b8a699f46c127c34580c4666b/bcrypt-5.0.0-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:d79e5c65dcc9af213594d6f7f1fa2c98ad3fc10431e7aa53c176b441943efbdd" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0b/7e/d4e47d2df1641a36d1212e5c0514f5291e1a956a7749f1e595c07a972038/bcrypt-5.0.0-cp38-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:2b732e7d388fa22d48920baa267ba5d97cca38070b69c0e2d37087b381c681fd" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0f/c3/0ae57a68be2039287ec28bc463b82e4b8dc23f9d12c0be331f4782e19108/bcrypt-5.0.0-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:0c8e093ea2532601a6f686edbc2c6b2ec24131ff5c52f7610dd64fa4553b5464" }, + { url = "https://mirrors.aliyun.com/pypi/packages/45/2b/77424511adb11e6a99e3a00dcc7745034bee89036ad7d7e255a7e47be7d8/bcrypt-5.0.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:5b1589f4839a0899c146e8892efe320c0fa096568abd9b95593efac50a87cb75" }, + { url = "https://mirrors.aliyun.com/pypi/packages/43/0a/405c753f6158e0f3f14b00b462d8bca31296f7ecfc8fc8bc7919c0c7d73a/bcrypt-5.0.0-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:89042e61b5e808b67daf24a434d89bab164d4de1746b37a8d173b6b14f3db9ff" }, + { url = "https://mirrors.aliyun.com/pypi/packages/62/83/b3efc285d4aadc1fa83db385ec64dcfa1707e890eb42f03b127d66ac1b7b/bcrypt-5.0.0-cp38-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:e3cf5b2560c7b5a142286f69bde914494b6d8f901aaa71e453078388a50881c4" }, + { url = "https://mirrors.aliyun.com/pypi/packages/95/7d/47ee337dacecde6d234890fe929936cb03ebc4c3a7460854bbd9c97780b8/bcrypt-5.0.0-cp38-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:f632fd56fc4e61564f78b46a2269153122db34988e78b6be8b32d28507b7eaeb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d6/3a/43d494dfb728f55f4e1cf8fd435d50c16a2d75493225b54c8d06122523c6/bcrypt-5.0.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:801cad5ccb6b87d1b430f183269b94c24f248dddbbc5c1f78b6ed231743e001c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/55/ab/a0727a4547e383e2e22a630e0f908113db37904f58719dc48d4622139b5c/bcrypt-5.0.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:3cf67a804fc66fc217e6914a5635000259fbbbb12e78a99488e4d5ba445a71eb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/1b/bb/461f352fdca663524b4643d8b09e8435b4990f17fbf4fea6bc2a90aa0cc7/bcrypt-5.0.0-cp38-abi3-win32.whl", hash = "sha256:3abeb543874b2c0524ff40c57a4e14e5d3a66ff33fb423529c88f180fd756538" }, + { url = "https://mirrors.aliyun.com/pypi/packages/41/aa/4190e60921927b7056820291f56fc57d00d04757c8b316b2d3c0d1d6da2c/bcrypt-5.0.0-cp38-abi3-win_amd64.whl", hash = "sha256:35a77ec55b541e5e583eb3436ffbbf53b0ffa1fa16ca6782279daf95d146dcd9" }, + { url = "https://mirrors.aliyun.com/pypi/packages/54/12/cd77221719d0b39ac0b55dbd39358db1cd1246e0282e104366ebbfb8266a/bcrypt-5.0.0-cp38-abi3-win_arm64.whl", hash = "sha256:cde08734f12c6a4e28dc6755cd11d3bdfea608d93d958fffbe95a7026ebe4980" }, + { url = "https://mirrors.aliyun.com/pypi/packages/5d/ba/2af136406e1c3839aea9ecadc2f6be2bcd1eff255bd451dd39bcf302c47a/bcrypt-5.0.0-cp39-abi3-macosx_10_12_universal2.whl", hash = "sha256:0c418ca99fd47e9c59a301744d63328f17798b5947b0f791e9af3c1c499c2d0a" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ac/ee/2f4985dbad090ace5ad1f7dd8ff94477fe089b5fab2040bd784a3d5f187b/bcrypt-5.0.0-cp39-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ddb4e1500f6efdd402218ffe34d040a1196c072e07929b9820f363a1fd1f4191" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e4/6e/b77ade812672d15cf50842e167eead80ac3514f3beacac8902915417f8b7/bcrypt-5.0.0-cp39-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7aeef54b60ceddb6f30ee3db090351ecf0d40ec6e2abf41430997407a46d2254" }, + { url = "https://mirrors.aliyun.com/pypi/packages/36/c4/ed00ed32f1040f7990dac7115f82273e3c03da1e1a1587a778d8cea496d8/bcrypt-5.0.0-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:f0ce778135f60799d89c9693b9b398819d15f1921ba15fe719acb3178215a7db" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e7/c4/fa6e16145e145e87f1fa351bbd54b429354fd72145cd3d4e0c5157cf4c70/bcrypt-5.0.0-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:a71f70ee269671460b37a449f5ff26982a6f2ba493b3eabdd687b4bf35f875ac" }, + { url = "https://mirrors.aliyun.com/pypi/packages/24/b4/11f8a31d8b67cca3371e046db49baa7c0594d71eb40ac8121e2fc0888db0/bcrypt-5.0.0-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:f8429e1c410b4073944f03bd778a9e066e7fad723564a52ff91841d278dfc822" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ac/31/79f11865f8078e192847d2cb526e3fa27c200933c982c5b2869720fa5fce/bcrypt-5.0.0-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:edfcdcedd0d0f05850c52ba3127b1fce70b9f89e0fe5ff16517df7e81fa3cbb8" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d4/8d/5e43d9584b3b3591a6f9b68f755a4da879a59712981ef5ad2a0ac1379f7a/bcrypt-5.0.0-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:611f0a17aa4a25a69362dcc299fda5c8a3d4f160e2abb3831041feb77393a14a" }, + { url = "https://mirrors.aliyun.com/pypi/packages/89/48/44590e3fc158620f680a978aafe8f87a4c4320da81ed11552f0323aa9a57/bcrypt-5.0.0-cp39-abi3-musllinux_1_1_aarch64.whl", hash = "sha256:db99dca3b1fdc3db87d7c57eac0c82281242d1eabf19dcb8a6b10eb29a2e72d1" }, + { url = "https://mirrors.aliyun.com/pypi/packages/5f/85/e4fbfc46f14f47b0d20493669a625da5827d07e8a88ee460af6cd9768b44/bcrypt-5.0.0-cp39-abi3-musllinux_1_1_x86_64.whl", hash = "sha256:5feebf85a9cefda32966d8171f5db7e3ba964b77fdfe31919622256f80f9cf42" }, + { url = "https://mirrors.aliyun.com/pypi/packages/25/ae/479f81d3f4594456a01ea2f05b132a519eff9ab5768a70430fa1132384b1/bcrypt-5.0.0-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:3ca8a166b1140436e058298a34d88032ab62f15aae1c598580333dc21d27ef10" }, + { url = "https://mirrors.aliyun.com/pypi/packages/df/d2/36a086dee1473b14276cd6ea7f61aef3b2648710b5d7f1c9e032c29b859f/bcrypt-5.0.0-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:61afc381250c3182d9078551e3ac3a41da14154fbff647ddf52a769f588c4172" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c0/f6/688d2cd64bfd0b14d805ddb8a565e11ca1fb0fd6817175d58b10052b6d88/bcrypt-5.0.0-cp39-abi3-win32.whl", hash = "sha256:64d7ce196203e468c457c37ec22390f1a61c85c6f0b8160fd752940ccfb3a683" }, + { url = "https://mirrors.aliyun.com/pypi/packages/9f/b9/9d9a641194a730bda138b3dfe53f584d61c58cd5230e37566e83ec2ffa0d/bcrypt-5.0.0-cp39-abi3-win_amd64.whl", hash = "sha256:64ee8434b0da054d830fa8e89e1c8bf30061d539044a39524ff7dec90481e5c2" }, + { url = "https://mirrors.aliyun.com/pypi/packages/27/44/d2ef5e87509158ad2187f4dd0852df80695bb1ee0cfe0a684727b01a69e0/bcrypt-5.0.0-cp39-abi3-win_arm64.whl", hash = "sha256:f2347d3534e76bf50bca5500989d6c1d05ed64b440408057a37673282c654927" }, +] + [[package]] name = "beartype" version = "0.22.9" @@ -837,15 +919,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/61/c59a849bd457c8a1b408ae828dbcc15e674962b5a29705e869e15b32bf25/biopython-1.86.tar.gz", hash = "sha256:93a50b586a4d2cec68ab2f99d03ef583c5761d8fba5535cb8e81da781d0d92ff" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/98/e2/199b8ccbd4b9bf234157db0668177b5b7784d62f29d9096fd0d3a70e3b86/biopython-1.86-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f8d372aae21d79b11613751c6ae23c88db0e94d25b7567b1f67aa0304fb61667" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d8/2f/1a7da2a55212b3d0a03866d22213f91273fee3722b5364575419fbe574a5/biopython-1.86-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:baf19d9237aaaa387a68f8f055f978af5c80338d7e037ab028e8d768928f1250" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5b/e9/4057d4c2aa22ca25c180ecbed2ce9e7d65bf787999778bc63b41df0d03b5/biopython-1.86-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:04f9abdf6cbf0087850de5f8148da0d420c4cb87905bf4de3145ad24a8d55dcd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a7/b2/3e6862720d7c51f0fbe7d6d25be72a95486779d9d98122283b4e8032fb40/biopython-1.86-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:187c3c24dd2255e7328f3e0523ab5d6350b73ff562517de0c1922385617101d2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d7/cb/61877367bf08670573d62513b239dc65cf2b7488dc74322cc6051da2e55e/biopython-1.86-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1859830b8262785c6b59dfe0c82cddb643974f63b9d2779bb9f3e2c47c0a95da" }, - { url = "https://mirrors.aliyun.com/pypi/packages/84/1a/3182a77776b76f3f5c64825ee1acf9355f665bed72ee9e8ff49e48f25d98/biopython-1.86-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dfd906c47b6fb38e3abb9f52e0c06822e6e82a043d38c2000773692c29db1ed8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1a/22/828b08fac8dbc8c1dbc1ad03815137cebc9c78303ec7d21b568544028119/biopython-1.86-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4a6ab2c60742f1c8494cfbbe3b7a8b45f0400c8f2b36b686b895d5e4d625f04e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/36/7a/122aea7653fa93d7eb72978928e80759082efffa70afe0c25a17e18521da/biopython-1.86-cp312-cp312-win32.whl", hash = "sha256:192c61bc3d782c171b7d50bb7d8189d84790d6e3c4b24fd41d1d7ffc7d303efe" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a9/13/00db03b01e54070d5b0ec9c71eef86e61afa733d9af76e5b9b09f5dc9165/biopython-1.86-cp312-cp312-win_amd64.whl", hash = "sha256:35a6b9c5dcdfb5c2631a313a007f3f41a7d72573ba2b68c962e10ea92096ff3b" }, { url = "https://mirrors.aliyun.com/pypi/packages/fd/6e/84d6c66ab93095aa7adb998a8eef045328470eafd36b9237c4db213e587c/biopython-1.86-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fb3a11a98e49428720dca227e2a5bdd57c973ee7c4df3cf6734c0aa13fd134c7" }, { url = "https://mirrors.aliyun.com/pypi/packages/12/75/60386f2640f13765b1651f2f26d8b4f893c46ee663df3ca76eda966d4f6a/biopython-1.86-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e161f3d3b6e65fbfd1ce22a01c3e9fa9da789adde4972fd0cc2370795ea5357b" }, { url = "https://mirrors.aliyun.com/pypi/packages/dd/de/a39adb98a0552a257219503c236ef17f007598af55326c0d143db52e5a92/biopython-1.86-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5aa8c9e92ee6fe59dfe0d2c2daf9a9eec6b812c78328caad038f79163c500218" }, @@ -890,6 +963,31 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/10/cb/f2ad4230dc2eb1a74edf38f1a38b9b52277f75bef262d8908e60d957e13c/blinker-1.9.0-py3-none-any.whl", hash = "sha256:ba0efaa9080b619ff2f3459d1d500c57bddea4a6b424b60a91141db6fd2f08bc" }, ] +[[package]] +name = "blis" +version = "1.3.3" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "numpy" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d0/d0/d8cc8c9a4488a787e7fa430f6055e5bd1ddb22c340a751d9e901b82e2efe/blis-1.3.3.tar.gz", hash = "sha256:034d4560ff3cc43e8aa37e188451b0440e3261d989bb8a42ceee865607715ecd" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/e6/f7/d26e62d9be3d70473a63e0a5d30bae49c2fe138bebac224adddcdef8a7ce/blis-1.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:1e647341f958421a86b028a2efe16ce19c67dba2a05f79e8f7e80b1ff45328aa" }, + { url = "https://mirrors.aliyun.com/pypi/packages/4a/78/750d12da388f714958eb2f2fd177652323bbe7ec528365c37129edd6eb84/blis-1.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d563160f874abb78a57e346f07312c5323f7ad67b6370052b6b17087ef234a8e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e8/36/eac4199c5b200a5f3e93cad197da8d26d909f218eb444c4f552647c95240/blis-1.3.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:30b8a5b90cb6cb81d1ada9ae05aa55fb8e70d9a0ae9db40d2401bb9c1c8f14c4" }, + { url = "https://mirrors.aliyun.com/pypi/packages/bf/51/472e7b36a6bedb5242a9757e7486f702c3619eff76e256735d0c8b1679c6/blis-1.3.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9f5c53b277f6ac5b3ca30bc12ebab7ea16c8f8c36b14428abb56924213dc127" }, + { url = "https://mirrors.aliyun.com/pypi/packages/84/da/d0dfb6d6e6321ae44df0321384c32c322bd07b15740d7422727a1a49fc5d/blis-1.3.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6297e7616c158b305c9a8a4e47ca5fc9b0785194dd96c903b1a1591a7ca21ddf" }, + { url = "https://mirrors.aliyun.com/pypi/packages/20/c5/2b0b5e556fa0364ed671051ea078a6d6d7b979b1cfef78d64ad3ca5f0c7f/blis-1.3.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:3f966ca74f89f8a33e568b9a1d71992fc9a0d29a423e047f0a212643e21b5458" }, + { url = "https://mirrors.aliyun.com/pypi/packages/31/07/4cdc81a47bf862c0b06d91f1bc6782064e8b69ac9b5d4ff51d97e4ff03da/blis-1.3.3-cp313-cp313-win_amd64.whl", hash = "sha256:7a0fc4b237a3a453bdc3c7ab48d91439fcd2d013b665c46948d9eaf9c3e45a97" }, + { url = "https://mirrors.aliyun.com/pypi/packages/5f/8a/80f7c68fbc24a76fc9c18522c46d6d69329c320abb18e26a707a5d874083/blis-1.3.3-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:c3e33cfbf22a418373766816343fcfcd0556012aa3ffdf562c29cddec448a415" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e5/52/d1aa3a51a7fc299b0c89dcaa971922714f50b1202769eebbdaadd1b5cff7/blis-1.3.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:6f165930e8d3a85c606d2003211497e28d528c7416fbfeafb6b15600963f7c9b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/99/4f/badc7bd7f74861b26c10123bba7b9d16f99cd9535ad0128780360713820f/blis-1.3.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:878d4d96d8f2c7a2459024f013f2e4e5f46d708b23437dae970d998e7bff14a0" }, + { url = "https://mirrors.aliyun.com/pypi/packages/72/a6/f62a3bd814ca19ec7e29ac889fd354adea1217df3183e10217de51e2eb8b/blis-1.3.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f36c0ca84a05ee5d3dbaa38056c4423c1fc29948b17a7923dd2fed8967375d74" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d4/6c/671af79ee42bc4c968cae35c091ac89e8721c795bfa4639100670dc59139/blis-1.3.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:e5a662c48cd4aad5dae1a950345df23957524f071315837a4c6feb7d3b288990" }, + { url = "https://mirrors.aliyun.com/pypi/packages/be/92/7cd7f8490da7c98ee01557f2105885cc597217b0e7fd2eeb9e22cdd4ef23/blis-1.3.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:9de26fbd72bac900c273b76d46f0b45b77a28eace2e01f6ac6c2239531a413bb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0a/de/acae8e9f9a1f4bb393d41c8265898b0f29772e38eac14e9f69d191e2c006/blis-1.3.3-cp314-cp314-win_amd64.whl", hash = "sha256:9e5fdf4211b1972400f8ff6dafe87cb689c5d84f046b4a76b207c0bd2270faaf" }, +] + [[package]] name = "boto3" version = "1.42.74" @@ -937,16 +1035,6 @@ version = "1.2.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/f7/16/c92ca344d646e71a43b8bb353f0a6490d7f6e06210f8554c8f874e454285/brotli-1.2.0.tar.gz", hash = "sha256:e310f77e41941c13340a95976fe66a8a95b01e783d430eeaf7a2f87e0a57dd0a" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/11/ee/b0a11ab2315c69bb9b45a2aaed022499c9c24a205c3a49c3513b541a7967/brotli-1.2.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:35d382625778834a7f3061b15423919aa03e4f5da34ac8e02c074e4b75ab4f84" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e1/2f/29c1459513cd35828e25531ebfcbf3e92a5e49f560b1777a9af7203eb46e/brotli-1.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7a61c06b334bd99bc5ae84f1eeb36bfe01400264b3c352f968c6e30a10f9d08b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3d/6f/feba03130d5fceadfa3a1bb102cb14650798c848b1df2a808356f939bb16/brotli-1.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:acec55bb7c90f1dfc476126f9711a8e81c9af7fb617409a9ee2953115343f08d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2b/38/f3abb554eee089bd15471057ba85f47e53a44a462cfce265d9bf7088eb09/brotli-1.2.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:260d3692396e1895c5034f204f0db022c056f9e2ac841593a4cf9426e2a3faca" }, - { url = "https://mirrors.aliyun.com/pypi/packages/03/a7/03aa61fbc3c5cbf99b44d158665f9b0dd3d8059be16c460208d9e385c837/brotli-1.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:072e7624b1fc4d601036ab3f4f27942ef772887e876beff0301d261210bca97f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/21/1b/0374a89ee27d152a5069c356c96b93afd1b94eae83f1e004b57eb6ce2f10/brotli-1.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:adedc4a67e15327dfdd04884873c6d5a01d3e3b6f61406f99b1ed4865a2f6d28" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cf/57/69d4fe84a67aef4f524dcd075c6eee868d7850e85bf01d778a857d8dbe0a/brotli-1.2.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:7a47ce5c2288702e09dc22a44d0ee6152f2c7eda97b3c8482d826a1f3cfc7da7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d5/3b/39e13ce78a8e9a621c5df3aeb5fd181fcc8caba8c48a194cd629771f6828/brotli-1.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:af43b8711a8264bb4e7d6d9a6d004c3a2019c04c01127a868709ec29962b6036" }, - { url = "https://mirrors.aliyun.com/pypi/packages/62/28/4d00cb9bd76a6357a66fcd54b4b6d70288385584063f4b07884c1e7286ac/brotli-1.2.0-cp312-cp312-win32.whl", hash = "sha256:e99befa0b48f3cd293dafeacdd0d191804d105d279e0b387a32054c1180f3161" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1c/4e/bc1dcac9498859d5e353c9b153627a3752868a9d5f05ce8dedd81a2354ab/brotli-1.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:b35c13ce241abdd44cb8ca70683f20c0c079728a36a996297adb5334adfc1c44" }, { url = "https://mirrors.aliyun.com/pypi/packages/6c/d4/4ad5432ac98c73096159d9ce7ffeb82d151c2ac84adcc6168e476bb54674/brotli-1.2.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:9e5825ba2c9998375530504578fd4d5d1059d09621a02065d1b6bfc41a8e05ab" }, { url = "https://mirrors.aliyun.com/pypi/packages/91/9f/9cc5bd03ee68a85dc4bc89114f7067c056a3c14b3d95f171918c088bf88d/brotli-1.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0cf8c3b8ba93d496b2fae778039e2f5ecc7cff99df84df337ca31d8f2252896c" }, { url = "https://mirrors.aliyun.com/pypi/packages/2e/b6/fe84227c56a865d16a6614e2c4722864b380cb14b13f3e6bef441e73a85a/brotli-1.2.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c8565e3cdc1808b1a34714b553b262c5de5fbda202285782173ec137fd13709f" }, @@ -999,6 +1087,15 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/da/ff/3f0982ecd37c2d6a7266c22e7ea2e47d0773fe449984184c5316459d2776/captcha-0.7.1-py3-none-any.whl", hash = "sha256:8b73b5aba841ad1e5bdb856205bf5f09560b728ee890eb9dae42901219c8c599" }, ] +[[package]] +name = "catalogue" +version = "2.0.10" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/38/b4/244d58127e1cdf04cf2dc7d9566f0d24ef01d5ce21811bab088ecc62b5ea/catalogue-2.0.10.tar.gz", hash = "sha256:4f56daa940913d3f09d589c191c74e5a6d51762b3a9e37dd53b7437afd6cda15" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/9e/96/d32b941a501ab566a16358d68b6eb4e4acc373fab3c3c4d7d9e649f7b4bb/catalogue-2.0.10-py3-none-any.whl", hash = "sha256:58c2de0020aa90f4a2da7dfad161bf7b3b054c86a5f09fcedc0b2b740c109a9f" }, +] + [[package]] name = "cattrs" version = "22.2.0" @@ -1023,13 +1120,6 @@ version = "5.9.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/bd/cb/09939728be094d155b5d4ac262e39877875f5f7e36eea66beb359f647bd0/cbor2-5.9.0.tar.gz", hash = "sha256:85c7a46279ac8f226e1059275221e6b3d0e370d2bb6bd0500f9780781615bcea" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/ee/39/72d8a5a4b06565561ec28f4fcb41aff7bb77f51705c01f00b8254a2aca4f/cbor2-5.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1f223dffb1bcdd2764665f04c1152943d9daa4bc124a576cd8dee1cad4264313" }, - { url = "https://mirrors.aliyun.com/pypi/packages/09/fd/7ddf3d3153b54c69c3be77172b8d9aa3a9d74f62a7fbde614d53eaeed9a4/cbor2-5.9.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ae6c706ac1d85a0b3cb3395308fd0c4d55e3202b4760773675957e93cdff45fc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/db/9d/7ede2cc42f9bb4260492e7d29d2aab781eacbbcfb09d983de1e695077199/cbor2-5.9.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4cd43d8fc374b31643b2830910f28177a606a7bc84975a62675dd3f2e320fc7b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ce/9d/588ebc7c5bc5843f609b05fe07be8575c7dec987735b0bbc908ac9c1264a/cbor2-5.9.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4aa07b392cc3d76fb31c08a46a226b58c320d1c172ff3073e864409ced7bc50f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f7/a1/6fc8f4b15c6a27e7fbb7966c30c2b4b18c274a3221fa2f5e6235502d34bc/cbor2-5.9.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:971d425b3a23b75953d8853d5f9911bdeefa09d759ee3b5e6b07b5ff3cbd9073" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cf/20/9a22cfe08be16ddfeef2542cf4eeed1b29f3f57ddbba0b42f7e0bb8331fd/cbor2-5.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:34a6cb15e6ab6a8eae94ad2041731cd3ef786af43a8df99f847969af5b902ee7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c6/9e/695f92d09006614034e25a9f5b10620f3b219f79c1bec3c37b7c6f27a7a9/cbor2-5.9.0-cp312-cp312-win_arm64.whl", hash = "sha256:7d1ddc4541e7367ac58c2470cc0df847f7137167fe4f5729e2d3cc0b993d7da4" }, { url = "https://mirrors.aliyun.com/pypi/packages/81/c5/4901e21a8afe9448fd947b11e8f383903207cd6dd0800e5f5a386838de5b/cbor2-5.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:fbb06f34aa645b4deca66643bba3d400d20c15312d1fe88d429be60c1ab50f27" }, { url = "https://mirrors.aliyun.com/pypi/packages/1b/10/df643a381aebc3f05486de4813662bc58accb640fc3275cb276a75e89694/cbor2-5.9.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ac684fe195c39821fca70d18afbf748f728aefbfbf88456018d299e559b8cae0" }, { url = "https://mirrors.aliyun.com/pypi/packages/c6/0c/8aa6b766059ae4a0ca1ec3ff96fe3823a69a7be880dba2e249f7fbe2700b/cbor2-5.9.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2a54fbb32cb828c214f7f333a707e4aec61182e7efdc06ea5d9596d3ecee624a" }, @@ -1065,18 +1155,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/eb/56/b1ba7935a17738ae8453301356628e8147c79dbb825bcbc73dc7401f9846/cffi-2.0.0.tar.gz", hash = "sha256:44d1b5909021139fe36001ae048dbdde8214afa20200eda0f64c068cac5d5529" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/ea/47/4f61023ea636104d4f16ab488e268b93008c3d0bb76893b1b31db1f96802/cffi-2.0.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d02d6655b0e54f54c4ef0b94eb6be0607b70853c45ce98bd278dc7de718be5d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/df/a2/781b623f57358e360d62cdd7a8c681f074a71d445418a776eef0aadb4ab4/cffi-2.0.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8eca2a813c1cb7ad4fb74d368c2ffbbb4789d377ee5bb8df98373c2cc0dee76c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ff/df/a4f0fbd47331ceeba3d37c2e51e9dfc9722498becbeec2bd8bc856c9538a/cffi-2.0.0-cp312-cp312-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:21d1152871b019407d8ac3985f6775c079416c282e431a4da6afe7aefd2bccbe" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d5/72/12b5f8d3865bf0f87cf1404d8c374e7487dcf097a1c91c436e72e6badd83/cffi-2.0.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b21e08af67b8a103c71a250401c78d5e0893beff75e28c53c98f4de42f774062" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c2/95/7a135d52a50dfa7c882ab0ac17e8dc11cec9d55d2c18dda414c051c5e69e/cffi-2.0.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:1e3a615586f05fc4065a8b22b8152f0c1b00cdbc60596d187c2a74f9e3036e4e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3a/c8/15cb9ada8895957ea171c62dc78ff3e99159ee7adb13c0123c001a2546c1/cffi-2.0.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:81afed14892743bbe14dacb9e36d9e0e504cd204e0b165062c488942b9718037" }, - { url = "https://mirrors.aliyun.com/pypi/packages/78/2d/7fa73dfa841b5ac06c7b8855cfc18622132e365f5b81d02230333ff26e9e/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3e17ed538242334bf70832644a32a7aae3d83b57567f9fd60a26257e992b79ba" }, - { url = "https://mirrors.aliyun.com/pypi/packages/07/e0/267e57e387b4ca276b90f0434ff88b2c2241ad72b16d31836adddfd6031b/cffi-2.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3925dd22fa2b7699ed2617149842d2e6adde22b262fcbfada50e3d195e4b3a94" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b6/75/1f2747525e06f53efbd878f4d03bac5b859cbc11c633d0fb81432d98a795/cffi-2.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2c8f814d84194c9ea681642fd164267891702542f028a15fc97d4674b6206187" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7b/2b/2b6435f76bfeb6bbf055596976da087377ede68df465419d192acf00c437/cffi-2.0.0-cp312-cp312-win32.whl", hash = "sha256:da902562c3e9c550df360bfa53c035b2f241fed6d9aef119048073680ace4a18" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f8/ed/13bd4418627013bec4ed6e54283b1959cf6db888048c7cf4b4c3b5b36002/cffi-2.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:da68248800ad6320861f129cd9c1bf96ca849a2771a59e0344e88681905916f5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/95/31/9f7f93ad2f8eff1dbc1c3656d7ca5bfd8fb52c9d786b4dcf19b2d02217fa/cffi-2.0.0-cp312-cp312-win_arm64.whl", hash = "sha256:4671d9dd5ec934cb9a73e7ee9676f9362aba54f7f34910956b84d727b0d73fb6" }, { url = "https://mirrors.aliyun.com/pypi/packages/4b/8d/a0a47a0c9e413a658623d014e91e74a50cdd2c423f7ccfd44086ef767f90/cffi-2.0.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:00bdf7acc5f795150faa6957054fbbca2439db2f775ce831222b66f192f03beb" }, { url = "https://mirrors.aliyun.com/pypi/packages/4a/d2/a6c0296814556c68ee32009d9c2ad4f85f2707cdecfd7727951ec228005d/cffi-2.0.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:45d5e886156860dc35862657e1494b9bae8dfa63bf56796f2fb56e1679fc0bca" }, { url = "https://mirrors.aliyun.com/pypi/packages/b0/1e/d22cc63332bd59b06481ceaac49d6c507598642e2230f201649058a7e704/cffi-2.0.0-cp313-cp313-manylinux1_i686.manylinux2014_i686.manylinux_2_17_i686.manylinux_2_5_i686.whl", hash = "sha256:07b271772c100085dd28b74fa0cd81c8fb1a3ba18b21e03d7c27f3436a10606b" }, @@ -1128,22 +1206,6 @@ version = "3.4.6" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/7b/60/e3bec1881450851b087e301bedc3daa9377a4d45f1c26aa90b0b235e38aa/charset_normalizer-3.4.6.tar.gz", hash = "sha256:1ae6b62897110aa7c79ea2f5dd38d1abca6db663687c0b1ad9aed6f6bae3d9d6" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/e5/62/c0815c992c9545347aeea7859b50dc9044d147e2e7278329c6e02ac9a616/charset_normalizer-3.4.6-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:2ef7fedc7a6ecbe99969cd09632516738a97eeb8bd7258bf8a0f23114c057dab" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a8/37/bdca6613c2e3c58c7421891d80cc3efa1d32e882f7c4a7ee6039c3fc951a/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a4ea868bc28109052790eb2b52a9ab33f3aa7adc02f96673526ff47419490e21" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6c/92/9934d1bbd69f7f398b38c5dae1cbf9cc672e7c34a4adf7b17c0a9c17d15d/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:836ab36280f21fc1a03c99cd05c6b7af70d2697e374c7af0b61ed271401a72a2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/af/90/25f6ab406659286be929fd89ab0e78e38aa183fc374e03aa3c12d730af8a/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f1ce721c8a7dfec21fcbdfe04e8f68174183cf4e8188e0645e92aa23985c57ff" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4e/ef/79a463eb0fff7f96afa04c1d4c51f8fc85426f918db467854bfb6a569ce3/charset_normalizer-3.4.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0e28d62a8fc7a1fa411c43bd65e346f3bce9716dc51b897fbe930c5987b402d5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f7/72/d0426afec4b71dc159fa6b4e68f868cd5a3ecd918fec5813a15d292a7d10/charset_normalizer-3.4.6-cp312-cp312-manylinux_2_31_armv7l.whl", hash = "sha256:530d548084c4a9f7a16ed4a294d459b4f229db50df689bfe92027452452943a0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bf/18/c82b06a68bfcb6ce55e508225d210c7e6a4ea122bfc0748892f3dc4e8e11/charset_normalizer-3.4.6-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:30f445ae60aad5e1f8bdbb3108e39f6fbc09f4ea16c815c66578878325f8f15a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/44/d6/0c25979b92f8adafdbb946160348d8d44aa60ce99afdc27df524379875cb/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ac2393c73378fea4e52aa56285a3d64be50f1a12395afef9cce47772f60334c2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2e/3d/7fea3e8fe84136bebbac715dd1221cc25c173c57a699c030ab9b8900cbb7/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:90ca27cd8da8118b18a52d5f547859cc1f8354a00cd1e8e5120df3e30d6279e5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/57/8a/d6f7fd5cb96c58ef2f681424fbca01264461336d2a7fc875e4446b1f1346/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:8e5a94886bedca0f9b78fecd6afb6629142fd2605aa70a125d49f4edc6037ee6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/16/50/478cdda782c8c9c3fb5da3cc72dd7f331f031e7f1363a893cdd6ca0f8de0/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:695f5c2823691a25f17bc5d5ffe79fa90972cc34b002ac6c843bb8a1720e950d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/75/fc/cc2fcac943939c8e4d8791abfa139f685e5150cae9f94b60f12520feaa9b/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:231d4da14bcd9301310faf492051bee27df11f2bc7549bc0bb41fef11b82daa2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a8/b7/a4add1d9a5f68f3d037261aecca83abdb0ab15960a3591d340e829b37298/charset_normalizer-3.4.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a056d1ad2633548ca18ffa2f85c202cfb48b68615129143915b8dc72a806a923" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6c/18/c094561b5d64a24277707698e54b7f67bd17a4f857bbfbb1072bba07c8bf/charset_normalizer-3.4.6-cp312-cp312-win32.whl", hash = "sha256:c2274ca724536f173122f36c98ce188fd24ce3dad886ec2b7af859518ce008a4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ab/20/0567efb3a8fd481b8f34f739ebddc098ed062a59fed41a8d193a61939e8f/charset_normalizer-3.4.6-cp312-cp312-win_amd64.whl", hash = "sha256:c8ae56368f8cc97c7e40a7ee18e1cedaf8e780cd8bc5ed5ac8b81f238614facb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/15/57/28d79b44b51933119e21f65479d0864a8d5893e494cf5daab15df0247c17/charset_normalizer-3.4.6-cp312-cp312-win_arm64.whl", hash = "sha256:899d28f422116b08be5118ef350c292b36fc15ec2daeb9ea987c89281c7bb5c4" }, { url = "https://mirrors.aliyun.com/pypi/packages/1e/1d/4fdabeef4e231153b6ed7567602f3b68265ec4e5b76d6024cf647d43d981/charset_normalizer-3.4.6-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:11afb56037cbc4b1555a34dd69151e8e069bee82e613a73bef6e714ce733585f" }, { url = "https://mirrors.aliyun.com/pypi/packages/47/7b/20e809b89c69d37be748d98e84dce6820bf663cf19cf6b942c951a3e8f41/charset_normalizer-3.4.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:423fb7e748a08f854a08a222b983f4df1912b1daedce51a72bd24fe8f26a1843" }, { url = "https://mirrors.aliyun.com/pypi/packages/37/a6/4f8d27527d59c039dce6f7622593cdcd3d70a8504d87d09eb11e9fdc6062/charset_normalizer-3.4.6-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d73beaac5e90173ac3deb9928a74763a6d230f494e4bfb422c217a0ad8e629bf" }, @@ -1219,6 +1281,15 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/ae/5a/4f025bc751087833686892e17e7564828e409c43b632878afeae554870cd/click_log-0.4.0-py2.py3-none-any.whl", hash = "sha256:a43e394b528d52112af599f2fc9e4b7cf3c15f94e53581f74fa6867e68c91756" }, ] +[[package]] +name = "cloudpathlib" +version = "0.24.0" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/06/19/58bc6b5d7d0f81c7209b05445af477e147c486552f96665a5912211839b9/cloudpathlib-0.24.0.tar.gz", hash = "sha256:c521a984e77b47e656fe78e20a7e3e260e0ab45fc69e33ac01094227c979e34a" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/c2/5b/ba933f896d9b0b07608d575a8501e2b4e32166b60d84c430a4a7285ebe64/cloudpathlib-0.24.0-py3-none-any.whl", hash = "sha256:b1c51e2d2ec7dc4fed6538991f4aea849d6cf11a7e6b9069f86e461aa1f9b5b4" }, +] + [[package]] name = "cn2an" version = "0.5.22" @@ -1314,6 +1385,15 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/07/1d/62f5bf92e12335eb63517f42671ed78512d48bbc69e02a942dd7b90f03f0/compressed_rtf-1.0.7-py3-none-any.whl", hash = "sha256:b7904921d78c67a0a4b7fff9fb361a00ae2b447b6edca010ce321cd98fa0fcc0" }, ] +[[package]] +name = "confection" +version = "1.3.3" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/ca/65/efd0fe8a936fc8ca2978cb7b82581fb20d901c6039e746a808f746b7647b/confection-1.3.3.tar.gz", hash = "sha256:f0f6810d567ff73993fe74d218ca5e1ffb6a44fb03f391257fc5d033546cbfaa" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/8d/e4/d66708bdf0d92fb4d49b22cdff4b10cec38aca5dcd7e81d909bb55c65cd7/confection-1.3.3-py3-none-any.whl", hash = "sha256:b9fef9ee84b237ef4611ec3eb5797b70e13063e6310ad9f15536373f5e313c82" }, +] + [[package]] name = "contourpy" version = "1.3.3" @@ -1323,17 +1403,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/58/01/1253e6698a07380cd31a736d248a3f2a50a7c88779a1813da27503cadc2a/contourpy-1.3.3.tar.gz", hash = "sha256:083e12155b210502d0bca491432bb04d56dc3432f95a979b429f2848c3dbe880" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/be/45/adfee365d9ea3d853550b2e735f9d66366701c65db7855cd07621732ccfc/contourpy-1.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b08a32ea2f8e42cf1d4be3169a98dd4be32bafe4f22b6c4cb4ba810fa9e5d2cb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/53/3e/405b59cfa13021a56bba395a6b3aca8cec012b45bf177b0eaf7a202cde2c/contourpy-1.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:556dba8fb6f5d8742f2923fe9457dbdd51e1049c4a43fd3986a0b14a1d815fc6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d4/1c/a12359b9b2ca3a845e8f7f9ac08bdf776114eb931392fcad91743e2ea17b/contourpy-1.3.3-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:92d9abc807cf7d0e047b95ca5d957cf4792fcd04e920ca70d48add15c1a90ea7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/63/12/897aeebfb475b7748ea67b61e045accdfcf0d971f8a588b67108ed7f5512/contourpy-1.3.3-cp312-cp312-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b2e8faa0ed68cb29af51edd8e24798bb661eac3bd9f65420c1887b6ca89987c8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/43/8a/a8c584b82deb248930ce069e71576fc09bd7174bbd35183b7943fb1064fd/contourpy-1.3.3-cp312-cp312-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:626d60935cf668e70a5ce6ff184fd713e9683fb458898e4249b63be9e28286ea" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cc/8f/ec6289987824b29529d0dfda0d74a07cec60e54b9c92f3c9da4c0ac732de/contourpy-1.3.3-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4d00e655fcef08aba35ec9610536bfe90267d7ab5ba944f7032549c55a146da1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/05/0a/a3fe3be3ee2dceb3e615ebb4df97ae6f3828aa915d3e10549ce016302bd1/contourpy-1.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:451e71b5a7d597379ef572de31eeb909a87246974d960049a9848c3bc6c41bf7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/33/1d/acad9bd4e97f13f3e2b18a3977fe1b4a37ecf3d38d815333980c6c72e963/contourpy-1.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:459c1f020cd59fcfe6650180678a9993932d80d44ccde1fa1868977438f0b411" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cf/8f/5847f44a7fddf859704217a99a23a4f6417b10e5ab1256a179264561540e/contourpy-1.3.3-cp312-cp312-win32.whl", hash = "sha256:023b44101dfe49d7d53932be418477dba359649246075c996866106da069af69" }, - { url = "https://mirrors.aliyun.com/pypi/packages/19/e8/6026ed58a64563186a9ee3f29f41261fd1828f527dd93d33b60feca63352/contourpy-1.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:8153b8bfc11e1e4d75bcb0bff1db232f9e10b274e0929de9d608027e0d34ff8b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d1/e2/f05240d2c39a1ed228d8328a78b6f44cd695f7ef47beb3e684cf93604f86/contourpy-1.3.3-cp312-cp312-win_arm64.whl", hash = "sha256:07ce5ed73ecdc4a03ffe3e1b3e3c1166db35ae7584be76f65dbbe28a7791b0cc" }, { url = "https://mirrors.aliyun.com/pypi/packages/68/35/0167aad910bbdb9599272bd96d01a9ec6852f36b9455cf2ca67bd4cc2d23/contourpy-1.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:177fb367556747a686509d6fef71d221a4b198a3905fe824430e5ea0fda54eb5" }, { url = "https://mirrors.aliyun.com/pypi/packages/96/e4/7adcd9c8362745b2210728f209bfbcf7d91ba868a2c5f40d8b58f54c509b/contourpy-1.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:d002b6f00d73d69333dac9d0b8d5e84d9724ff9ef044fd63c5986e62b7c9e1b1" }, { url = "https://mirrors.aliyun.com/pypi/packages/73/23/90e31ceeed1de63058a02cb04b12f2de4b40e3bef5e082a7c18d9c8ae281/contourpy-1.3.3-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:348ac1f5d4f1d66d3322420f01d42e43122f43616e0f194fc1c9f5d830c5b286" }, @@ -1386,21 +1455,6 @@ version = "7.13.5" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9d/e0/70553e3000e345daff267cec284ce4cbf3fc141b6da229ac52775b5428f1/coverage-7.13.5.tar.gz", hash = "sha256:c81f6515c4c40141f83f502b07bbfa5c240ba25bbe73da7b33f1e5b6120ff179" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/a0/c3/a396306ba7db865bf96fc1fb3b7fd29bcbf3d829df642e77b13555163cd6/coverage-7.13.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:460cf0114c5016fa841214ff5564aa4864f11948da9440bc97e21ad1f4ba1e01" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a6/16/a68a19e5384e93f811dccc51034b1fd0b865841c390e3c931dcc4699e035/coverage-7.13.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0e223ce4b4ed47f065bfb123687686512e37629be25cc63728557ae7db261422" }, - { url = "https://mirrors.aliyun.com/pypi/packages/29/72/20b917c6793af3a5ceb7fb9c50033f3ec7865f2911a1416b34a7cfa0813b/coverage-7.13.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:6e3370441f4513c6252bf042b9c36d22491142385049243253c7e48398a15a9f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8c/49/cd14b789536ac6a4778c453c6a2338bc0a2fb60c5a5a41b4008328b9acc1/coverage-7.13.5-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:03ccc709a17a1de074fb1d11f217342fb0d2b1582ed544f554fc9fc3f07e95f5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9d/00/7b0edcfe64e2ed4c0340dac14a52ad0f4c9bd0b8b5e531af7d55b703db7c/coverage-7.13.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3f4818d065964db3c1c66dc0fbdac5ac692ecbc875555e13374fdbe7eedb4376" }, - { url = "https://mirrors.aliyun.com/pypi/packages/93/89/7ffc4ba0f5d0a55c1e84ea7cee39c9fc06af7b170513d83fbf3bbefce280/coverage-7.13.5-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:012d5319e66e9d5a218834642d6c35d265515a62f01157a45bcc036ecf947256" }, - { url = "https://mirrors.aliyun.com/pypi/packages/81/bd/73ddf85f93f7e6fa83e77ccecb6162d9415c79007b4bc124008a4995e4a7/coverage-7.13.5-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:8dd02af98971bdb956363e4827d34425cb3df19ee550ef92855b0acb9c7ce51c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a0/81/278aff4e8dec4926a0bcb9486320752811f543a3ce5b602cc7a29978d073/coverage-7.13.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:f08fd75c50a760c7eb068ae823777268daaf16a80b918fa58eea888f8e3919f5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/70/ee/fe1621488e2e0a58d7e94c4800f0d96f79671553488d401a612bebae324b/coverage-7.13.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:843ea8643cf967d1ac7e8ecd4bb00c99135adf4816c0c0593fdcc47b597fcf09" }, - { url = "https://mirrors.aliyun.com/pypi/packages/37/a6/f79fb37aa104b562207cc23cb5711ab6793608e246cae1e93f26b2236ed9/coverage-7.13.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:9d44d7aa963820b1b971dbecd90bfe5fe8f81cff79787eb6cca15750bd2f79b9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/75/f0/ed15262a58ec81ce457ceb717b7f78752a1713556b19081b76e90896e8d4/coverage-7.13.5-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:7132bed4bd7b836200c591410ae7d97bf7ae8be6fc87d160b2bd881df929e7bf" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0f/e9/9129958f20e7e9d4d56d51d42ccf708d15cac355ff4ac6e736e97a9393d2/coverage-7.13.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:a698e363641b98843c517817db75373c83254781426e94ada3197cabbc2c919c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a4/d7/0ad9b15812d81272db94379fe4c6df8fd17781cc7671fdfa30c76ba5ff7b/coverage-7.13.5-cp312-cp312-win32.whl", hash = "sha256:bdba0a6b8812e8c7df002d908a9a2ea3c36e92611b5708633c50869e6d922fdf" }, - { url = "https://mirrors.aliyun.com/pypi/packages/29/3d/821a9a5799fac2556bcf0bd37a70d1d11fa9e49784b6d22e92e8b2f85f18/coverage-7.13.5-cp312-cp312-win_amd64.whl", hash = "sha256:d2c87e0c473a10bffe991502eac389220533024c8082ec1ce849f4218dded810" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d4/fa/2238c2ad08e35cf4f020ea721f717e09ec3152aea75d191a7faf3ef009a8/coverage-7.13.5-cp312-cp312-win_arm64.whl", hash = "sha256:bf69236a9a81bdca3bff53796237aab096cdbf8d78a66ad61e992d9dac7eb2de" }, { url = "https://mirrors.aliyun.com/pypi/packages/74/8c/74fedc9663dcf168b0a059d4ea756ecae4da77a489048f94b5f512a8d0b3/coverage-7.13.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5ec4af212df513e399cf11610cc27063f1586419e814755ab362e50a85ea69c1" }, { url = "https://mirrors.aliyun.com/pypi/packages/0c/c9/44fb661c55062f0818a6ffd2685c67aa30816200d5f2817543717d4b92eb/coverage-7.13.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:941617e518602e2d64942c88ec8499f7fbd49d3f6c4327d3a71d43a1973032f3" }, { url = "https://mirrors.aliyun.com/pypi/packages/5f/13/93419671cee82b780bab7ea96b67c8ef448f5f295f36bf5031154ec9a790/coverage-7.13.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:da305e9937617ee95c2e39d8ff9f040e0487cbf1ac174f777ed5eddd7a7c1f26" }, @@ -1470,21 +1524,6 @@ version = "2.11.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/14/12/34bf6e840a79130dfd0da7badfb6f7810b8fcfd60e75b0539372667b41b6/cramjam-2.11.0.tar.gz", hash = "sha256:5c82500ed91605c2d9781380b378397012e25127e89d64f460fea6aeac4389b4" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/0b/0d/7c84c913a5fae85b773a9dcf8874390f9d68ba0fcc6630efa7ff1541b950/cramjam-2.11.0-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:dba5c14b8b4f73ea1e65720f5a3fe4280c1d27761238378be8274135c60bbc6e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2b/cc/4f6d185d8a744776f53035e72831ff8eefc2354f46ab836f4bd3c4f6c138/cramjam-2.11.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:11eb40722b3fcf3e6890fba46c711bf60f8dc26360a24876c85e52d76c33b25b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1c/a8/626c76263085c6d5ded0e71823b411e9522bfc93ba6cc59855a5869296e7/cramjam-2.11.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aeb26e2898994b6e8319f19a4d37c481512acdcc6d30e1b5ecc9d8ec57e835cb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e9/52/0851a16a62447532e30ba95a80e638926fdea869a34b4b5b9d0a020083ba/cramjam-2.11.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.whl", hash = "sha256:4f8d82081ed7d8fe52c982bd1f06e4c7631a73fe1fb6d4b3b3f2404f87dc40fe" }, - { url = "https://mirrors.aliyun.com/pypi/packages/98/76/122e444f59dbc216451d8e3d8282c9665dc79eaf822f5f1470066be1b695/cramjam-2.11.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:092a3ec26e0a679305018380e4f652eae1b6dfe3fc3b154ee76aa6b92221a17c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a3/bc/3a0189aef1af2b29632c039c19a7a1b752bc21a4053582a5464183a0ad3d/cramjam-2.11.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:529d6d667c65fd105d10bd83d1cd3f9869f8fd6c66efac9415c1812281196a92" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2e/80/8a6343b13778ce52d94bb8d5365a30c3aa951276b1857201fe79d7e2ad25/cramjam-2.11.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:555eb9c90c450e0f76e27d9ff064e64a8b8c6478ab1a5594c91b7bc5c82fd9f0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/df/6b/cd1778a207c29eda10791e3dfa018b588001928086e179fc71254793c625/cramjam-2.11.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5edf4c9e32493035b514cf2ba0c969d81ccb31de63bd05490cc8bfe3b431674e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/dc/f0/5c2a5cd5711032f3b191ca50cb786c17689b4a9255f9f768866e6c9f04d9/cramjam-2.11.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2fa2fe41f48c4d58d923803383b0737f048918b5a0d10390de9628bb6272b107" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f9/8b/b363a5fb2c3347504fe9a64f8d0f1e276844f0e532aa7162c061cd1ffee4/cramjam-2.11.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9ca14cf1cabdb0b77d606db1bb9e9ca593b1dbd421fcaf251ec9a5431ec449f3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/78/7b/d83dad46adb6c988a74361f81ad9c5c22642be53ad88616a19baedd06243/cramjam-2.11.0-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:309e95bf898829476bccf4fd2c358ec00e7ff73a12f95a3cdeeba4bb1d3683d5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1a/be/60d9be4cb33d8740a4aa94c7513f2ef3c4eba4fd13536f086facbafade71/cramjam-2.11.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:86dca35d2f15ef22922411496c220f3c9e315d5512f316fe417461971cc1648d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/11/b0/4a595f01a243aec8ad272b160b161c44351190c35d98d7787919d962e9e5/cramjam-2.11.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:193c6488bd2f514cbc0bef5c18fad61a5f9c8d059dd56edf773b3b37f0e85496" }, - { url = "https://mirrors.aliyun.com/pypi/packages/38/47/7776659aaa677046b77f527106e53ddd47373416d8fcdb1e1a881ec5dc06/cramjam-2.11.0-cp312-cp312-win32.whl", hash = "sha256:514e2c008a8b4fa823122ca3ecab896eac41d9aa0f5fc881bd6264486c204e32" }, - { url = "https://mirrors.aliyun.com/pypi/packages/75/b1/d53002729cfd94c5844ddfaf1233c86d29f2dbfc1b764a6562c41c044199/cramjam-2.11.0-cp312-cp312-win_amd64.whl", hash = "sha256:53fed080476d5f6ad7505883ec5d1ec28ba36c2273db3b3e92d7224fe5e463db" }, { url = "https://mirrors.aliyun.com/pypi/packages/0a/8b/406c5dc0f8e82385519d8c299c40fd6a56d97eca3fcd6f5da8dad48de75b/cramjam-2.11.0-cp313-cp313-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:2c289729cc1c04e88bafa48b51082fb462b0a57dbc96494eab2be9b14dca62af" }, { url = "https://mirrors.aliyun.com/pypi/packages/00/ad/4186884083d6e4125b285903e17841827ab0d6d0cffc86216d27ed91e91d/cramjam-2.11.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:045201ee17147e36cf43d8ae2fa4b4836944ac672df5874579b81cf6d40f1a1f" }, { url = "https://mirrors.aliyun.com/pypi/packages/54/01/91b485cf76a7efef638151e8a7d35784dae2c4ff221b1aec2c083e4b106d/cramjam-2.11.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:619cd195d74c9e1d2a3ad78d63451d35379c84bd851aec552811e30842e1c67a" }, @@ -1582,15 +1621,6 @@ version = "2.8" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/e3/66/7e97aa77af7cf6afbff26e3651b564fe41932599bc2d3dce0b2f73d4829a/crc32c-2.8.tar.gz", hash = "sha256:578728964e59c47c356aeeedee6220e021e124b9d3e8631d95d9a5e5f06e261c" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/b6/36/fd18ef23c42926b79c7003e16cb0f79043b5b179c633521343d3b499e996/crc32c-2.8-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:572ffb1b78cce3d88e8d4143e154d31044a44be42cb3f6fbbf77f1e7a941c5ab" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7f/b8/c584958e53f7798dd358f5bdb1bbfc97483134f053ee399d3eeb26cca075/crc32c-2.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:cf827b3758ee0c4aacd21ceca0e2da83681f10295c38a10bfeb105f7d98f7a68" }, - { url = "https://mirrors.aliyun.com/pypi/packages/62/e6/6f2af0ec64a668a46c861e5bc778ea3ee42171fedfc5440f791f470fd783/crc32c-2.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:106fbd79013e06fa92bc3b51031694fcc1249811ed4364ef1554ee3dd2c7f5a2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/17/8b/4a04bd80a024f1a23978f19ae99407783e06549e361ab56e9c08bba3c1d3/crc32c-2.8-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:6dde035f91ffbfe23163e68605ee5a4bb8ceebd71ed54bb1fb1d0526cdd125a2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/21/8f/01c7afdc76ac2007d0e6a98e7300b4470b170480f8188475b597d1f4b4c6/crc32c-2.8-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e41ebe7c2f0fdcd9f3a3fd206989a36b460b4d3f24816d53e5be6c7dba72c5e1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/32/2b/8f78c5a8cc66486be5f51b6f038fc347c3ba748d3ea68be17a014283c331/crc32c-2.8-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ecf66cf90266d9c15cea597d5cc86c01917cd1a238dc3c51420c7886fa750d7e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/db/86/fad1a94cdeeeb6b6e2323c87f970186e74bfd6fbfbc247bf5c88ad0873d5/crc32c-2.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:59eee5f3a69ad0793d5fa9cdc9b9d743b0cd50edf7fccc0a3988a821fef0208c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d5/db/1a7cb6757a1e32376fa2dfce00c815ea4ee614a94f9bff8228e37420c183/crc32c-2.8-cp312-cp312-win32.whl", hash = "sha256:a73d03ce3604aa5d7a2698e9057a0eef69f529c46497b27ee1c38158e90ceb76" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bf/8e/2024de34399b2e401a37dcb54b224b56c747b0dc46de4966886827b4d370/crc32c-2.8-cp312-cp312-win_amd64.whl", hash = "sha256:56b3b7d015247962cf58186e06d18c3d75a1a63d709d3233509e1c50a2d36aa2" }, { url = "https://mirrors.aliyun.com/pypi/packages/e8/d8/3ae227890b3be40955a7144106ef4dd97d6123a82c2a5310cdab58ca49d8/crc32c-2.8-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:36f1e03ee9e9c6938e67d3bcb60e36f260170aa5f37da1185e04ef37b56af395" }, { url = "https://mirrors.aliyun.com/pypi/packages/bd/8b/178d3f987cd0e049b484615512d3f91f3d2caeeb8ff336bb5896ae317438/crc32c-2.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b2f3226b94b85a8dd9b3533601d7a63e9e3e8edf03a8a169830ee8303a199aeb" }, { url = "https://mirrors.aliyun.com/pypi/packages/f2/a1/48145ae2545ebc0169d3283ebe882da580ea4606bfb67cf4ca922ac3cfc3/crc32c-2.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6e08628bc72d5b6bc8e0730e8f142194b610e780a98c58cb6698e665cb885a5b" }, @@ -1711,6 +1741,46 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/e7/05/c19819d5e3d95294a6f5947fb9b9629efb316b96de511b418c53d245aae6/cycler-0.12.1-py3-none-any.whl", hash = "sha256:85cef7cff222d8644161529808465972e51340599459b8ac3ccbac5a854e0d30" }, ] +[[package]] +name = "cymem" +version = "2.0.13" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/c0/8f/2f0fbb32535c3731b7c2974c569fb9325e0a38ed5565a08e1139a3b71e82/cymem-2.0.13.tar.gz", hash = "sha256:1c91a92ae8c7104275ac26bd4d29b08ccd3e7faff5893d3858cb6fadf1bc1588" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/ce/0f/95a4d1e3bebfdfa7829252369357cf9a764f67569328cd9221f21e2c952e/cymem-2.0.13-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:891fd9030293a8b652dc7fb9fdc79a910a6c76fc679cd775e6741b819ffea476" }, + { url = "https://mirrors.aliyun.com/pypi/packages/bf/a0/8fc929cc29ae466b7b4efc23ece99cbd3ea34992ccff319089c624d667fd/cymem-2.0.13-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:89c4889bd16513ce1644ccfe1e7c473ba7ca150f0621e66feac3a571bde09e7e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/4a/b3/deeb01354ebaf384438083ffe0310209ef903db3e7ba5a8f584b06d28387/cymem-2.0.13-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:45dcaba0f48bef9cc3d8b0b92058640244a95a9f12542210b51318da97c2cf28" }, + { url = "https://mirrors.aliyun.com/pypi/packages/36/36/bc980b9a14409f3356309c45a8d88d58797d02002a9d794dd6c84e809d3a/cymem-2.0.13-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e96848faaafccc0abd631f1c5fb194eac0caee4f5a8777fdbb3e349d3a21741c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/fd/dd/a12522952624685bd0f8968e26d2ed6d059c967413ce6eb52292f538f1b0/cymem-2.0.13-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:e02d3e2c3bfeb21185d5a4a70790d9df40629a87d8d7617dc22b4e864f665fa3" }, + { url = "https://mirrors.aliyun.com/pypi/packages/08/11/5dc933ddfeb2dfea747a0b935cb965b9a7580b324d96fc5f5a1b5ff8df29/cymem-2.0.13-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:fece5229fd5ecdcd7a0738affb8c59890e13073ae5626544e13825f26c019d3c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/70/66/d23b06166864fa94e13a98e5922986ce774832936473578febce64448d75/cymem-2.0.13-cp313-cp313-win_amd64.whl", hash = "sha256:38aefeb269597c1a0c2ddf1567dd8605489b661fa0369c6406c1acd433b4c7ba" }, + { url = "https://mirrors.aliyun.com/pypi/packages/2f/9e/c7b21271ab88a21760f3afdec84d2bc09ffa9e6c8d774ad9d4f1afab0416/cymem-2.0.13-cp313-cp313-win_arm64.whl", hash = "sha256:717270dcfd8c8096b479c42708b151002ff98e434a7b6f1f916387a6c791e2ad" }, + { url = "https://mirrors.aliyun.com/pypi/packages/7f/28/d3b03427edc04ae04910edf1c24b993881c3ba93a9729a42bcbb816a1808/cymem-2.0.13-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:7e1a863a7f144ffb345397813701509cfc74fc9ed360a4d92799805b4b865dd1" }, + { url = "https://mirrors.aliyun.com/pypi/packages/35/a9/7ed53e481f47ebfb922b0b42e980cec83e98ccb2137dc597ea156642440c/cymem-2.0.13-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:c16cb80efc017b054f78998c6b4b013cef509c7b3d802707ce1f85a1d68361bf" }, + { url = "https://mirrors.aliyun.com/pypi/packages/61/39/a3d6ad073cf7f0fbbb8bbf09698c3c8fac11be3f791d710239a4e8dd3438/cymem-2.0.13-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0d78a27c88b26c89bd1ece247d1d5939dba05a1dae6305aad8fd8056b17ddb51" }, + { url = "https://mirrors.aliyun.com/pypi/packages/36/0c/20697c8bc19f624a595833e566f37d7bcb9167b0ce69de896eba7cfc9c2d/cymem-2.0.13-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6d36710760f817194dacb09d9fc45cb6a5062ed75e85f0ef7ad7aeeb13d80cc3" }, + { url = "https://mirrors.aliyun.com/pypi/packages/82/d4/9326e3422d1c2d2b4a8fb859bdcce80138f6ab721ddafa4cba328a505c71/cymem-2.0.13-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:c8f30971cadd5dcf73bcfbbc5849b1f1e1f40db8cd846c4aa7d3b5e035c7b583" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ed/bc/68da7dd749b72884dc22e898562f335002d70306069d496376e5ff3b6153/cymem-2.0.13-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:9d441d0e45798ec1fd330373bf7ffa6b795f229275f64016b6a193e6e2a51522" }, + { url = "https://mirrors.aliyun.com/pypi/packages/50/23/dbf2ad6ecd19b99b3aab6203b1a06608bbd04a09c522d836b854f2f30f73/cymem-2.0.13-cp313-cp313t-win_amd64.whl", hash = "sha256:d1c950eebb9f0f15e3ef3591313482a5a611d16fc12d545e2018cd607f40f472" }, + { url = "https://mirrors.aliyun.com/pypi/packages/54/3f/35701c13e1fc7b0895198c8b20068c569a841e0daf8e0b14d1dc0816b28f/cymem-2.0.13-cp313-cp313t-win_arm64.whl", hash = "sha256:042e8611ef862c34a97b13241f5d0da86d58aca3cecc45c533496678e75c5a1f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/a7/2e/f0e1596010a9a57fa9ebd124a678c07c5b2092283781ae51e79edcf5cb98/cymem-2.0.13-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:d2a4bf67db76c7b6afc33de44fb1c318207c3224a30da02c70901936b5aafdf1" }, + { url = "https://mirrors.aliyun.com/pypi/packages/bc/45/8ccc21df08fcbfa6aa3efeb7efc11a1c81c90e7476e255768bb9c29ba02a/cymem-2.0.13-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:92a2ce50afa5625fb5ce7c9302cee61e23a57ccac52cd0410b4858e572f8614b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/01/8c/fe16531631f051d3d1226fa42e2d76fd2c8d5cfa893ec93baee90c7a9d90/cymem-2.0.13-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:bc116a70cc3a5dc3d1684db5268eff9399a0be8603980005e5b889564f1ea42f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/47/4b/39d67b80ffb260457c05fcc545de37d82e9e2dbafc93dd6b64f17e09b933/cymem-2.0.13-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:68489bf0035c4c280614067ab6a82815b01dc9fcd486742a5306fe9f68deb7ef" }, + { url = "https://mirrors.aliyun.com/pypi/packages/53/0e/76f6531f74dfdfe7107899cce93ab063bb7ee086ccd3910522b31f623c08/cymem-2.0.13-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:03cb7bdb55718d5eb6ef0340b1d2430ba1386db30d33e9134d01ba9d6d34d705" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c7/7c/eee56757db81f0aefc2615267677ae145aff74228f529838425057003c0d/cymem-2.0.13-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:1710390e7fb2510a8091a1991024d8ae838fd06b02cdfdcd35f006192e3c6b0e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/77/e0/a4b58ec9e53c836dce07ef39837a64a599f4a21a134fc7ca57a3a8f9a4b5/cymem-2.0.13-cp314-cp314-win_amd64.whl", hash = "sha256:ac699c8ec72a3a9de8109bd78821ab22f60b14cf2abccd970b5ff310e14158ed" }, + { url = "https://mirrors.aliyun.com/pypi/packages/61/81/9931d1f83e5aeba175440af0b28f0c2e6f71274a5a7b688bc3e907669388/cymem-2.0.13-cp314-cp314-win_arm64.whl", hash = "sha256:90c2d0c04bcda12cd5cebe9be93ce3af6742ad8da96e1b1907e3f8e00291def1" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b7/ef/af447c2184dec6dec973be14614df8ccb4d16d1c74e0784ab4f02538433c/cymem-2.0.13-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:ff036bbc1464993552fd1251b0a83fe102af334b301e3896d7aa05a4999ad042" }, + { url = "https://mirrors.aliyun.com/pypi/packages/8c/95/e10f33a8d4fc17f9b933d451038218437f9326c2abb15a3e7f58ce2a06ec/cymem-2.0.13-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:fb8291691ba7ff4e6e000224cc97a744a8d9588418535c9454fd8436911df612" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e7/7a/5efeb2d2ea6ebad2745301ad33a4fa9a8f9a33b66623ee4d9185683007a6/cymem-2.0.13-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:d8d06ea59006b1251ad5794bcc00121e148434826090ead0073c7b7fedebe431" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0b/28/2a3f65842cc8443c2c0650cf23d525be06c8761ab212e0a095a88627be1b/cymem-2.0.13-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:c0046a619ecc845ccb4528b37b63426a0cbcb4f14d7940add3391f59f13701e6" }, + { url = "https://mirrors.aliyun.com/pypi/packages/98/73/dd5f9729398f0108c2e71d942253d0d484d299d08b02e474d7cfc43ed0b0/cymem-2.0.13-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:18ad5b116a82fa3674bc8838bd3792891b428971e2123ae8c0fd3ca472157c5e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/5a/01/ffe51729a8f961a437920560659073e47f575d4627445216c1177ecd4a41/cymem-2.0.13-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:666ce6146bc61b9318aa70d91ce33f126b6344a25cf0b925621baed0c161e9cc" }, + { url = "https://mirrors.aliyun.com/pypi/packages/fd/ac/c9e7d68607f71ef978c81e334ab2898b426944c71950212b1467186f69f9/cymem-2.0.13-cp314-cp314t-win_amd64.whl", hash = "sha256:84c1168c563d9d1e04546cb65e3e54fde2bf814f7c7faf11fc06436598e386d1" }, + { url = "https://mirrors.aliyun.com/pypi/packages/66/66/150e406a2db5535533aa3c946de58f0371f2e412e23f050c704588023e6e/cymem-2.0.13-cp314-cp314t-win_arm64.whl", hash = "sha256:e9027764dc5f1999fb4b4cabee1d0322c59e330c0a6485b436a68275f614277f" }, +] + [[package]] name = "darabonba-core" version = "1.0.5" @@ -1751,10 +1821,6 @@ version = "1.8.20" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/e0/b7/cd8080344452e4874aae67c40d8940e2b4d47b01601a8fd9f44786c757c7/debugpy-1.8.20.tar.gz", hash = "sha256:55bc8701714969f1ab89a6d5f2f3d40c36f91b2cbe2f65d98bf8196f6a6a2c33" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/14/57/7f34f4736bfb6e00f2e4c96351b07805d83c9a7b33d28580ae01374430f7/debugpy-1.8.20-cp312-cp312-macosx_15_0_universal2.whl", hash = "sha256:4ae3135e2089905a916909ef31922b2d733d756f66d87345b3e5e52b7a55f13d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ab/78/b193a3975ca34458f6f0e24aaf5c3e3da72f5401f6054c0dfd004b41726f/debugpy-1.8.20-cp312-cp312-manylinux_2_34_x86_64.whl", hash = "sha256:88f47850a4284b88bd2bfee1f26132147d5d504e4e86c22485dfa44b97e19b4b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c1/55/f14deb95eaf4f30f07ef4b90a8590fc05d9e04df85ee379712f6fb6736d7/debugpy-1.8.20-cp312-cp312-win32.whl", hash = "sha256:4057ac68f892064e5f98209ab582abfee3b543fb55d2e87610ddc133a954d390" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a1/39/2bef246368bd42f9bd7cba99844542b74b84dacbdbea0833e610f384fee8/debugpy-1.8.20-cp312-cp312-win_amd64.whl", hash = "sha256:a1a8f851e7cf171330679ef6997e9c579ef6dd33c9098458bd9986a0f4ca52e3" }, { url = "https://mirrors.aliyun.com/pypi/packages/15/e2/fc500524cc6f104a9d049abc85a0a8b3f0d14c0a39b9c140511c61e5b40b/debugpy-1.8.20-cp313-cp313-macosx_15_0_universal2.whl", hash = "sha256:5dff4bb27027821fdfcc9e8f87309a28988231165147c31730128b1c983e282a" }, { url = "https://mirrors.aliyun.com/pypi/packages/90/83/fb33dcea789ed6018f8da20c5a9bc9d82adc65c0c990faed43f7c955da46/debugpy-1.8.20-cp313-cp313-manylinux_2_34_x86_64.whl", hash = "sha256:84562982dd7cf5ebebfdea667ca20a064e096099997b175fe204e86817f64eaf" }, { url = "https://mirrors.aliyun.com/pypi/packages/a6/25/b1e4a01bfb824d79a6af24b99ef291e24189080c93576dfd9b1a2815cd0f/debugpy-1.8.20-cp313-cp313-win32.whl", hash = "sha256:da11dea6447b2cadbf8ce2bec59ecea87cc18d2c574980f643f2d2dfe4862393" }, @@ -1912,19 +1978,6 @@ name = "editdistance" version = "0.8.1" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d5/18/9f4f975ca87a390832b1c22478f3702fcdf739f83211e24d054b7551270d/editdistance-0.8.1.tar.gz", hash = "sha256:d1cdf80a5d5014b0c9126a69a42ce55a457b457f6986ff69ca98e4fe4d2d8fed" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/cb/4c/7f195588949b4e72436dc7fc902632381f96e586af829685b56daebb38b8/editdistance-0.8.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b04af61b3fcdd287a07c15b6ae3b02af01c5e3e9c3aca76b8c1d13bd266b6f57" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8d/82/31dc1640d830cd7d36865098329f34e4dad3b77f31cfb9404b347e700196/editdistance-0.8.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:18fc8b6eaae01bfd9cf999af726c1e8dcf667d120e81aa7dbd515bea7427f62f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ea/2a/6b823e71cef694d6f070a1d82be2842706fa193541aab8856a8f42044cd0/editdistance-0.8.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6a87839450a5987028738d061ffa5ef6a68bac2ddc68c9147a8aae9806629c7f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e1/31/bfb8e590f922089dc3471ed7828a6da2fc9453eba38c332efa9ee8749fd7/editdistance-0.8.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:24b5f9c9673c823d91b5973d0af8b39f883f414a55ade2b9d097138acd10f31e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a9/c7/57423942b2f847cdbbb46494568d00cd8a45500904ea026f0aad6ca01bc7/editdistance-0.8.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c59248eabfad603f0fba47b0c263d5dc728fb01c2b6b50fb6ca187cec547fdb3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1b/05/dfa4cdcce063596cbf0d7a32c46cd0f4fa70980311b7da64d35f33ad02a0/editdistance-0.8.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84e239d88ff52821cf64023fabd06a1d9a07654f364b64bf1284577fd3a79d0e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0e/14/39608ff724a9523f187c4e28926d78bc68f2798f74777ac6757981108345/editdistance-0.8.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:2f7f71698f83e8c83839ac0d876a0f4ef996c86c5460aebd26d85568d4afd0db" }, - { url = "https://mirrors.aliyun.com/pypi/packages/df/92/4a1c61d72da40dedfd0ff950fdc71ae83f478330c58a8bccfd776518bd67/editdistance-0.8.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:04e229d6f4ce0c12abc9f4cd4023a5b5fa9620226e0207b119c3c2778b036250" }, - { url = "https://mirrors.aliyun.com/pypi/packages/47/3d/9877566e724c8a37f2228a84ec5cbf66dbfd0673515baf68a0fe07caff40/editdistance-0.8.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:e16721636da6d6b68a2c09eaced35a94f4a4a704ec09f45756d4fd5e128ed18d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d2/f5/8c50757d198b8ca30ddb91e8b8f0247a8dca04ff2ec30755245f0ab1ff0c/editdistance-0.8.1-cp312-cp312-win32.whl", hash = "sha256:87533cf2ebc3777088d991947274cd7e1014b9c861a8aa65257bcdc0ee492526" }, - { url = "https://mirrors.aliyun.com/pypi/packages/28/f0/65101e51dc7c850e7b7581a5d8fa8721a1d7479a0dca6c08386328e19882/editdistance-0.8.1-cp312-cp312-win_amd64.whl", hash = "sha256:09f01ed51746d90178af7dd7ea4ebb41497ef19f53c7f327e864421743dffb0a" }, -] [[package]] name = "elastic-transport" @@ -1966,6 +2019,14 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/6b/ee/4699000ef357e476a3984fd1eff236f820e3346c4aef7c7772e580b81b31/elasticsearch_dsl-8.12.0-py3-none-any.whl", hash = "sha256:2ea9e6ded64d21a8f1ef72477a4d116c6fbeea631ac32a2e2490b9c0d09a99a6" }, ] +[[package]] +name = "en-core-web-sm" +version = "3.8.0" +source = { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" } +wheels = [ + { url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl", hash = "sha256:1932429db727d4bff3deed6b34cfc05df17794f4a52eeb26cf8928f7c1a0fb85" }, +] + [[package]] name = "et-xmlfile" version = "2.0.0" @@ -1987,9 +2048,6 @@ wheels = [ name = "exceptiongroup" version = "1.3.1" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, -] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/50/79/66800aadf48771f6b62f7eb014e352e5d06856655206165d775e675a02c9/exceptiongroup-1.3.1.tar.gz", hash = "sha256:8b412432c6055b0b7d14c310000ae93352ed6754f70fa8f7c34141f91c4e3219" } wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/8a/0e/97c33bf5009bdbac74fd2beace167cab3f978feb69cc36f1ef79360d6c4e/exceptiongroup-1.3.1-py3-none-any.whl", hash = "sha256:a7a39a3bd276781e98394987d3a5701d0c4edffb633bb7a5144577f82c773598" }, @@ -2045,12 +2103,6 @@ version = "1.12.1" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/65/8b/fa2d3287fd2267be6261d0177c6809a7fa12c5600ddb33490c8dc29e77b2/fastavro-1.12.1.tar.gz", hash = "sha256:2f285be49e45bc047ab2f6bed040bb349da85db3f3c87880e4b92595ea093b2b" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/7c/f0/10bd1a3d08667fa0739e2b451fe90e06df575ec8b8ba5d3135c70555c9bd/fastavro-1.12.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:509818cb24b98a804fc80be9c5fed90f660310ae3d59382fc811bfa187122167" }, - { url = "https://mirrors.aliyun.com/pypi/packages/78/ad/0d985bc99e1fa9e74c636658000ba38a5cd7f5ab2708e9c62eaf736ecf1a/fastavro-1.12.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:089e155c0c76e0d418d7e79144ce000524dd345eab3bc1e9c5ae69d500f71b14" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0d/9e/b4951dc84ebc34aac69afcbfbb22ea4a91080422ec2bfd2c06076ff1d419/fastavro-1.12.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:44cbff7518901c91a82aab476fcab13d102e4999499df219d481b9e15f61af34" }, - { url = "https://mirrors.aliyun.com/pypi/packages/af/f8/5a8df450a9f55ca8441f22ea0351d8c77809fc121498b6970daaaf667a21/fastavro-1.12.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a275e48df0b1701bb764b18a8a21900b24cf882263cb03d35ecdba636bbc830b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/99/b2/40f25299111d737e58b85696e91138a66c25b7334f5357e7ac2b0e8966f8/fastavro-1.12.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2de72d786eb38be6b16d556b27232b1bf1b2797ea09599507938cdb7a9fe3e7c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e0/07/85157a7c57c5f8b95507d7829b5946561e5ee656ff80e9dd9a757f53ddaf/fastavro-1.12.1-cp312-cp312-win_amd64.whl", hash = "sha256:9090f0dee63fe022ee9cc5147483366cc4171c821644c22da020d6b48f576b4f" }, { url = "https://mirrors.aliyun.com/pypi/packages/bb/57/26d5efef9182392d5ac9f253953c856ccb66e4c549fd3176a1e94efb05c9/fastavro-1.12.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:78df838351e4dff9edd10a1c41d1324131ffecbadefb9c297d612ef5363c049a" }, { url = "https://mirrors.aliyun.com/pypi/packages/33/cb/8ab55b21d018178eb126007a56bde14fd01c0afc11d20b5f2624fe01e698/fastavro-1.12.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:780476c23175d2ae457c52f45b9ffa9d504593499a36cd3c1929662bf5b7b14b" }, { url = "https://mirrors.aliyun.com/pypi/packages/fe/03/9c94ec9bf873eb1ffb0aa694f4e71940154e6e9728ddfdc46046d7e8ced4/fastavro-1.12.1-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0714b285160fcd515eb0455540f40dd6dac93bdeacdb03f24e8eac3d8aa51f8d" }, @@ -2086,15 +2138,6 @@ dependencies = [ { name = "pandas" }, ] wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/97/9c/f2c018807cab35716df732be6c09ec017ad9ee40dc2e876b10ed5d9a963e/fastparquet-2026.3.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:c413adcea221c11e8a14d096d825b42d4f0b4b6621f64d6c13f4a433574906e6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/93/bf/6470b62e3eabb46e5abc6ad4e0c13587e1448f2365f7c35079fe4d6602ab/fastparquet-2026.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:4623c12e5dc05f6164cad7a2f6962c1e8f69f4670abd6b19fe7b1f13b4f4937d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e8/ef/78a5db203e2e1d19249286f52ecb5531b8863e56a346d9d193633c3030fd/fastparquet-2026.3.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:41def5d94abf44830e58b4c2ed137b71f2f0e068c6241a1d2524595178880851" }, - { url = "https://mirrors.aliyun.com/pypi/packages/61/99/e43283ac6cc83269c8214b8ee57e7773ea5f39016a8e8fcfe4529fa2cc30/fastparquet-2026.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:6a5ddb40b58b62ef660ea9f0774d3b3cfe6d0b88c20b44b986e500439290de81" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cd/9a/f5aada3af89dcd3027e543fa39756f67790daef0c31f03973bb97c6171c9/fastparquet-2026.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:f6cc397aa8ca5bb2e84670270b46a89e6d6e426f8bfce5437d028a90cd2d8b3d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/07/92/bdd4d8dfd59a6ae92b33eab0f583fd5099188c8065d875d22782f26b79e0/fastparquet-2026.3.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:46c60acc4b5752cd883bd0cc9076a01698b1b5f28cb7f94449fba68f40758316" }, - { url = "https://mirrors.aliyun.com/pypi/packages/70/7d/d46abd9713f53d90ebc47c373d78ddb34c24e5fa6a02c5a974370f8a57b0/fastparquet-2026.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c61d734cff4d29f16bf1c813b4d1725dec3676cb82a2f617713a894b4e97546d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/28/16/6dda2bf60e830feb1a1bdecb01e8aa33b011058ee767418cef4bc68a1249/fastparquet-2026.3.0-cp312-cp312-win32.whl", hash = "sha256:8835b763f1843ecde3f7e8bc9deda4a7dc317b65b1dfc9a10e7e4f26eac73ce4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/45/5b/7cc76aa44962280e496f35715f172afbd6476fcde5ecfa8fdc1c30416b03/fastparquet-2026.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:db0698e4e34788baadb4d8871f93409c9803bc661b7d58d90f616ded889289bb" }, { url = "https://mirrors.aliyun.com/pypi/packages/af/aa/3dbde9b0592a7aca0489edefa368b861a7d85df1ec51d7f5f05d83c4ad0f/fastparquet-2026.3.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:4e0f5464bc0661b345e26aa7feab34bd21c9ca2d3c4f411278f50c76e7adb7f2" }, { url = "https://mirrors.aliyun.com/pypi/packages/b5/f1/d81496c2887f166ea7222ef81d489dcc139ff3dc0f4b0393c0d201bdfb47/fastparquet-2026.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:97d48ea111b0cc09bf99b97c2218c5fd24abac8b53879b4ce73eea55d5484a55" }, { url = "https://mirrors.aliyun.com/pypi/packages/6d/1d/dba2033c57087d74ec463fbf9fc23b57a1bd731db38877f2b002d8b8c05b/fastparquet-2026.3.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:ab0c62c1890def8a40f3d878fc75fbf725a21df4e3676da74a56195346824bb0" }, @@ -2130,17 +2173,6 @@ version = "0.14.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/c3/7d/d9daedf0f2ebcacd20d599928f8913e9d2aea1d56d2d355a93bfa2b611d7/fastuuid-0.14.0.tar.gz", hash = "sha256:178947fc2f995b38497a74172adee64fdeb8b7ec18f2a5934d037641ba265d26" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/02/a2/e78fcc5df65467f0d207661b7ef86c5b7ac62eea337c0c0fcedbeee6fb13/fastuuid-0.14.0-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:77e94728324b63660ebf8adb27055e92d2e4611645bf12ed9d88d30486471d0a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2b/b3/c846f933f22f581f558ee63f81f29fa924acd971ce903dab1a9b6701816e/fastuuid-0.14.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:caa1f14d2102cb8d353096bc6ef6c13b2c81f347e6ab9d6fbd48b9dea41c153d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/54/ea/682551030f8c4fa9a769d9825570ad28c0c71e30cf34020b85c1f7ee7382/fastuuid-0.14.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d23ef06f9e67163be38cece704170486715b177f6baae338110983f99a72c070" }, - { url = "https://mirrors.aliyun.com/pypi/packages/14/dd/5927f0a523d8e6a76b70968e6004966ee7df30322f5fc9b6cdfb0276646a/fastuuid-0.14.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0c9ec605ace243b6dbe3bd27ebdd5d33b00d8d1d3f580b39fdd15cd96fd71796" }, - { url = "https://mirrors.aliyun.com/pypi/packages/16/6e/c0fb547eef61293153348f12e0f75a06abb322664b34a1573a7760501336/fastuuid-0.14.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:808527f2407f58a76c916d6aa15d58692a4a019fdf8d4c32ac7ff303b7d7af09" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2d/b1/b9c75e03b768f61cf2e84ee193dc18601aeaf89a4684b20f2f0e9f52b62c/fastuuid-0.14.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2fb3c0d7fef6674bbeacdd6dbd386924a7b60b26de849266d1ff6602937675c8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fc/fa/f7395fdac07c7a54f18f801744573707321ca0cee082e638e36452355a9d/fastuuid-0.14.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab3f5d36e4393e628a4df337c2c039069344db5f4b9d2a3c9cea48284f1dd741" }, - { url = "https://mirrors.aliyun.com/pypi/packages/66/49/c9fd06a4a0b1f0f048aacb6599e7d96e5d6bc6fa680ed0d46bf111929d1b/fastuuid-0.14.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:b9a0ca4f03b7e0b01425281ffd44e99d360e15c895f1907ca105854ed85e2057" }, - { url = "https://mirrors.aliyun.com/pypi/packages/be/9c/909e8c95b494e8e140e8be6165d5fc3f61fdc46198c1554df7b3e1764471/fastuuid-0.14.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3acdf655684cc09e60fb7e4cf524e8f42ea760031945aa8086c7eae2eeeabeb8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/90/eb/d29d17521976e673c55ef7f210d4cdd72091a9ec6755d0fd4710d9b3c871/fastuuid-0.14.0-cp312-cp312-win32.whl", hash = "sha256:9579618be6280700ae36ac42c3efd157049fe4dd40ca49b021280481c78c3176" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cc/fc/f5c799a6ea6d877faec0472d0b27c079b47c86b1cdc577720a5386483b36/fastuuid-0.14.0-cp312-cp312-win_amd64.whl", hash = "sha256:d9e4332dc4ba054434a9594cbfaf7823b57993d7d8e7267831c3e059857cf397" }, { url = "https://mirrors.aliyun.com/pypi/packages/a5/83/ae12dd39b9a39b55d7f90abb8971f1a5f3c321fd72d5aa83f90dc67fe9ed/fastuuid-0.14.0-cp313-cp313-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:77a09cb7427e7af74c594e409f7731a0cf887221de2f698e1ca0ebf0f3139021" }, { url = "https://mirrors.aliyun.com/pypi/packages/53/b0/a4b03ff5d00f563cc7546b933c28cb3f2a07344b2aec5834e874f7d44143/fastuuid-0.14.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:9bd57289daf7b153bfa3e8013446aa144ce5e8c825e9e366d455155ede5ea2dc" }, { url = "https://mirrors.aliyun.com/pypi/packages/9c/6d/64aee0a0f6a58eeabadd582e55d0d7d70258ffdd01d093b30c53d668303b/fastuuid-0.14.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:ac60fc860cdf3c3f327374db87ab8e064c86566ca8c49d2e30df15eda1b0c2d5" }, @@ -2296,14 +2328,6 @@ version = "4.62.1" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9a/08/7012b00a9a5874311b639c3920270c36ee0c445b69d9989a85e5c92ebcb0/fonttools-4.62.1.tar.gz", hash = "sha256:e54c75fd6041f1122476776880f7c3c3295ffa31962dc6ebe2543c00dca58b5d" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/47/d4/dbacced3953544b9a93088cc10ef2b596d348c983d5c67a404fa41ec51ba/fonttools-4.62.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:90365821debbd7db678809c7491ca4acd1e0779b9624cdc6ddaf1f31992bf974" }, - { url = "https://mirrors.aliyun.com/pypi/packages/66/9e/a769c8e99b81e5a87ab7e5e7236684de4e96246aae17274e5347d11ebd78/fonttools-4.62.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:12859ff0b47dd20f110804c3e0d0970f7b832f561630cd879969011541a464a9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/69/64/f19a9e3911968c37e1e620e14dfc5778299e1474f72f4e57c5ec771d9489/fonttools-4.62.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c125ffa00c3d9003cdaaf7f2c79e6e535628093e14b5de1dccb08859b680936" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9b/8a/99c8b3c3888c5c474c08dbfd7c8899786de9604b727fcefb055b42c84bba/fonttools-4.62.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:149f7d84afca659d1a97e39a4778794a2f83bf344c5ee5134e09995086cc2392" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d1/c6/0f904540d3e6ab463c1243a0d803504826a11604c72dd58c2949796a1762/fonttools-4.62.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0aa72c43a601cfa9273bb1ae0518f1acadc01ee181a6fc60cd758d7fdadffc04" }, - { url = "https://mirrors.aliyun.com/pypi/packages/29/0b/5cbef6588dc9bd6b5c9ad6a4d5a8ca384d0cea089da31711bbeb4f9654a6/fonttools-4.62.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:19177c8d96c7c36359266e571c5173bcee9157b59cfc8cb0153c5673dc5a3a7d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4a/47/b3a5342d381595ef439adec67848bed561ab7fdb1019fa522e82101b7d9c/fonttools-4.62.1-cp312-cp312-win32.whl", hash = "sha256:a24decd24d60744ee8b4679d38e88b8303d86772053afc29b19d23bb8207803c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/28/b1/0c2ab56a16f409c6c8a68816e6af707827ad5d629634691ff60a52879792/fonttools-4.62.1-cp312-cp312-win_amd64.whl", hash = "sha256:9e7863e10b3de72376280b515d35b14f5eeed639d1aa7824f4cf06779ec65e42" }, { url = "https://mirrors.aliyun.com/pypi/packages/3b/56/6f389de21c49555553d6a5aeed5ac9767631497ac836c4f076273d15bd72/fonttools-4.62.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:c22b1014017111c401469e3acc5433e6acf6ebcc6aa9efb538a533c800971c79" }, { url = "https://mirrors.aliyun.com/pypi/packages/03/c5/0e3966edd5ec668d41dfe418787726752bc07e2f5fd8c8f208615e61fa89/fonttools-4.62.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:68959f5fc58ed4599b44aad161c2837477d7f35f5f79402d97439974faebfebe" }, { url = "https://mirrors.aliyun.com/pypi/packages/52/94/e6ac4b44026de7786fe46e3bfa0c87e51d5d70a841054065d49cd62bb909/fonttools-4.62.1-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ef46db46c9447103b8f3ff91e8ba009d5fe181b1920a83757a5762551e32bb68" }, @@ -2356,22 +2380,6 @@ version = "1.8.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/2d/f5/c831fac6cc817d26fd54c7eaccd04ef7e0288806943f7cc5bbf69f3ac1f0/frozenlist-1.8.0.tar.gz", hash = "sha256:3ede829ed8d842f6cd48fc7081d7a41001a56f1f38603f9d49bf3020d59a31ad" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/69/29/948b9aa87e75820a38650af445d2ef2b6b8a6fab1a23b6bb9e4ef0be2d59/frozenlist-1.8.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:78f7b9e5d6f2fdb88cdde9440dc147259b62b9d3b019924def9f6478be254ac1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/64/80/4f6e318ee2a7c0750ed724fa33a4bdf1eacdc5a39a7a24e818a773cd91af/frozenlist-1.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:229bf37d2e4acdaf808fd3f06e854a4a7a3661e871b10dc1f8f1896a3b05f18b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2b/94/5c8a2b50a496b11dd519f4a24cb5496cf125681dd99e94c604ccdea9419a/frozenlist-1.8.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f833670942247a14eafbb675458b4e61c82e002a148f49e68257b79296e865c4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6a/bd/d91c5e39f490a49df14320f4e8c80161cfcce09f1e2cde1edd16a551abb3/frozenlist-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:494a5952b1c597ba44e0e78113a7266e656b9794eec897b19ead706bd7074383" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8f/83/f61505a05109ef3293dfb1ff594d13d64a2324ac3482be2cedc2be818256/frozenlist-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:96f423a119f4777a4a056b66ce11527366a8bb92f54e541ade21f2374433f6d4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d8/cb/cb6c7b0f7d4023ddda30cf56b8b17494eb3a79e3fda666bf735f63118b35/frozenlist-1.8.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3462dd9475af2025c31cc61be6652dfa25cbfb56cbbf52f4ccfe029f38decaf8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/31/c5/cd7a1f3b8b34af009fb17d4123c5a778b44ae2804e3ad6b86204255f9ec5/frozenlist-1.8.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:c4c800524c9cd9bac5166cd6f55285957fcfc907db323e193f2afcd4d9abd69b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c0/01/2f95d3b416c584a1e7f0e1d6d31998c4a795f7544069ee2e0962a4b60740/frozenlist-1.8.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:d6a5df73acd3399d893dafc71663ad22534b5aa4f94e8a2fabfe856c3c1b6a52" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ce/03/024bf7720b3abaebcff6d0793d73c154237b85bdf67b7ed55e5e9596dc9a/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:405e8fe955c2280ce66428b3ca55e12b3c4e9c336fb2103a4937e891c69a4a29" }, - { url = "https://mirrors.aliyun.com/pypi/packages/69/fa/f8abdfe7d76b731f5d8bd217827cf6764d4f1d9763407e42717b4bed50a0/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:908bd3f6439f2fef9e85031b59fd4f1297af54415fb60e4254a95f75b3cab3f3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f5/3c/b051329f718b463b22613e269ad72138cc256c540f78a6de89452803a47d/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:294e487f9ec720bd8ffcebc99d575f7eff3568a08a253d1ee1a0378754b74143" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0f/ae/58282e8f98e444b3f4dd42448ff36fa38bef29e40d40f330b22e7108f565/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:74c51543498289c0c43656701be6b077f4b265868fa7f8a8859c197006efb608" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8f/96/007e5944694d66123183845a106547a15944fbbb7154788cbf7272789536/frozenlist-1.8.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:776f352e8329135506a1d6bf16ac3f87bc25b28e765949282dcc627af36123aa" }, - { url = "https://mirrors.aliyun.com/pypi/packages/66/bb/852b9d6db2fa40be96f29c0d1205c306288f0684df8fd26ca1951d461a56/frozenlist-1.8.0-cp312-cp312-win32.whl", hash = "sha256:433403ae80709741ce34038da08511d4a77062aa924baf411ef73d1146e74faf" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b8/af/38e51a553dd66eb064cdf193841f16f077585d4d28394c2fa6235cb41765/frozenlist-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:34187385b08f866104f0c0617404c8eb08165ab1272e884abc89c112e9c00746" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a7/06/1dc65480ab147339fecc70797e9c2f69d9cea9cf38934ce08df070fdb9cb/frozenlist-1.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:fe3c58d2f5db5fbd18c2987cba06d51b0529f52bc3a6cdc33d3f4eab725104bd" }, { url = "https://mirrors.aliyun.com/pypi/packages/2d/40/0832c31a37d60f60ed79e9dfb5a92e1e2af4f40a16a29abcc7992af9edff/frozenlist-1.8.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8d92f1a84bb12d9e56f818b3a746f3efba93c1b63c8387a73dde655e1e42282a" }, { url = "https://mirrors.aliyun.com/pypi/packages/30/ba/b0b3de23f40bc55a7057bd38434e25c34fa48e17f20ee273bbde5e0650f3/frozenlist-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:96153e77a591c8adc2ee805756c61f59fef4cf4073a9275ee86fe8cba41241f7" }, { url = "https://mirrors.aliyun.com/pypi/packages/0c/ab/6e5080ee374f875296c4243c381bbdef97a9ac39c6e3ce1d5f7d42cb78d6/frozenlist-1.8.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f21f00a91358803399890ab167098c131ec2ddd5f8f5fd5fe9c9f2c6fcd91e40" }, @@ -2457,15 +2465,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/da/71/ae30dadffc90b9006d77af76b393cb9dfbfc9629f339fc1574a1c52e6806/future-1.0.0-py3-none-any.whl", hash = "sha256:929292d34f5872e70396626ef385ec22355a1fae8ad29e1a734c3e43f9fbc216" }, ] -[[package]] -name = "gast" -version = "0.7.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/91/f6/e73969782a2ecec280f8a176f2476149dd9dba69d5f8779ec6108a7721e6/gast-0.7.0.tar.gz", hash = "sha256:0bb14cd1b806722e91ddbab6fb86bba148c22b40e7ff11e248974e04c8adfdae" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/1d/33/f1c6a276de27b7d7339a34749cc33fa87f077f921969c47185d34a887ae2/gast-0.7.0-py3-none-any.whl", hash = "sha256:99cbf1365633a74099f69c59bd650476b96baa5ef196fec88032b00b31ba36f7" }, -] - [[package]] name = "gensim" version = "4.4.0" @@ -2477,11 +2476,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/1a/80/fe9d2e1ace968041814dbcfce4e8499a643a36c41267fa4b6c4f54cce420/gensim-4.4.0.tar.gz", hash = "sha256:a3f5b626da5518e79a479140361c663089fe7998df8ba52d56e1ded71ac5bdf5" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/4f/65/d5285865ca54b93d41ccd8683c2d79952434957c76b411283c7a6c66ca69/gensim-4.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:0845b2fa039dbea5667fb278b5414e70f6d48fd208ef51f33e84a78444288d8d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/32/59/f0ea443cbfb3b06e1d2e060217bb91f954845f6df38cbc9c5468b6c9c638/gensim-4.4.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1853fc5be730f692c444a826041fef9a2fc8d74c73bb59748904b2e3221daa86" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f0/b8/9b0ba15756e41ccfdd852f9c65cd2b552f240c201dc3237ad8c178642e80/gensim-4.4.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:23a2a4260f01c8f71bae5dd0e8a01bb247a2c789480c033e0eaba100b0ad4239" }, - { url = "https://mirrors.aliyun.com/pypi/packages/97/2c/c29701826c963b04a43d5d7b87573a74040387ab9219e65b10f377d22b5b/gensim-4.4.0-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4b73ff30af6ddd0d2ddf9473b1eb44603cd79ec14c87d93b75291802b991916c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fd/f2/9ec6863143888bf390cdc5261f6d9e71d79bc95d98fb815679dba478d5f6/gensim-4.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:b3a3f9bc8d4178b01d114e1c58c5ab2333f131c7415fb3d8ec8f1ecfe4c5b544" }, { url = "https://mirrors.aliyun.com/pypi/packages/80/6c/4e522973e07ca491d33cc7829996b9e8c8663a16b3f87f580cbdc2732d97/gensim-4.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b8961b7a2bb5190b46bc6cd26c29d5bfea22f99123ed5f506ebd0aaf65996758" }, { url = "https://mirrors.aliyun.com/pypi/packages/cc/6a/593107ee98331128ed20e5d074865587558a0766659be787a40550ab66df/gensim-4.4.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:59d0d29099a76dd97d4563e002f3488a43e51f99d46387025da38007ebfeeff9" }, { url = "https://mirrors.aliyun.com/pypi/packages/d9/ef/1675e1a3a04f7d0293a21082f57f4a6a8bf0a9e387da58b71db648b663de/gensim-4.4.0-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3bec3e6a1ecaa6439b21a3e42ceb0ca67ffabc114b646f89b1aab5fe69a39ffc" }, @@ -2613,11 +2607,6 @@ version = "1.8.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/03/41/4b9c02f99e4c5fb477122cd5437403b552873f014616ac1d19ac8221a58d/google_crc32c-1.8.0.tar.gz", hash = "sha256:a428e25fb7691024de47fecfbff7ff957214da51eddded0da0ae0e0f03a2cf79" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/e9/5f/7307325b1198b59324c0fa9807cafb551afb65e831699f2ce211ad5c8240/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:4b8286b659c1335172e39563ab0a768b8015e88e08329fa5321f774275fc3113" }, - { url = "https://mirrors.aliyun.com/pypi/packages/21/8e/58c0d5d86e2220e6a37befe7e6a94dd2f6006044b1a33edf1ff6d9f7e319/google_crc32c-1.8.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:2a3dc3318507de089c5384cc74d54318401410f82aa65b2d9cdde9d297aca7cb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ce/a9/a780cc66f86335a6019f557a8aaca8fbb970728f0efd2430d15ff1beae0e/google_crc32c-1.8.0-cp312-cp312-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:14f87e04d613dfa218d6135e81b78272c3b904e2a7053b841481b38a7d901411" }, - { url = "https://mirrors.aliyun.com/pypi/packages/21/3f/3457ea803db0198c9aaca2dd373750972ce28a26f00544b6b85088811939/google_crc32c-1.8.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:cb5c869c2923d56cb0c8e6bcdd73c009c36ae39b652dbe46a05eb4ef0ad01454" }, - { url = "https://mirrors.aliyun.com/pypi/packages/df/c0/87c2073e0c72515bb8733d4eef7b21548e8d189f094b5dad20b0ecaf64f6/google_crc32c-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:3cc0c8912038065eafa603b238abf252e204accab2a704c63b9e14837a854962" }, { url = "https://mirrors.aliyun.com/pypi/packages/d1/db/000f15b41724589b0e7bc24bc7a8967898d8d3bc8caf64c513d91ef1f6c0/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3ebb04528e83b2634857f43f9bb8ef5b2bbe7f10f140daeb01b58f972d04736b" }, { url = "https://mirrors.aliyun.com/pypi/packages/d7/0d/8ebed0c39c53a7e838e2a486da8abb0e52de135f1b376ae2f0b160eb4c1a/google_crc32c-1.8.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:450dc98429d3e33ed2926fc99ee81001928d63460f8538f21a5d6060912a8e27" }, { url = "https://mirrors.aliyun.com/pypi/packages/ce/42/b468aec74a0354b34c8cbf748db20d6e350a68a2b0912e128cabee49806c/google_crc32c-1.8.0-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:3b9776774b24ba76831609ffbabce8cdf6fa2bd5e9df37b594221c7e333a81fa" }, @@ -2651,18 +2640,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/84/de/7d3ee9c94b74c3578ea4f88d45e8de9405902f857932334d81e89bce3dfa/google_genai-1.68.0-py3-none-any.whl", hash = "sha256:a1bc9919c0e2ea2907d1e319b65471d3d6d58c54822039a249fe1323e4178d15" }, ] -[[package]] -name = "google-pasta" -version = "0.2.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "six" }, -] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/35/4a/0bd53b36ff0323d10d5f24ebd67af2de10a1117f5cf4d7add90df92756f1/google-pasta-0.2.0.tar.gz", hash = "sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/a3/de/c648ef6835192e6e2cc03f40b19eeda4382c49b5bafb43d88b931c4c74ac/google_pasta-0.2.0-py3-none-any.whl", hash = "sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed" }, -] - [[package]] name = "google-resumable-media" version = "2.8.0" @@ -2750,15 +2727,6 @@ version = "3.3.2" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/a3/51/1664f6b78fc6ebbd98019a1fd730e83fa78f2db7058f72b1463d3612b8db/greenlet-3.3.2.tar.gz", hash = "sha256:2eaf067fc6d886931c7962e8c6bede15d2f01965560f3359b27c80bde2d151f2" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/ea/ab/1608e5a7578e62113506740b88066bf09888322a311cff602105e619bd87/greenlet-3.3.2-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:ac8d61d4343b799d1e526db579833d72f23759c71e07181c2d2944e429eb09cd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a5/23/0eae412a4ade4e6623ff7626e38998cb9b11e9ff1ebacaa021e4e108ec15/greenlet-3.3.2-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3ceec72030dae6ac0c8ed7591b96b70410a8be370b6a477b1dbc072856ad02bd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f8/16/5b1678a9c07098ecb9ab2dd159fafaf12e963293e61ee8d10ecb55273e5e/greenlet-3.3.2-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:a2a5be83a45ce6188c045bcc44b0ee037d6a518978de9a5d97438548b953a1ac" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5c/c5/cc09412a29e43406eba18d61c70baa936e299bc27e074e2be3806ed29098/greenlet-3.3.2-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:ae9e21c84035c490506c17002f5c8ab25f980205c3e61ddb3a2a2a2e6c411fcb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/50/1f/5155f55bd71cabd03765a4aac9ac446be129895271f73872c36ebd4b04b6/greenlet-3.3.2-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:43e99d1749147ac21dde49b99c9abffcbc1e2d55c67501465ef0930d6e78e070" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fc/dd/845f249c3fcd69e32df80cdab059b4be8b766ef5830a3d0aa9d6cad55beb/greenlet-3.3.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4c956a19350e2c37f2c48b336a3afb4bff120b36076d9d7fb68cb44e05d95b79" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2a/50/2649fe21fcc2b56659a452868e695634722a6655ba245d9f77f5656010bf/greenlet-3.3.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:6c6f8ba97d17a1e7d664151284cb3315fc5f8353e75221ed4324f84eb162b395" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9b/40/cc802e067d02af8b60b6771cea7d57e21ef5e6659912814babb42b864713/greenlet-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:34308836d8370bddadb41f5a7ce96879b72e2fdfb4e87729330c6ab52376409f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/58/2e/fe7f36ff1982d6b10a60d5e0740c759259a7d6d2e1dc41da6d96de32fff6/greenlet-3.3.2-cp312-cp312-win_arm64.whl", hash = "sha256:d3a62fa76a32b462a97198e4c9e99afb9ab375115e74e9a83ce180e7a496f643" }, { url = "https://mirrors.aliyun.com/pypi/packages/ac/48/f8b875fa7dea7dd9b33245e37f065af59df6a25af2f9561efa8d822fde51/greenlet-3.3.2-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:aa6ac98bdfd716a749b84d4034486863fd81c3abde9aa3cf8eff9127981a4ae4" }, { url = "https://mirrors.aliyun.com/pypi/packages/49/8d/9771d03e7a8b1ee456511961e1b97a6d77ae1dea4a34a5b98eee706689d3/greenlet-3.3.2-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ab0c7e7901a00bc0a7284907273dc165b32e0d109a6713babd04471327ff7986" }, { url = "https://mirrors.aliyun.com/pypi/packages/59/0e/4223c2bbb63cd5c97f28ffb2a8aee71bdfb30b323c35d409450f51b91e3e/greenlet-3.3.2-cp313-cp313-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d248d8c23c67d2291ffd47af766e2a3aa9fa1c6703155c099feb11f526c63a92" }, @@ -2813,16 +2781,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/06/8a/3d098f35c143a89520e568e6539cc098fcd294495910e359889ce8741c84/grpcio-1.78.0.tar.gz", hash = "sha256:7382b95189546f375c174f53a5fa873cef91c4b8005faa05cc5b3beea9c4f1c5" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/4e/f4/7384ed0178203d6074446b3c4f46c90a22ddf7ae0b3aee521627f54cfc2a/grpcio-1.78.0-cp312-cp312-linux_armv7l.whl", hash = "sha256:f9ab915a267fc47c7e88c387a3a28325b58c898e23d4995f765728f4e3dedb97" }, - { url = "https://mirrors.aliyun.com/pypi/packages/81/ed/be1caa25f06594463f685b3790b320f18aea49b33166f4141bfdc2bfb236/grpcio-1.78.0-cp312-cp312-macosx_11_0_universal2.whl", hash = "sha256:3f8904a8165ab21e07e58bf3e30a73f4dffc7a1e0dbc32d51c61b5360d26f43e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/24/a7/f06d151afc4e64b7e3cc3e872d331d011c279aaab02831e40a81c691fb65/grpcio-1.78.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:859b13906ce098c0b493af92142ad051bf64c7870fa58a123911c88606714996" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8a/a8/4482922da832ec0082d0f2cc3a10976d84a7424707f25780b82814aafc0a/grpcio-1.78.0-cp312-cp312-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:b2342d87af32790f934a79c3112641e7b27d63c261b8b4395350dad43eff1dc7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/54/bf/f4a3b9693e35d25b24b0b39fa46d7d8a3c439e0a3036c3451764678fec20/grpcio-1.78.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:12a771591ae40bc65ba67048fa52ef4f0e6db8279e595fd349f9dfddeef571f9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c7/b9/521875265cc99fe5ad4c5a17010018085cae2810a928bf15ebe7d8bcd9cc/grpcio-1.78.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:185dea0d5260cbb2d224c507bf2a5444d5abbb1fa3594c1ed7e4c709d5eb8383" }, - { url = "https://mirrors.aliyun.com/pypi/packages/05/86/296a82844fd40a4ad4a95f100b55044b4f817dece732bf686aea1a284147/grpcio-1.78.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:51b13f9aed9d59ee389ad666b8c2214cc87b5de258fa712f9ab05f922e3896c6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f3/e4/ea3c0caf5468537f27ad5aab92b681ed7cc0ef5f8c9196d3fd42c8c2286b/grpcio-1.78.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fd5f135b1bd58ab088930b3c613455796dfa0393626a6972663ccdda5b4ac6ce" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d7/47/7f05f81e4bb6b831e93271fb12fd52ba7b319b5402cbc101d588f435df00/grpcio-1.78.0-cp312-cp312-win32.whl", hash = "sha256:94309f498bcc07e5a7d16089ab984d42ad96af1d94b5a4eb966a266d9fcabf68" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ad/e7/d6914822c88aa2974dbbd10903d801a28a19ce9cd8bad7e694cbbcf61528/grpcio-1.78.0-cp312-cp312-win_amd64.whl", hash = "sha256:9566fe4ababbb2610c39190791e5b829869351d14369603702e890ef3ad2d06e" }, { url = "https://mirrors.aliyun.com/pypi/packages/05/a9/8f75894993895f361ed8636cd9237f4ab39ef87fd30db17467235ed1c045/grpcio-1.78.0-cp313-cp313-linux_armv7l.whl", hash = "sha256:ce3a90455492bf8bfa38e56fbbe1dbd4f872a3d8eeaf7337dc3b1c8aa28c271b" }, { url = "https://mirrors.aliyun.com/pypi/packages/55/06/0b78408e938ac424100100fd081189451b472236e8a3a1f6500390dc4954/grpcio-1.78.0-cp313-cp313-macosx_11_0_universal2.whl", hash = "sha256:2bf5e2e163b356978b23652c4818ce4759d40f4712ee9ec5a83c4be6f8c23a3a" }, { url = "https://mirrors.aliyun.com/pypi/packages/88/93/b59fe7832ff6ae3c78b813ea43dac60e295fa03606d14d89d2e0ec29f4f3/grpcio-1.78.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8f2ac84905d12918e4e55a16da17939eb63e433dc11b677267c35568aa63fc84" }, @@ -2870,16 +2828,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/ad/9a/edfefb47f11ef6b0f39eea4d8f022c5bb05ac1d14fcc7058e84a51305b73/grpcio_tools-1.71.2.tar.gz", hash = "sha256:b5304d65c7569b21270b568e404a5a843cf027c66552a6a0978b23f137679c09" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/9c/d3/3ed30a9c5b2424627b4b8411e2cd6a1a3f997d3812dbc6a8630a78bcfe26/grpcio_tools-1.71.2-cp312-cp312-linux_armv7l.whl", hash = "sha256:bfc0b5d289e383bc7d317f0e64c9dfb59dc4bef078ecd23afa1a816358fb1473" }, - { url = "https://mirrors.aliyun.com/pypi/packages/54/61/e0b7295456c7e21ef777eae60403c06835160c8d0e1e58ebfc7d024c51d3/grpcio_tools-1.71.2-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:b4669827716355fa913b1376b1b985855d5cfdb63443f8d18faf210180199006" }, - { url = "https://mirrors.aliyun.com/pypi/packages/75/d7/7bcad6bcc5f5b7fab53e6bce5db87041f38ef3e740b1ec2d8c49534fa286/grpcio_tools-1.71.2-cp312-cp312-manylinux_2_17_aarch64.whl", hash = "sha256:d4071f9b44564e3f75cdf0f05b10b3e8c7ea0ca5220acbf4dc50b148552eef2f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b2/8a/e4c1c4cb8c9ff7f50b7b2bba94abe8d1e98ea05f52a5db476e7f1c1a3c70/grpcio_tools-1.71.2-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a28eda8137d587eb30081384c256f5e5de7feda34776f89848b846da64e4be35" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fd/aa/95bc77fda5c2d56fb4a318c1b22bdba8914d5d84602525c99047114de531/grpcio_tools-1.71.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b19c083198f5eb15cc69c0a2f2c415540cbc636bfe76cea268e5894f34023b40" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c9/ff/ca11f930fe1daa799ee0ce1ac9630d58a3a3deed3dd2f465edb9a32f299d/grpcio_tools-1.71.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:784c284acda0d925052be19053d35afbf78300f4d025836d424cf632404f676a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/64/10/c6fc97914c7e19c9bb061722e55052fa3f575165da9f6510e2038d6e8643/grpcio_tools-1.71.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:381e684d29a5d052194e095546eef067201f5af30fd99b07b5d94766f44bf1ae" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e5/d6/965f36cfc367c276799b730d5dd1311b90a54a33726e561393b808339b04/grpcio_tools-1.71.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3e4b4801fabd0427fc61d50d09588a01b1cfab0ec5e8a5f5d515fbdd0891fd11" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8d/f0/c05d5c3d0c1d79ac87df964e9d36f1e3a77b60d948af65bec35d3e5c75a3/grpcio_tools-1.71.2-cp312-cp312-win32.whl", hash = "sha256:84ad86332c44572305138eafa4cc30040c9a5e81826993eae8227863b700b490" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e2/e9/c84c1078f0b7af7d8a40f5214a9bdd8d2a567ad6c09975e6e2613a08d29d/grpcio_tools-1.71.2-cp312-cp312-win_amd64.whl", hash = "sha256:8e1108d37eecc73b1c4a27350a6ed921b5dda25091700c1da17cfe30761cd462" }, { url = "https://mirrors.aliyun.com/pypi/packages/60/9c/bdf9c5055a1ad0a09123402d73ecad3629f75b9cf97828d547173b328891/grpcio_tools-1.71.2-cp313-cp313-linux_armv7l.whl", hash = "sha256:b0f0a8611614949c906e25c225e3360551b488d10a366c96d89856bcef09f729" }, { url = "https://mirrors.aliyun.com/pypi/packages/49/d0/6aaee4940a8fb8269c13719f56d69c8d39569bee272924086aef81616d4a/grpcio_tools-1.71.2-cp313-cp313-macosx_10_14_universal2.whl", hash = "sha256:7931783ea7ac42ac57f94c5047d00a504f72fbd96118bf7df911bb0e0435fc0f" }, { url = "https://mirrors.aliyun.com/pypi/packages/d9/11/50a471dcf301b89c0ed5ab92c533baced5bd8f796abfd133bbfadf6b60e5/grpcio_tools-1.71.2-cp313-cp313-manylinux_2_17_aarch64.whl", hash = "sha256:d188dc28e069aa96bb48cb11b1338e47ebdf2e2306afa58a8162cc210172d7a8" }, @@ -2914,49 +2862,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd" }, ] -[[package]] -name = "h5py" -version = "3.16.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "numpy" }, -] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/db/33/acd0ce6863b6c0d7735007df01815403f5589a21ff8c2e1ee2587a38f548/h5py-3.16.0.tar.gz", hash = "sha256:a0dbaad796840ccaa67a4c144a0d0c8080073c34c76d5a6941d6818678ef2738" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/c8/c0/5d4119dba94093bbafede500d3defd2f5eab7897732998c04b54021e530b/h5py-3.16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c5313566f4643121a78503a473f0fb1e6dcc541d5115c44f05e037609c565c4d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b0/42/c84efcc1d4caebafb1ecd8be4643f39c85c47a80fe254d92b8b43b1eadaf/h5py-3.16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:42b012933a83e1a558c673176676a10ce2fd3759976a0fedee1e672d1e04fc9d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/89/84/06281c82d4d1686fde1ac6b0f307c50918f1c0151062445ab3b6fa5a921d/h5py-3.16.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:ff24039e2573297787c3063df64b60aab0591980ac898329a08b0320e0cf2527" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9e/e9/1a19e42cd43cc1365e127db6aae85e1c671da1d9a5d746f4d34a50edb577/h5py-3.16.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:dfc21898ff025f1e8e67e194965a95a8d4754f452f83454538f98f8a3fcb207e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b7/8e/9790c1655eabeb85b92b1ecab7d7e62a2069e53baefd58c98f0909c7a948/h5py-3.16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:698dd69291272642ffda44a0ecd6cd3bda5faf9621452d255f57ce91487b9794" }, - { url = "https://mirrors.aliyun.com/pypi/packages/51/d7/ab693274f1bd7e8c5f9fdd6c7003a88d59bedeaf8752716a55f532924fbb/h5py-3.16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:2b2c02b0a160faed5fb33f1ba8a264a37ee240b22e049ecc827345d0d9043074" }, - { url = "https://mirrors.aliyun.com/pypi/packages/03/c1/0976b235cf29ead553e22f2fb6385a8252b533715e00d0ae52ed7b900582/h5py-3.16.0-cp312-cp312-win_amd64.whl", hash = "sha256:96b422019a1c8975c2d5dadcf61d4ba6f01c31f92bbde6e4649607885fe502d6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/14/d9/866b7e570b39070f92d47b0ff1800f0f8239b6f9e45f02363d7112336c1f/h5py-3.16.0-cp312-cp312-win_arm64.whl", hash = "sha256:39c2838fb1e8d97bcf1755e60ad1f3dd76a7b2a475928dc321672752678b96db" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0f/9e/6142ebfda0cb6e9349c091eae73c2e01a770b7659255248d637bec54a88b/h5py-3.16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:370a845f432c2c9619db8eed334d1e610c6015796122b0e57aa46312c22617d9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b0/65/5e088a45d0f43cd814bc5bec521c051d42005a472e804b1a36c48dada09b/h5py-3.16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:42108e93326c50c2810025aade9eac9d6827524cdccc7d4b75a546e5ab308edb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/da/1e/6172269e18cc5a484e2913ced33339aad588e02ba407fafd00d369e22ef3/h5py-3.16.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:099f2525c9dcf28de366970a5fb34879aab20491589fa89ce2863a84218bb524" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bd/98/ef2b6fe2903e377cbe870c3b2800d62552f1e3dbe81ce49e1923c53d1c5c/h5py-3.16.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:9300ad32dea9dfc5171f94d5f6948e159ed93e4701280b0f508773b3f582f402" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bc/81/5b62d760039eed64348c98129d17061fdfc7839fc9c04eaaad6dee1004e4/h5py-3.16.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:171038f23bccddfc23f344cadabdfc9917ff554db6a0d417180d2747fe4c75a7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/28/c4/532123bcd9080e250696779c927f2cb906c8bf3447df98f5ceb8dcded539/h5py-3.16.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7e420b539fb6023a259a1b14d4c9f6df8cf50d7268f48e161169987a57b737ff" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c3/d9/a27997f84341fc0dfcdd1fe4179b6ba6c32a7aa880fdb8c514d4dad6fba3/h5py-3.16.0-cp313-cp313-win_amd64.whl", hash = "sha256:18f2bbcd545e6991412253b98727374c356d67caa920e68dc79eab36bf5fedad" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a5/23/bb8647521d4fd770c30a76cfc6cb6a2f5495868904054e92f2394c5a78ff/h5py-3.16.0-cp313-cp313-win_arm64.whl", hash = "sha256:656f00e4d903199a1d58df06b711cf3ca632b874b4207b7dbec86185b5c8c7d4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/48/3c/7fcd9b4c9eed82e91fb15568992561019ae7a829d1f696b2c844355d95dd/h5py-3.16.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:9c9d307c0ef862d1cd5714f72ecfafe0a5d7529c44845afa8de9f46e5ba8bd65" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6a/b7/9366ed44ced9b7ef357ab48c94205280276db9d7f064aa3012a97227e966/h5py-3.16.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8c1eff849cdd53cbc73c214c30ebdb6f1bb8b64790b4b4fc36acdb5e43570210" }, - { url = "https://mirrors.aliyun.com/pypi/packages/58/a5/4964bc0e91e86340c2bbda83420225b2f770dcf1eb8a39464871ad769436/h5py-3.16.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:e2c04d129f180019e216ee5f9c40b78a418634091c8782e1f723a6ca3658b965" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f1/16/d905e7f53e661ce2c24686c38048d8e2b750ffc4350009d41c4e6c6c9826/h5py-3.16.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:e4360f15875a532bc7b98196c7592ed4fc92672a57c0a621355961cafb17a6dd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4b/f2/58f34cb74af46d39f4cd18ea20909a8514960c5a3e5b92fd06a28161e0a8/h5py-3.16.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:3fae9197390c325e62e0a1aa977f2f62d994aa87aab182abbea85479b791197c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ce/ca/934a39c24ce2e2db017268c08da0537c20fa0be7e1549be3e977313fc8f5/h5py-3.16.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:43259303989ac8adacc9986695b31e35dba6fd1e297ff9c6a04b7da5542139cc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3e/14/615a450205e1b56d16c6783f5ccd116cde05550faad70ae077c955654a75/h5py-3.16.0-cp314-cp314-win_amd64.whl", hash = "sha256:fa48993a0b799737ba7fd21e2350fa0a60701e58180fae9f2de834bc39a147ab" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7b/48/a6faef5ed632cae0c65ac6b214a6614a0b510c3183532c521bdb0055e117/h5py-3.16.0-cp314-cp314-win_arm64.whl", hash = "sha256:1897a771a7f40d05c262fc8f37376ec37873218544b70216872876c627640f63" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5d/32/0c8bb8aedb62c772cf7c1d427c7d1951477e8c2835f872bc0a13d1f85f86/h5py-3.16.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:15922e485844f77c0b9d275396d435db3baa58292a9c2176a386e072e0cf2491" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1d/1f/fcc5977d32d6387c5c9a694afee716a5e20658ac08b3ff24fdec79fb05f2/h5py-3.16.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:df02dd29bd247f98674634dfe41f89fd7c16ba3d7de8695ec958f58404a4e618" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f5/a1/af87f64b9f986889884243643621ebbd4ac72472ba8ec8cec891ac8e2ca1/h5py-3.16.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:0f456f556e4e2cebeebd9d66adf8dc321770a42593494a0b6f0af54a7567b242" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cc/d0/146f5eaff3dc246a9c7f6e5e4f42bd45cc613bce16693bcd4d1f7c958bf5/h5py-3.16.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:3e6cb3387c756de6a9492d601553dffea3fe11b5f22b443aac708c69f3f55e16" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a1/9d/12a13424f1e604fc7df9497b73c0356fb78c2fb206abd7465ce47226e8fd/h5py-3.16.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:8389e13a1fd745ad2856873e8187fd10268b2d9677877bb667b41aebd771d8b7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/41/8c/bbe98f813722b4873818a8db3e15aa3e625b59278566905ac439725e8070/h5py-3.16.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:346df559a0f7dcb31cf8e44805319e2ab24b8957c45e7708ce503b2ec79ba725" }, - { url = "https://mirrors.aliyun.com/pypi/packages/32/9e/87e6705b4d6890e7cecdf876e2a7d3e40654a2ae37482d79a6f1b87f7b92/h5py-3.16.0-cp314-cp314t-win_amd64.whl", hash = "sha256:4c6ab014ab704b4feaa719ae783b86522ed0bf1f82184704ed3c9e4e3228796e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/96/91/9fad90cfc5f9b2489c7c26ad897157bce82f0e9534a986a221b99760b23b/h5py-3.16.0-cp314-cp314t-win_arm64.whl", hash = "sha256:faca8fb4e4319c09d83337adc80b2ca7d5c5a343c2d6f1b6388f32cfecca13c1" }, -] - [[package]] name = "hanziconv" version = "0.3.2" @@ -3228,17 +3133,6 @@ version = "3.5.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/f4/57/60d1a6a512f2f0508d0bc8b4f1cc5616fd3196619b66bd6a01f9155a1292/ijson-3.5.0.tar.gz", hash = "sha256:94688760720e3f5212731b3cb8d30267f9a045fb38fb3870254e7b9504246f31" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/aa/17/9c63c7688025f3a8c47ea717b8306649c8c7244e49e20a2be4e3515dc75c/ijson-3.5.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1ebefbe149a6106cc848a3eaf536af51a9b5ccc9082de801389f152dba6ab755" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6f/dd/e15c2400244c117b06585452ebc63ae254f5a6964f712306afd1422daae0/ijson-3.5.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:19e30d9f00f82e64de689c0b8651b9cfed879c184b139d7e1ea5030cec401c21" }, - { url = "https://mirrors.aliyun.com/pypi/packages/77/a9/bf4fe3538a0c965f16b406f180a06105b875da83f0743e36246be64ef550/ijson-3.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a04a33ee78a6f27b9b8528c1ca3c207b1df3b8b867a4cf2fcc4109986f35c227" }, - { url = "https://mirrors.aliyun.com/pypi/packages/31/76/6f91bdb019dd978fce1bc5ea1cd620cfc096d258126c91db2c03a20a7f34/ijson-3.5.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7d48dc2984af02eb3c56edfb3f13b3f62f2f3e4fe36f058c8cfc75d93adf4fed" }, - { url = "https://mirrors.aliyun.com/pypi/packages/11/be/bbc983059e48a54b0121ee60042979faed7674490bbe7b2c41560db3f436/ijson-3.5.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f1e73a44844d9adbca9cf2c4132cd875933e83f3d4b23881fcaf82be83644c7d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6d/81/2fee58f9024a3449aee83edfa7167fb5ccd7e1af2557300e28531bb68e16/ijson-3.5.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:7389a56b8562a19948bdf1d7bae3a2edc8c7f86fb59834dcb1c4c722818e645a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c7/56/f1706761fcc096c9d414b3dcd000b1e6e5c24364c21cfba429837f98ee8d/ijson-3.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3176f23f8ebec83f374ed0c3b4e5a0c4db7ede54c005864efebbed46da123608" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d9/6e/ee0d9c875a0193b632b3e9ccd1b22a50685fb510256ad57ba483b6529f77/ijson-3.5.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:6babd88e508630c6ef86c9bebaaf13bb2fb8ec1d8f8868773a03c20253f599bc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d2/bf/f9d4399d0e6e3fd615035290a71e97c843f17f329b43638c0a01cf112d73/ijson-3.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:dc1b3836b174b6db2fa8319f1926fb5445abd195dc963368092103f8579cb8ed" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b2/71/a7254a065933c0e2ffd3586f46187d84830d3d7b6f41cfa5901820a4f87d/ijson-3.5.0-cp312-cp312-win32.whl", hash = "sha256:6673de9395fb9893c1c79a43becd8c8fbee0a250be6ea324bfd1487bb5e9ee4c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8f/7b/2edca79b359fc9f95d774616867a03ecccdf333797baf5b3eea79733918c/ijson-3.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:f4f7fabd653459dcb004175235f310435959b1bb5dfa8878578391c6cc9ad944" }, { url = "https://mirrors.aliyun.com/pypi/packages/a2/71/d67e764a712c3590627480643a3b51efcc3afa4ef3cb54ee4c989073c97e/ijson-3.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:e9cedc10e40dd6023c351ed8bfc7dcfce58204f15c321c3c1546b9c7b12562a4" }, { url = "https://mirrors.aliyun.com/pypi/packages/1a/39/f1c299371686153fa3cf5c0736b96247a87a1bee1b7145e6d21f359c505a/ijson-3.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3647649f782ee06c97490b43680371186651f3f69bebe64c6083ee7615d185e5" }, { url = "https://mirrors.aliyun.com/pypi/packages/16/94/b1438e204d75e01541bebe3e668fe3e68612d210e9931ae1611062dd0a56/ijson-3.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:90e74be1dce05fce73451c62d1118671f78f47c9f6be3991c82b91063bf01fc9" }, @@ -3322,7 +3216,7 @@ wheels = [ [[package]] name = "infinity-sdk" -version = "0.7.0.dev6" +version = "0.7.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ { name = "datrie" }, @@ -3339,9 +3233,9 @@ dependencies = [ { name = "sqlglot", extra = ["rs"] }, { name = "thrift" }, ] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/21/5c/27a1afab0d96200421706afc33eba3d34684e2055c63761b3700f52f7cbf/infinity_sdk-0.7.0.dev6.tar.gz", hash = "sha256:6d8b9be0ace7fa5c790ed8bee39dc28faef448c74d60bb3be7c86f244d5d9b46" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/46/c9/0bb4981acdaca79b864b26d998201f652ef6024fce7d0c39da9c11605890/infinity_sdk-0.7.0.tar.gz", hash = "sha256:42ba8c6acd4fad918b1ed189fab3383023e4750d46ef1bf1c9465ffcf3ff8335" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/8c/e5/6dbc74929b63c55e3bf867e9fcade32b5b7fbea2f8d2cf9d99e2df1ae70c/infinity_sdk-0.7.0.dev6-py3-none-any.whl", hash = "sha256:9cf97aaea0238881d6be2cb11585e57069c7dfb2fce0e8002868b6bf916dba51" }, + { url = "https://mirrors.aliyun.com/pypi/packages/27/56/01a0b4b816c70595a83ad2d7ec387d1d991f0ea4607d77799bac010be27c/infinity_sdk-0.7.0-py3-none-any.whl", hash = "sha256:4772fded64ff733eb3dff36df3a2c1c867049acf3c6e54d4e40b5a915c0f3c18" }, ] [[package]] @@ -3375,6 +3269,15 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/99/15/4ac989bf4019271fa688d9f95c7b01088958306b4fdc6929f9fa042a6e81/inscriptis-2.7.1-py3-none-any.whl", hash = "sha256:fd41d122e92b646527bca413e9e0270793d42c11fbe8045e388686199b6f30ca" }, ] +[[package]] +name = "invoke" +version = "3.0.3" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/33/f6/227c48c5fe47fa178ccf1fda8f047d16c97ba926567b661e9ce2045c600c/invoke-3.0.3.tar.gz", hash = "sha256:437b6a622223824380bfb4e64f612711a6b648c795f565efc8625af66fb57f0c" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/5a/de/bbc12563bbf979618d17625a4e753ff7a078523e28d870d3626daa97261a/invoke-3.0.3-py3-none-any.whl", hash = "sha256:f11327165e5cbb89b2ad1d88d3292b5113332c43b8553b494da435d6ec6f5053" }, +] + [[package]] name = "ir-datasets" version = "0.5.11" @@ -3454,19 +3357,6 @@ version = "0.13.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/0d/5e/4ec91646aee381d01cdb9974e30882c9cd3b8c5d1079d6b5ff4af522439a/jiter-0.13.0.tar.gz", hash = "sha256:f2839f9c2c7e2dffc1bc5929a510e14ce0a946be9365fd1219e7ef342dae14f4" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/2e/30/7687e4f87086829955013ca12a9233523349767f69653ebc27036313def9/jiter-0.13.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:0a2bd69fc1d902e89925fc34d1da51b2128019423d7b339a45d9e99c894e0663" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c3/27/e57f9a783246ed95481e6749cc5002a8a767a73177a83c63ea71f0528b90/jiter-0.13.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f917a04240ef31898182f76a332f508f2cc4b57d2b4d7ad2dbfebbfe167eb505" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cf/52/e5719a60ac5d4d7c5995461a94ad5ef962a37c8bf5b088390e6fad59b2ff/jiter-0.13.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c1e2b199f446d3e82246b4fd9236d7cb502dc2222b18698ba0d986d2fecc6152" }, - { url = "https://mirrors.aliyun.com/pypi/packages/61/db/c1efc32b8ba4c740ab3fc2d037d8753f67685f475e26b9d6536a4322bcdd/jiter-0.13.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04670992b576fa65bd056dbac0c39fe8bd67681c380cb2b48efa885711d9d726" }, - { url = "https://mirrors.aliyun.com/pypi/packages/55/8a/fb75556236047c8806995671a18e4a0ad646ed255276f51a20f32dceaeec/jiter-0.13.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5a1aff1fbdb803a376d4d22a8f63f8e7ccbce0b4890c26cc7af9e501ab339ef0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7e/16/43512e6ee863875693a8e6f6d532e19d650779d6ba9a81593ae40a9088ff/jiter-0.13.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3b3fb8c2053acaef8580809ac1d1f7481a0a0bdc012fd7f5d8b18fb696a5a089" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f8/4c/09b93e30e984a187bc8aaa3510e1ec8dcbdcd71ca05d2f56aac0492453aa/jiter-0.13.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bdaba7d87e66f26a2c45d8cbadcbfc4bf7884182317907baf39cfe9775bb4d93" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1a/1b/46c5e349019874ec5dfa508c14c37e29864ea108d376ae26d90bee238cd7/jiter-0.13.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:7b88d649135aca526da172e48083da915ec086b54e8e73a425ba50999468cc08" }, - { url = "https://mirrors.aliyun.com/pypi/packages/15/9e/26184760e85baee7162ad37b7912797d2077718476bf91517641c92b3639/jiter-0.13.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:e404ea551d35438013c64b4f357b0474c7abf9f781c06d44fcaf7a14c69ff9e2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e9/34/2c9355247d6debad57a0a15e76ab1566ab799388042743656e566b3b7de1/jiter-0.13.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1f4748aad1b4a93c8bdd70f604d0f748cdc0e8744c5547798acfa52f10e79228" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ac/4a/9f2c23255d04a834398b9c2e0e665382116911dc4d06b795710503cdad25/jiter-0.13.0-cp312-cp312-win32.whl", hash = "sha256:0bf670e3b1445fc4d31612199f1744f67f889ee1bbae703c4b54dc097e5dd394" }, - { url = "https://mirrors.aliyun.com/pypi/packages/09/ee/f0ae675a957ae5a8f160be3e87acea6b11dc7b89f6b7ab057e77b2d2b13a/jiter-0.13.0-cp312-cp312-win_amd64.whl", hash = "sha256:15db60e121e11fe186c0b15236bd5d18381b9ddacdcf4e659feb96fc6c969c92" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1b/02/ae611edf913d3cbf02c97cdb90374af2082c48d7190d74c1111dde08bcdd/jiter-0.13.0-cp312-cp312-win_arm64.whl", hash = "sha256:41f92313d17989102f3cb5dd533a02787cdb99454d494344b0361355da52fcb9" }, { url = "https://mirrors.aliyun.com/pypi/packages/91/9c/7ee5a6ff4b9991e1a45263bfc46731634c4a2bde27dfda6c8251df2d958c/jiter-0.13.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:1f8a55b848cbabf97d861495cd65f1e5c590246fabca8b48e1747c4dfc8f85bf" }, { url = "https://mirrors.aliyun.com/pypi/packages/7c/02/be5b870d1d2be5dd6a91bdfb90f248fbb7dcbd21338f092c6b89817c3dbf/jiter-0.13.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f556aa591c00f2c45eb1b89f68f52441a016034d18b65da60e2d2875bbbf344a" }, { url = "https://mirrors.aliyun.com/pypi/packages/da/92/b25d2ec333615f5f284f3a4024f7ce68cfa0604c322c6808b2344c7f5d2b/jiter-0.13.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f7e1d61da332ec412350463891923f960c3073cf1aae93b538f0bb4c8cd46efb" }, @@ -3510,10 +3400,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/47/66/eea81dfff765ed66c68fd2ed8c96245109e13c896c2a5015c7839c92367e/jiter-0.13.0-cp314-cp314t-win32.whl", hash = "sha256:24dc96eca9f84da4131cdf87a95e6ce36765c3b156fc9ae33280873b1c32d5f6" }, { url = "https://mirrors.aliyun.com/pypi/packages/ff/32/4ac9c7a76402f8f00d00842a7f6b83b284d0cf7c1e9d4227bc95aa6d17fa/jiter-0.13.0-cp314-cp314t-win_amd64.whl", hash = "sha256:0a8d76c7524087272c8ae913f5d9d608bd839154b62c4322ef65723d2e5bb0b8" }, { url = "https://mirrors.aliyun.com/pypi/packages/f9/8e/7def204fea9f9be8b3c21a6f2dd6c020cf56c7d5ff753e0e23ed7f9ea57e/jiter-0.13.0-cp314-cp314t-win_arm64.whl", hash = "sha256:2c26cf47e2cad140fa23b6d58d435a7c0161f5c514284802f25e87fddfe11024" }, - { url = "https://mirrors.aliyun.com/pypi/packages/80/60/e50fa45dd7e2eae049f0ce964663849e897300433921198aef94b6ffa23a/jiter-0.13.0-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:3d744a6061afba08dd7ae375dcde870cffb14429b7477e10f67e9e6d68772a0a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d2/73/a009f41c5eed71c49bec53036c4b33555afcdee70682a18c6f66e396c039/jiter-0.13.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:ff732bd0a0e778f43d5009840f20b935e79087b4dc65bd36f1cd0f9b04b8ff7f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c4/10/528b439290763bff3d939268085d03382471b442f212dca4ff5f12802d43/jiter-0.13.0-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ab44b178f7981fcaea7e0a5df20e773c663d06ffda0198f1a524e91b2fde7e59" }, - { url = "https://mirrors.aliyun.com/pypi/packages/67/8a/a342b2f0251f3dac4ca17618265d93bf244a2a4d089126e81e4c1056ac50/jiter-0.13.0-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7bb00b6d26db67a05fe3e12c76edc75f32077fb51deed13822dc648fa373bc19" }, ] [[package]] @@ -3585,46 +3471,12 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/4a/4a/cf14bf3b1f5ffb13c69cf5f0ea78031247790558ee88984a8bdd22fae60d/kaitaistruct-0.11-py2.py3-none-any.whl", hash = "sha256:5c6ce79177b4e193a577ecd359e26516d1d6d000a0bffd6e1010f2a46a62a561" }, ] -[[package]] -name = "keras" -version = "3.14.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "absl-py" }, - { name = "h5py" }, - { name = "ml-dtypes" }, - { name = "namex" }, - { name = "numpy" }, - { name = "optree" }, - { name = "packaging" }, - { name = "rich" }, -] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/88/ce/47874047a49eedc2a5d3b41bc4f1f572bb637f51e4351ef3538e49a63800/keras-3.14.0.tar.gz", hash = "sha256:86fcf8249a25264a566ac393c287c7ad657000e5e62615dcaad4b3472a17aeda" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/c0/20/78d26f81115d570bdf0e57d19b81de9ad8aa55ddb68eb10c8f0699fccb63/keras-3.14.0-py3-none-any.whl", hash = "sha256:19ce94b798caaba4d404ab6ef4753b44219170e5c2868156de8bb0494a260114" }, -] - [[package]] name = "kiwisolver" version = "1.5.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d0/67/9c61eccb13f0bdca9307614e782fec49ffdde0f7a2314935d489fa93cd9c/kiwisolver-1.5.0.tar.gz", hash = "sha256:d4193f3d9dc3f6f79aaed0e5637f45d98850ebf01f7ca20e69457f3e8946b66a" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/4d/b2/818b74ebea34dabe6d0c51cb1c572e046730e64844da6ed646d5298c40ce/kiwisolver-1.5.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:4e9750bc21b886308024f8a54ccb9a2cc38ac9fa813bf4348434e3d54f337ff9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bf/d9/405320f8077e8e1c5c4bd6adc45e1e6edf6d727b6da7f2e2533cf58bff71/kiwisolver-1.5.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:72ec46b7eba5b395e0a7b63025490d3214c11013f4aacb4f5e8d6c3041829588" }, - { url = "https://mirrors.aliyun.com/pypi/packages/99/9f/795fedf35634f746151ca8839d05681ceb6287fbed6cc1c9bf235f7887c2/kiwisolver-1.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ed3a984b31da7481b103f68776f7128a89ef26ed40f4dc41a2223cda7fb24819" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c4/13/680c54afe3e65767bed7ec1a15571e1a2f1257128733851ade24abcefbcc/kiwisolver-1.5.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:bb5136fb5352d3f422df33f0c879a1b0c204004324150cc3b5e3c4f310c9049f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c8/2f/cebfcdb60fd6a9b0f6b47a9337198bcbad6fbe15e68189b7011fd914911f/kiwisolver-1.5.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b2af221f268f5af85e776a73d62b0845fc8baf8ef0abfae79d29c77d0e776aaf" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f2/0d/9b782923aada3fafb1d6b84e13121954515c669b18af0c26e7d21f579855/kiwisolver-1.5.0-cp312-cp312-manylinux_2_24_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:b0f172dc8ffaccb8522d7c5d899de00133f2f1ca7b0a49b7da98e901de87bf2d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/27/70/83241b6634b04fe44e892688d5208332bde130f38e610c0418f9ede47ded/kiwisolver-1.5.0-cp312-cp312-manylinux_2_24_s390x.manylinux_2_28_s390x.whl", hash = "sha256:6ab8ba9152203feec73758dad83af9a0bbe05001eb4639e547207c40cfb52083" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e4/db/30ed226fb271ae1a6431fc0fe0edffb2efe23cadb01e798caeb9f2ceae8f/kiwisolver-1.5.0-cp312-cp312-manylinux_2_39_riscv64.whl", hash = "sha256:cdee07c4d7f6d72008d3f73b9bf027f4e11550224c7c50d8df1ae4a37c1402a6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ec/bd/c314595208e4c9587652d50959ead9e461995389664e490f4dce7ff0f782/kiwisolver-1.5.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:7c60d3c9b06fb23bd9c6139281ccbdc384297579ae037f08ae90c69f6845c0b1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c1/43/0499cec932d935229b5543d073c2b87c9c22846aab48881e9d8d6e742a2d/kiwisolver-1.5.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:e315e5ec90d88e140f57696ff85b484ff68bb311e36f2c414aa4286293e6dee0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3d/6f/79b0d760907965acfd9d61826a3d41f8f093c538f55cd2633d3f0db269f6/kiwisolver-1.5.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:1465387ac63576c3e125e5337a6892b9e99e0627d52317f3ca79e6930d889d15" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ab/31/01d0537c41cb75a551a438c3c7a80d0c60d60b81f694dac83dd436aec0d0/kiwisolver-1.5.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:530a3fd64c87cffa844d4b6b9768774763d9caa299e9b75d8eca6a4423b31314" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e4/34/8aefdd0be9cfd00a44509251ba864f5caf2991e36772e61c408007e7f417/kiwisolver-1.5.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:1d9daea4ea6b9be74fe2f01f7fbade8d6ffab263e781274cffca0dba9be9eec9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ad/cf/0348374369ca588f8fe9c338fae49fa4e16eeb10ffb3d012f23a54578a9e/kiwisolver-1.5.0-cp312-cp312-win_amd64.whl", hash = "sha256:f18c2d9782259a6dc132fdc7a63c168cbc74b35284b6d75c673958982a378384" }, - { url = "https://mirrors.aliyun.com/pypi/packages/28/26/192b26196e2316e2bd29deef67e37cdf9870d9af8e085e521afff0fed526/kiwisolver-1.5.0-cp312-cp312-win_arm64.whl", hash = "sha256:f7c7553b13f69c1b29a5bde08ddc6d9d0c8bfb84f9ed01c30db25944aeb852a7" }, { url = "https://mirrors.aliyun.com/pypi/packages/9d/69/024d6711d5ba575aa65d5538042e99964104e97fa153a9f10bc369182bc2/kiwisolver-1.5.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:fd40bb9cd0891c4c3cb1ddf83f8bbfa15731a248fdc8162669405451e2724b09" }, { url = "https://mirrors.aliyun.com/pypi/packages/ce/48/adbb40df306f587054a348831220812b9b1d787aff714cfbc8556e38fccd/kiwisolver-1.5.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c0e1403fd7c26d77c1f03e096dc58a5c726503fa0db0456678b8668f76f521e3" }, { url = "https://mirrors.aliyun.com/pypi/packages/a8/3a/d0a972b34e1c63e2409413104216cd1caa02c5a37cb668d1687d466c1c45/kiwisolver-1.5.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:dda366d548e89a90d88a86c692377d18d8bd64b39c1fb2b92cb31370e2896bbd" }, @@ -3684,10 +3536,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/07/18/43a5f24608d8c313dd189cf838c8e68d75b115567c6279de7796197cfb6a/kiwisolver-1.5.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:e7a116ae737f0000343218c4edf5bd45893bfeaff0993c0b215d7124c9f77646" }, { url = "https://mirrors.aliyun.com/pypi/packages/3b/b5/98222136d839b8afabcaa943b09bd05888c2d36355b7e448550211d1fca4/kiwisolver-1.5.0-cp314-cp314t-win_amd64.whl", hash = "sha256:1dd9b0b119a350976a6d781e7278ec7aca0b201e1a9e2d23d9804afecb6ca681" }, { url = "https://mirrors.aliyun.com/pypi/packages/99/a2/ca7dc962848040befed12732dff6acae7fb3c4f6fc4272b3f6c9a30b8713/kiwisolver-1.5.0-cp314-cp314t-win_arm64.whl", hash = "sha256:58f812017cd2985c21fbffb4864d59174d4903dd66fa23815e74bbc7a0e2dd57" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1c/fa/2910df836372d8761bb6eff7d8bdcb1613b5c2e03f260efe7abe34d388a7/kiwisolver-1.5.0-graalpy312-graalpy250_312_native-macosx_10_13_x86_64.whl", hash = "sha256:5ae8e62c147495b01a0f4765c878e9bfdf843412446a247e28df59936e99e797" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0f/41/c5f71f9f00aabcc71fee8b7475e3f64747282580c2fe748961ba29b18385/kiwisolver-1.5.0-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:f6764a4ccab3078db14a632420930f6186058750df066b8ea2a7106df91d3203" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fa/06/7399a607f434119c6e1fdc8ec89a8d51ccccadf3341dee4ead6bd14caaf5/kiwisolver-1.5.0-graalpy312-graalpy250_312_native-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c31c13da98624f957b0fb1b5bae5383b2333c2c3f6793d9825dd5ce79b525cb7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b5/91/53255615acd2a1eaca307ede3c90eb550bae9c94581f8c00081b6b1c8f44/kiwisolver-1.5.0-graalpy312-graalpy250_312_native-win_amd64.whl", hash = "sha256:1f1489f769582498610e015a8ef2d36f28f505ab3096d0e16b4858a9ec214f57" }, ] [[package]] @@ -3718,23 +3566,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/82/3d/14ce75ef66813643812f3093ab17e46d3a206942ce7376d31ec2d36229e7/lark-1.3.1-py3-none-any.whl", hash = "sha256:c629b661023a014c37da873b4ff58a817398d12635d3bbb2c5a03be7fe5d1e12" }, ] -[[package]] -name = "libclang" -version = "18.1.1" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/6e/5c/ca35e19a4f142adffa27e3d652196b7362fa612243e2b916845d801454fc/libclang-18.1.1.tar.gz", hash = "sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/4b/49/f5e3e7e1419872b69f6f5e82ba56e33955a74bd537d8a1f5f1eff2f3668a/libclang-18.1.1-1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e2/e5/fc61bbded91a8830ccce94c5294ecd6e88e496cc85f6704bf350c0634b70/libclang-18.1.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/db/ed/1df62b44db2583375f6a8a5e2ca5432bbdc3edb477942b9b7c848c720055/libclang-18.1.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1d/fc/716c1e62e512ef1c160e7984a73a5fc7df45166f2ff3f254e71c58076f7c/libclang-18.1.1-py2.py3-none-manylinux2010_x86_64.whl", hash = "sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3c/3d/f0ac1150280d8d20d059608cf2d5ff61b7c3b7f7bcf9c0f425ab92df769a/libclang-18.1.1-py2.py3-none-manylinux2014_aarch64.whl", hash = "sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fe/2f/d920822c2b1ce9326a4c78c0c2b4aa3fde610c7ee9f631b600acb5376c26/libclang-18.1.1-py2.py3-none-manylinux2014_armv7l.whl", hash = "sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2d/c2/de1db8c6d413597076a4259cea409b83459b2db997c003578affdd32bf66/libclang-18.1.1-py2.py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0b/2d/3f480b1e1d31eb3d6de5e3ef641954e5c67430d5ac93b7fa7e07589576c7/libclang-18.1.1-py2.py3-none-win_amd64.whl", hash = "sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/71/cf/e01dc4cc79779cd82d77888a88ae2fa424d93b445ad4f6c02bfc18335b70/libclang-18.1.1-py2.py3-none-win_arm64.whl", hash = "sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8" }, -] - [[package]] name = "litellm" version = "1.82.6" @@ -3764,10 +3595,6 @@ version = "0.46.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/74/cd/08ae687ba099c7e3d21fe2ea536500563ef1943c5105bf6ab4ee3829f68e/llvmlite-0.46.0.tar.gz", hash = "sha256:227c9fd6d09dce2783c18b754b7cd9d9b3b3515210c46acc2d3c5badd9870ceb" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/2b/f8/4db016a5e547d4e054ff2f3b99203d63a497465f81ab78ec8eb2ff7b2304/llvmlite-0.46.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b9588ad4c63b4f0175a3984b85494f0c927c6b001e3a246a3a7fb3920d9a137" }, - { url = "https://mirrors.aliyun.com/pypi/packages/aa/85/4890a7c14b4fa54400945cb52ac3cd88545bbdb973c440f98ca41591cdc5/llvmlite-0.46.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3535bd2bb6a2d7ae4012681ac228e5132cdb75fefb1bcb24e33f2f3e0c865ed4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6a/07/3d31d39c1a1a08cd5337e78299fca77e6aebc07c059fbd0033e3edfab45c/llvmlite-0.46.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4cbfd366e60ff87ea6cc62f50bc4cd800ebb13ed4c149466f50cf2163a473d1e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2a/6b/d139535d7590a1bba1ceb68751bef22fadaa5b815bbdf0e858e3875726b2/llvmlite-0.46.0-cp312-cp312-win_amd64.whl", hash = "sha256:398b39db462c39563a97b912d4f2866cd37cba60537975a09679b28fbbc0fb38" }, { url = "https://mirrors.aliyun.com/pypi/packages/e6/ff/3eba7eb0aed4b6fca37125387cd417e8c458e750621fce56d2c541f67fa8/llvmlite-0.46.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:30b60892d034bc560e0ec6654737aaa74e5ca327bd8114d82136aa071d611172" }, { url = "https://mirrors.aliyun.com/pypi/packages/0e/54/737755c0a91558364b9200702c3c9c15d70ed63f9b98a2c32f1c2aa1f3ba/llvmlite-0.46.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:6cc19b051753368a9c9f31dc041299059ee91aceec81bd57b0e385e5d5bf1a54" }, { url = "https://mirrors.aliyun.com/pypi/packages/e6/91/14f32e1d70905c1c0aa4e6609ab5d705c3183116ca02ac6df2091868413a/llvmlite-0.46.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bca185892908f9ede48c0acd547fe4dc1bafefb8a4967d47db6cf664f9332d12" }, @@ -3784,23 +3611,6 @@ version = "5.4.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/76/3d/14e82fc7c8fb1b7761f7e748fd47e2ec8276d137b6acfe5a4bb73853e08f/lxml-5.4.0.tar.gz", hash = "sha256:d12832e1dbea4be280b22fd0ea7c9b87f0d8fc51ba06e92dc62d52f804f78ebd" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/f8/4c/d101ace719ca6a4ec043eb516fcfcb1b396a9fccc4fcd9ef593df34ba0d5/lxml-5.4.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:b5aff6f3e818e6bdbbb38e5967520f174b18f539c2b9de867b1e7fde6f8d95a4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/11/84/beddae0cec4dd9ddf46abf156f0af451c13019a0fa25d7445b655ba5ccb7/lxml-5.4.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:942a5d73f739ad7c452bf739a62a0f83e2578afd6b8e5406308731f4ce78b16d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d0/25/d0d93a4e763f0462cccd2b8a665bf1e4343dd788c76dcfefa289d46a38a9/lxml-5.4.0-cp312-cp312-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:460508a4b07364d6abf53acaa0a90b6d370fafde5693ef37602566613a9b0779" }, - { url = "https://mirrors.aliyun.com/pypi/packages/31/ce/1df18fb8f7946e7f3388af378b1f34fcf253b94b9feedb2cec5969da8012/lxml-5.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:529024ab3a505fed78fe3cc5ddc079464e709f6c892733e3f5842007cec8ac6e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4e/62/f4a6c60ae7c40d43657f552f3045df05118636be1165b906d3423790447f/lxml-5.4.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ca56ebc2c474e8f3d5761debfd9283b8b18c76c4fc0967b74aeafba1f5647f9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9e/aa/04f00009e1e3a77838c7fc948f161b5d2d5de1136b2b81c712a263829ea4/lxml-5.4.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a81e1196f0a5b4167a8dafe3a66aa67c4addac1b22dc47947abd5d5c7a3f24b5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c9/1f/e0b2f61fa2404bf0f1fdf1898377e5bd1b74cc9b2cf2c6ba8509b8f27990/lxml-5.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:00b8686694423ddae324cf614e1b9659c2edb754de617703c3d29ff568448df5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/24/a2/8263f351b4ffe0ed3e32ea7b7830f845c795349034f912f490180d88a877/lxml-5.4.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:c5681160758d3f6ac5b4fea370495c48aac0989d6a0f01bb9a72ad8ef5ab75c4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/05/00/41db052f279995c0e35c79d0f0fc9f8122d5b5e9630139c592a0b58c71b4/lxml-5.4.0-cp312-cp312-manylinux_2_28_ppc64le.whl", hash = "sha256:2dc191e60425ad70e75a68c9fd90ab284df64d9cd410ba8d2b641c0c45bc006e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1d/be/ee99e6314cdef4587617d3b3b745f9356d9b7dd12a9663c5f3b5734b64ba/lxml-5.4.0-cp312-cp312-manylinux_2_28_s390x.whl", hash = "sha256:67f779374c6b9753ae0a0195a892a1c234ce8416e4448fe1e9f34746482070a7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ad/36/239820114bf1d71f38f12208b9c58dec033cbcf80101cde006b9bde5cffd/lxml-5.4.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:79d5bfa9c1b455336f52343130b2067164040604e41f6dc4d8313867ed540079" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d4/e1/1b795cc0b174efc9e13dbd078a9ff79a58728a033142bc6d70a1ee8fc34d/lxml-5.4.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3d3c30ba1c9b48c68489dc1829a6eede9873f52edca1dda900066542528d6b20" }, - { url = "https://mirrors.aliyun.com/pypi/packages/72/48/3c198455ca108cec5ae3662ae8acd7fd99476812fd712bb17f1b39a0b589/lxml-5.4.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:1af80c6316ae68aded77e91cd9d80648f7dd40406cef73df841aa3c36f6907c8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d6/10/5bf51858971c51ec96cfc13e800a9951f3fd501686f4c18d7d84fe2d6352/lxml-5.4.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4d885698f5019abe0de3d352caf9466d5de2baded00a06ef3f1216c1a58ae78f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2b/11/06710dd809205377da380546f91d2ac94bad9ff735a72b64ec029f706c85/lxml-5.4.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:aea53d51859b6c64e7c51d522c03cc2c48b9b5d6172126854cc7f01aa11f52bc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f5/b0/15b6217834b5e3a59ebf7f53125e08e318030e8cc0d7310355e6edac98ef/lxml-5.4.0-cp312-cp312-win32.whl", hash = "sha256:d90b729fd2732df28130c064aac9bb8aff14ba20baa4aee7bd0795ff1187545f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/91/1e/05ddcb57ad2f3069101611bd5f5084157d90861a2ef460bf42f45cced944/lxml-5.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1dc4ca99e89c335a7ed47d38964abcb36c5910790f9bd106f2a8fa2ee0b909d2" }, { url = "https://mirrors.aliyun.com/pypi/packages/87/cb/2ba1e9dd953415f58548506fa5549a7f373ae55e80c61c9041b7fd09a38a/lxml-5.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:773e27b62920199c6197130632c18fb7ead3257fce1ffb7d286912e56ddb79e0" }, { url = "https://mirrors.aliyun.com/pypi/packages/b5/3e/6602a4dca3ae344e8609914d6ab22e52ce42e3e1638c10967568c5c1450d/lxml-5.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ce9c671845de9699904b1e9df95acfe8dfc183f2310f163cdaa91a3535af95de" }, { url = "https://mirrors.aliyun.com/pypi/packages/4c/72/bf00988477d3bb452bef9436e45aeea82bb40cdfb4684b83c967c53909c7/lxml-5.4.0-cp313-cp313-manylinux_2_12_i686.manylinux2010_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9454b8d8200ec99a224df8854786262b1bd6461f4280064c807303c642c05e76" }, @@ -3843,14 +3653,6 @@ version = "4.4.5" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/57/51/f1b86d93029f418033dddf9b9f79c8d2641e7454080478ee2aab5123173e/lz4-4.4.5.tar.gz", hash = "sha256:5f0b9e53c1e82e88c10d7c180069363980136b9d7a8306c4dca4f760d60c39f0" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/1b/ac/016e4f6de37d806f7cc8f13add0a46c9a7cfc41a5ddc2bc831d7954cf1ce/lz4-4.4.5-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:df5aa4cead2044bab83e0ebae56e0944cc7fcc1505c7787e9e1057d6d549897e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8d/df/0fadac6e5bd31b6f34a1a8dbd4db6a7606e70715387c27368586455b7fc9/lz4-4.4.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6d0bf51e7745484d2092b3a51ae6eb58c3bd3ce0300cf2b2c14f76c536d5697a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b7/17/34e36cc49bb16ca73fb57fbd4c5eaa61760c6b64bce91fcb4e0f4a97f852/lz4-4.4.5-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:7b62f94b523c251cf32aa4ab555f14d39bd1a9df385b72443fd76d7c7fb051f5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/90/1c/b1d8e3741e9fc89ed3b5f7ef5f22586c07ed6bb04e8343c2e98f0fa7ff04/lz4-4.4.5-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c3ea562c3af274264444819ae9b14dbbf1ab070aff214a05e97db6896c7597e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/55/d9/e3867222474f6c1b76e89f3bd914595af69f55bf2c1866e984c548afdc15/lz4-4.4.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:24092635f47538b392c4eaeff14c7270d2c8e806bf4be2a6446a378591c5e69e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b2/e7/d667d337367686311c38b580d1ca3d5a23a6617e129f26becd4f5dc458df/lz4-4.4.5-cp312-cp312-win32.whl", hash = "sha256:214e37cfe270948ea7eb777229e211c601a3e0875541c1035ab408fbceaddf50" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a5/0b/a54cd7406995ab097fceb907c7eb13a6ddd49e0b231e448f1a81a50af65c/lz4-4.4.5-cp312-cp312-win_amd64.whl", hash = "sha256:713a777de88a73425cf08eb11f742cd2c98628e79a8673d6a52e3c5f0c116f33" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6a/7e/dc28a952e4bfa32ca16fa2eb026e7a6ce5d1411fcd5986cd08c74ec187b9/lz4-4.4.5-cp312-cp312-win_arm64.whl", hash = "sha256:a88cbb729cc333334ccfb52f070463c21560fca63afcf636a9f160a55fac3301" }, { url = "https://mirrors.aliyun.com/pypi/packages/2f/46/08fd8ef19b782f301d56a9ccfd7dafec5fd4fc1a9f017cf22a1accb585d7/lz4-4.4.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:6bb05416444fafea170b07181bc70640975ecc2a8c92b3b658c554119519716c" }, { url = "https://mirrors.aliyun.com/pypi/packages/8f/3f/ea3334e59de30871d773963997ecdba96c4584c5f8007fd83cfc8f1ee935/lz4-4.4.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:b424df1076e40d4e884cfcc4c77d815368b7fb9ebcd7e634f937725cd9a8a72a" }, { url = "https://mirrors.aliyun.com/pypi/packages/41/7b/7b3a2a0feb998969f4793c650bb16eff5b06e80d1f7bff867feb332f2af2/lz4-4.4.5-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:216ca0c6c90719731c64f41cfbd6f27a736d7e50a10b70fad2a9c9b262ec923d" }, @@ -3938,17 +3740,6 @@ version = "3.0.3" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/7e/99/7690b6d4034fffd95959cbe0c02de8deb3098cc577c67bb6a24fe5d7caa7/markupsafe-3.0.3.tar.gz", hash = "sha256:722695808f4b6457b320fdc131280796bdceb04ab50fe1795cd540799ebe1698" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/5a/72/147da192e38635ada20e0a2e1a51cf8823d2119ce8883f7053879c2199b5/markupsafe-3.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d53197da72cc091b024dd97249dfc7794d6a56530370992a5e1a08983ad9230e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9a/81/7e4e08678a1f98521201c3079f77db69fb552acd56067661f8c2f534a718/markupsafe-3.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:1872df69a4de6aead3491198eaf13810b565bdbeec3ae2dc8780f14458ec73ce" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1e/2c/799f4742efc39633a1b54a92eec4082e4f815314869865d876824c257c1e/markupsafe-3.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3a7e8ae81ae39e62a41ec302f972ba6ae23a5c5396c8e60113e9066ef893da0d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3c/2e/8d0c2ab90a8c1d9a24f0399058ab8519a3279d1bd4289511d74e909f060e/markupsafe-3.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6dd0be5b5b189d31db7cda48b91d7e0a9795f31430b7f271219ab30f1d3ac9d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2c/54/887f3092a85238093a0b2154bd629c89444f395618842e8b0c41783898ea/markupsafe-3.0.3-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:94c6f0bb423f739146aec64595853541634bde58b2135f27f61c1ffd1cd4d16a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c9/2f/336b8c7b6f4a4d95e91119dc8521402461b74a485558d8f238a68312f11c/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:be8813b57049a7dc738189df53d69395eba14fb99345e0a5994914a3864c8a4b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/32/43/67935f2b7e4982ffb50a4d169b724d74b62a3964bc1a9a527f5ac4f1ee2b/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:83891d0e9fb81a825d9a6d61e3f07550ca70a076484292a70fde82c4b807286f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/89/e0/4486f11e51bbba8b0c041098859e869e304d1c261e59244baa3d295d47b7/markupsafe-3.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:77f0643abe7495da77fb436f50f8dab76dbc6e5fd25d39589a0f1fe6548bfa2b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2f/e1/78ee7a023dac597a5825441ebd17170785a9dab23de95d2c7508ade94e0e/markupsafe-3.0.3-cp312-cp312-win32.whl", hash = "sha256:d88b440e37a16e651bda4c7c2b930eb586fd15ca7406cb39e211fcff3bf3017d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/aa/5b/bec5aa9bbbb2c946ca2733ef9c4ca91c91b6a24580193e891b5f7dbe8e1e/markupsafe-3.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:26a5784ded40c9e318cfc2bdb30fe164bdb8665ded9cd64d500a34fb42067b1c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e5/f1/216fc1bbfd74011693a4fd837e7026152e89c4bcf3e77b6692fba9923123/markupsafe-3.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:35add3b638a5d900e807944a078b51922212fb3dedb01633a8defc4b01a3c85f" }, { url = "https://mirrors.aliyun.com/pypi/packages/38/2f/907b9c7bbba283e68f20259574b13d005c121a0fa4c175f9bed27c4597ff/markupsafe-3.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e1cf1972137e83c5d4c136c43ced9ac51d0e124706ee1c8aa8532c1287fa8795" }, { url = "https://mirrors.aliyun.com/pypi/packages/9c/d9/5f7756922cdd676869eca1c4e3c0cd0df60ed30199ffd775e319089cb3ed/markupsafe-3.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:116bb52f642a37c115f517494ea5feb03889e04df47eeff5b130b1808ce7c219" }, { url = "https://mirrors.aliyun.com/pypi/packages/00/07/575a68c754943058c78f30db02ee03a64b3c638586fba6a6dd56830b30a3/markupsafe-3.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:133a43e73a802c5562be9bbcd03d090aa5a1fe899db609c29e8c8d815c5f6de6" }, @@ -4012,13 +3803,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/8a/76/d3c6e3a13fe484ebe7718d14e269c9569c4eb0020a968a327acb3b9a8fe6/matplotlib-3.10.8.tar.gz", hash = "sha256:2299372c19d56bcd35cf05a2738308758d32b9eaed2371898d8f5bd33f084aa3" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/9e/67/f997cdcbb514012eb0d10cd2b4b332667997fb5ebe26b8d41d04962fa0e6/matplotlib-3.10.8-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:64fcc24778ca0404ce0cb7b6b77ae1f4c7231cdd60e6778f999ee05cbd581b9a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7e/65/07d5f5c7f7c994f12c768708bd2e17a4f01a2b0f44a1c9eccad872433e2e/matplotlib-3.10.8-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b9a5ca4ac220a0cdd1ba6bcba3608547117d30468fefce49bb26f55c1a3d5c58" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3e/f3/c5195b1ae57ef85339fd7285dfb603b22c8b4e79114bae5f4f0fcf688677/matplotlib-3.10.8-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:3ab4aabc72de4ff77b3ec33a6d78a68227bf1123465887f9905ba79184a1cc04" }, - { url = "https://mirrors.aliyun.com/pypi/packages/00/f9/7638f5cc82ec8a7aa005de48622eecc3ed7c9854b96ba15bd76b7fd27574/matplotlib-3.10.8-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:24d50994d8c5816ddc35411e50a86ab05f575e2530c02752e02538122613371f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/57/61/78cd5920d35b29fd2a0fe894de8adf672ff52939d2e9b43cb83cd5ce1bc7/matplotlib-3.10.8-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:99eefd13c0dc3b3c1b4d561c1169e65fe47aab7b8158754d7c084088e2329466" }, - { url = "https://mirrors.aliyun.com/pypi/packages/30/4e/c10f171b6e2f44d9e3a2b96efa38b1677439d79c99357600a62cc1e9594e/matplotlib-3.10.8-cp312-cp312-win_amd64.whl", hash = "sha256:dd80ecb295460a5d9d260df63c43f4afbdd832d725a531f008dad1664f458adf" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f1/76/934db220026b5fef85f45d51a738b91dea7d70207581063cd9bd8fafcf74/matplotlib-3.10.8-cp312-cp312-win_arm64.whl", hash = "sha256:3c624e43ed56313651bc18a47f838b60d7b8032ed348911c54906b130b20071b" }, { url = "https://mirrors.aliyun.com/pypi/packages/3d/b9/15fd5541ef4f5b9a17eefd379356cf12175fe577424e7b1d80676516031a/matplotlib-3.10.8-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:3f2e409836d7f5ac2f1c013110a4d50b9f7edc26328c108915f9075d7d7a91b6" }, { url = "https://mirrors.aliyun.com/pypi/packages/8d/a0/2ba3473c1b66b9c74dc7107c67e9008cb1782edbe896d4c899d39ae9cf78/matplotlib-3.10.8-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:56271f3dac49a88d7fca5060f004d9d22b865f743a12a23b1e937a0be4818ee1" }, { url = "https://mirrors.aliyun.com/pypi/packages/75/97/a471f1c3eb1fd6f6c24a31a5858f443891d5127e63a7788678d14e249aea/matplotlib-3.10.8-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:a0a7f52498f72f13d4a25ea70f35f4cb60642b466cbb0a9be951b5bc3f45a486" }, @@ -4134,21 +3918,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/9b/f7/4a5e785ec9fbd65146a27b6b70b6cdc161a66f2024e4b04ac06a67f5578b/mistune-3.2.0-py3-none-any.whl", hash = "sha256:febdc629a3c78616b94393c6580551e0e34cc289987ec6c35ed3f4be42d0eee1" }, ] -[[package]] -name = "ml-dtypes" -version = "0.4.1" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "numpy" }, -] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/fd/15/76f86faa0902836cc133939732f7611ace68cf54148487a99c539c272dc8/ml_dtypes-0.4.1.tar.gz", hash = "sha256:fad5f2de464fd09127e49b7fd1252b9006fb43d2edc1ff112d390c324af5ca7a" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/ba/1a/99e924f12e4b62139fbac87419698c65f956d58de0dbfa7c028fa5b096aa/ml_dtypes-0.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:827d3ca2097085cf0355f8fdf092b888890bb1b1455f52801a2d7756f056f54b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8f/8c/7b610bd500617854c8cc6ed7c8cfb9d48d6a5c21a1437a36a4b9bc8a3598/ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:772426b08a6172a891274d581ce58ea2789cc8abc1c002a27223f314aaf894e7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c7/c6/f89620cecc0581dc1839e218c4315171312e46c62a62da6ace204bda91c0/ml_dtypes-0.4.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:126e7d679b8676d1a958f2651949fbfa182832c3cd08020d8facd94e4114f3e9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ae/11/a742d3c31b2cc8557a48efdde53427fd5f9caa2fa3c9c27d826e78a66f51/ml_dtypes-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:df0fb650d5c582a9e72bb5bd96cfebb2cdb889d89daff621c8fbc60295eba66c" }, -] - [[package]] name = "moodlepy" version = "0.24.1" @@ -4204,14 +3973,6 @@ version = "0.20.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/ea/9c/bfbd12955a49180cbd234c5d29ec6f74fe641698f0cd9df154a854fc8a15/msgspec-0.20.0.tar.gz", hash = "sha256:692349e588fde322875f8d3025ac01689fead5901e7fb18d6870a44519d62a29" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/d9/6f/1e25eee957e58e3afb2a44b94fa95e06cebc4c236193ed0de3012fff1e19/msgspec-0.20.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2aba22e2e302e9231e85edc24f27ba1f524d43c223ef5765bd8624c7df9ec0a5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7f/ee/af51d090ada641d4b264992a486435ba3ef5b5634bc27e6eb002f71cef7d/msgspec-0.20.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:716284f898ab2547fedd72a93bb940375de9fbfe77538f05779632dc34afdfde" }, - { url = "https://mirrors.aliyun.com/pypi/packages/49/d6/9709ee093b7742362c2934bfb1bbe791a1e09bed3ea5d8a18ce552fbfd73/msgspec-0.20.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:558ed73315efa51b1538fa8f1d3b22c8c5ff6d9a2a62eff87d25829b94fc5054" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5c/a2/488517a43ccf5a4b6b6eca6dd4ede0bd82b043d1539dd6bb908a19f8efd3/msgspec-0.20.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:509ac1362a1d53aa66798c9b9fd76872d7faa30fcf89b2fba3bcbfd559d56eb0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d5/e8/49b832808aa23b85d4f090d1d2e48a4e3834871415031ed7c5fe48723156/msgspec-0.20.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1353c2c93423602e7dea1aa4c92f3391fdfc25ff40e0bacf81d34dbc68adb870" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9f/56/1dc2fa53685dca9c3f243a6cbecd34e856858354e455b77f47ebd76cf5bf/msgspec-0.20.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cb33b5eb5adb3c33d749684471c6a165468395d7aa02d8867c15103b81e1da3e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5a/51/aba940212c23b32eedce752896205912c2668472ed5b205fc33da28a6509/msgspec-0.20.0-cp312-cp312-win_amd64.whl", hash = "sha256:fb1d934e435dd3a2b8cf4bbf47a8757100b4a1cfdc2afdf227541199885cdacb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/41/ad/3b9f259d94f183daa9764fef33fdc7010f7ecffc29af977044fa47440a83/msgspec-0.20.0-cp312-cp312-win_arm64.whl", hash = "sha256:00648b1e19cf01b2be45444ba9dc961bd4c056ffb15706651e64e5d6ec6197b7" }, { url = "https://mirrors.aliyun.com/pypi/packages/8a/d1/b902d38b6e5ba3bdddbec469bba388d647f960aeed7b5b3623a8debe8a76/msgspec-0.20.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:9c1ff8db03be7598b50dd4b4a478d6fe93faae3bd54f4f17aa004d0e46c14c46" }, { url = "https://mirrors.aliyun.com/pypi/packages/57/b6/eff0305961a1d9447ec2b02f8c73c8946f22564d302a504185b730c9a761/msgspec-0.20.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f6532369ece217fd37c5ebcfd7e981f2615628c21121b7b2df9d3adcf2fd69b8" }, { url = "https://mirrors.aliyun.com/pypi/packages/99/93/f2ec1ae1de51d3fdee998a1ede6b2c089453a2ee82b5c1b361ed9095064a/msgspec-0.20.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:f9a1697da2f85a751ac3cc6a97fceb8e937fc670947183fb2268edaf4016d1ee" }, @@ -4257,24 +4018,6 @@ version = "6.7.1" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/1a/c2/c2d94cbe6ac1753f3fc980da97b3d930efe1da3af3c9f5125354436c073d/multidict-6.7.1.tar.gz", hash = "sha256:ec6652a1bee61c53a3e5776b6049172c53b6aaba34f18c9ad04f82712bac623d" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/8d/9c/f20e0e2cf80e4b2e4b1c365bf5fe104ee633c751a724246262db8f1a0b13/multidict-6.7.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:a90f75c956e32891a4eda3639ce6dd86e87105271f43d43442a3aedf3cddf172" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fe/cf/18ef143a81610136d3da8193da9d80bfe1cb548a1e2d1c775f26b23d024a/multidict-6.7.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:3fccb473e87eaa1382689053e4a4618e7ba7b9b9b8d6adf2027ee474597128cd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a9/65/1caac9d4cd32e8433908683446eebc953e82d22b03d10d41a5f0fefe991b/multidict-6.7.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0fa96985700739c4c7853a43c0b3e169360d6855780021bfc6d0f1ce7c123e7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cf/3b/d6bd75dc4f3ff7c73766e04e705b00ed6dbbaccf670d9e05a12b006f5a21/multidict-6.7.1-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:cb2a55f408c3043e42b40cc8eecd575afa27b7e0b956dfb190de0f8499a57a53" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fd/80/c959c5933adedb9ac15152e4067c702a808ea183a8b64cf8f31af8ad3155/multidict-6.7.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:eb0ce7b2a32d09892b3dd6cc44877a0d02a33241fafca5f25c8b6b62374f8b75" }, - { url = "https://mirrors.aliyun.com/pypi/packages/86/85/7ed40adafea3d4f1c8b916e3b5cc3a8e07dfcdcb9cd72800f4ed3ca1b387/multidict-6.7.1-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:c3a32d23520ee37bf327d1e1a656fec76a2edd5c038bf43eddfa0572ec49c60b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d2/57/b8565ff533e48595503c785f8361ff9a4fde4d67de25c207cd0ba3befd03/multidict-6.7.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:9c90fed18bffc0189ba814749fdcc102b536e83a9f738a9003e569acd540a733" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e0/50/9810c5c29350f7258180dfdcb2e52783a0632862eb334c4896ac717cebcb/multidict-6.7.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:da62917e6076f512daccfbbde27f46fed1c98fee202f0559adec8ee0de67f71a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f3/8d/5e5be3ced1d12966fefb5c4ea3b2a5b480afcea36406559442c6e31d4a48/multidict-6.7.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfde23ef6ed9db7eaee6c37dcec08524cb43903c60b285b172b6c094711b3961" }, - { url = "https://mirrors.aliyun.com/pypi/packages/31/6e/d8a26d81ac166a5592782d208dd90dfdc0a7a218adaa52b45a672b46c122/multidict-6.7.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:3758692429e4e32f1ba0df23219cd0b4fc0a52f476726fff9337d1a57676a582" }, - { url = "https://mirrors.aliyun.com/pypi/packages/59/4c/7c672c8aad41534ba619bcd4ade7a0dc87ed6b8b5c06149b85d3dd03f0cd/multidict-6.7.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:398c1478926eca669f2fd6a5856b6de9c0acf23a2cb59a14c0ba5844fa38077e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7b/bd/84c24de512cbafbdbc39439f74e967f19570ce7924e3007174a29c348916/multidict-6.7.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c102791b1c4f3ab36ce4101154549105a53dc828f016356b3e3bcae2e3a039d3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fa/ba/f5449385510825b73d01c2d4087bf6d2fccc20a2d42ac34df93191d3dd03/multidict-6.7.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:a088b62bd733e2ad12c50dad01b7d0166c30287c166e137433d3b410add807a6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d7/11/afc7c677f68f75c84a69fe37184f0f82fce13ce4b92f49f3db280b7e92b3/multidict-6.7.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:3d51ff4785d58d3f6c91bdbffcb5e1f7ddfda557727043aa20d20ec4f65e324a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2b/17/ebb9644da78c4ab36403739e0e6e0e30ebb135b9caf3440825001a0bddcb/multidict-6.7.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc5907494fccf3e7d3f94f95c91d6336b092b5fc83811720fae5e2765890dfba" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ca/a4/840f5b97339e27846c46307f2530a2805d9d537d8b8bd416af031cad7fa0/multidict-6.7.1-cp312-cp312-win32.whl", hash = "sha256:28ca5ce2fd9716631133d0e9a9b9a745ad7f60bac2bccafb56aa380fc0b6c511" }, - { url = "https://mirrors.aliyun.com/pypi/packages/80/31/0b2517913687895f5904325c2069d6a3b78f66cc641a86a2baf75a05dcbb/multidict-6.7.1-cp312-cp312-win_amd64.whl", hash = "sha256:fcee94dfbd638784645b066074b338bc9cc155d4b4bffa4adce1615c5a426c19" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0c/5b/aba28e4ee4006ae4c7df8d327d31025d760ffa992ea23812a601d226e682/multidict-6.7.1-cp312-cp312-win_arm64.whl", hash = "sha256:ba0a9fb644d0c1a2194cf7ffb043bd852cea63a57f66fbd33959f7dae18517bf" }, { url = "https://mirrors.aliyun.com/pypi/packages/f2/22/929c141d6c0dba87d3e1d38fbdf1ba8baba86b7776469f2bc2d3227a1e67/multidict-6.7.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:2b41f5fed0ed563624f1c17630cb9941cf2309d4df00e494b551b5f3e3d67a23" }, { url = "https://mirrors.aliyun.com/pypi/packages/c7/75/bc704ae15fee974f8fccd871305e254754167dce5f9e42d88a2def741a1d/multidict-6.7.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:84e61e3af5463c19b67ced91f6c634effb89ef8bfc5ca0267f954451ed4bb6a2" }, { url = "https://mirrors.aliyun.com/pypi/packages/79/76/55cd7186f498ed080a18440c9013011eb548f77ae1b297206d030eb1180a/multidict-6.7.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:935434b9853c7c112eee7ac891bc4cb86455aa631269ae35442cb316790c1445" }, @@ -4373,6 +4116,46 @@ version = "0.0.12" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/17/0d/74f0293dfd7dcc3837746d0138cbedd60b31701ecc75caec7d3f281feba0/multitasking-0.0.12.tar.gz", hash = "sha256:2fba2fa8ed8c4b85e227c5dd7dc41c7d658de3b6f247927316175a57349b84d1" } +[[package]] +name = "murmurhash" +version = "1.0.15" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/23/2e/88c147931ea9725d634840d538622e94122bceaf346233349b7b5c62964b/murmurhash-1.0.15.tar.gz", hash = "sha256:58e2b27b7847f9e2a6edf10b47a8c8dd70a4705f45dccb7bf76aeadacf56ba01" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/29/2f/ba300b5f04dae0409202d6285668b8a9d3ade43a846abee3ef611cb388d5/murmurhash-1.0.15-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:fe50dc70e52786759358fd1471e309b94dddfffb9320d9dfea233c7684c894ba" }, + { url = "https://mirrors.aliyun.com/pypi/packages/34/02/29c19d268e6f4ea1ed2a462c901eed1ed35b454e2cbc57da592fad663ac6/murmurhash-1.0.15-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:1349a7c23f6092e7998ddc5bd28546cc31a595afc61e9fdb3afc423feec3d7ad" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e2/63/58e2de2b5232cd294c64092688c422196e74f9fa8b3958bdf02d33df24b9/murmurhash-1.0.15-cp313-cp313-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:b3ba6d05de2613535b5a9227d4ad8ef40a540465f64660d4a8800634ae10e04f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/aa/9a/d13e2e9f8ba1ced06840921a50f7cece0a475453284158a3018b72679761/murmurhash-1.0.15-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fa1b70b3cc2801ab44179c65827bbd12009c68b34e9d9ce7125b6a0bd35af63c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b2/e1/47994f1813fa205c84977b0ff51ae6709f8539af052c7491a5f863d82bdc/murmurhash-1.0.15-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:213d710fb6f4ef3bc11abbfad0fa94a75ffb675b7dc158c123471e5de869f9af" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b9/ea/90c1fd00b4aeb704fb5e84cd666b33ffd7f245155048071ffbb51d2bb57d/murmurhash-1.0.15-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:b65a5c4e7f5d71f7ccac2d2b60bdf7092d7976270878cfec59d5a66a533db823" }, + { url = "https://mirrors.aliyun.com/pypi/packages/00/db/da73462dbfa77f6433b128d2120ba7ba300f8c06dc4f4e022c38d240a5f5/murmurhash-1.0.15-cp313-cp313-win_amd64.whl", hash = "sha256:9aba94c5d841e1904cd110e94ceb7f49cfb60a874bbfb27e0373622998fb7c7c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/bb/83/032729ef14971b938fbef41ee125fc8800020ee229bd35178b6ede8ee934/murmurhash-1.0.15-cp313-cp313-win_arm64.whl", hash = "sha256:263807eca40d08c7b702413e45cca75ecb5883aa337237dc5addb660f1483378" }, + { url = "https://mirrors.aliyun.com/pypi/packages/10/83/7547d9205e9bd2f8e5dfd0b682cc9277594f98909f228eb359489baec1df/murmurhash-1.0.15-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:694fd42a74b7ce257169d14c24aa616aa6cd4ccf8abe50eca0557e08da99d055" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b7/c7/3afd5de7a5b3ae07fe2d3a3271b327ee1489c58ba2b2f2159bd31a25edb9/murmurhash-1.0.15-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a2ea4546ba426390beff3cd10db8f0152fdc9072c4f2583ec7d8aa9f3e4ac070" }, + { url = "https://mirrors.aliyun.com/pypi/packages/02/69/d6637ee67d78ebb2538c00411f28ea5c154886bbe1db16c49435a8a4ab16/murmurhash-1.0.15-cp313-cp313t-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:34e5a91139c40b10f98d0b297907f5d5267b4b1b2e5dd2eb74a021824f751b98" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ab/4c/89e590165b4c7da6bf941441212a721a270195332d3aacfdfdf527d466ca/murmurhash-1.0.15-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:dc35606868a5961cf42e79314ca0bddf5a400ce377b14d83192057928d6252ec" }, + { url = "https://mirrors.aliyun.com/pypi/packages/07/7a/95c42df0c21d2e413b9fcd17317a7587351daeb264dc29c6aec1fdbd26f8/murmurhash-1.0.15-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:43cc6ac3b91ca0f7a5ae9c063ba4d6c26972c97fd7c25280ecc666413e4c5535" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d0/22/9d02c880a88b83bb3ce7d6a38fb727373ab78d82e5f3d8d9fc5612219f90/murmurhash-1.0.15-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:847d712136cb462f0e4bd6229ee2d9eb996d8854eb8312dff3d20c8f5181fda5" }, + { url = "https://mirrors.aliyun.com/pypi/packages/9a/e3/750232524e0dc262e8dcede6536dafc766faadd9a52f1d23746b02948ad8/murmurhash-1.0.15-cp313-cp313t-win_amd64.whl", hash = "sha256:2680851af6901dbe66cc4aa7ef8e263de47e6e1b425ae324caa571bdf18f8d58" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ff/89/4ad9d215ef6ade89f27a72dc4e86b98ef1a43534cc3e6a6900a362a0bf0a/murmurhash-1.0.15-cp313-cp313t-win_arm64.whl", hash = "sha256:189a8de4d657b5da9efd66601b0636330b08262b3a55431f2379097c986995d0" }, + { url = "https://mirrors.aliyun.com/pypi/packages/1c/69/726df275edf07688146966e15eaaa23168100b933a2e1a29b37eb56c6db8/murmurhash-1.0.15-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:7c4280136b738e85ff76b4bdc4341d0b867ee753e73fd8b6994288080c040d0b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/59/8f/24ecf9061bc2b20933df8aba47c73e904274ea8811c8300cab92f6f82372/murmurhash-1.0.15-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d4d681f474830489e2ec1d912095cfff027fbaf2baa5414c7e9d25b89f0fab68" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ba/26/fff3caba25aa3c0622114e03c69fb66c839b22335b04d7cce91a3a126d44/murmurhash-1.0.15-cp314-cp314-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:d7e47c5746785db6a43b65fac47b9e63dd71dfbd89a8c92693425b9715e68c6e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/df/e4/0f2b9fc533467a27afb4e906c33f32d5f637477de87dd94690e0c44335a6/murmurhash-1.0.15-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:e8e674f02a99828c8a671ba99cd03299381b2f0744e6f25c29cadfc6151dc724" }, + { url = "https://mirrors.aliyun.com/pypi/packages/da/bf/9d1c107989728ec46e25773d503aa54070b32822a18cfa7f9d5f41bc17a5/murmurhash-1.0.15-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:26fd7c7855ac4850ad8737991d7b0e3e501df93ebaf0cf45aa5954303085fdba" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0d/81/dcf27c71445c0e993b10e33169a098ca60ee702c5c58fcbde205fa6332a6/murmurhash-1.0.15-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:cb8ebafae60d5f892acff533cc599a359954d8c016a829514cb3f6e9ee10f322" }, + { url = "https://mirrors.aliyun.com/pypi/packages/bc/32/e874a14b2d2246bd2d16f80f49fad393a3865d4ee7d66d2cae939a67a29a/murmurhash-1.0.15-cp314-cp314-win_amd64.whl", hash = "sha256:898a629bf111f1aeba4437e533b5b836c0a9d2dd12d6880a9c75f6ca13e30e22" }, + { url = "https://mirrors.aliyun.com/pypi/packages/af/8e/4fca051ed8ae4d23a15aaf0a82b18cb368e8cf84f1e3b474d5749ec46069/murmurhash-1.0.15-cp314-cp314-win_arm64.whl", hash = "sha256:88dc1dd53b7b37c0df1b8b6bce190c12763014492f0269ff7620dc6027f470f4" }, + { url = "https://mirrors.aliyun.com/pypi/packages/38/9c/c72c2a4edd86aac829337ab9f83cf04cdb15e5d503e4c9a3a243f30a261c/murmurhash-1.0.15-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:6cb4e962ec4f928b30c271b2d84e6707eff6d942552765b663743cfa618b294b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ac/d7/72b47ebc86436cd0aa1fd4c6e8779521ec389397ac11389990278d0f7a47/murmurhash-1.0.15-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5678a3ea4fbf0cbaaca2bed9b445f556f294d5f799c67185d05ffcb221a77faf" }, + { url = "https://mirrors.aliyun.com/pypi/packages/64/bb/6d2f09135079c34dc2d26e961c52742d558b320c61503f273eab6ba743d9/murmurhash-1.0.15-cp314-cp314t-manylinux1_x86_64.manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:ef19f38c6b858eef83caf710773db98c8f7eb2193b4c324650c74f3d8ba299e0" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b9/e2/9c1b462e33f9cb2d632056f07c90b502fc20bd7da50a15d0557343bd2fed/murmurhash-1.0.15-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22aa3ceaedd2e57078b491ed08852d512b84ff4ff9bb2ff3f9bf0eec7f214c9e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e8/73/8694db1408fcdfa73589f7df6c445437ea146986fa1e393ec60d26d6e30c/murmurhash-1.0.15-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:bba0e0262c0d08682b028cb963ac477bd9839029486fa1333fc5c01fb6072749" }, + { url = "https://mirrors.aliyun.com/pypi/packages/2d/f9/8e360bdfc3c44e267e7e046f0e0b9922766da92da26959a6963f597e6bb5/murmurhash-1.0.15-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:4fd8189ee293a09f30f4931408f40c28ccd42d9de4f66595f8814879339378bc" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f9/31/97649680595b1096803d877ababb9a67c07f4378f177ec885eea28b9db6d/murmurhash-1.0.15-cp314-cp314t-win_amd64.whl", hash = "sha256:66395b1388f7daa5103db92debe06842ae3be4c0749ef6db68b444518666cdcc" }, + { url = "https://mirrors.aliyun.com/pypi/packages/76/66/4fce8755f25d77324401886c00017c556be7ca3039575b94037aff905385/murmurhash-1.0.15-cp314-cp314t-win_arm64.whl", hash = "sha256:c22e56c6a0b70598a66e456de5272f76088bc623688da84ef403148a6d41851d" }, +] + [[package]] name = "mygene" version = "3.2.2" @@ -4400,11 +4183,6 @@ version = "9.6.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/6f/6e/c89babc7de3df01467d159854414659c885152579903a8220c8db02a3835/mysql_connector_python-9.6.0.tar.gz", hash = "sha256:c453bb55347174d87504b534246fb10c589daf5d057515bf615627198a3c7ef1" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/8f/d9/2a4b4d90b52f4241f0f71618cd4bd8779dd6d18db8058b0a4dd83ec0541c/mysql_connector_python-9.6.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:9664e217c72dd6fb700f4c8512af90261f72d2f5d7c00c4e13e4c1e09bfa3d5e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/33/91/2495835733a054e716a17dc28404748b33f2dc1da1ae4396fb45574adf40/mysql_connector_python-9.6.0-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:1ed4b5c4761e5333035293e746683890e4ef2e818e515d14023fd80293bc31fa" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7a/69/e83abbbbf7f8eed855b5a5ff7285bc0afb1199418ac036c7691edf41e154/mysql_connector_python-9.6.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:5095758dcb89a6bce2379f349da336c268c407129002b595c5dba82ce387e2a5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/82/44/67bb61c71f398fbc739d07e8dcadad94e2f655874cb32ae851454066bea0/mysql_connector_python-9.6.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:4ae4e7780fad950a4f267dea5851048d160f5b71314a342cdbf30b154f1c74f7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ba/39/994c4f7e9c59d3ca534a831d18442ac4c529865db20aeaa4fd94e2af5efd/mysql_connector_python-9.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:c180e0b4100d7402e03993bfac5c97d18e01d7ca9d198d742fffc245077f8ffe" }, { url = "https://mirrors.aliyun.com/pypi/packages/2f/58/9521aa678708ec6cebfd40524c14c3d151e4f29e3774e6086aa0a30d203b/mysql_connector_python-9.6.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e86e45a7b540ca09af8a18ecfa761e0cdeccfdb62818331614ec030ae44bfd26" }, { url = "https://mirrors.aliyun.com/pypi/packages/39/8d/b108f9bcce9780f6a1f91decb2af54defdaf845e237ddc42f2b4578f1cd7/mysql_connector_python-9.6.0-cp313-cp313-macosx_14_0_x86_64.whl", hash = "sha256:8d3e9252384e1b7f95b07020664f2673d9c29c5e95eeda2e048b3331e190b9d4" }, { url = "https://mirrors.aliyun.com/pypi/packages/d6/28/735cd93d16e76dc2feb4abb3f1229a1d9475af34d80c26712fec6abe1d70/mysql_connector_python-9.6.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:0fa18ead33cb699ea92005695077cef09aa494eebf51164ee30c891c3eaea90c" }, @@ -4418,15 +4196,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/15/dd/b3250826c29cee7816de4409a2fe5e469a68b9a89f6bfaa5eed74f05532c/mysql_connector_python-9.6.0-py2.py3-none-any.whl", hash = "sha256:44b0fb57207ebc6ae05b5b21b7968a9ed33b29187fe87b38951bad2a334d75d5" }, ] -[[package]] -name = "namex" -version = "0.1.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/0c/c0/ee95b28f029c73f8d49d8f52edaed02a1d4a9acb8b69355737fdb1faa191/namex-0.1.0.tar.gz", hash = "sha256:117f03ccd302cc48e3f5c58a296838f6b89c83455ab8683a1e85f2a430aa4306" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/b2/bc/465daf1de06409cdd4532082806770ee0d8d7df434da79c76564d0f69741/namex-0.1.0-py3-none-any.whl", hash = "sha256:e2012a474502f1e2251267062aae3114611f07df4224b6e06334c57b0f2ce87c" }, -] - [[package]] name = "nest-asyncio" version = "1.6.0" @@ -4470,10 +4239,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/23/c9/a0fb41787d01d621046138da30f6c2100d80857bf34b3390dd68040f27a3/numba-0.64.0.tar.gz", hash = "sha256:95e7300af648baa3308127b1955b52ce6d11889d16e8cfe637b4f85d2fca52b1" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/70/a6/9fc52cb4f0d5e6d8b5f4d81615bc01012e3cf24e1052a60f17a68deb8092/numba-0.64.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:69440a8e8bc1a81028446f06b363e28635aa67bd51b1e498023f03b812e0ce68" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9b/89/1a74ea99b180b7a5587b0301ed1b183a2937c4b4b67f7994689b5d36fc34/numba-0.64.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f13721011f693ba558b8dd4e4db7f2640462bba1b855bdc804be45bbeb55031a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/91/e1/583c647404b15f807410510fec1eb9b80cb8474165940b7749f026f21cbc/numba-0.64.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0b180b1133f2b5d8b3f09d96b6d7a9e51a7da5dda3c09e998b5bcfac85d222c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/85/23/0fce5789b8a5035e7ace21216a468143f3144e02013252116616c58339aa/numba-0.64.0-cp312-cp312-win_amd64.whl", hash = "sha256:e63dc94023b47894849b8b106db28ccb98b49d5498b98878fac1a38f83ac007a" }, { url = "https://mirrors.aliyun.com/pypi/packages/52/80/2734de90f9300a6e2503b35ee50d9599926b90cbb7ac54f9e40074cd07f1/numba-0.64.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:3bab2c872194dcd985f1153b70782ec0fbbe348fffef340264eacd3a76d59fd6" }, { url = "https://mirrors.aliyun.com/pypi/packages/42/e8/14b5853ebefd5b37723ef365c5318a30ce0702d39057eaa8d7d76392859d/numba-0.64.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:703a246c60832cad231d2e73c1182f25bf3cc8b699759ec8fe58a2dbc689a70c" }, { url = "https://mirrors.aliyun.com/pypi/packages/8a/a2/f60dc6c96d19b7185144265a5fbf01c14993d37ff4cd324b09d0212aa7ce/numba-0.64.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7e2e49a7900ee971d32af7609adc0cfe6aa7477c6f6cccdf6d8138538cf7756f" }, @@ -4489,16 +4254,6 @@ name = "numpy" version = "1.26.4" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/65/6e/09db70a523a96d25e115e71cc56a6f9031e7b8cd166c1ac8438307c14058/numpy-1.26.4.tar.gz", hash = "sha256:2a02aba9ed12e4ac4eb3ea9421c420301a0c6460d9830d74a9df87efa4912010" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/95/12/8f2020a8e8b8383ac0177dc9570aad031a3beb12e38847f7129bacd96228/numpy-1.26.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:b3ce300f3644fb06443ee2222c2201dd3a89ea6040541412b8fa189341847218" }, - { url = "https://mirrors.aliyun.com/pypi/packages/75/5b/ca6c8bd14007e5ca171c7c03102d17b4f4e0ceb53957e8c44343a9546dcc/numpy-1.26.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:03a8c78d01d9781b28a6989f6fa1bb2c4f2d51201cf99d3dd875df6fbd96b23b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/79/f8/97f10e6755e2a7d027ca783f63044d5b1bc1ae7acb12afe6a9b4286eac17/numpy-1.26.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9fad7dcb1aac3c7f0584a5a8133e3a43eeb2fe127f47e3632d43d677c66c102b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0f/50/de23fde84e45f5c4fda2488c759b69990fd4512387a8632860f3ac9cd225/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:675d61ffbfa78604709862923189bad94014bef562cc35cf61d3a07bba02a7ed" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4c/0c/9c603826b6465e82591e05ca230dfc13376da512b25ccd0894709b054ed0/numpy-1.26.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:ab47dbe5cc8210f55aa58e4805fe224dac469cde56b9f731a4c098b91917159a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/76/8c/2ba3902e1a0fc1c74962ea9bb33a534bb05984ad7ff9515bf8d07527cadd/numpy-1.26.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:1dda2e7b4ec9dd512f84935c5f126c8bd8b9f2fc001e9f54af255e8c5f16b0e0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/28/4a/46d9e65106879492374999e76eb85f87b15328e06bd1550668f79f7b18c6/numpy-1.26.4-cp312-cp312-win32.whl", hash = "sha256:50193e430acfc1346175fcbdaa28ffec49947a06918b7b92130744e81e640110" }, - { url = "https://mirrors.aliyun.com/pypi/packages/16/2e/86f24451c2d530c88daf997cb8d6ac622c1d40d19f5a031ed68a4b73a374/numpy-1.26.4-cp312-cp312-win_amd64.whl", hash = "sha256:08beddf13648eb95f8d867350f6a018a4be2e5ad54c8d8caed89ebca558b2818" }, -] [[package]] name = "oauthlib" @@ -4576,11 +4331,6 @@ dependencies = [ { name = "sympy" }, ] wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/1b/9e/f748cd64161213adeef83d0cb16cb8ace1e62fa501033acdd9f9341fff57/onnxruntime-1.23.2-cp312-cp312-macosx_13_0_arm64.whl", hash = "sha256:b8f029a6b98d3cf5be564d52802bb50a8489ab73409fa9db0bf583eabb7c2321" }, - { url = "https://mirrors.aliyun.com/pypi/packages/91/9d/a81aafd899b900101988ead7fb14974c8a58695338ab6a0f3d6b0100f30b/onnxruntime-1.23.2-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:218295a8acae83905f6f1aed8cacb8e3eb3bd7513a13fe4ba3b2664a19fc4a6b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3c/35/4e40f2fba272a6698d62be2cd21ddc3675edfc1a4b9ddefcc4648f115315/onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:76ff670550dc23e58ea9bc53b5149b99a44e63b34b524f7b8547469aaa0dcb8c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ef/88/9cc25d2bafe6bc0d4d3c1db3ade98196d5b355c0b273e6a5dc09c5d5d0d5/onnxruntime-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:0f9b4ae77f8e3c9bee50c27bc1beede83f786fe1d52e99ac85aa8d65a01e9b77" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c0/b4/569d298f9fc4d286c11c45e85d9ffa9e877af12ace98af8cab52396e8f46/onnxruntime-1.23.2-cp312-cp312-win_amd64.whl", hash = "sha256:25de5214923ce941a3523739d34a520aac30f21e631de53bba9174dc9c004435" }, { url = "https://mirrors.aliyun.com/pypi/packages/3d/41/fba0cabccecefe4a1b5fc8020c44febb334637f133acefc7ec492029dd2c/onnxruntime-1.23.2-cp313-cp313-macosx_13_0_arm64.whl", hash = "sha256:2ff531ad8496281b4297f32b83b01cdd719617e2351ffe0dba5684fb283afa1f" }, { url = "https://mirrors.aliyun.com/pypi/packages/fe/f9/2d49ca491c6a986acce9f1d1d5fc2099108958cc1710c28e89a032c9cfe9/onnxruntime-1.23.2-cp313-cp313-macosx_13_0_x86_64.whl", hash = "sha256:162f4ca894ec3de1a6fd53589e511e06ecdc3ff646849b62a9da7489dee9ce95" }, { url = "https://mirrors.aliyun.com/pypi/packages/1c/a1/428ee29c6eaf09a6f6be56f836213f104618fb35ac6cc586ff0f477263eb/onnxruntime-1.23.2-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:45d127d6e1e9b99d1ebeae9bcd8f98617a812f53f46699eafeb976275744826b" }, @@ -4603,8 +4353,6 @@ dependencies = [ { name = "sympy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, ] wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/6c/d9/b7140a4f1615195938c7e358c0804bb84271f0d6886b5cbf105c6cb58aae/onnxruntime_gpu-1.23.2-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f2d1f720685d729b5258ec1b36dee1de381b8898189908c98cbeecdb2f2b5c2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/87/da/2685c79e5ea587beddebe083601fead0bdf3620bc2f92d18756e7de8a636/onnxruntime_gpu-1.23.2-cp312-cp312-win_amd64.whl", hash = "sha256:fe925a84b00e291e0ad3fac29bfd8f8e06112abc760cdc82cb711b4f3935bd95" }, { url = "https://mirrors.aliyun.com/pypi/packages/03/05/40d561636e4114b54aa06d2371bfbca2d03e12cfdf5d4b85814802f18a75/onnxruntime_gpu-1.23.2-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:1e8f75af5da07329d0c3a5006087f4051d8abd133b4be7c9bae8cdab7bea4c26" }, { url = "https://mirrors.aliyun.com/pypi/packages/b6/3b/418300438063d403384c79eaef1cb13c97627042f2247b35a887276a355a/onnxruntime_gpu-1.23.2-cp313-cp313-win_amd64.whl", hash = "sha256:7f1b3f49e5e126b99e23ec86b4203db41c2a911f6165f7624f2bc8267aaca767" }, { url = "https://mirrors.aliyun.com/pypi/packages/b8/dc/80b145e3134d7eba31309b3299a2836e37c76e4c419a261ad9796f8f8d65/onnxruntime_gpu-1.23.2-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:20959cd4ae358aab6579ab9123284a7b1498f7d51ec291d429a5edc26511306f" }, @@ -4797,108 +4545,12 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/b2/37/cc6a55e448deaa9b27377d087da8615a3416d8ad523d5960b78dbeadd02a/opentelemetry_semantic_conventions-0.61b0-py3-none-any.whl", hash = "sha256:fa530a96be229795f8cef353739b618148b0fe2b4b3f005e60e262926c4d38e2" }, ] -[[package]] -name = "opt-einsum" -version = "3.4.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/8c/b9/2ac072041e899a52f20cf9510850ff58295003aa75525e58343591b0cbfb/opt_einsum-3.4.0.tar.gz", hash = "sha256:96ca72f1b886d148241348783498194c577fa30a8faac108586b14f1ba4473ac" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/23/cd/066e86230ae37ed0be70aae89aabf03ca8d9f39c8aea0dec8029455b5540/opt_einsum-3.4.0-py3-none-any.whl", hash = "sha256:69bb92469f86a1565195ece4ac0323943e83477171b91d24c35afe028a90d7cd" }, -] - -[[package]] -name = "optree" -version = "0.19.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "typing-extensions" }, -] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/3d/63/7b078bc36d5a206c21b03565a818ede38ff0fbf014e92085ec467ef10adb/optree-0.19.0.tar.gz", hash = "sha256:bc1991a948590756409e76be4e29efd4a487a185056d35db6c67619c19ea27a1" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/2d/bf/5cbbf61a27f94797c3d9786f6230223023a943b60f5e893d52368f10b8b1/optree-0.19.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7ec4b2ce49622c6be2c8634712b6c63cc274835bac89a56e3ab2ca863a32ff4b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/00/9e/65899e6470f5df289ccdbe9e228fb0cd0ae45ccda8e32c92d6efae1530ef/optree-0.19.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f0978603623b4b1f794f05f6bbed0645cb7e219f4a5a349b2a2bd4514d84ac82" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d1/dc/f4826835be660181f1b4444ac92b51dda96d4634d3c2271e14598da7bf2a/optree-0.19.0-cp312-cp312-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8c9e52c50ed3f3f8b1cf4e47a20a7c5e77175b4f84b2ecf390a76f0d1dd91da6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ce/b0/89283ac1dd1ead3aa3d7a6b45a26846f457bded79a83b6828fc1ed9a6db3/optree-0.19.0-cp312-cp312-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:3fe3e5f7a30a7d08ddba0a34e48f5483f6c4d7bb710375434ad3633170c73c48" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2a/a2/47f620f87b0544b2e0eb0b3c661682bd0ea1c79f6e38f9147bc0f835c973/optree-0.19.0-cp312-cp312-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:8315527e1f14a91173fe6871847da7b949048ec61ff8b3e507fc286e75b0aa3c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/84/e9/b9ae18404135de53809fb994b754ac0eac838d8c4dfa8a10a811d8dec91d/optree-0.19.0-cp312-cp312-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:938fb15d140ab65148f4e6975048facbef83a9210353fbedd471ac39e7544339" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0a/e5/a77df15a62b37bb14c81b5757e2a0573f57e7c06d125a410ad2cd7cefb72/optree-0.19.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:2b8209570340135a7e586c90f393f3c6359e8a49c40d783196721cc487e51d9c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8c/43/1aa431cee19cd98c4229e468767021f9a92195d9431857e28198a3a3ce2f/optree-0.19.0-cp312-cp312-manylinux_2_39_riscv64.whl", hash = "sha256:1397dc925026917531a43fda32054ae1e77e5ed9bf8284bcae6354c19c26e14a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5b/b9/b94fd3a116b80951d692a82f4135ae84b3d78bd1b092250aff76a3366138/optree-0.19.0-cp312-cp312-win32.whl", hash = "sha256:68f58e8f8b75c76c51e61e3dc2d9e94609bafb0e1a6459e6d525ced905cd9a74" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9e/7f/31fa1b2311038bfc355ad6e4e4e63d028719cb67fb3ebe6fb76ff2124105/optree-0.19.0-cp312-cp312-win_amd64.whl", hash = "sha256:5c44ca0f579ed3e0ca777a5711d4a6c1b374feacf1bb4fe9cfe85297b0c8d237" }, - { url = "https://mirrors.aliyun.com/pypi/packages/09/86/863bc3f42f83113f5c6a5beaf4fec3c3481a76872f3244d0e64fb9ebd3b0/optree-0.19.0-cp312-cp312-win_arm64.whl", hash = "sha256:0461f796b4ade3fab519d821b0fa521f07e2af70206b76aac75fcfdc2e051fca" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ee/61/d79c7eeb87e98d08bc8d95ed08dee83bedb4e55371a7d2ae3c874ec02608/optree-0.19.0-cp313-cp313-android_24_arm64_v8a.whl", hash = "sha256:1eea5b7be833c6d555d08ff68046d3dd2112dfb39e6f1eb09887ab6c617a6d64" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2d/ed/e80504f65e7e80fdcd129258428d7976ea9f03bf9dad56a5293c44d563ad/optree-0.19.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:4d9cf9dfa0ac051e0ed82869d782f0affdbdb1daa5f2e851d37ea8625c60071a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/65/e5/d1926a2f0e0240f6800ff385c8486879f7da0a5a030b7aa5d84e44e9c9ca/optree-0.19.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:43c4f8ba5755d56d046be2cb1380cbc362234ad93fd9933384c6dd7fdebe6c4a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/61/88/9c598325e89bbed29b37a381ebb2b94f1d9d769c973b879b3e9766b4b16d/optree-0.19.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:36b1134680ee3f9768ede290da653e1604a8083bce69fef8fb4e46863346d5c8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6b/d2/fcba2a1826d362a64cb36ec9f675ed6dcddee47099948913122b0aafbe44/optree-0.19.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:c9f7e7e7bf2ef011d0be1c2e87c96f5dc543dad1ac34430c2f606938c9ec5135" }, - { url = "https://mirrors.aliyun.com/pypi/packages/eb/43/5e6d51d8c203a79cff084efa9f04a745b8ef5cf4c86dbb127e7b192f14d9/optree-0.19.0-cp313-cp313-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bb5752f17afa017b08b0cbac8a383d4bb90035b353bef7a25fe03cda69a21d33" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4b/dc/dc09347136876287b463b8599239d6fa338298fd322ac629817bd2f4def4/optree-0.19.0-cp313-cp313-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:e9b6245993494b1aa54529eb7356aeefa6704c8b436e6e5f20b25c30f7af7620" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ee/cc/5d2c9cf906bd3ae357e7221450bacefd0321d7b94e6171dec39552b346e6/optree-0.19.0-cp313-cp313-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:7351a24b30568c963a92b19f543c9562b36b3222caed2a5ac3209ef910972bec" }, - { url = "https://mirrors.aliyun.com/pypi/packages/64/7f/75b10f88da994fc3da3dc1ab7d54bab7bd3a6fa5eb81b586f13f8bd6ab0e/optree-0.19.0-cp313-cp313-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:2c6610a1d1d74af0f53c9bbabb7c265679a9a07e03783c8cc4a678ba3bb6f9a5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/78/fc/753bf69b907652d54b7c6012ccb320d8c1a3161454e415331058b6f04246/optree-0.19.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:37e07a5233be64329cbf41e20ab07c50da53bdc374109a2b376be49c4a34a37f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e2/a8/70640f9998438f50a0a1c57f2a12aac856cd937f2c4c4feef5a3cfe8e9c7/optree-0.19.0-cp313-cp313-manylinux_2_39_riscv64.whl", hash = "sha256:c23a25caff6b096b62379adb99e2c401805141497ebb8131f271a4c93f5ed5dc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ad/05/0b8bf4abf5d1a7cd9a19ba680e1ec64ad38eec3204e4e16a769e8aeaa4a2/optree-0.19.0-cp313-cp313-win32.whl", hash = "sha256:045cf112adaebc76c9c7cabde857c01babfc9fae8aa0a28d48f7c565fadf0cb9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b1/c7/9ce83f115d7f4a47741827a037067b9026c29996ad7913bc40277924c773/optree-0.19.0-cp313-cp313-win_amd64.whl", hash = "sha256:bc0c6c9f99fb90e3a20a8b94c219e6b03e585f65ab9a11c9acd1511a5f885f79" }, - { url = "https://mirrors.aliyun.com/pypi/packages/17/fd/97c27d6e51c8b958b29f5c7b4cdcae4f2e7c9ef5b5465be459811a48876b/optree-0.19.0-cp313-cp313-win_arm64.whl", hash = "sha256:48f492363fa0f9ffe5029d0ecafd2fa30ffe0d5d52c8dd414123f47b743bd42e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/46/45/9a2f05b5d033482b58ca36df6f41b0b28af3ccfa43267a82254c973dcd14/optree-0.19.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d6362b9e9a0f4dd7c5b88debe182a90541aba7f1ad02d00922d01c4df4b3c933" }, - { url = "https://mirrors.aliyun.com/pypi/packages/20/b7/5d0a013c5461e0933ce7385a06eed625358de12216c80da935138e6af205/optree-0.19.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:381096a293d385fd3135e5c707bb7e58c584bc9bd50f458237b49da21a621df3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d6/2c/d3f2674411c8e3338e91e7446af239597ae6efd23f14e2039f29ced3d73e/optree-0.19.0-cp313-cp313t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:a9675007cc54371be544bb33fd7eb07b0773d88deacf8aa4cc72fa735c4a4d33" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e9/e9/009964734f19d6996291e77f2c1da5d35a743defc4e89aefb01260e2f9d6/optree-0.19.0-cp313-cp313t-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:406b355d6f29f99535efa97ea16eda70414968271a894c99f48cd91848723706" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2b/4c/96706f855c6b623259e754f751020acfb3452e412f7c85330629ab4b9ecc/optree-0.19.0-cp313-cp313t-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:d05e5bf6ce30258cda643ea50cc424038e5107905e9fc11d19a04453a8d2ee27" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b5/e4/9b23a27c9bd211d22a2e55a5a66e62afe5c75ff98b81fc7d000d879e75e6/optree-0.19.0-cp313-cp313t-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:b6e11479d98690fc9efd15d65195af37608269bb1e176b5a836b066440f9c52f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/15/3b/462582f0050508f1ce0734f1dffd19078fb013fa12ccf0761c208ab6f756/optree-0.19.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8d523ffc6d3e22851ed25bec806a6c78d68340259e79941059752209b07a75ec" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d6/c6/843c6a33b700ef88407bd5840813e53c6986b6130d94c75c49ff7a2e31f9/optree-0.19.0-cp313-cp313t-manylinux_2_39_riscv64.whl", hash = "sha256:ca148527b6e5d59c25c733e66d4165fbcf85102f4ea10f096370fda533fe77d1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e3/ed/13f938444de70bec2ff0edef8917a08160d41436a3cad976e541d21747f5/optree-0.19.0-cp313-cp313t-win32.whl", hash = "sha256:40d067cf87e76ad21b8ee2e6ba0347c517c88c2ce7190d666b30b4057e4de5ba" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e1/a2/5074dedbc1be5deca76fe57285ec3e7d5d475922572f92a90f3b3a4f21c5/optree-0.19.0-cp313-cp313t-win_amd64.whl", hash = "sha256:b133e1b9a30ec0bca3f875cfa68c2ce88c0b9e08b21f97f687bb669266411f4a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/49/3a/ea23a29f63d8eadab4e030ebc1329906d44f631076cd1da4751388649960/optree-0.19.0-cp313-cp313t-win_arm64.whl", hash = "sha256:45184b3c73e2147b26b139f34f15c2111cde54b8893b1104a00281c3f283b209" }, - { url = "https://mirrors.aliyun.com/pypi/packages/81/46/643ea3d06c24d351888edfef387e611e550b64a14758169eaeb1d285e658/optree-0.19.0-cp314-cp314-android_24_arm64_v8a.whl", hash = "sha256:adf611b95d3159209c5d1eafcb2eb669733aaf75f9b6754f92d2d8b749192579" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d7/10/8717b93d93fcc3c42a6ee0e0a1a222fe25bc749b32a9e353b039dab836ce/optree-0.19.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:bad7bb78baa83f950bb3c59b09d7ca93d30f6bb975a1a7ce8c5f3dfe65fc834d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a1/5e/8263600ef51ae2decb3e31776c810b8c6b5f8927697046c4434b17346d9d/optree-0.19.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:73f122e8acf2f1fd346e9c08f771bc1f7394359793fe632a8e1040733bdbcbec" }, - { url = "https://mirrors.aliyun.com/pypi/packages/04/3c/40774378ebf423d7f074dfd7169f0466eb9de734f0ea5fbb368eddcb1e49/optree-0.19.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:36e426e96b3e1773e879189b12c306b58ae70052efc4087e3f14545701c7ac35" }, - { url = "https://mirrors.aliyun.com/pypi/packages/08/67/2e19866a03a6e75eb62194a5b55e1e3154ca1517478c300232b0229f8c2a/optree-0.19.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:d22b947603be4768c2bd73a59652c94d63465f928b3099e9035f9c48dfc61953" }, - { url = "https://mirrors.aliyun.com/pypi/packages/45/a5/7c059f643bc34c70cc5ebe63c82ae6c33b6b746219f96757d840ea1e2dcd/optree-0.19.0-cp314-cp314-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:14cc72d0c3a3c0d0b13c66801f2adc6583a01f8499fd151caaa649aabb7f99b9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/67/1a/2c5041cf476fb4b2a27f6644934ac2d079e3e4491f609cba411b3d890291/optree-0.19.0-cp314-cp314-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:5369ac9584ef3fbb703699be694e84dbc78b730bd6d00c48c0c5a588617a1980" }, - { url = "https://mirrors.aliyun.com/pypi/packages/40/a0/abcd7bc3218e1108d253d6783f3e610f0ac3d0e63b2720bff94eb4ed4689/optree-0.19.0-cp314-cp314-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:80b3dca5607f04316a9dcb2bb46df2f04abf4da71731bd4a53a1559c0bee6181" }, - { url = "https://mirrors.aliyun.com/pypi/packages/82/49/7983e66210c78965bc75e386c329ec34854370d337a9ebdc4c8aede3a0b3/optree-0.19.0-cp314-cp314-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:1bb36da9b95b165c7b77fd3ff0af36a30b802cd1c020da3bcdc8aa029991c4ea" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fe/16/00261f20f467b9e8950a76ec1749f01359bf47f2fc3dac5e206de99835c0/optree-0.19.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fb220bb85128c8de71aeffb9c38be817569e4bca413b38d5e0de11ba6471ef4a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/18/31/5e78a451ba9a6ed4b0903b10080dc028e3c9b9c5797cce0ca73990fb5604/optree-0.19.0-cp314-cp314-manylinux_2_39_riscv64.whl", hash = "sha256:5d2b83a37f150f827b8b0bc2c486056f9b2203e7b0bee699d2ee96a36c090f3a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9c/03/1516cb4fdb753cd76e5dc595217f84df48372bdabe1a7fb740a5b2530f5c/optree-0.19.0-cp314-cp314-win32.whl", hash = "sha256:b0c23d50b7f6a7c80f642307c87eee841cf513239706f2f60bd9480304170054" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7e/c3/587cc9aa8d4742cd690da79460081e7d834499e07e8b2bd2ccc4c66928df/optree-0.19.0-cp314-cp314-win_amd64.whl", hash = "sha256:ff773c852122cef6dcae68b5e252a20aaf5d2986f78e278d747e226e7829d44e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e6/9b/c17c74ef6b85ad1a2687de8a08d1b56e3a27154b4db6c3ef1e9c2c53a96c/optree-0.19.0-cp314-cp314-win_arm64.whl", hash = "sha256:259ac2a426816d53d576c143b8dca87176af45fc8efd5dfe09db50d74a2fa0a5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ef/4c/e881fb840cef2cead7582ee36c0e0348e66730cb2a2af1938338c72b1bf3/optree-0.19.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:428fdc8cf5dc43fa32496be6aa84fc0d8f549f899062dd9dd0aa7e3aa7f77ae9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b6/6b/0a8538815abe28e4307dd98385d4991d36555b841b060df3295a8408b856/optree-0.19.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:d1b497032b5823a09625b118fd4df84199fb0895afb78af536d638ce7645beb6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/71/0c/d70a513fa93dbaa0e3e8c9b218b3805efb7083369cd14e1340bd2c0bc910/optree-0.19.0-cp314-cp314t-manylinux_2_26_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e5f05fecbca17b48451ba3455198cec9db20802c0ffbbba51eaeb421bd846a1c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/77/04/bd30c9f4e694f7b6585f333208ac7894578c1fa30dc5c938f22155df7859/optree-0.19.0-cp314-cp314t-manylinux_2_26_i686.manylinux_2_28_i686.whl", hash = "sha256:a51d0ad4e9dd089f317c94d95b7fa360e87491324e2bfa83d9c4f18dd928d4e1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e5/17/aba83aa0e8bf31c00cdd3863c2a05854ce414426a69c094ae51210b76677/optree-0.19.0-cp314-cp314t-manylinux_2_26_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:108ab83937d91658ef96c4f70a6c76b36038754f4779907ee8f127780575740f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e0/da/52e684c42dc29d3b4d52f2029545742ef43e151cea112d9093d2ad164f53/optree-0.19.0-cp314-cp314t-manylinux_2_26_s390x.manylinux_2_28_s390x.whl", hash = "sha256:a39fdd614f46bcaf810b2bb1ed940e82b8a19e654bc325df0cc6554e25c3b7eb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2d/f7/0d41edf484e11ba5357f91dba8d85ce06ca9d840ac7d95e58b856a49b13b/optree-0.19.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfc1bcba22f182f39f1a80ae3ac511ebfa4daea62c3058edd021ce7a5cda3009" }, - { url = "https://mirrors.aliyun.com/pypi/packages/79/5e/a8f49cfd6c3ae0e59dcb1155cd49f1e5ba41889c9388360264c8369589c6/optree-0.19.0-cp314-cp314t-manylinux_2_39_riscv64.whl", hash = "sha256:afe595a052cc45d3addb6045f04a3ca7e1fb664de032ecbbb2bfd76dfe1fcb61" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9a/1b/4105e562d86b2de7eb3f240164a7dd3948e268878a9ee8925bfe1ad1da4f/optree-0.19.0-cp314-cp314t-win32.whl", hash = "sha256:b15ab972e2133e70570259386684624a17128daab7fb353a0a7435e9dd2c7354" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c4/43/bbc4c7a1f37f1a0ed6efe07a5c44b2835e81d1f6ce1cca6a395a2339e60f/optree-0.19.0-cp314-cp314t-win_amd64.whl", hash = "sha256:c90c15a80c325c2c6e03e20c95350df5db4591d35e8e4a35a40d2f865c260193" }, - { url = "https://mirrors.aliyun.com/pypi/packages/62/12/6758b43dbddc6911e3225a15ca686c913959fb63c267840b54f0002be503/optree-0.19.0-cp314-cp314t-win_arm64.whl", hash = "sha256:a1e7b358df8fc4b97a05380d446e87b08eac899c1f34d9846b9afa0be7f96bc7" }, -] - [[package]] name = "orjson" version = "3.10.18" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/81/0b/fea456a3ffe74e70ba30e01ec183a9b26bec4d497f61dcfce1b601059c60/orjson-3.10.18.tar.gz", hash = "sha256:e8da3947d92123eda795b68228cafe2724815621fe35e8e320a9e9593a4bcd53" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/21/1a/67236da0916c1a192d5f4ccbe10ec495367a726996ceb7614eaa687112f2/orjson-3.10.18-cp312-cp312-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:50c15557afb7f6d63bc6d6348e0337a880a04eaa9cd7c9d569bcb4e760a24753" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b3/bc/c7f1db3b1d094dc0c6c83ed16b161a16c214aaa77f311118a93f647b32dc/orjson-3.10.18-cp312-cp312-macosx_15_0_arm64.whl", hash = "sha256:356b076f1662c9813d5fa56db7d63ccceef4c271b1fb3dd522aca291375fcf17" }, - { url = "https://mirrors.aliyun.com/pypi/packages/af/84/664657cd14cc11f0d81e80e64766c7ba5c9b7fc1ec304117878cc1b4659c/orjson-3.10.18-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:559eb40a70a7494cd5beab2d73657262a74a2c59aff2068fdba8f0424ec5b39d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9a/bb/f50039c5bb05a7ab024ed43ba25d0319e8722a0ac3babb0807e543349978/orjson-3.10.18-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:f3c29eb9a81e2fbc6fd7ddcfba3e101ba92eaff455b8d602bf7511088bbc0eae" }, - { url = "https://mirrors.aliyun.com/pypi/packages/93/8c/ee74709fc072c3ee219784173ddfe46f699598a1723d9d49cbc78d66df65/orjson-3.10.18-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6612787e5b0756a171c7d81ba245ef63a3533a637c335aa7fcb8e665f4a0966f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6a/37/e6d3109ee004296c80426b5a62b47bcadd96a3deab7443e56507823588c5/orjson-3.10.18-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ac6bd7be0dcab5b702c9d43d25e70eb456dfd2e119d512447468f6405b4a69c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4f/5d/387dafae0e4691857c62bd02839a3bf3fa648eebd26185adfac58d09f207/orjson-3.10.18-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9f72f100cee8dde70100406d5c1abba515a7df926d4ed81e20a9730c062fe9ad" }, - { url = "https://mirrors.aliyun.com/pypi/packages/27/6f/875e8e282105350b9a5341c0222a13419758545ae32ad6e0fcf5f64d76aa/orjson-3.10.18-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9dca85398d6d093dd41dc0983cbf54ab8e6afd1c547b6b8a311643917fbf4e0c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/48/b2/73a1f0b4790dcb1e5a45f058f4f5dcadc8a85d90137b50d6bbc6afd0ae50/orjson-3.10.18-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:22748de2a07fcc8781a70edb887abf801bb6142e6236123ff93d12d92db3d406" }, - { url = "https://mirrors.aliyun.com/pypi/packages/56/f5/7ed133a5525add9c14dbdf17d011dd82206ca6840811d32ac52a35935d19/orjson-3.10.18-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:3a83c9954a4107b9acd10291b7f12a6b29e35e8d43a414799906ea10e75438e6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/11/7c/439654221ed9c3324bbac7bdf94cf06a971206b7b62327f11a52544e4982/orjson-3.10.18-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:303565c67a6c7b1f194c94632a4a39918e067bd6176a48bec697393865ce4f06" }, - { url = "https://mirrors.aliyun.com/pypi/packages/48/e7/d58074fa0cc9dd29a8fa2a6c8d5deebdfd82c6cfef72b0e4277c4017563a/orjson-3.10.18-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:86314fdb5053a2f5a5d881f03fca0219bfdf832912aa88d18676a5175c6916b5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/57/4d/fe17581cf81fb70dfcef44e966aa4003360e4194d15a3f38cbffe873333a/orjson-3.10.18-cp312-cp312-win32.whl", hash = "sha256:187ec33bbec58c76dbd4066340067d9ece6e10067bb0cc074a21ae3300caa84e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e6/22/469f62d25ab5f0f3aee256ea732e72dc3aab6d73bac777bd6277955bceef/orjson-3.10.18-cp312-cp312-win_amd64.whl", hash = "sha256:f9f94cf6d3f9cd720d641f8399e390e7411487e493962213390d1ae45c7814fc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/10/b0/1040c447fac5b91bc1e9c004b69ee50abb0c1ffd0d24406e1350c58a7fcb/orjson-3.10.18-cp312-cp312-win_arm64.whl", hash = "sha256:3d600be83fe4514944500fa8c2a0a77099025ec6482e8087d7659e891f23058a" }, { url = "https://mirrors.aliyun.com/pypi/packages/04/f0/8aedb6574b68096f3be8f74c0b56d36fd94bcf47e6c7ed47a7bd1474aaa8/orjson-3.10.18-cp313-cp313-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:69c34b9441b863175cc6a01f2935de994025e773f814412030f269da4f7be147" }, { url = "https://mirrors.aliyun.com/pypi/packages/bc/f7/7118f965541aeac6844fcb18d6988e111ac0d349c9b80cda53583e758908/orjson-3.10.18-cp313-cp313-macosx_15_0_arm64.whl", hash = "sha256:1ebeda919725f9dbdb269f59bc94f861afbe2a27dce5608cdba2d92772364d1c" }, { url = "https://mirrors.aliyun.com/pypi/packages/fb/d9/839637cc06eaf528dd8127b36004247bf56e064501f68df9ee6fd56a88ee/orjson-3.10.18-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5adf5f4eed520a4959d29ea80192fa626ab9a20b2ea13f8f6dc58644f6927103" }, @@ -4918,15 +4570,32 @@ wheels = [ [[package]] name = "ormsgpack" -version = "1.5.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/c5/70/11a6ab33136c2f98bb64e96743a55c7a87b87bae0413460cab7cc5764951/ormsgpack-1.5.0.tar.gz", hash = "sha256:00c0743ebaa8d21f1c868fbb609c99151ea79e67fec98b51a29077efd91ce348" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/47/19/df1626f9c149a20d2273eecf97ae913a026be2730264db86126ac3e594db/ormsgpack-1.5.0-cp312-cp312-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:a921b0d54b5fb5ba1ea4e87c65caa8992736224f1fc5ce8f46a882e918c8e22d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/82/cc/bad6d4a237ff0943cb1c8c4a12fe95bcd7ff81c0f8bca26340efd599aa1d/ormsgpack-1.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c6d423668e2c3abdbc474562b1c73360ff7326f06cb9532dcb73254b5b63dae4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c7/d1/3ed38a54923fe04eace750c0f0adbc149fb2b028375c71e864aee5e2d6d6/ormsgpack-1.5.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:eeb2dd4ed3e503a8266dcbfbb8d810a36baa34e4bb4229e90e9c213058a06d74" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9b/52/0261a80de2486793b4844c2668b17f49d03a20aba13a8d3d975831b1d866/ormsgpack-1.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f13bd643df1324e8797caba4c5c0168a87524df8424e8413ba29723e89a586a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b7/f0/2ebda08824d4f658c5ad048bcbe64e352b637b661b4d26c51d7403d30569/ormsgpack-1.5.0-cp312-none-win_amd64.whl", hash = "sha256:e016da381a126478c4bafab0ae19d3a2537f6471341ecced4bb61471e8841cad" }, +version = "1.12.2" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/12/0c/f1761e21486942ab9bb6feaebc610fa074f7c5e496e6962dea5873348077/ormsgpack-1.12.2.tar.gz", hash = "sha256:944a2233640273bee67521795a73cf1e959538e0dfb7ac635505010455e53b33" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/eb/29/bb0eba3288c0449efbb013e9c6f58aea79cf5cb9ee1921f8865f04c1a9d7/ormsgpack-1.12.2-cp313-cp313-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:5ea60cb5f210b1cfbad8c002948d73447508e629ec375acb82910e3efa8ff355" }, + { url = "https://mirrors.aliyun.com/pypi/packages/6e/31/5efa31346affdac489acade2926989e019e8ca98129658a183e3add7af5e/ormsgpack-1.12.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3601f19afdbea273ed70b06495e5794606a8b690a568d6c996a90d7255e51c1" }, + { url = "https://mirrors.aliyun.com/pypi/packages/eb/56/d0087278beef833187e0167f8527235ebe6f6ffc2a143e9de12a98b1ce87/ormsgpack-1.12.2-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:29a9f17a3dac6054c0dce7925e0f4995c727f7c41859adf9b5572180f640d172" }, + { url = "https://mirrors.aliyun.com/pypi/packages/1c/a2/072343e1413d9443e5a252a8eb591c2d5b1bffbe5e7bfc78c069361b92eb/ormsgpack-1.12.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39c1bd2092880e413902910388be8715f70b9f15f20779d44e673033a6146f2d" }, + { url = "https://mirrors.aliyun.com/pypi/packages/a2/8b/a0da3b98a91d41187a63b02dda14267eefc2a74fcb43cc2701066cf1510e/ormsgpack-1.12.2-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:50b7249244382209877deedeee838aef1542f3d0fc28b8fe71ca9d7e1896a0d7" }, + { url = "https://mirrors.aliyun.com/pypi/packages/19/bb/6d226bc4cf9fc20d8eb1d976d027a3f7c3491e8f08289a2e76abe96a65f3/ormsgpack-1.12.2-cp313-cp313-musllinux_1_2_armv7l.whl", hash = "sha256:5af04800d844451cf102a59c74a841324868d3f1625c296a06cc655c542a6685" }, + { url = "https://mirrors.aliyun.com/pypi/packages/fb/f1/bb2c7223398543dedb3dbf8bb93aaa737b387de61c5feaad6f908841b782/ormsgpack-1.12.2-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:cec70477d4371cd524534cd16472d8b9cc187e0e3043a8790545a9a9b296c258" }, + { url = "https://mirrors.aliyun.com/pypi/packages/7b/e8/0fb45f57a2ada1fed374f7494c8cd55e2f88ccd0ab0a669aa3468716bf5f/ormsgpack-1.12.2-cp313-cp313-win_amd64.whl", hash = "sha256:21f4276caca5c03a818041d637e4019bc84f9d6ca8baa5ea03e5cc8bf56140e9" }, + { url = "https://mirrors.aliyun.com/pypi/packages/7a/d4/0cfeea1e960d550a131001a7f38a5132c7ae3ebde4c82af1f364ccc5d904/ormsgpack-1.12.2-cp313-cp313-win_arm64.whl", hash = "sha256:baca4b6773d20a82e36d6fd25f341064244f9f86a13dead95dd7d7f996f51709" }, + { url = "https://mirrors.aliyun.com/pypi/packages/94/16/24d18851334be09c25e87f74307c84950f18c324a4d3c0b41dabdbf19c29/ormsgpack-1.12.2-cp314-cp314-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:bc68dd5915f4acf66ff2010ee47c8906dc1cf07399b16f4089f8c71733f6e36c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b5/a2/88b9b56f83adae8032ac6a6fa7f080c65b3baf9b6b64fd3d37bd202991d4/ormsgpack-1.12.2-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:46d084427b4132553940070ad95107266656cb646ea9da4975f85cb1a6676553" }, + { url = "https://mirrors.aliyun.com/pypi/packages/a9/80/43e4555963bf602e5bdc79cbc8debd8b6d5456c00d2504df9775e74b450b/ormsgpack-1.12.2-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c010da16235806cf1d7bc4c96bf286bfa91c686853395a299b3ddb49499a3e13" }, + { url = "https://mirrors.aliyun.com/pypi/packages/78/e1/7cfbf28de8bca6efe7e525b329c31277d1b64ce08dcba723971c241a9d60/ormsgpack-1.12.2-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:18867233df592c997154ff942a6503df274b5ac1765215bceba7a231bea2745d" }, + { url = "https://mirrors.aliyun.com/pypi/packages/95/f8/30ae5716e88d792a4e879debee195653c26ddd3964c968594ddef0a3cc7e/ormsgpack-1.12.2-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:b009049086ddc6b8f80c76b3955df1aa22a5fbd7673c525cd63bf91f23122ede" }, + { url = "https://mirrors.aliyun.com/pypi/packages/dc/81/aee5b18a3e3a0e52f718b37ab4b8af6fae0d9d6a65103036a90c2a8ffb5d/ormsgpack-1.12.2-cp314-cp314-musllinux_1_2_armv7l.whl", hash = "sha256:1dcc17d92b6390d4f18f937cf0b99054824a7815818012ddca925d6e01c2e49e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/bd/17/71c9ba472d5d45f7546317f467a5fc941929cd68fb32796ca3d13dcbaec2/ormsgpack-1.12.2-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:f04b5e896d510b07c0ad733d7fce2d44b260c5e6c402d272128f8941984e4285" }, + { url = "https://mirrors.aliyun.com/pypi/packages/2e/a6/ac99cd7fe77e822fed5250ff4b86fa66dd4238937dd178d2299f10b69816/ormsgpack-1.12.2-cp314-cp314-win_amd64.whl", hash = "sha256:ae3aba7eed4ca7cb79fd3436eddd29140f17ea254b91604aa1eb19bfcedb990f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/3a/67/339872846a1ae4592535385a1c1f93614138566d7af094200c9c3b45d1e5/ormsgpack-1.12.2-cp314-cp314-win_arm64.whl", hash = "sha256:118576ea6006893aea811b17429bfc561b4778fad393f5f538c84af70b01260c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/49/c2/6feb972dc87285ad381749d3882d8aecbde9f6ecf908dd717d33d66df095/ormsgpack-1.12.2-cp314-cp314t-macosx_10_12_x86_64.macosx_11_0_arm64.macosx_10_12_universal2.whl", hash = "sha256:7121b3d355d3858781dc40dafe25a32ff8a8242b9d80c692fd548a4b1f7fd3c8" }, + { url = "https://mirrors.aliyun.com/pypi/packages/a3/9a/900a6b9b413e0f8a471cf07830f9cf65939af039a362204b36bd5b581d8b/ormsgpack-1.12.2-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ee766d2e78251b7a63daf1cddfac36a73562d3ddef68cacfb41b2af64698033" }, + { url = "https://mirrors.aliyun.com/pypi/packages/87/4c/27a95466354606b256f24fad464d7c97ab62bce6cc529dd4673e1179b8fb/ormsgpack-1.12.2-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:292410a7d23de9b40444636b9b8f1e4e4b814af7f1ef476e44887e52a123f09d" }, + { url = "https://mirrors.aliyun.com/pypi/packages/73/cd/29cee6007bddf7a834e6cd6f536754c0535fcb939d384f0f37a38b1cddb8/ormsgpack-1.12.2-cp314-cp314t-win_amd64.whl", hash = "sha256:837dd316584485b72ef451d08dd3e96c4a11d12e4963aedb40e08f89685d8ec2" }, ] [[package]] @@ -4976,13 +4645,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/33/01/d40b85317f86cf08d853a4f495195c73815fdf205eef3993821720274518/pandas-2.3.3.tar.gz", hash = "sha256:e05e1af93b977f7eafa636d043f9f94c7ee3ac81af99c13508215942e64c993b" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/9c/fb/231d89e8637c808b997d172b18e9d4a4bc7bf31296196c260526055d1ea0/pandas-2.3.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:6d21f6d74eb1725c2efaa71a2bfc661a0689579b58e9c0ca58a739ff0b002b53" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5c/bd/bf8064d9cfa214294356c2d6702b716d3cf3bb24be59287a6a21e24cae6b/pandas-2.3.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3fd2f887589c7aa868e02632612ba39acb0b8948faf5cc58f0850e165bd46f35" }, - { url = "https://mirrors.aliyun.com/pypi/packages/57/56/cf2dbe1a3f5271370669475ead12ce77c61726ffd19a35546e31aa8edf4e/pandas-2.3.3-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ecaf1e12bdc03c86ad4a7ea848d66c685cb6851d807a26aa245ca3d2017a1908" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e5/63/cd7d615331b328e287d8233ba9fdf191a9c2d11b6af0c7a59cfcec23de68/pandas-2.3.3-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b3d11d2fda7eb164ef27ffc14b4fcab16a80e1ce67e9f57e19ec0afaf715ba89" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a6/de/8b1895b107277d52f2b42d3a6806e69cfef0d5cf1d0ba343470b9d8e0a04/pandas-2.3.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a68e15f780eddf2b07d242e17a04aa187a7ee12b40b930bfdd78070556550e98" }, - { url = "https://mirrors.aliyun.com/pypi/packages/87/21/84072af3187a677c5893b170ba2c8fbe450a6ff911234916da889b698220/pandas-2.3.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:371a4ab48e950033bcf52b6527eccb564f52dc826c02afd9a1bc0ab731bba084" }, - { url = "https://mirrors.aliyun.com/pypi/packages/86/41/585a168330ff063014880a80d744219dbf1dd7a1c706e75ab3425a987384/pandas-2.3.3-cp312-cp312-win_amd64.whl", hash = "sha256:a16dcec078a01eeef8ee61bf64074b4e524a2a3f4b3be9326420cabe59c4778b" }, { url = "https://mirrors.aliyun.com/pypi/packages/cd/4b/18b035ee18f97c1040d94debd8f2e737000ad70ccc8f5513f4eefad75f4b/pandas-2.3.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:56851a737e3470de7fa88e6131f41281ed440d29a9268dcbf0002da5ac366713" }, { url = "https://mirrors.aliyun.com/pypi/packages/31/94/72fac03573102779920099bcac1c3b05975c2cb5f01eac609faf34bed1ca/pandas-2.3.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:bdcd9d1167f4885211e401b3036c0c8d9e274eee67ea8d0758a256d60704cfe8" }, { url = "https://mirrors.aliyun.com/pypi/packages/16/87/9472cf4a487d848476865321de18cc8c920b8cab98453ab79dbbc98db63a/pandas-2.3.3-cp313-cp313-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e32e7cc9af0f1cc15548288a51a3b681cc2a219faa838e995f7dc53dbab1062d" }, @@ -5020,6 +4682,21 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/00/2f/804f58f0b856ab3bf21617cccf5b39206e6c4c94c2cd227bde125ea6105f/parameterized-0.9.0-py2.py3-none-any.whl", hash = "sha256:4e0758e3d41bea3bbd05ec14fc2c24736723f243b28d702081aef438c9372b1b" }, ] +[[package]] +name = "paramiko" +version = "5.0.0" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "bcrypt" }, + { name = "cryptography" }, + { name = "invoke" }, + { name = "pynacl" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/62/93/dcc25d52f49022ae6175d15e6bd751f1acc99b98bc61fc55e5155a7be2e7/paramiko-5.0.0.tar.gz", hash = "sha256:36763b5b95c2a0dcfdf1abc48e48156ee425b21efe2f0e787c2dd5a95c0e5e79" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/82/5b/eadf6d45de38d30ab603f49393b6cd2cbe7e233af8cf90197e32782b68a9/paramiko-5.0.0-py3-none-any.whl", hash = "sha256:b7044611c30140d9a75261653210e2002977b71a0497ff3ba0d98d7edbf62f7c" }, +] + [[package]] name = "patchright" version = "1.58.2" @@ -5106,17 +4783,6 @@ version = "12.2.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/8c/21/c2bcdd5906101a30244eaffc1b6e6ce71a31bd0742a01eb89e660ebfac2d/pillow-12.2.0.tar.gz", hash = "sha256:a830b1a40919539d07806aa58e1b114df53ddd43213d9c8b75847eee6c0182b5" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/58/be/7482c8a5ebebbc6470b3eb791812fff7d5e0216c2be3827b30b8bb6603ed/pillow-12.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2d192a155bbcec180f8564f693e6fd9bccff5a7af9b32e2e4bf8c9c69dbad6b5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d8/95/0a351b9289c2b5cbde0bacd4a83ebc44023e835490a727b2a3bd60ddc0f4/pillow-12.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3f40b3c5a968281fd507d519e444c35f0ff171237f4fdde090dd60699458421" }, - { url = "https://mirrors.aliyun.com/pypi/packages/de/af/4e8e6869cbed569d43c416fad3dc4ecb944cb5d9492defaed89ddd6fe871/pillow-12.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:03e7e372d5240cc23e9f07deca4d775c0817bffc641b01e9c3af208dbd300987" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e9/9e/c05e19657fd57841e476be1ab46c4d501bffbadbafdc31a6d665f8b737b6/pillow-12.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b86024e52a1b269467a802258c25521e6d742349d760728092e1bc2d135b4d76" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2b/54/1789c455ed10176066b6e7e6da1b01e50e36f94ba584dc68d9eebfe9156d/pillow-12.2.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7371b48c4fa448d20d2714c9a1f775a81155050d383333e0a6c15b1123dda005" }, - { url = "https://mirrors.aliyun.com/pypi/packages/43/e3/fdc657359e919462369869f1c9f0e973f353f9a9ee295a39b1fea8ee1a77/pillow-12.2.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:62f5409336adb0663b7caa0da5c7d9e7bdbaae9ce761d34669420c2a801b2780" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8b/f8/2f6825e441d5b1959d2ca5adec984210f1ec086435b0ed5f52c19b3b8a6e/pillow-12.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:01afa7cf67f74f09523699b4e88c73fb55c13346d212a59a2db1f86b0a63e8c5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/67/f9/029a27095ad20f854f9dba026b3ea6428548316e057e6fc3545409e86651/pillow-12.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc3d34d4a8fbec3e88a79b92e5465e0f9b842b628675850d860b8bd300b159f5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/be/42/025cfe05d1be22dbfdb4f264fe9de1ccda83f66e4fc3aac94748e784af04/pillow-12.2.0-cp312-cp312-win32.whl", hash = "sha256:58f62cc0f00fd29e64b29f4fd923ffdb3859c9f9e6105bfc37ba1d08994e8940" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5d/7b/25a221d2c761c6a8ae21bfa3874988ff2583e19cf8a27bf2fee358df7942/pillow-12.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7f84204dee22a783350679a0333981df803dac21a0190d706a50475e361c93f5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/10/e1/542a474affab20fd4a0f1836cb234e8493519da6b76899e30bcc5d990b8b/pillow-12.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:af73337013e0b3b46f175e79492d96845b16126ddf79c438d7ea7ff27783a414" }, { url = "https://mirrors.aliyun.com/pypi/packages/4a/01/53d10cf0dbad820a8db274d259a37ba50b88b24768ddccec07355382d5ad/pillow-12.2.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:8297651f5b5679c19968abefd6bb84d95fe30ef712eb1b2d9b2d31ca61267f4c" }, { url = "https://mirrors.aliyun.com/pypi/packages/0f/98/f3a6657ecb698c937f6c76ee564882945f29b79bad496abcba0e84659ec5/pillow-12.2.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:50d8520da2a6ce0af445fa6d648c4273c3eeefbc32d7ce049f22e8b5c3daecc2" }, { url = "https://mirrors.aliyun.com/pypi/packages/69/bc/8986948f05e3ea490b8442ea1c1d4d990b24a7e43d8a51b2c7d8b1dced36/pillow-12.2.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:766cef22385fa1091258ad7e6216792b156dc16d8d3fa607e7545b2b72061f1c" }, @@ -5294,13 +4960,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/42/8b/5f939eaf1fbeb7ff914fe540d659486951a056e5537b8f454362045b6c72/pot-0.9.6.post1.tar.gz", hash = "sha256:9b6cc14a8daecfe1268268168cf46548f9130976b22b24a9e8ec62a734be6c43" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/b9/28/13622807461f9f6082a8cd6768f9b4a810bc3a8fda474b81572da94b4d23/pot-0.9.6.post1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:f7c542fc20662e35c24dd82eeff8a737220757434d7f0038664a7322221452f7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c6/5c/b4e017560531f53d06798c681b0d0a9488bb8116bc98da9d399a3d096391/pot-0.9.6.post1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:c1755516a7354cbd6110ad2e5f341b98b9968240c2f0f67b0ff5e3ebcb3105bd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/07/9f/57e49b3f7173359741053c5e2766a45dcf649d767c2e967ef93526c9045f/pot-0.9.6.post1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3207362d3e3b5aaa783f452aa85f66e83edbefb5764f34662860af54ac72ee6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/30/60/fa72dd6094f7dbe6b38e2c6907af8cd0f18c6bd107e0cf4874deddaba883/pot-0.9.6.post1-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:05f6659c5657e6d7e9f98f4a82e0ed64f88e9fce69b2e557416d156343919ba3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2f/3f/cc519c1176116271b6282268a705162fa042c16cc922bc56039445c9d697/pot-0.9.6.post1-cp312-cp312-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4f1b0148ae17bec0ed12264c6da3a05e13913b716e2a8c9043242b5d8349d8df" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f5/01/0132c94404cd0b1b2f21c4a49698db9dcd6107c47c02b22df1ed38206b2a/pot-0.9.6.post1-cp312-cp312-win32.whl", hash = "sha256:571e543cc2b0a462365002203595baf2b89c3d064cce4fce70fd1231e832c21f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c1/6d/23229c0e198a4f7fb27750b3ef8497e6ebed23fe531ed64b5194da8b2b02/pot-0.9.6.post1-cp312-cp312-win_amd64.whl", hash = "sha256:b1d8bd9a334c72baa37f9a2b268de5366c23c0f9c9e3d6dc25d150137ec2823c" }, { url = "https://mirrors.aliyun.com/pypi/packages/53/17/e4aebb8deef58b0d40ac339d952d12c63559801b50ae43c622d49bebda7e/pot-0.9.6.post1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:659fff750a162f58b52b33a64c4ac358f4ff44e9dff0841052c088e1b6a54430" }, { url = "https://mirrors.aliyun.com/pypi/packages/f7/b9/3646c153b13f999ac30112dcf85c5f233af79b0d98c37b52dda9a624c91b/pot-0.9.6.post1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4f54830e9f9cb78b1ff7abd5c5bf162625ed6aea903241267c64ea9f0fb73ddb" }, { url = "https://mirrors.aliyun.com/pypi/packages/53/e9/c7092f7aec8cb32739ad66ba1f1259626546e4893b61b905ce2da3987235/pot-0.9.6.post1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e9fd4b1fafacd37debdb984687ddb26f5c43d1429401847d388a6f1bd1f10e98" }, @@ -5310,6 +4969,42 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/f7/b1/8ca34418e7c4a2ec666e2204539577287223c4e78ab80b1c746cedb559c3/pot-0.9.6.post1-cp313-cp313-win_amd64.whl", hash = "sha256:a43e2b61389bd32f5b488da2488999ed55867e95fedb25dd64f9f390e40b4fab" }, ] +[[package]] +name = "preshed" +version = "3.0.13" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "cymem" }, + { name = "murmurhash" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/43/75/fe6b7bbd0dea530a001b0e24c331b21a0be2786e402abf3c57f5dce43d4b/preshed-3.0.13.tar.gz", hash = "sha256:d75f718bbfd97e992f7827e0fa7faf6a91bdd9c922d5baa4b50d62731396cb89" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/0e/2a/401158195d6dc7f6aef0b354d74d0e95c9da124499448c2b3dbb95b71204/preshed-3.0.13-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c0d0c14187dc0078d8a63bf190ec045a4d13e7748b6caeb557a7d575e411410b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/88/8f/e20e64573988528785447a6893b2e7ab287ecfd85b3888e978b28812fd20/preshed-3.0.13-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7770987c2e57497cd26124a9be5f652b5b3ccd0def89859ab0da8bca6144a3de" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b9/72/18168f881359c4482d312f8dc196371bdd61c1583a52b34390da4c88bbea/preshed-3.0.13-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4a7bc48220de579be6bdb0a8715482cf36e2a625a6fd5ad26c9f43485a4a23b5" }, + { url = "https://mirrors.aliyun.com/pypi/packages/fd/3a/3543476091087102775568cea9885dde3453569e9aeee365809108de572f/preshed-3.0.13-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e5c8462472f790c16708306aef3a102a762bd19dfe3d2f8ee08bd5e12f51b835" }, + { url = "https://mirrors.aliyun.com/pypi/packages/cf/65/b13f01329decc44ef53cfb6b4601ba85382dcb2a4ec78d9250f03a418066/preshed-3.0.13-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:c046736239cc8d72670749b79b526e4111839a2fc461a58545d212797649129c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d1/c7/f1a996c6832234efd4d543041b582418d41ac480ee55c557ec9e65344637/preshed-3.0.13-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:7c333f18e9a81c8a6de0603fd8781e17115324b117c445ca91abdf7bfb1abe49" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e3/b9/96fb71499049885ce19545903fdd38877bbc2be0da47e37c04d01f3e9f66/preshed-3.0.13-cp313-cp313-win_amd64.whl", hash = "sha256:461327f8dd36520dcf1fd55a671e0c3c2c97a2d95e22fc85faa31173f4785dda" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ef/a7/32a4903019d936a2316fdd330bedddac287ac26326107d24fb76a1fbc60a/preshed-3.0.13-cp313-cp313-win_arm64.whl", hash = "sha256:35d6c5acb3ee3b12b87a551913063f0cec784055c2af16e028c19fe875f079d0" }, + { url = "https://mirrors.aliyun.com/pypi/packages/bb/b5/993886c98f5caaa6f07a648cac97a7c62a3093091cad65e1e43a1bd41cc4/preshed-3.0.13-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:d2f1efae396cadab5f3890a2fd43d2ee65373ef9096ccbb805e51e8d8bcc563b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c6/86/b7fd137cbf140afd6c45e895946068a15f5b55642916de0075e6eb18581c/preshed-3.0.13-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:8d6acc1f5031a535a55a6f7148e2f274554a8343a16309c700cebea0fe7aee8c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/8b/ca/21a7e79625614134273dfed32bca5bb4c2ec1313e33fbd12d41657536f1f/preshed-3.0.13-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:7da9d931e7660dcdd757e5870269f0c159126d682ed73ed313971d199eb0f334" }, + { url = "https://mirrors.aliyun.com/pypi/packages/8f/3a/2dbd299516461831ae90e0d5b0637137bf28520c4e6dd0b01d6f1886659a/preshed-3.0.13-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d4ae5cfe075bb7a07982e382bca44f41ddf041f4d24cbd358e8cccfc049259b8" }, + { url = "https://mirrors.aliyun.com/pypi/packages/7c/d3/af654eba4f6587c4ee02c5043e62c194b0a1c4431ffef0c67b9518f6b61c/preshed-3.0.13-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7557963d0125a3a7bcdb2eb6948f3e45da31b5a7f066b55320de3dea22d7557f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/bf/9b/ebcb2b9e8cb881e40b55b0bf450f8a6b187e2ef3ae0c685cce81d2d85026/preshed-3.0.13-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c4bc60dc994864095d784b7e4d77dba3e64188d169ac88722b699d175561fddb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/97/f7/c6c012779edcaa6e2cd092c554e98dc53e77f41205b07208655ba77e2327/preshed-3.0.13-cp314-cp314-win_amd64.whl", hash = "sha256:208dcebbe294bf1881ce33fb015d56ab2a7587aece85a09147727174207892e4" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f8/82/390ef87d732ef64e673ef6bf9e5d898453986e979efa50fb3a400e2c0766/preshed-3.0.13-cp314-cp314-win_arm64.whl", hash = "sha256:cf8e1a7a1823b2a7765121446c630140ac6e8650c07a6efbf375e168d1fef4f7" }, + { url = "https://mirrors.aliyun.com/pypi/packages/80/3a/a9dde3167bcecb27ae82ce4567b5ab1aa3989113ae6814c092ce223cc4ef/preshed-3.0.13-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:9ca43ecbc3783eda4d6ab3416ae2ecd9ef23dca5f53995843f69f7457bcd0677" }, + { url = "https://mirrors.aliyun.com/pypi/packages/74/d4/22d9355b50b6a13b407dcad0a81df83fb1d5602092d1f05834674dde8fda/preshed-3.0.13-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:c8596e41a258ff213553a441e0bb3eb388fd8158e84a7bf3aae6d8ede2c166d3" }, + { url = "https://mirrors.aliyun.com/pypi/packages/70/42/a225ee83fdb306d2a503f21a627953b820f4e079c90c8a84338957cb8ff5/preshed-3.0.13-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:4f8856ca3d88e9b250630d70abb4f260d8933151ddfb413024784b25b009868e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/40/ba/09a9dfe3d22d7e745483fd5d7f2a82cd4d39c161f7d2daa0faa4bd6402be/preshed-3.0.13-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0e5b2865aecbd2e1e10e5d19bb8bfad765863c1307c6c3e51f2a08bd64122409" }, + { url = "https://mirrors.aliyun.com/pypi/packages/6c/5c/e10e2e05133e7fcbd7c40536af1148c82dd24357b8f5726e2c7bc51cfd53/preshed-3.0.13-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:09f96b477c987755b3c945df214ea1c1c80bfb350e9f34e78da89585535b77e8" }, + { url = "https://mirrors.aliyun.com/pypi/packages/37/aa/51e5b4109a4cdfae28c3613eeeb10764a3794ebef8de93ffbb109465bea3/preshed-3.0.13-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:670db59a52e1823b5f088c764df474e65b686592d4093adbeef14581c95ee2cb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0e/6a/1d966f367a14c703dde629d150d996c1b727d442f620300b21c9ec1a24d1/preshed-3.0.13-cp314-cp314t-win_amd64.whl", hash = "sha256:b03e21b0bf95eb56e23973f32cabb930e94f352228652f81c0955dbd6967d904" }, + { url = "https://mirrors.aliyun.com/pypi/packages/22/80/368139067603e590a000122355f9c8576c8ebed4fb0b8849feaa2698489d/preshed-3.0.13-cp314-cp314t-win_arm64.whl", hash = "sha256:b980f3ea9bb74b7f94464bc3d6eb3c9162b6b79b531febd14c6465c24344d2cc" }, +] + [[package]] name = "primp" version = "1.1.3" @@ -5384,21 +5079,6 @@ version = "0.4.1" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9e/da/e9fc233cf63743258bff22b3dfa7ea5baef7b5bc324af47a0ad89b8ffc6f/propcache-0.4.1.tar.gz", hash = "sha256:f48107a8c637e80362555f37ecf49abe20370e557cc4ab374f04ec4423c97c3d" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/a2/0f/f17b1b2b221d5ca28b4b876e8bb046ac40466513960646bda8e1853cdfa2/propcache-0.4.1-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:e153e9cd40cc8945138822807139367f256f89c6810c2634a4f6902b52d3b4e2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/76/47/8ccf75935f51448ba9a16a71b783eb7ef6b9ee60f5d14c7f8a8a79fbeed7/propcache-0.4.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:cd547953428f7abb73c5ad82cbb32109566204260d98e41e5dfdc682eb7f8403" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0a/b6/5c9a0e42df4d00bfb4a3cbbe5cf9f54260300c88a0e9af1f47ca5ce17ac0/propcache-0.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f048da1b4f243fc44f205dfd320933a951b8d89e0afd4c7cacc762a8b9165207" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9e/d3/6c7ee328b39a81ee877c962469f1e795f9db87f925251efeb0545e0020d0/propcache-0.4.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ec17c65562a827bba85e3872ead335f95405ea1674860d96483a02f5c698fa72" }, - { url = "https://mirrors.aliyun.com/pypi/packages/01/5d/1c53f4563490b1d06a684742cc6076ef944bc6457df6051b7d1a877c057b/propcache-0.4.1-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:405aac25c6394ef275dee4c709be43745d36674b223ba4eb7144bf4d691b7367" }, - { url = "https://mirrors.aliyun.com/pypi/packages/20/e1/ce4620633b0e2422207c3cb774a0ee61cac13abc6217763a7b9e2e3f4a12/propcache-0.4.1-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:0013cb6f8dde4b2a2f66903b8ba740bdfe378c943c4377a200551ceb27f379e4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/46/4b/3aae6835b8e5f44ea6a68348ad90f78134047b503765087be2f9912140ea/propcache-0.4.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:15932ab57837c3368b024473a525e25d316d8353016e7cc0e5ba9eb343fbb1cf" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6e/a5/8a5e8678bcc9d3a1a15b9a29165640d64762d424a16af543f00629c87338/propcache-0.4.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:031dce78b9dc099f4c29785d9cf5577a3faf9ebf74ecbd3c856a7b92768c3df3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f1/63/b7b215eddeac83ca1c6b934f89d09a625aa9ee4ba158338854c87210cc36/propcache-0.4.1-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:ab08df6c9a035bee56e31af99be621526bd237bea9f32def431c656b29e41778" }, - { url = "https://mirrors.aliyun.com/pypi/packages/57/74/f580099a58c8af587cac7ba19ee7cb418506342fbbe2d4a4401661cca886/propcache-0.4.1-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:4d7af63f9f93fe593afbf104c21b3b15868efb2c21d07d8732c0c4287e66b6a6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c4/ee/542f1313aff7eaf19c2bb758c5d0560d2683dac001a1c96d0774af799843/propcache-0.4.1-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:cfc27c945f422e8b5071b6e93169679e4eb5bf73bbcbf1ba3ae3a83d2f78ebd9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8f/18/9c6b015dd9c6930f6ce2229e1f02fb35298b847f2087ea2b436a5bfa7287/propcache-0.4.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:35c3277624a080cc6ec6f847cbbbb5b49affa3598c4535a0a4682a697aaa5c75" }, - { url = "https://mirrors.aliyun.com/pypi/packages/80/9e/e7b85720b98c45a45e1fca6a177024934dc9bc5f4d5dd04207f216fc33ed/propcache-0.4.1-cp312-cp312-win32.whl", hash = "sha256:671538c2262dadb5ba6395e26c1731e1d52534bfe9ae56d0b5573ce539266aa8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/54/09/d19cff2a5aaac632ec8fc03737b223597b1e347416934c1b3a7df079784c/propcache-0.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:cb2d222e72399fcf5890d1d5cc1060857b9b236adff2792ff48ca2dfd46c81db" }, - { url = "https://mirrors.aliyun.com/pypi/packages/68/ab/6b5c191bb5de08036a8c697b265d4ca76148efb10fa162f14af14fb5f076/propcache-0.4.1-cp312-cp312-win_arm64.whl", hash = "sha256:204483131fb222bdaaeeea9f9e6c6ed0cac32731f75dfc1d4a567fc1926477c1" }, { url = "https://mirrors.aliyun.com/pypi/packages/bf/df/6d9c1b6ac12b003837dde8a10231a7344512186e87b36e855bef32241942/propcache-0.4.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:43eedf29202c08550aac1d14e0ee619b0430aaef78f85864c1a892294fbc28cf" }, { url = "https://mirrors.aliyun.com/pypi/packages/8b/e8/677a0025e8a2acf07d3418a2e7ba529c9c33caf09d3c1f25513023c1db56/propcache-0.4.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:d62cdfcfd89ccb8de04e0eda998535c406bf5e060ffd56be6c586cbcc05b3311" }, { url = "https://mirrors.aliyun.com/pypi/packages/89/a4/92380f7ca60f99ebae761936bc48a72a639e8a47b29050615eef757cb2a7/propcache-0.4.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:cae65ad55793da34db5f54e4029b89d3b9b9490d8abe1b4c7ab5d4b8ec7ebf74" }, @@ -5522,17 +5202,6 @@ version = "2.9.11" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/ac/6c/8767aaa597ba424643dc87348c6f1754dd9f48e80fdc1b9f7ca5c3a7c213/psycopg2-binary-2.9.11.tar.gz", hash = "sha256:b6aed9e096bf63f9e75edf2581aa9a7e7186d97ab5c177aa6c87797cd591236c" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/d8/91/f870a02f51be4a65987b45a7de4c2e1897dd0d01051e2b559a38fa634e3e/psycopg2_binary-2.9.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:be9b840ac0525a283a96b556616f5b4820e0526addb8dcf6525a0fa162730be4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/27/fa/cae40e06849b6c9a95eb5c04d419942f00d9eaac8d81626107461e268821/psycopg2_binary-2.9.11-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f090b7ddd13ca842ebfe301cd587a76a4cf0913b1e429eb92c1be5dbeb1a19bc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2d/75/364847b879eb630b3ac8293798e380e441a957c53657995053c5ec39a316/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ab8905b5dcb05bf3fb22e0cf90e10f469563486ffb6a96569e51f897c750a76a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6f/a0/567f7ea38b6e1c62aafd58375665a547c00c608a471620c0edc364733e13/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:bf940cd7e7fec19181fdbc29d76911741153d51cab52e5c21165f3262125685e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/30/da/4e42788fb811bbbfd7b7f045570c062f49e350e1d1f3df056c3fb5763353/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa0f693d3c68ae925966f0b14b8edda71696608039f4ed61b1fe9ffa468d16db" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3c/94/c1777c355bc560992af848d98216148be5f1be001af06e06fc49cbded578/psycopg2_binary-2.9.11-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a1cf393f1cdaf6a9b57c0a719a1068ba1069f022a59b8b1fe44b006745b59757" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bd/42/c9a21edf0e3daa7825ed04a4a8588686c6c14904344344a039556d78aa58/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ef7a6beb4beaa62f88592ccc65df20328029d721db309cb3250b0aae0fa146c3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/12/22/dedfbcfa97917982301496b6b5e5e6c5531d1f35dd2b488b08d1ebc52482/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:31b32c457a6025e74d233957cc9736742ac5a6cb196c6b68499f6bb51390bd6a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/66/ea/d3390e6696276078bd01b2ece417deac954dfdd552d2edc3d03204416c0c/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:edcb3aeb11cb4bf13a2af3c53a15b3d612edeb6409047ea0b5d6a21a9d744b34" }, - { url = "https://mirrors.aliyun.com/pypi/packages/12/9a/0402ded6cbd321da0c0ba7d34dc12b29b14f5764c2fc10750daa38e825fc/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b6d93d7c0b61a1dd6197d208ab613eb7dcfdcca0a49c42ceb082257991de9d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b1/d2/99b55e85832ccde77b211738ff3925a5d73ad183c0b37bcbbe5a8ff04978/psycopg2_binary-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:b33fabeb1fde21180479b2d4667e994de7bbf0eec22832ba5d9b5e4cf65b6c6d" }, { url = "https://mirrors.aliyun.com/pypi/packages/ff/a8/a2709681b3ac11b0b1786def10006b8995125ba268c9a54bea6f5ae8bd3e/psycopg2_binary-2.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b8fb3db325435d34235b044b199e56cdf9ff41223a4b9752e8576465170bb38c" }, { url = "https://mirrors.aliyun.com/pypi/packages/62/e1/c2b38d256d0dafd32713e9f31982a5b028f4a3651f446be70785f484f472/psycopg2_binary-2.9.11-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:366df99e710a2acd90efed3764bb1e28df6c675d33a7fb40df9b7281694432ee" }, { url = "https://mirrors.aliyun.com/pypi/packages/11/32/b2ffe8f3853c181e88f0a157c5fb4e383102238d73c52ac6d93a5c8bffe6/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c55b385daa2f92cb64b12ec4536c66954ac53654c7f15a203578da4e78105c0" }, @@ -5597,13 +5266,6 @@ version = "22.0.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/30/53/04a7fdc63e6056116c9ddc8b43bc28c12cdd181b85cbeadb79278475f3ae/pyarrow-22.0.0.tar.gz", hash = "sha256:3d600dc583260d845c7d8a6db540339dd883081925da2bd1c5cb808f720b3cd9" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/af/63/ba23862d69652f85b615ca14ad14f3bcfc5bf1b99ef3f0cd04ff93fdad5a/pyarrow-22.0.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:bea79263d55c24a32b0d79c00a1c58bb2ee5f0757ed95656b01c0fb310c5af3d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b1/d0/f9ad86fe809efd2bcc8be32032fa72e8b0d112b01ae56a053006376c5930/pyarrow-22.0.0-cp312-cp312-macosx_12_0_x86_64.whl", hash = "sha256:12fe549c9b10ac98c91cf791d2945e878875d95508e1a5d14091a7aaa66d9cf8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b4/a8/f910afcb14630e64d673f15904ec27dd31f1e009b77033c365c84e8c1e1d/pyarrow-22.0.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:334f900ff08ce0423407af97e6c26ad5d4e3b0763645559ece6fbf3747d6a8f5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/13/95/aec81f781c75cd10554dc17a25849c720d54feafb6f7847690478dcf5ef8/pyarrow-22.0.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:c6c791b09c57ed76a18b03f2631753a4960eefbbca80f846da8baefc6491fcfe" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bb/d4/74ac9f7a54cfde12ee42734ea25d5a3c9a45db78f9def949307a92720d37/pyarrow-22.0.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:c3200cb41cdbc65156e5f8c908d739b0dfed57e890329413da2748d1a2cd1a4e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2e/71/fedf2499bf7a95062eafc989ace56572f3343432570e1c54e6599d5b88da/pyarrow-22.0.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:ac93252226cf288753d8b46280f4edf3433bf9508b6977f8dd8526b521a1bbb9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/68/ed/b202abd5a5b78f519722f3d29063dda03c114711093c1995a33b8e2e0f4b/pyarrow-22.0.0-cp312-cp312-win_amd64.whl", hash = "sha256:44729980b6c50a5f2bfcc2668d36c569ce17f8b17bccaf470c4313dcbbf13c9d" }, { url = "https://mirrors.aliyun.com/pypi/packages/a6/d6/d0fac16a2963002fc22c8fa75180a838737203d558f0ed3b564c4a54eef5/pyarrow-22.0.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:e6e95176209257803a8b3d0394f21604e796dadb643d2f7ca21b66c9c0b30c9a" }, { url = "https://mirrors.aliyun.com/pypi/packages/c6/9c/1d6357347fbae062ad3f17082f9ebc29cc733321e892c0d2085f42a2212b/pyarrow-22.0.0-cp313-cp313-macosx_12_0_x86_64.whl", hash = "sha256:001ea83a58024818826a9e3f89bf9310a114f7e26dfe404a4c32686f97bd7901" }, { url = "https://mirrors.aliyun.com/pypi/packages/ff/c0/782344c2ce58afbea010150df07e3a2f5fdad299cd631697ae7bd3bac6e3/pyarrow-22.0.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:ce20fe000754f477c8a9125543f1936ea5b8867c5406757c224d745ed033e691" }, @@ -5661,12 +5323,6 @@ version = "1.4.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/f6/21/3c06205bb407e1f79b73b7b4dfb3950bd9537c4f625a68ab5cc41177f5bc/pyclipper-1.4.0.tar.gz", hash = "sha256:9882bd889f27da78add4dd6f881d25697efc740bf840274e749988d25496c8e1" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/90/1b/7a07b68e0842324d46c03e512d8eefa9cb92ba2a792b3b4ebf939dafcac3/pyclipper-1.4.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:222ac96c8b8281b53d695b9c4fedc674f56d6d4320ad23f1bdbd168f4e316140" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6b/dd/8bd622521c05d04963420ae6664093f154343ed044c53ea260a310c8bb4d/pyclipper-1.4.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:f3672dbafbb458f1b96e1ee3e610d174acb5ace5bd2ed5d1252603bb797f2fc6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7a/06/6e3e241882bf7d6ab23d9c69ba4e85f1ec47397cbbeee948a16cf75e21ed/pyclipper-1.4.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d1f807e2b4760a8e5c6d6b4e8c1d71ef52b7fe1946ff088f4fa41e16a881a5ca" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cf/f4/3418c1cd5eea640a9fa2501d4bc0b3655fa8d40145d1a4f484b987990a75/pyclipper-1.4.0-cp312-cp312-manylinux_2_24_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ce1f83c9a4e10ea3de1959f0ae79e9a5bd41346dff648fee6228ba9eaf8b3872" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ac/94/c85401d24be634af529c962dd5d781f3cb62a67cd769534df2cb3feee97a/pyclipper-1.4.0-cp312-cp312-win32.whl", hash = "sha256:3ef44b64666ebf1cb521a08a60c3e639d21b8c50bfbe846ba7c52a0415e936f4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/97/77/dfea08e3b230b82ee22543c30c35d33d42f846a77f96caf7c504dd54fab1/pyclipper-1.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:d1e5498d883b706a4ce636247f0d830c6eb34a25b843a1b78e2c969754ca9037" }, { url = "https://mirrors.aliyun.com/pypi/packages/67/d0/cbce7d47de1e6458f66a4d999b091640134deb8f2c7351eab993b70d2e10/pyclipper-1.4.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:d49df13cbb2627ccb13a1046f3ea6ebf7177b5504ec61bdef87d6a704046fd6e" }, { url = "https://mirrors.aliyun.com/pypi/packages/ce/cc/742b9d69d96c58ac156947e1b56d0f81cbacbccf869e2ac7229f2f86dc4e/pyclipper-1.4.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:37bfec361e174110cdddffd5ecd070a8064015c99383d95eb692c253951eee8a" }, { url = "https://mirrors.aliyun.com/pypi/packages/db/48/dd301d62c1529efdd721b47b9e5fb52120fcdac5f4d3405cfc0d2f391414/pyclipper-1.4.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:14c8bdb5a72004b721c4e6f448d2c2262d74a7f0c9e3076aeff41e564a92389f" }, @@ -5766,20 +5422,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/71/70/23b021c950c2addd24ec408e9ab05d59b035b39d97cdc1130e1bce647bb6/pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/5f/5d/5f6c63eebb5afee93bcaae4ce9a898f3373ca23df3ccaef086d0233a35a7/pydantic_core-2.41.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f41a7489d32336dbf2199c8c0a215390a751c5b014c2c1c5366e817202e9cdf7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/aa/32/9c2e8ccb57c01111e0fd091f236c7b371c1bccea0fa85247ac55b1e2b6b6/pydantic_core-2.41.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:070259a8818988b9a84a449a2a7337c7f430a22acc0859c6b110aa7212a6d9c0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/68/b8/a01b53cb0e59139fbc9e4fda3e9724ede8de279097179be4ff31f1abb65a/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69" }, - { url = "https://mirrors.aliyun.com/pypi/packages/38/de/8c36b5198a29bdaade07b5985e80a233a5ac27137846f3bc2d3b40a47360/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75" }, - { url = "https://mirrors.aliyun.com/pypi/packages/00/b5/0e8e4b5b081eac6cb3dbb7e60a65907549a1ce035a724368c330112adfdd/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05" }, - { url = "https://mirrors.aliyun.com/pypi/packages/77/56/87a61aad59c7c5b9dc8caad5a41a5545cba3810c3e828708b3d7404f6cef/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0d/76/941cc9f73529988688a665a5c0ecff1112b3d95ab48f81db5f7606f522d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d3/43/ebef01f69baa07a482844faaa0a591bad1ef129253ffd0cdaa9d8a7f72d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b1/87/41f3202e4193e3bacfc2c065fab7706ebe81af46a83d3e27605029c1f5a6/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/49/7d/4c00df99cb12070b6bccdef4a195255e6020a550d572768d92cc54dba91a/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cc/6a/ebf4b1d65d458f3cda6a7335d141305dfa19bdc61140a884d165a8a1bbc7/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/49/3b/774f2b5cd4192d5ab75870ce4381fd89cf218af999515baf07e7206753f0/pydantic_core-2.41.5-cp312-cp312-win32.whl", hash = "sha256:b74557b16e390ec12dca509bce9264c3bbd128f8a2c376eaa68003d7f327276d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/86/45/00173a033c801cacf67c190fef088789394feaf88a98a7035b0e40d53dc9/pydantic_core-2.41.5-cp312-cp312-win_amd64.whl", hash = "sha256:1962293292865bca8e54702b08a4f26da73adc83dd1fcf26fbc875b35d81c815" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f9/22/91fbc821fa6d261b376a3f73809f907cec5ca6025642c463d3488aad22fb/pydantic_core-2.41.5-cp312-cp312-win_arm64.whl", hash = "sha256:1746d4a3d9a794cacae06a5eaaccb4b8643a131d45fbc9af23e353dc0a5ba5c3" }, { url = "https://mirrors.aliyun.com/pypi/packages/87/06/8806241ff1f70d9939f9af039c6c35f2360cf16e93c2ca76f184e76b1564/pydantic_core-2.41.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:941103c9be18ac8daf7b7adca8228f8ed6bb7a1849020f643b3a14d15b1924d9" }, { url = "https://mirrors.aliyun.com/pypi/packages/94/02/abfa0e0bda67faa65fef1c84971c7e45928e108fe24333c81f3bfe35d5f5/pydantic_core-2.41.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:112e305c3314f40c93998e567879e887a3160bb8689ef3d2c04b6cc62c33ac34" }, { url = "https://mirrors.aliyun.com/pypi/packages/15/df/a4c740c0943e93e6500f9eb23f4ca7ec9bf71b19e608ae5b579678c8d02f/pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbaad15cb0c90aa221d43c00e77bb33c93e8d36e0bf74760cd00e732d10a6a0" }, @@ -5822,10 +5464,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/5c/96/5fb7d8c3c17bc8c62fdb031c47d77a1af698f1d7a406b0f79aaa1338f9ad/pydantic_core-2.41.5-cp314-cp314t-win32.whl", hash = "sha256:b4ececa40ac28afa90871c2cc2b9ffd2ff0bf749380fbdf57d165fd23da353aa" }, { url = "https://mirrors.aliyun.com/pypi/packages/22/ed/182129d83032702912c2e2d8bbe33c036f342cc735737064668585dac28f/pydantic_core-2.41.5-cp314-cp314t-win_amd64.whl", hash = "sha256:80aa89cad80b32a912a65332f64a4450ed00966111b6615ca6816153d3585a8c" }, { url = "https://mirrors.aliyun.com/pypi/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008" }, - { url = "https://mirrors.aliyun.com/pypi/packages/09/32/59b0c7e63e277fa7911c2fc70ccfb45ce4b98991e7ef37110663437005af/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:7da7087d756b19037bc2c06edc6c170eeef3c3bafcb8f532ff17d64dc427adfd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/aa/81/05e400037eaf55ad400bcd318c05bb345b57e708887f07ddb2d20e3f0e98/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:aabf5777b5c8ca26f7824cb4a120a740c9588ed58df9b2d196ce92fba42ff8dc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6e/0d/e3549b2399f71d56476b77dbf3cf8937cec5cd70536bdc0e374a421d0599/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c007fe8a43d43b3969e8469004e9845944f1a80e6acd47c150856bb87f230c56" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f7/07/34573da085946b6a313d7c42f82f16e8920bfd730665de2d11c0c37a74b5/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76d0819de158cd855d1cbb8fcafdf6f5cf1eb8e470abe056d5d161106e38062b" }, ] [[package]] @@ -6006,15 +5644,6 @@ version = "5.3.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/8f/85/44b10070a769a56bd910009bb185c0c0a82daff8d567cd1a116d7d730c7d/pyodbc-5.3.0.tar.gz", hash = "sha256:2fe0e063d8fb66efd0ac6dc39236c4de1a45f17c33eaded0d553d21c199f4d05" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/f5/0c/7ecf8077f4b932a5d25896699ff5c394ffc2a880a9c2c284d6a3e6ea5949/pyodbc-5.3.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5ebf6b5d989395efe722b02b010cb9815698a4d681921bf5db1c0e1195ac1bde" }, - { url = "https://mirrors.aliyun.com/pypi/packages/03/78/9fbde156055d88c1ef3487534281a5b1479ee7a2f958a7e90714968749ac/pyodbc-5.3.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:197bb6ddafe356a916b8ee1b8752009057fce58e216e887e2174b24c7ab99269" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9f/f9/8c106dcd6946e95fee0da0f1ba58cd90eb872eebe8968996a2ea1f7ac3c1/pyodbc-5.3.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:c6ccb5315ec9e081f5cbd66f36acbc820ad172b8fa3736cf7f993cdf69bd8a96" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4b/30/2c70f47a76a4fafa308d148f786aeb35a4d67a01d41002f1065b465d9994/pyodbc-5.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:5dd3d5e469f89a3112cf8b0658c43108a4712fad65e576071e4dd44d2bd763c7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7d/b2/0631d84731606bfe40d3b03a436b80cbd16b63b022c7b13444fb30761ca8/pyodbc-5.3.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b180bc5e49b74fd40a24ef5b0fe143d0c234ac1506febe810d7434bf47cb925b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/74/b9/707c5314cca9401081b3757301241c167a94ba91b4bd55c8fa591bf35a4a/pyodbc-5.3.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e3c39de3005fff3ae79246f952720d44affc6756b4b85398da4c5ea76bf8f506" }, - { url = "https://mirrors.aliyun.com/pypi/packages/97/7c/893036c8b0c8d359082a56efdaa64358a38dda993124162c3faa35d1924d/pyodbc-5.3.0-cp312-cp312-win32.whl", hash = "sha256:d32c3259762bef440707098010035bbc83d1c73d81a434018ab8c688158bd3bb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c0/70/5e61b216cc13c7f833ef87f4cdeab253a7873f8709253f5076e9bb16c1b3/pyodbc-5.3.0-cp312-cp312-win_amd64.whl", hash = "sha256:fe77eb9dcca5fc1300c9121f81040cc9011d28cff383e2c35416e9ec06d4bc95" }, - { url = "https://mirrors.aliyun.com/pypi/packages/aa/85/e7d0629c9714a85eb4f85d21602ce6d8a1ec0f313fde8017990cf913e3b4/pyodbc-5.3.0-cp312-cp312-win_arm64.whl", hash = "sha256:afe7c4ac555a8d10a36234788fc6cfc22a86ce37fc5ba88a1f75b3e6696665dc" }, { url = "https://mirrors.aliyun.com/pypi/packages/0c/1d/9e74cbcc1d4878553eadfd59138364b38656369eb58f7e5b42fb344c0ce7/pyodbc-5.3.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7e9ab0b91de28a5ab838ac4db0253d7cc8ce2452efe4ad92ee6a57b922bf0c24" }, { url = "https://mirrors.aliyun.com/pypi/packages/37/c7/27d83f91b3144d3e275b5b387f0564b161ddbc4ce1b72bb3b3653e7f4f7a/pyodbc-5.3.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:6132554ffbd7910524d643f13ce17f4a72f3a6824b0adef4e9a7f66efac96350" }, { url = "https://mirrors.aliyun.com/pypi/packages/1b/33/2bb24e7fc95e98a7b11ea5ad1f256412de35d2e9cc339be198258c1d9a76/pyodbc-5.3.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:1629af4706e9228d79dabb4863c11cceb22a6dab90700db0ef449074f0150c0d" }, @@ -6050,7 +5679,6 @@ version = "25.1.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ { name = "cryptography" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/04/8c/cd89ad05804f8e3c17dea8f178c3f40eeab5694c30e0c9f5bcd49f576fc3/pyopenssl-25.1.0.tar.gz", hash = "sha256:8d031884482e0c67ee92bf9a4d8cceb08d92aba7136432ffb0703c5280fc205b" } wheels = [ @@ -6153,7 +5781,6 @@ version = "1.3.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ { name = "pytest" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/90/2c/8af215c0f776415f3590cac4f9086ccefd6fd463befeae41cd4d3f193e5a/pytest_asyncio-1.3.0.tar.gz", hash = "sha256:d7f52f36d231b80ee124cd216ffb19369aa168fc10095013c6b014a34d3ee9e5" } wheels = [ @@ -6221,20 +5848,6 @@ version = "0.6.2" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/01/18/e1e53ade001b30a3c6642d876e5defe8431da8c31fb7798909e6c8ab8c34/python_calamine-0.6.2.tar.gz", hash = "sha256:2c90e5224c5e92db9fcd8f22b6085ce63b935cfe7a893ac9a1c3c56793bafd9d" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/7a/ec/e111c1a3a4c138ebc41e416e33730ee6d7c54e714af21c2a4e59b41715a5/python_calamine-0.6.2-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:857e4cddadba9b55c76dc583c58c5dc101a6cd5320190c10f8b2ab98d66c9040" }, - { url = "https://mirrors.aliyun.com/pypi/packages/53/26/fe4c2138ff21542e2f1130a4d83c330d7f9486b62775196e998b88a03de6/python_calamine-0.6.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:cd89d6a53e4b22328cd685fc054c31d359cb3ae67bd24bc57e1c1db62a4cfc97" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6a/b0/bfeaf45ac5e2f6553723dd2fbe127d1d17c6f26496db5781de42a933776a/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4d6c9af39db39e0c70710ae79cd1b5d980f9c0aea55fc16d194460c1561a0c6a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fe/6e/81106aa80609075015d400584030605b05f5e12931717160dcc58fdc4980/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9a2382dbc410dd48c99d89ee460662cc70892fe1b2901ab982604b923e8eb8f6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4d/ba/6311b24f9889246be63b664630c5601039ef771f7ed04c8f51aace39b7a9/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3ebb93255709874ede5b5e62828cb5758e60097e5390b6c9a3eb7751b617b12e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/23/e4/027a1b046d30768872307ebe808dc4cdc5357295cdcda98b30b3ea924904/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:837bca19bd945cb83aded433f4cf76e80d70a5400404d876400ca7e88e5ea311" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5c/4d/da8716a1b3a66938aaabe36873f6fa210fa063bab1b20c2ec236013de6b3/python_calamine-0.6.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:723990a47668cb819f307ccc634741370d3cd3804a0ee8cda392a522ae6d5016" }, - { url = "https://mirrors.aliyun.com/pypi/packages/36/40/9521e8da5496cbc4b18027626a40018301f546b3e9802ca2f3a6cb5b4739/python_calamine-0.6.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b067630d693e1d7de41e3d44a99c7dd3feebb52db8dda8636ac3f70d8b6a4ad6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cd/b0/7a63963512c5ba7e9539b7452e2b1561625e63e4e29c044e487e2e93dcbe/python_calamine-0.6.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:6ab09c9da53a2b33633e9f940aed11c08e083810a0fd6885826cdc52ba4f86a5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/22/81/e2bc38a5cf9629f656adcdabe8e134028f60c236e4bb96375dda90db3fdd/python_calamine-0.6.2-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:ae08e1308a0d0c6b8b4cc0a039ed8a85fc9ee2f8a3ca9ea57b1af9f97ed68fe4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b1/ea/513117015fd5903ca6dde9c8fb8502af60af6965642f4e3311623943e673/python_calamine-0.6.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:c441a20c7aff0e904ca01b5cdc1e5be2c6d4a41a24a0ea4d5ea6d211343bb95f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a2/14/8846478dacf31535f5f15448ade3bc688b51f3183f1b52844451aa27b0e6/python_calamine-0.6.2-cp312-cp312-win32.whl", hash = "sha256:39cae8e66f8bce499f5f965f4575ddf61e30184cc97f02e1c7031a57abe0903b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bb/e2/2d2dcf4ec7e5ec08e33bf966ab010a7be178a4b623bd5f7601d47f2c734c/python_calamine-0.6.2-cp312-cp312-win_amd64.whl", hash = "sha256:1617efa24532f2420934a8cf77e6d33ff1740cae1d39355cab4f4cf141fdab49" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f1/eb/2f50f3395c0435e6186cab56c36d04c06581ba827264bca1f1acae523aa3/python_calamine-0.6.2-cp312-cp312-win_arm64.whl", hash = "sha256:c2b378db494740e540e8157a7e5fe61dadae69ad2d988a7c80f9583f434acf07" }, { url = "https://mirrors.aliyun.com/pypi/packages/15/db/f409c3ffa5d452b8184978c94440b48c933c79232c5e40fe9ce3608ff06d/python_calamine-0.6.2-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:4c6e68c233841604fa3f63899d13bd2e47cddf0787c4b4b8188f74c3be452045" }, { url = "https://mirrors.aliyun.com/pypi/packages/66/fe/8cf4309a00ad5628c45e69f13352d6a1e0e0a3148a2fc28d7a43a8cefec9/python_calamine-0.6.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0fd5bcbd904d05f8b9f127a93706fdbb0a5934efdc9677b402a82d91e6e3f920" }, { url = "https://mirrors.aliyun.com/pypi/packages/b8/cc/c5edfb89a99d19c66b029e2e6dc0db052709888753fc0a771bf28343c5e5/python_calamine-0.6.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cef6454aa1b3b2137d7a202c9f84b87dffdd187ff218f2cee459480c102c20a3" }, @@ -6388,9 +6001,6 @@ name = "pywin32" version = "311" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/e7/ab/01ea1943d4eba0f850c3c61e78e8dd59757ff815ff3ccd0a84de5f541f42/pywin32-311-cp312-cp312-win32.whl", hash = "sha256:750ec6e621af2b948540032557b10a2d43b0cee2ae9758c54154d711cc852d31" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d1/a8/a0e8d07d4d051ec7502cd58b291ec98dcc0c3fff027caad0470b72cfcc2f/pywin32-311-cp312-cp312-win_amd64.whl", hash = "sha256:b8c095edad5c211ff31c05223658e71bf7116daa0ecf3ad85f3201ea3190d067" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ba/3a/2ae996277b4b50f17d61f0603efd8253cb2d79cc7ae159468007b586396d/pywin32-311-cp312-cp312-win_arm64.whl", hash = "sha256:e286f46a9a39c4a18b319c28f59b61de793654af2f395c102b4f819e584b5852" }, { url = "https://mirrors.aliyun.com/pypi/packages/a5/be/3fd5de0979fcb3994bfee0d65ed8ca9506a8a1260651b86174f6a86f52b3/pywin32-311-cp313-cp313-win32.whl", hash = "sha256:f95ba5a847cba10dd8c4d8fefa9f2a6cf283b8b88ed6178fa8a6c1ab16054d0d" }, { url = "https://mirrors.aliyun.com/pypi/packages/e3/28/e0a1909523c6890208295a29e05c2adb2126364e289826c0a8bc7297bd5c/pywin32-311-cp313-cp313-win_amd64.whl", hash = "sha256:718a38f7e5b058e76aee1c56ddd06908116d35147e133427e59a3983f703a20d" }, { url = "https://mirrors.aliyun.com/pypi/packages/04/bf/90339ac0f55726dce7d794e6d79a18a91265bdf3aa70b6b9ca52f35e022a/pywin32-311-cp313-cp313-win_arm64.whl", hash = "sha256:7b4075d959648406202d92a2310cb990fea19b535c7f4a78d3f5e10b926eeb8a" }, @@ -6405,16 +6015,6 @@ version = "6.0.3" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/05/8e/961c0007c59b8dd7729d542c61a4d537767a59645b82a0b521206e1e25c2/pyyaml-6.0.3.tar.gz", hash = "sha256:d76623373421df22fb4cf8817020cbb7ef15c725b9d5e45f17e189bfc384190f" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/d1/33/422b98d2195232ca1826284a76852ad5a86fe23e31b009c9886b2d0fb8b2/pyyaml-6.0.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7f047e29dcae44602496db43be01ad42fc6f1cc0d8cd6c83d342306c32270196" }, - { url = "https://mirrors.aliyun.com/pypi/packages/89/a0/6cf41a19a1f2f3feab0e9c0b74134aa2ce6849093d5517a0c550fe37a648/pyyaml-6.0.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:fc09d0aa354569bc501d4e787133afc08552722d3ab34836a80547331bb5d4a0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ed/23/7a778b6bd0b9a8039df8b1b1d80e2e2ad78aa04171592c8a5c43a56a6af4/pyyaml-6.0.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9149cad251584d5fb4981be1ecde53a1ca46c891a79788c0df828d2f166bda28" }, - { url = "https://mirrors.aliyun.com/pypi/packages/65/30/d7353c338e12baef4ecc1b09e877c1970bd3382789c159b4f89d6a70dc09/pyyaml-6.0.3-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:5fdec68f91a0c6739b380c83b951e2c72ac0197ace422360e6d5a959d8d97b2c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8b/9d/b3589d3877982d4f2329302ef98a8026e7f4443c765c46cfecc8858c6b4b/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:ba1cc08a7ccde2d2ec775841541641e4548226580ab850948cbfda66a1befcdc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/05/c0/b3be26a015601b822b97d9149ff8cb5ead58c66f981e04fedf4e762f4bd4/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8dc52c23056b9ddd46818a57b78404882310fb473d63f17b07d5c40421e47f8e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/be/8e/98435a21d1d4b46590d5459a22d88128103f8da4c2d4cb8f14f2a96504e1/pyyaml-6.0.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:41715c910c881bc081f1e8872880d3c650acf13dfa8214bad49ed4cede7c34ea" }, - { url = "https://mirrors.aliyun.com/pypi/packages/74/93/7baea19427dcfbe1e5a372d81473250b379f04b1bd3c4c5ff825e2327202/pyyaml-6.0.3-cp312-cp312-win32.whl", hash = "sha256:96b533f0e99f6579b3d4d4995707cf36df9100d67e0c8303a0c55b27b5f99bc5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/86/bf/899e81e4cce32febab4fb42bb97dcdf66bc135272882d1987881a4b519e9/pyyaml-6.0.3-cp312-cp312-win_amd64.whl", hash = "sha256:5fcd34e47f6e0b794d17de1b4ff496c00986e1c83f7ab2fb8fcfe9616ff7477b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1a/08/67bd04656199bbb51dbed1439b7f27601dfb576fb864099c7ef0c3e55531/pyyaml-6.0.3-cp312-cp312-win_arm64.whl", hash = "sha256:64386e5e707d03a7e172c0701abfb7e10f0fb753ee1d773128192742712a98fd" }, { url = "https://mirrors.aliyun.com/pypi/packages/d1/11/0fd08f8192109f7169db964b5707a2f1e8b745d4e239b784a5a1dd80d1db/pyyaml-6.0.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:8da9669d359f02c0b91ccc01cac4a67f16afec0dac22c2ad09f46bee0697eba8" }, { url = "https://mirrors.aliyun.com/pypi/packages/b1/16/95309993f1d3748cd644e02e38b75d50cbc0d9561d21f390a76242ce073f/pyyaml-6.0.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:2283a07e2c21a2aa78d9c4442724ec1eb15f5e42a723b99cb3d822d48f5f7ad1" }, { url = "https://mirrors.aliyun.com/pypi/packages/50/31/b20f376d3f810b9b2371e72ef5adb33879b25edb7a6d072cb7ca0c486398/pyyaml-6.0.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:ee2922902c45ae8ccada2c5b501ab86c36525b883eff4255313a253a3160861c" }, @@ -6547,7 +6147,7 @@ wheels = [ [[package]] name = "ragflow" -version = "0.25.2" +version = "0.25.5" source = { virtual = "." } dependencies = [ { name = "agentrun-sdk" }, @@ -6558,10 +6158,12 @@ dependencies = [ { name = "arxiv" }, { name = "asana" }, { name = "atlassian-python-api" }, + { name = "audioop-lts" }, { name = "azure-identity" }, { name = "azure-storage-file-datalake" }, { name = "beartype" }, { name = "bio" }, + { name = "boto3" }, { name = "boxsdk" }, { name = "captcha" }, { name = "chardet" }, @@ -6577,6 +6179,7 @@ dependencies = [ { name = "duckduckgo-search" }, { name = "editdistance" }, { name = "elasticsearch-dsl" }, + { name = "en-core-web-sm" }, { name = "exceptiongroup" }, { name = "extract-msg" }, { name = "feedparser" }, @@ -6622,6 +6225,7 @@ dependencies = [ { name = "opendal" }, { name = "opensearch-py" }, { name = "ormsgpack" }, + { name = "paramiko" }, { name = "pdfplumber" }, { name = "peewee" }, { name = "pluginlib" }, @@ -6653,8 +6257,8 @@ dependencies = [ { name = "selenium-wire" }, { name = "slack-sdk" }, { name = "socksio" }, + { name = "spacy" }, { name = "sqlglotrs" }, - { name = "strenum" }, { name = "tavily-python" }, { name = "tencentcloud-sdk-python" }, { name = "tika" }, @@ -6688,8 +6292,6 @@ test = [ { name = "reportlab" }, { name = "requests" }, { name = "requests-toolbelt" }, - { name = "tensorflow-cpu", version = "2.18.0", source = { registry = "https://mirrors.aliyun.com/pypi/simple" }, marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "tensorflow-cpu", version = "2.18.1", source = { registry = "https://mirrors.aliyun.com/pypi/simple" }, marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, ] [package.metadata] @@ -6702,10 +6304,12 @@ requires-dist = [ { name = "arxiv", specifier = "==2.1.3" }, { name = "asana", specifier = ">=5.2.2" }, { name = "atlassian-python-api", specifier = "==4.0.7" }, + { name = "audioop-lts", specifier = ">=0.2.1" }, { name = "azure-identity", specifier = ">=1.25.3" }, { name = "azure-storage-file-datalake", specifier = "==12.16.0" }, { name = "beartype", specifier = ">=0.20.0,<1.0.0" }, { name = "bio", specifier = "==1.7.1" }, + { name = "boto3", specifier = ">=1.28.0" }, { name = "boxsdk", specifier = ">=10.1.0" }, { name = "captcha", specifier = ">=0.7.1" }, { name = "chardet", specifier = ">=5.2.0,<6.0.0" }, @@ -6721,6 +6325,7 @@ requires-dist = [ { name = "duckduckgo-search", specifier = ">=7.2.0,<8.0.0" }, { name = "editdistance", specifier = "==0.8.1" }, { name = "elasticsearch-dsl", specifier = "==8.12.0" }, + { name = "en-core-web-sm", url = "https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.8.0/en_core_web_sm-3.8.0-py3-none-any.whl" }, { name = "exceptiongroup", specifier = ">=1.3.0,<2.0.0" }, { name = "extract-msg", specifier = ">=0.39.0" }, { name = "feedparser", specifier = ">=6.0.11,<7.0.0" }, @@ -6740,7 +6345,7 @@ requires-dist = [ { name = "grpcio-status", specifier = "==1.67.1" }, { name = "html-text", specifier = "==0.6.2" }, { name = "infinity-emb", specifier = ">=0.0.66,<0.0.67" }, - { name = "infinity-sdk", specifier = "==0.7.0.dev6" }, + { name = "infinity-sdk", specifier = "==0.7.0" }, { name = "jira", specifier = "==3.10.5" }, { name = "json-repair", specifier = "==0.35.0" }, { name = "langfuse", specifier = ">=4.0.1" }, @@ -6765,7 +6370,8 @@ requires-dist = [ { name = "opencv-python-headless", specifier = "==4.10.0.84" }, { name = "opendal", specifier = ">=0.45.0,<0.46.0" }, { name = "opensearch-py", specifier = "==2.7.1" }, - { name = "ormsgpack", specifier = "==1.5.0" }, + { name = "ormsgpack", specifier = ">=1.5.0" }, + { name = "paramiko", specifier = ">=3.5.1" }, { name = "pdfplumber", specifier = "==0.10.4" }, { name = "peewee", specifier = ">=3.17.1,<4.0.0" }, { name = "pluginlib", specifier = ">=0.10.0" }, @@ -6797,8 +6403,8 @@ requires-dist = [ { name = "selenium-wire", specifier = "==5.1.0" }, { name = "slack-sdk", specifier = "==3.37.0" }, { name = "socksio", specifier = "==1.0.0" }, + { name = "spacy", specifier = "==3.8.14" }, { name = "sqlglotrs", specifier = "==0.9.0" }, - { name = "strenum", specifier = "==0.4.15" }, { name = "tavily-python", specifier = "==0.5.1" }, { name = "tencentcloud-sdk-python", specifier = "==3.0.1478" }, { name = "tika", specifier = "==2.6.0" }, @@ -6832,7 +6438,6 @@ test = [ { name = "reportlab", specifier = ">=4.4.1" }, { name = "requests", specifier = ">=2.32.2" }, { name = "requests-toolbelt", specifier = ">=1.0.0" }, - { name = "tensorflow-cpu", specifier = ">=2.17.0" }, ] [[package]] @@ -6910,7 +6515,6 @@ source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ { name = "attrs" }, { name = "rpds-py" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/22/f5/df4e9027acead3ecc63e50fe1e36aca1523e1719559c499951bb4b53188f/referencing-0.37.0.tar.gz", hash = "sha256:44aefc3142c5b842538163acb373e24cce6632bd54bdb01b21ad5863489f50d8" } wheels = [ @@ -6923,22 +6527,6 @@ version = "2026.2.28" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/8b/71/41455aa99a5a5ac1eaf311f5d8efd9ce6433c03ac1e0962de163350d0d97/regex-2026.2.28.tar.gz", hash = "sha256:a729e47d418ea11d03469f321aaf67cdee8954cde3ff2cf8403ab87951ad10f2" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/07/42/9061b03cf0fc4b5fa2c3984cbbaed54324377e440a5c5a29d29a72518d62/regex-2026.2.28-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:fcf26c3c6d0da98fada8ae4ef0aa1c3405a431c0a77eb17306d38a89b02adcd7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/77/83/0c8a5623a233015595e3da499c5a1c13720ac63c107897a6037bb97af248/regex-2026.2.28-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:02473c954af35dd2defeb07e44182f5705b30ea3f351a7cbffa9177beb14da5d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9e/06/3ef1ac6910dc3295ebd71b1f9bfa737e82cfead211a18b319d45f85ddd09/regex-2026.2.28-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:9b65d33a17101569f86d9c5966a8b1d7fbf8afdda5a8aa219301b0a80f58cf7d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/dd/c9/8cc8d850b35ab5650ff6756a1cb85286e2000b66c97520b29c1587455344/regex-2026.2.28-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e71dcecaa113eebcc96622c17692672c2d104b1d71ddf7adeda90da7ddeb26fc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e9/5d/57702597627fc23278ebf36fbb497ac91c0ce7fec89ac6c81e420ca3e38c/regex-2026.2.28-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:481df4623fa4969c8b11f3433ed7d5e3dc9cec0f008356c3212b3933fb77e3d8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/02/6d/f3ecad537ca2811b4d26b54ca848cf70e04fcfc138667c146a9f3157779c/regex-2026.2.28-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:64e7c6ad614573e0640f271e811a408d79a9e1fe62a46adb602f598df42a818d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9e/40/bb226f203caa22c1043c1ca79b36340156eca0f6a6742b46c3bb222a3a57/regex-2026.2.28-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:d6b08a06976ff4fb0d83077022fde3eca06c55432bb997d8c0495b9a4e9872f4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/44/7c/c6d91d8911ac6803b45ca968e8e500c46934e58c0903cbc6d760ee817a0a/regex-2026.2.28-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:864cdd1a2ef5716b0ab468af40139e62ede1b3a53386b375ec0786bb6783fc05" }, - { url = "https://mirrors.aliyun.com/pypi/packages/dc/8d/4a9368d168d47abd4158580b8c848709667b1cd293ff0c0c277279543bd0/regex-2026.2.28-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:511f7419f7afab475fd4d639d4aedfc54205bcb0800066753ef68a59f0f330b5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cc/bf/2c72ab5d8b7be462cb1651b5cc333da1d0068740342f350fcca3bca31947/regex-2026.2.28-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:b42f7466e32bf15a961cf09f35fa6323cc72e64d3d2c990b10de1274a5da0a59" }, - { url = "https://mirrors.aliyun.com/pypi/packages/7c/f4/6b65c979bb6d09f51bb2d2a7bc85de73c01ec73335d7ddd202dcb8cd1c8f/regex-2026.2.28-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:8710d61737b0c0ce6836b1da7109f20d495e49b3809f30e27e9560be67a257bf" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8e/32/29ea5e27400ee86d2cc2b4e80aa059df04eaf78b4f0c18576ae077aeff68/regex-2026.2.28-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:4390c365fd2d45278f45afd4673cb90f7285f5701607e3ad4274df08e36140ae" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1d/91/3233d03b5f865111cd517e1c95ee8b43e8b428d61fa73764a80c9bb6f537/regex-2026.2.28-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:cb3b1db8ff6c7b8bf838ab05583ea15230cb2f678e569ab0e3a24d1e8320940b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/76/92/abc706c1fb03b4580a09645b206a3fc032f5a9f457bc1a8038ac555658ab/regex-2026.2.28-cp312-cp312-win32.whl", hash = "sha256:f8ed9a5d4612df9d4de15878f0bc6aa7a268afbe5af21a3fdd97fa19516e978c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fa/06/2a6f7dff190e5fa9df9fb4acf2fdf17a1aa0f7f54596cba8de608db56b3a/regex-2026.2.28-cp312-cp312-win_amd64.whl", hash = "sha256:01d65fd24206c8e1e97e2e31b286c59009636c022eb5d003f52760b0f42155d4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b7/f0/58a2484851fadf284458fdbd728f580d55c1abac059ae9f048c63b92f427/regex-2026.2.28-cp312-cp312-win_arm64.whl", hash = "sha256:c0b5ccbb8ffb433939d248707d4a8b31993cb76ab1a0187ca886bf50e96df952" }, { url = "https://mirrors.aliyun.com/pypi/packages/87/f6/dc9ef48c61b79c8201585bf37fa70cd781977da86e466cd94e8e95d2443b/regex-2026.2.28-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:6d63a07e5ec8ce7184452cb00c41c37b49e67dc4f73b2955b5b8e782ea970784" }, { url = "https://mirrors.aliyun.com/pypi/packages/95/c8/c20390f2232d3f7956f420f4ef1852608ad57aa26c3dd78516cb9f3dc913/regex-2026.2.28-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e59bc8f30414d283ae8ee1617b13d8112e7135cb92830f0ec3688cb29152585a" }, { url = "https://mirrors.aliyun.com/pypi/packages/d2/a6/ba1068a631ebd71a230e7d8013fcd284b7c89c35f46f34a7da02082141b1/regex-2026.2.28-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:de0cf053139f96219ccfabb4a8dd2d217c8c82cb206c91d9f109f3f552d6b43d" }, @@ -7128,21 +6716,6 @@ version = "0.30.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/20/af/3f2f423103f1113b36230496629986e0ef7e199d2aa8392452b484b38ced/rpds_py-0.30.0.tar.gz", hash = "sha256:dd8ff7cf90014af0c0f787eea34794ebf6415242ee1d6fa91eaba725cc441e84" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/03/e7/98a2f4ac921d82f33e03f3835f5bf3a4a40aa1bfdc57975e74a97b2b4bdd/rpds_py-0.30.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:a161f20d9a43006833cd7068375a94d035714d73a172b681d8881820600abfad" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4d/a1/bca7fd3d452b272e13335db8d6b0b3ecde0f90ad6f16f3328c6fb150c889/rpds_py-0.30.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6abc8880d9d036ecaafe709079969f56e876fcf107f7a8e9920ba6d5a3878d05" }, - { url = "https://mirrors.aliyun.com/pypi/packages/65/1c/ae157e83a6357eceff62ba7e52113e3ec4834a84cfe07fa4b0757a7d105f/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ca28829ae5f5d569bb62a79512c842a03a12576375d5ece7d2cadf8abe96ec28" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d4/36/eb2eb8515e2ad24c0bd43c3ee9cd74c33f7ca6430755ccdb240fd3144c44/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:a1010ed9524c73b94d15919ca4d41d8780980e1765babf85f9a2f90d247153dd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d6/65/ad8dc1784a331fabbd740ef6f71ce2198c7ed0890dab595adb9ea2d775a1/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f8d1736cfb49381ba528cd5baa46f82fdc65c06e843dab24dd70b63d09121b3f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/63/8e/0cfa7ae158e15e143fe03993b5bcd743a59f541f5952e1546b1ac1b5fd45/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d948b135c4693daff7bc2dcfc4ec57237a29bd37e60c2fabf5aff2bbacf3e2f1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/60/1b/6f8f29f3f995c7ffdde46a626ddccd7c63aefc0efae881dc13b6e5d5bb16/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47f236970bccb2233267d89173d3ad2703cd36a0e2a6e92d0560d333871a3d23" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6d/d5/a266341051a7a3ca2f4b750a3aa4abc986378431fc2da508c5034d081b70/rpds_py-0.30.0-cp312-cp312-manylinux_2_31_riscv64.whl", hash = "sha256:2e6ecb5a5bcacf59c3f912155044479af1d0b6681280048b338b28e364aca1f6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/10/3b/71b725851df9ab7a7a4e33cf36d241933da66040d195a84781f49c50490c/rpds_py-0.30.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:a8fa71a2e078c527c3e9dc9fc5a98c9db40bcc8a92b4e8858e36d329f8684b51" }, - { url = "https://mirrors.aliyun.com/pypi/packages/00/2b/e59e58c544dc9bd8bd8384ecdb8ea91f6727f0e37a7131baeff8d6f51661/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:73c67f2db7bc334e518d097c6d1e6fed021bbc9b7d678d6cc433478365d1d5f5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/da/3e/a18e6f5b460893172a7d6a680e86d3b6bc87a54c1f0b03446a3c8c7b588f/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:5ba103fb455be00f3b1c2076c9d4264bfcb037c976167a6047ed82f23153f02e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5c/e2/714694e4b87b85a18e2c243614974413c60aa107fd815b8cbc42b873d1d7/rpds_py-0.30.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:7cee9c752c0364588353e627da8a7e808a66873672bcb5f52890c33fd965b394" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6f/ab/d5d5e3bcedb0a77f4f613706b750e50a5a3ba1c15ccd3665ecc636c968fd/rpds_py-0.30.0-cp312-cp312-win32.whl", hash = "sha256:1ab5b83dbcf55acc8b08fc62b796ef672c457b17dbd7820a11d6c52c06839bdf" }, - { url = "https://mirrors.aliyun.com/pypi/packages/39/3b/f786af9957306fdc38a74cef405b7b93180f481fb48453a114bb6465744a/rpds_py-0.30.0-cp312-cp312-win_amd64.whl", hash = "sha256:a090322ca841abd453d43456ac34db46e8b05fd9b3b4ac0c78bcde8b089f959b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f3/d2/b91dc748126c1559042cfe41990deb92c4ee3e2b415f6b5234969ffaf0cc/rpds_py-0.30.0-cp312-cp312-win_arm64.whl", hash = "sha256:669b1805bd639dd2989b281be2cfd951c6121b65e729d9b843e9639ef1fd555e" }, { url = "https://mirrors.aliyun.com/pypi/packages/ed/dc/d61221eb88ff410de3c49143407f6f3147acf2538c86f2ab7ce65ae7d5f9/rpds_py-0.30.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:f83424d738204d9770830d35290ff3273fbb02b41f919870479fab14b9d303b2" }, { url = "https://mirrors.aliyun.com/pypi/packages/fd/32/55fb50ae104061dbc564ef15cc43c013dc4a9f4527a1f4d99baddf56fe5f/rpds_py-0.30.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e7536cd91353c5273434b4e003cbda89034d67e7710eab8761fd918ec6c69cf8" }, { url = "https://mirrors.aliyun.com/pypi/packages/58/70/faed8186300e3b9bdd138d0273109784eea2396c68458ed580f885dfe7ad/rpds_py-0.30.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2771c6c15973347f50fece41fc447c054b7ac2ae0502388ce3b6738cd366e3d4" }, @@ -7259,16 +6832,6 @@ version = "0.2.15" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/ea/97/60fda20e2fb54b83a61ae14648b0817c8f5d84a3821e40bfbdae1437026a/ruamel_yaml_clib-0.2.15.tar.gz", hash = "sha256:46e4cc8c43ef6a94885f72512094e482114a8a706d3c555a34ed4b0d20200600" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/72/4b/5fde11a0722d676e469d3d6f78c6a17591b9c7e0072ca359801c4bd17eee/ruamel_yaml_clib-0.2.15-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:cb15a2e2a90c8475df45c0949793af1ff413acfb0a716b8b94e488ea95ce7cff" }, - { url = "https://mirrors.aliyun.com/pypi/packages/85/82/4d08ac65ecf0ef3b046421985e66301a242804eb9a62c93ca3437dc94ee0/ruamel_yaml_clib-0.2.15-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:64da03cbe93c1e91af133f5bec37fd24d0d4ba2418eaf970d7166b0a26a148a2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b9/cb/22366d68b280e281a932403b76da7a988108287adff2bfa5ce881200107a/ruamel_yaml_clib-0.2.15-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:f6d3655e95a80325b84c4e14c080b2470fe4f33b6846f288379ce36154993fb1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/71/73/81230babf8c9e33770d43ed9056f603f6f5f9665aea4177a2c30ae48e3f3/ruamel_yaml_clib-0.2.15-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:71845d377c7a47afc6592aacfea738cc8a7e876d586dfba814501d8c53c1ba60" }, - { url = "https://mirrors.aliyun.com/pypi/packages/61/62/150c841f24cda9e30f588ef396ed83f64cfdc13b92d2f925bb96df337ba9/ruamel_yaml_clib-0.2.15-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11e5499db1ccbc7f4b41f0565e4f799d863ea720e01d3e99fa0b7b5fcd7802c9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/30/93/e79bd9cbecc3267499d9ead919bd61f7ddf55d793fb5ef2b1d7d92444f35/ruamel_yaml_clib-0.2.15-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4b293a37dc97e2b1e8a1aec62792d1e52027087c8eea4fc7b5abd2bdafdd6642" }, - { url = "https://mirrors.aliyun.com/pypi/packages/8d/06/1eb640065c3a27ce92d76157f8efddb184bd484ed2639b712396a20d6dce/ruamel_yaml_clib-0.2.15-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:512571ad41bba04eac7268fe33f7f4742210ca26a81fe0c75357fa682636c690" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a5/21/ee353e882350beab65fcc47a91b6bdc512cace4358ee327af2962892ff16/ruamel_yaml_clib-0.2.15-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e5e9f630c73a490b758bf14d859a39f375e6999aea5ddd2e2e9da89b9953486a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/57/34/cc1b94057aa867c963ecf9ea92ac59198ec2ee3a8d22a126af0b4d4be712/ruamel_yaml_clib-0.2.15-cp312-cp312-win32.whl", hash = "sha256:f4421ab780c37210a07d138e56dd4b51f8642187cdfb433eb687fe8c11de0144" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b3/e5/8925a4208f131b218f9a7e459c0d6fcac8324ae35da269cb437894576366/ruamel_yaml_clib-0.2.15-cp312-cp312-win_amd64.whl", hash = "sha256:2b216904750889133d9222b7b873c199d48ecbb12912aca78970f84a5aa1a4bc" }, { url = "https://mirrors.aliyun.com/pypi/packages/17/5e/2f970ce4c573dc30c2f95825f2691c96d55560268ddc67603dc6ea2dd08e/ruamel_yaml_clib-0.2.15-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4dcec721fddbb62e60c2801ba08c87010bd6b700054a09998c4d09c08147b8fb" }, { url = "https://mirrors.aliyun.com/pypi/packages/d6/03/a1baa5b94f71383913f21b96172fb3a2eb5576a4637729adbf7cd9f797f8/ruamel_yaml_clib-0.2.15-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:65f48245279f9bb301d1276f9679b82e4c080a1ae25e679f682ac62446fac471" }, { url = "https://mirrors.aliyun.com/pypi/packages/dc/19/40d676802390f85784235a05788fd28940923382e3f8b943d25febbb98b7/ruamel_yaml_clib-0.2.15-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:46895c17ead5e22bea5e576f1db7e41cb273e8d062c04a6a49013d9f60996c25" }, @@ -7338,12 +6901,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/0e/d4/40988bf3b8e34feec1d0e6a051446b1f66225f8529b9309becaeef62b6c4/scikit_learn-1.8.0.tar.gz", hash = "sha256:9bccbb3b40e3de10351f8f5068e105d0f4083b1a65fa07b6634fbc401a6287fd" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/90/74/e6a7cc4b820e95cc38cf36cd74d5aa2b42e8ffc2d21fe5a9a9c45c1c7630/scikit_learn-1.8.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:5fb63362b5a7ddab88e52b6dbb47dac3fd7dafeee740dc6c8d8a446ddedade8e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/49/d8/9be608c6024d021041c7f0b3928d4749a706f4e2c3832bbede4fb4f58c95/scikit_learn-1.8.0-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:5025ce924beccb28298246e589c691fe1b8c1c96507e6d27d12c5fadd85bfd76" }, - { url = "https://mirrors.aliyun.com/pypi/packages/dd/47/f187b4636ff80cc63f21cd40b7b2d177134acaa10f6bb73746130ee8c2e5/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4496bb2cf7a43ce1a2d7524a79e40bc5da45cf598dbf9545b7e8316ccba47bb4" }, - { url = "https://mirrors.aliyun.com/pypi/packages/97/74/b7a304feb2b49df9fafa9382d4d09061a96ee9a9449a7cbea7988dda0828/scikit_learn-1.8.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a0bcfe4d0d14aec44921545fd2af2338c7471de9cb701f1da4c9d85906ab847a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9f/c4/0ab22726a04ede56f689476b760f98f8f46607caecff993017ac1b64aa5d/scikit_learn-1.8.0-cp312-cp312-win_amd64.whl", hash = "sha256:35c007dedb2ffe38fe3ee7d201ebac4a2deccd2408e8621d53067733e3c74809" }, - { url = "https://mirrors.aliyun.com/pypi/packages/24/90/344a67811cfd561d7335c1b96ca21455e7e472d281c3c279c4d3f2300236/scikit_learn-1.8.0-cp312-cp312-win_arm64.whl", hash = "sha256:8c497fff237d7b4e07e9ef1a640887fa4fb765647f86fbe00f969ff6280ce2bb" }, { url = "https://mirrors.aliyun.com/pypi/packages/03/aa/e22e0768512ce9255eba34775be2e85c2048da73da1193e841707f8f039c/scikit_learn-1.8.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:0d6ae97234d5d7079dc0040990a6f7aeb97cb7fa7e8945f1999a429b23569e0a" }, { url = "https://mirrors.aliyun.com/pypi/packages/58/37/31b83b2594105f61a381fc74ca19e8780ee923be2d496fcd8d2e1147bd99/scikit_learn-1.8.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:edec98c5e7c128328124a029bceb09eda2d526997780fef8d65e9a69eead963e" }, { url = "https://mirrors.aliyun.com/pypi/packages/2d/5a/3f1caed8765f33eabb723596666da4ebbf43d11e96550fb18bdec42b467b/scikit_learn-1.8.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:74b66d8689d52ed04c271e1329f0c61635bcaf5b926db9b12d58914cdc01fe57" }, @@ -7379,16 +6936,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/7a/97/5a3609c4f8d58b039179648e62dd220f89864f56f7357f5d4f45c29eb2cc/scipy-1.17.1.tar.gz", hash = "sha256:95d8e012d8cb8816c226aef832200b1d45109ed4464303e997c5b13122b297c0" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/35/48/b992b488d6f299dbe3f11a20b24d3dda3d46f1a635ede1c46b5b17a7b163/scipy-1.17.1-cp312-cp312-macosx_10_14_x86_64.whl", hash = "sha256:35c3a56d2ef83efc372eaec584314bd0ef2e2f0d2adb21c55e6ad5b344c0dcb8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b2/02/cf107b01494c19dc100f1d0b7ac3cc08666e96ba2d64db7626066cee895e/scipy-1.17.1-cp312-cp312-macosx_12_0_arm64.whl", hash = "sha256:fcb310ddb270a06114bb64bbe53c94926b943f5b7f0842194d585c65eb4edd76" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cf/a9/599c28631bad314d219cf9ffd40e985b24d603fc8a2f4ccc5ae8419a535b/scipy-1.17.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:cc90d2e9c7e5c7f1a482c9875007c095c3194b1cfedca3c2f3291cdc2bc7c086" }, - { url = "https://mirrors.aliyun.com/pypi/packages/35/f5/906eda513271c8deb5af284e5ef0206d17a96239af79f9fa0aebfe0e36b4/scipy-1.17.1-cp312-cp312-macosx_14_0_x86_64.whl", hash = "sha256:c80be5ede8f3f8eded4eff73cc99a25c388ce98e555b17d31da05287015ffa5b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/da/34/16f10e3042d2f1d6b66e0428308ab52224b6a23049cb2f5c1756f713815f/scipy-1.17.1-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e19ebea31758fac5893a2ac360fedd00116cbb7628e650842a6691ba7ca28a21" }, - { url = "https://mirrors.aliyun.com/pypi/packages/01/8e/1e35281b8ab6d5d72ebe9911edcdffa3f36b04ed9d51dec6dd140396e220/scipy-1.17.1-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:02ae3b274fde71c5e92ac4d54bc06c42d80e399fec704383dcd99b301df37458" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c5/5c/9d7f4c88bea6e0d5a4f1bc0506a53a00e9fcb198de372bfe4d3652cef482/scipy-1.17.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:8a604bae87c6195d8b1045eddece0514d041604b14f2727bbc2b3020172045eb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/65/94/7698add8f276dbab7a9de9fb6b0e02fc13ee61d51c7c3f85ac28b65e1239/scipy-1.17.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:f590cd684941912d10becc07325a3eeb77886fe981415660d9265c4c418d0bea" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a2/84/dc08d77fbf3d87d3ee27f6a0c6dcce1de5829a64f2eae85a0ecc1f0daa73/scipy-1.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:41b71f4a3a4cab9d366cd9065b288efc4d4f3c0b37a91a8e0947fb5bd7f31d87" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bc/98/fe9ae9ffb3b54b62559f52dedaebe204b408db8109a8c66fdd04869e6424/scipy-1.17.1-cp312-cp312-win_arm64.whl", hash = "sha256:f4115102802df98b2b0db3cce5cb9b92572633a1197c77b7553e5203f284a5b3" }, { url = "https://mirrors.aliyun.com/pypi/packages/76/27/07ee1b57b65e92645f219b37148a7e7928b82e2b5dbeccecb4dff7c64f0b/scipy-1.17.1-cp313-cp313-macosx_10_14_x86_64.whl", hash = "sha256:5e3c5c011904115f88a39308379c17f91546f77c1667cea98739fe0fccea804c" }, { url = "https://mirrors.aliyun.com/pypi/packages/ec/ae/db19f8ab842e9b724bf5dbb7db29302a91f1e55bc4d04b1025d6d605a2c5/scipy-1.17.1-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:6fac755ca3d2c3edcb22f479fceaa241704111414831ddd3bc6056e18516892f" }, { url = "https://mirrors.aliyun.com/pypi/packages/5b/58/3ce96251560107b381cbd6e8413c483bbb1228a6b919fa8652b0d4090e7f/scipy-1.17.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:7ff200bf9d24f2e4d5dc6ee8c3ac64d739d3a89e2326ba68aaf6c4a2b838fd7d" }, @@ -7511,14 +7058,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/4d/bc/0989043118a27cccb4e906a46b7565ce36ca7b57f5a18b78f4f1b0f72d9d/shapely-2.1.2.tar.gz", hash = "sha256:2ed4ecb28320a433db18a5bf029986aa8afcfd740745e78847e330d5d94922a9" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/24/c0/f3b6453cf2dfa99adc0ba6675f9aaff9e526d2224cbd7ff9c1a879238693/shapely-2.1.2-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fe2533caae6a91a543dec62e8360fe86ffcdc42a7c55f9dfd0128a977a896b94" }, - { url = "https://mirrors.aliyun.com/pypi/packages/86/07/59dee0bc4b913b7ab59ab1086225baca5b8f19865e6101db9ebb7243e132/shapely-2.1.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ba4d1333cc0bc94381d6d4308d2e4e008e0bd128bdcff5573199742ee3634359" }, - { url = "https://mirrors.aliyun.com/pypi/packages/26/29/a5397e75b435b9895cd53e165083faed5d12fd9626eadec15a83a2411f0f/shapely-2.1.2-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:0bd308103340030feef6c111d3eb98d50dc13feea33affc8a6f9fa549e9458a3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b9/37/e781683abac55dde9771e086b790e554811a71ed0b2b8a1e789b7430dd44/shapely-2.1.2-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1e7d4d7ad262a48bb44277ca12c7c78cb1b0f56b32c10734ec9a1d30c0b0c54b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d8/f3/9876b64d4a5a321b9dc482c92bb6f061f2fa42131cba643c699f39317cb9/shapely-2.1.2-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:e9eddfe513096a71896441a7c37db72da0687b34752c4e193577a145c71736fc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d1/a0/704c7292f7014c7e74ec84eddb7b109e1fbae74a16deae9c1504b1d15565/shapely-2.1.2-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:980c777c612514c0cf99bc8a9de6d286f5e186dcaf9091252fcd444e5638193d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/53/46/319c9dc788884ad0785242543cdffac0e6530e4d0deb6c4862bc4143dcf3/shapely-2.1.2-cp312-cp312-win32.whl", hash = "sha256:9111274b88e4d7b54a95218e243282709b330ef52b7b86bc6aaf4f805306f454" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ec/bf/cb6c1c505cb31e818e900b9312d514f381fbfa5c4363edfce0fcc4f8c1a4/shapely-2.1.2-cp312-cp312-win_amd64.whl", hash = "sha256:743044b4cfb34f9a67205cee9279feaf60ba7d02e69febc2afc609047cb49179" }, { url = "https://mirrors.aliyun.com/pypi/packages/c3/90/98ef257c23c46425dc4d1d31005ad7c8d649fe423a38b917db02c30f1f5a/shapely-2.1.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b510dda1a3672d6879beb319bc7c5fd302c6c354584690973c838f46ec3e0fa8" }, { url = "https://mirrors.aliyun.com/pypi/packages/6d/ab/0bee5a830d209adcd3a01f2d4b70e587cdd9fd7380d5198c064091005af8/shapely-2.1.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:8cff473e81017594d20ec55d86b54bc635544897e13a7cfc12e36909c5309a2a" }, { url = "https://mirrors.aliyun.com/pypi/packages/2d/5e/7d7f54ba960c13302584c73704d8c4d15404a51024631adb60b126a4ae88/shapely-2.1.2-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:fe7b77dc63d707c09726b7908f575fc04ff1d1ad0f3fb92aec212396bc6cfe5e" }, @@ -7637,6 +7176,59 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95" }, ] +[[package]] +name = "spacy" +version = "3.8.14" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "catalogue" }, + { name = "confection" }, + { name = "cymem" }, + { name = "jinja2" }, + { name = "murmurhash" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "preshed" }, + { name = "pydantic" }, + { name = "requests" }, + { name = "setuptools" }, + { name = "spacy-legacy" }, + { name = "spacy-loggers" }, + { name = "srsly" }, + { name = "thinc" }, + { name = "tqdm" }, + { name = "typer" }, + { name = "wasabi" }, + { name = "weasel" }, +] +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/1b/e5/822bbdfa459fee863ef2e9879a34b0ae5db7cd1e3eb76d32c766f19222e9/spacy-3.8.14-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:2b4f60fa8b9641a5e93e7a96db0cdd106d05d61756bf1d0ddcd1705ad347909a" }, + { url = "https://mirrors.aliyun.com/pypi/packages/7e/de/0e512154113e1f341567f2b9341835775e4180c180221e60faedaebb2f65/spacy-3.8.14-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:0860c57220c633ccb20468bcd64bfb0d28908990c371a8857951d093a148dc8e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0c/4f/29c7e56afc7db07348a9e0efe0243b5eef465d5dc3d56433f164378c3fa6/spacy-3.8.14-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c24620b7dba879c69cebc51ef3b1107d4d4e44a1e0d4baa439372887d00c3fd9" }, + { url = "https://mirrors.aliyun.com/pypi/packages/1e/ce/cae678f664d5467016819253f5d6e52f8e68a12d8e799b651d73ec2a9a4b/spacy-3.8.14-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9699c1248d115d5825987c287a6f6acd66386ef3ebee7994ee67ba093e932c59" }, + { url = "https://mirrors.aliyun.com/pypi/packages/04/d4/419868afd449bdd367df005932537eea66c71e97c899ba278f3124933f3c/spacy-3.8.14-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:042d799e342fdb6bb5b02a4213a95acc9116c40ed3c849bb0a8296fbe648ec22" }, + { url = "https://mirrors.aliyun.com/pypi/packages/ec/53/df5c1fee45f200b749ba72eeb536fbb2c545fc56230324954263b2f3be00/spacy-3.8.14-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:69b2264294097336e86832e8663f1ab3a7215621184863c96c082ab17ee11937" }, + { url = "https://mirrors.aliyun.com/pypi/packages/12/c2/f1882ec2f5cc9c4e73cf2132997a03c397d7ceeb5ee7f7bb878b51a16365/spacy-3.8.14-cp313-cp313-win_amd64.whl", hash = "sha256:4b6d4f20e291a7c70e37de2f246622b44a0ce82efaa710c9801c6bd599e75177" }, +] + +[[package]] +name = "spacy-legacy" +version = "3.0.12" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/d9/79/91f9d7cc8db5642acad830dcc4b49ba65a7790152832c4eceb305e46d681/spacy-legacy-3.0.12.tar.gz", hash = "sha256:b37d6e0c9b6e1d7ca1cf5bc7152ab64a4c4671f59c85adaf7a3fcb870357a774" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/c3/55/12e842c70ff8828e34e543a2c7176dac4da006ca6901c9e8b43efab8bc6b/spacy_legacy-3.0.12-py2.py3-none-any.whl", hash = "sha256:476e3bd0d05f8c339ed60f40986c07387c0a71479245d6d0f4298dbd52cda55f" }, +] + +[[package]] +name = "spacy-loggers" +version = "1.0.5" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/67/3d/926db774c9c98acf66cb4ed7faf6c377746f3e00b84b700d0868b95d0712/spacy-loggers-1.0.5.tar.gz", hash = "sha256:d60b0bdbf915a60e516cc2e653baeff946f0cfc461b452d11a4d5458c6fe5f24" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/33/78/d1a1a026ef3af911159398c939b1509d5c36fe524c7b644f34a5146c4e16/spacy_loggers-1.0.5-py3-none-any.whl", hash = "sha256:196284c9c446cc0cdb944005384270d775fdeaf4f494d8e269466cfa497ef645" }, +] + [[package]] name = "sphinx" version = "9.1.0" @@ -7755,13 +7347,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/1f/73/b4a9737255583b5fa858e0bb8e116eb94b88c910164ed2ed719147bde3de/sqlalchemy-2.0.48.tar.gz", hash = "sha256:5ca74f37f3369b45e1f6b7b06afb182af1fd5dde009e4ffd831830d98cbe5fe7" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/ef/91/a42ae716f8925e9659df2da21ba941f158686856107a61cc97a95e7647a3/sqlalchemy-2.0.48-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:348174f228b99f33ca1f773e85510e08927620caa59ffe7803b37170df30332b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b9/52/f75f516a1f3888f027c1cfb5d22d4376f4b46236f2e8669dcb0cddc60275/sqlalchemy-2.0.48-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:53667b5f668991e279d21f94ccfa6e45b4e3f4500e7591ae59a8012d0f010dcb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/37/9a/0c28b6371e0cdcb14f8f1930778cb3123acfcbd2c95bb9cf6b4a2ba0cce3/sqlalchemy-2.0.48-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:34634e196f620c7a61d18d5cf7dc841ca6daa7961aed75d532b7e58b309ac894" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1c/46/0aee8f3ff20b1dcbceb46ca2d87fcc3d48b407925a383ff668218509d132/sqlalchemy-2.0.48-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:546572a1793cc35857a2ffa1fe0e58571af1779bcc1ffa7c9fb0839885ed69a9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ce/8c/a957bc91293b49181350bfd55e6dfc6e30b7f7d83dc6792d72043274a390/sqlalchemy-2.0.48-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:07edba08061bc277bfdc772dd2a1a43978f5a45994dd3ede26391b405c15221e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4b/44/1d257d9f9556661e7bdc83667cc414ba210acfc110c82938cb3611eea58f/sqlalchemy-2.0.48-cp312-cp312-win32.whl", hash = "sha256:908a3fa6908716f803b86896a09a2c4dde5f5ce2bb07aacc71ffebb57986ce99" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f2/af/c3c7e1f3a2b383155a16454df62ae8c62a30dd238e42e68c24cebebbfae6/sqlalchemy-2.0.48-cp312-cp312-win_amd64.whl", hash = "sha256:68549c403f79a8e25984376480959975212a670405e3913830614432b5daa07a" }, { url = "https://mirrors.aliyun.com/pypi/packages/d1/c6/569dc8bf3cd375abc5907e82235923e986799f301cd79a903f784b996fca/sqlalchemy-2.0.48-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e3070c03701037aa418b55d36532ecb8f8446ed0135acb71c678dbdf12f5b6e4" }, { url = "https://mirrors.aliyun.com/pypi/packages/6d/ff/f4e04a4bd5a24304f38cb0d4aa2ad4c0fb34999f8b884c656535e1b2b74c/sqlalchemy-2.0.48-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2645b7d8a738763b664a12a1542c89c940daa55196e8d73e55b169cc5c99f65f" }, { url = "https://mirrors.aliyun.com/pypi/packages/fe/88/cb59509e4668d8001818d7355d9995be90c321313078c912420603a7cb95/sqlalchemy-2.0.48-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b19151e76620a412c2ac1c6f977ab1b9fa7ad43140178345136456d5265b32ed" }, @@ -7811,16 +7396,6 @@ version = "0.9.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/e0/c1/de7ee4729d49d15339717d6c4cc9aac06382c1161a8212dfdd266d51ffe5/sqlglotrs-0.9.0.tar.gz", hash = "sha256:72f61561d63607a8d88f5da608c11e21b2a57773ca631e6b89a4eed668da2db5" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/28/a2/c898fe0dffea8ea988fdd7a15bdb414488eca2f9c7def679bf69c490a0f6/sqlglotrs-0.9.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:1ae7b3b1fedd7b99f6a2c7d7ad1f2b23e433d69ed6e2a5ededa26fc9d74da626" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a8/17/344e5e600b61d177a7e535f078f04466097666120059a4a016d21fa1290c/sqlglotrs-0.9.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:938723a4ee7647f2a858ac581ac6cbbfe40320b843f9826f6b0d204579781466" }, - { url = "https://mirrors.aliyun.com/pypi/packages/da/0f/39d33a403416dc608c0dba31f1b8be5c6476ab7795043e73be4350974adf/sqlglotrs-0.9.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:816cdd9b5838c4df5c5206180508a87e6f2ef1860f9bc4655c8125257ef51484" }, - { url = "https://mirrors.aliyun.com/pypi/packages/39/c9/9971b2dd27c9781bec09c5c29676bf0c70cbf0345f1bc4c2315c1fcf68ab/sqlglotrs-0.9.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:080d58c906673c8905965af640cab16203b1e991f8f52a468c371e5f75b1ea04" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bb/8b/3f61abd5844b65cab7085e4c9af3af0e01f7a21e9786125498d901a87a40/sqlglotrs-0.9.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e5241de862190e0c01830833d42bc58a479821d8bd07c51f1e74b5bddc0eb51b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ea/58/bd10f0ebd55f4d043922792dc1eb4b55ecbe9be323e749cd40586d3d6b0f/sqlglotrs-0.9.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:830198b4de0409e07fa82d2d515cb3b6f8e9627a966aacceb2c538e2bd4d2ceb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/60/34/7d2972e0c41747296b1ff29a671eac7ae6584cd1e29c012edbc4082b7ca7/sqlglotrs-0.9.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61011f8b28cb4b23abcc780c6a622aacd6b7acc546363c24501891e29a1950c7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/06/ce/37cf36d3765ecea1e5d22b1f107a3022ae5032bf319f805f3b918abdddeb/sqlglotrs-0.9.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:78eed1e668109ebc61771c0163bf9ff2d8073eea24034ba012edf71ba0759bf0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b6/a6/faea946e386e29f066a476cbcadc091369ac356f9b24b3e2c7e539d8800b/sqlglotrs-0.9.0-cp312-cp312-win32.whl", hash = "sha256:136a5001e43401b81b678e6f3433edc317cba08af3e7098e0228deef87f23562" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2b/e2/9264dd3b2a4369fbcb7b911f5ddaa0bed73ab5ae2d910b4fa14b0f56879e/sqlglotrs-0.9.0-cp312-cp312-win_amd64.whl", hash = "sha256:b1c54ed249f16676fe8270738c8f05f08b1516d8b2975387b45bd67aa6f3b3a5" }, { url = "https://mirrors.aliyun.com/pypi/packages/f3/27/6d42c98f2f33fc6dbbc7d669bf99ea6f7898d8bcd0aaf87aa1a4c96cc9c9/sqlglotrs-0.9.0-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:e2a5a697dbfc9cfee5434433a4d698a26df94277e0916bbfc25e1e72436cd0c0" }, { url = "https://mirrors.aliyun.com/pypi/packages/50/53/d1f8f42ec14d69d8ba249036d83dcb4d6b51fe5b3ddb357499c737ae2a99/sqlglotrs-0.9.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:3a22d4064e923bbe07750f6e4b4b338e5b9fa0cbc2073bd503cc4b1c9280c2ac" }, { url = "https://mirrors.aliyun.com/pypi/packages/52/e0/a2aa5e533427af4b64f9a630000cfee3cbbf877f58dcd79bb931963adf8a/sqlglotrs-0.9.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2fbf6f211d4b0d091855984279be7a9d57b89a43db07aeaf6cabee075c08ac80" }, @@ -7843,6 +7418,41 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/ab/e3/5b7b4bb702691630d5b1f72470cdcfd8220bf32bc3ed9514af59904186bd/sqlglotrs-0.9.0-cp314-cp314-win_amd64.whl", hash = "sha256:41c8606a13a7284216dd3649521e0fe402e660f5e48acac6acf0facaa676d0bb" }, ] +[[package]] +name = "srsly" +version = "2.5.3" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "catalogue" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/2b/db/f794f219a6c788b881252d2536a8c4a97d2bdaadc690391e1cb53d123d71/srsly-2.5.3.tar.gz", hash = "sha256:08f98dbecbff3a31466c4ae7c833131f59d3655a0ad8ac749e6e2c149e2b0680" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/9d/5c/12901e3794f4158abc6da750725aad6c2afddb1e4227b300fe7c71f66957/srsly-2.5.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:e67b6bbacbfadea5e100266d2797f2d4cec9883ea4dc84a5537673850036a8d8" }, + { url = "https://mirrors.aliyun.com/pypi/packages/04/61/181c26370995f96f56f1b64b801e3ca1e0d703fc36506ae28606d62369fb/srsly-2.5.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:348c231b4477d8fe86603131d0f166d2feac9c372704dfc4398be71cc5b6fb07" }, + { url = "https://mirrors.aliyun.com/pypi/packages/77/c6/35876c78889f8ffe11ed3521644e666c3aef20ea31527b70f47456cf35c2/srsly-2.5.3-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:b0938c2978c91ae1ef9c1f2ba35abb86330e198fb23469e356eba311e02233ee" }, + { url = "https://mirrors.aliyun.com/pypi/packages/3e/da/40b71ca9906c8eb8f8feb6ac11d33dad458c85a56e1de764b96d402168a0/srsly-2.5.3-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5f6a837954429ecbe6dcdd27390d2fb4c7d01a3f99c9ffcf9ce66b2a6dd1b738" }, + { url = "https://mirrors.aliyun.com/pypi/packages/dc/14/c0dd30cc8b93ce8137ff4766f743c882440ce49195fffc5d50eaeef311a6/srsly-2.5.3-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:3576c125c486ce2958c2047e8858fe3cfc9ea877adfa05203b0986f9badee355" }, + { url = "https://mirrors.aliyun.com/pypi/packages/08/f3/34354f183d8faafc631585571224b54d1b4b67e796972c36519c074ca355/srsly-2.5.3-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:5fb59c42922e095d1ea36085c55bc16e2adb06a7bfe57b24d381e0194ae699f2" }, + { url = "https://mirrors.aliyun.com/pypi/packages/a4/d9/5531f8a19492060b4e76e4ab06aca6f096fb5128fe18cc813d1772daf653/srsly-2.5.3-cp313-cp313-win_amd64.whl", hash = "sha256:111805927f05f5db440aeeacb85ce43da0b19ce7b2a09567a9ef8d30f3cc4d83" }, + { url = "https://mirrors.aliyun.com/pypi/packages/8e/8a/62fb7a971eca29e12f03fb9ddacb058548c14d33e5b5675ff0f85839cc7b/srsly-2.5.3-cp313-cp313-win_arm64.whl", hash = "sha256:0f106b0a700ab56e4a7c431b0f1444009ab6cb332edc7bbf6811c2a43f4722cb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/e1/5b/e4ef43c2a381711230af98d4c94a5323df48d6a7899ee652e05bf889290e/srsly-2.5.3-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:39c13d552a9f9674a12cdcdc66b0c2f02f3430d0cd04c5f9cf598824c2bd3d65" }, + { url = "https://mirrors.aliyun.com/pypi/packages/92/2d/ebce7f3717e52cd0a01f4ec570f388f3b7098526794fcf1ad734e0b8f852/srsly-2.5.3-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:14c930767cc169611a2dc14e23bc7638cfb616d6f79029700ade033607343540" }, + { url = "https://mirrors.aliyun.com/pypi/packages/22/47/a8f3e9b214be2624c8e8a78d38ca7b1d4e26b92d57018412e4bfc4abe89a/srsly-2.5.3-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2f2d464f0d0237e32fb53f0ec6f05418652c550e772b50e9918e83a1577cba4d" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d6/71/2a89dc3180a51e633a87a079ca064225f4aaf46c7b2a5fc720e28f261d98/srsly-2.5.3-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d18933248a5bb0ad56a1bae6003a9a7f37daac2ecb0c5bcbfaaf081b317e1c84" }, + { url = "https://mirrors.aliyun.com/pypi/packages/b8/36/72e5ce3153927ca404b6f5bf5280e6ff3399c11557df472b153945468e0a/srsly-2.5.3-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:7ea5412ea229e571ac9738cbe14f845cc06c8e4e956afb5f42061ccd087ef31f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/04/b2/0895de109c28eca0d41a811ab7c076d4e4a505e8466f06bae22f5180a1dd/srsly-2.5.3-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:8d3988970b4cf7d03bdd5b5169302ff84562dd2e1e0f84aeb34df3e5b5dc19bf" }, + { url = "https://mirrors.aliyun.com/pypi/packages/c7/79/a37fa7759797fbdfe0a2e029ab13e78b1e81e191220d2bb8ff57d869aefb/srsly-2.5.3-cp314-cp314-win_amd64.whl", hash = "sha256:6a02d7dcc16126c8fae1c1c09b2072798a1dc482ab5f9c52b12c7114dac47325" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d7/25/0dae019b3b90ad9037f91de4c390555cdaac9460a93ad62b02b03babdff5/srsly-2.5.3-cp314-cp314-win_arm64.whl", hash = "sha256:1c9129c4abe31903ff7996904a51afdd5428060de6c3d12af49a4da5e8df2821" }, + { url = "https://mirrors.aliyun.com/pypi/packages/3a/44/72dd5285b2e05435d98b0797f101d91d9b345d491ddc1fdb9bd09e27ccb8/srsly-2.5.3-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:29d5d01ba4c2e9c01f936e5e6d5babc4a47b38c9cbd6e1ec23f6d5a49df32605" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d2/ad/002c71b87fc3f648c9bf0ec47de0c3822bf2c95c8896a589dd03e7fd3977/srsly-2.5.3-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:5c8df4039426d99f0148b5743542842ab96b82daded0b342555e15a639927757" }, + { url = "https://mirrors.aliyun.com/pypi/packages/2a/35/2cea3d5e80aeecfc4ece9e7e1783e7792cc3bad7ab85ab585882e1db4e38/srsly-2.5.3-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:06a43d63bde2e8cccadb953d7fff70b18196ca286b65dd2ad16006d65f3f8166" }, + { url = "https://mirrors.aliyun.com/pypi/packages/aa/38/8a4d7e86dd0370a2e5af251b646000197bb5b7e0f9aa360c71bbfb253d0d/srsly-2.5.3-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:808cfafc047f0dec507a34c8fa8e4cda5722737fd33577df73452f52f7aca644" }, + { url = "https://mirrors.aliyun.com/pypi/packages/99/05/340129de5ea7b237271b12f8a6962cfa7eb0c5a3056794626d348c5ae7c7/srsly-2.5.3-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:71d4cbe2b2a1335c76ed0acae2dc862163787d8b01a705e1949796907ed94ccd" }, + { url = "https://mirrors.aliyun.com/pypi/packages/01/cb/d7fee7ab27c6aa2e3f865fb7b50ba18c81a4c763bba12bdf53df246441bc/srsly-2.5.3-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:565f69083d33cb329cfc74317da937fb3270c0f40fabc1b4488702d8074b4a3e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d8/d1/9bad3a0f2fa7b72f4e0cf1d267b00513092d20ef538c47f72823ae4f7656/srsly-2.5.3-cp314-cp314t-win_amd64.whl", hash = "sha256:8ac016ffaeac35bc010992b71bf8afdd39d458f201c8138d84cf78778a936e6c" }, + { url = "https://mirrors.aliyun.com/pypi/packages/2a/ae/57d1d7af907e20c077e113e0e4976f87b82c0a415403d99284a262229dd0/srsly-2.5.3-cp314-cp314t-win_arm64.whl", hash = "sha256:d822083fe26ec6728bd8c273ac121fc4ab3864a0fdf0cf0ff3efb188fcd209ed" }, +] + [[package]] name = "sse-starlette" version = "3.3.3" @@ -7862,7 +7472,6 @@ version = "1.0.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } dependencies = [ { name = "anyio" }, - { name = "typing-extensions", marker = "python_full_version < '3.13'" }, ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/81/69/17425771797c36cded50b7fe44e850315d039f28b15901ab44839e70b593/starlette-1.0.0.tar.gz", hash = "sha256:6a4beaf1f81bb472fd19ea9b918b50dc3a77a6f2e190a12954b25e6ed5eea149" } wheels = [ @@ -7882,12 +7491,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/0d/81/e8d74b34f85285f7335d30c5e3c2d7c0346997af9f3debf9a0a9a63de184/statsmodels-0.14.6.tar.gz", hash = "sha256:4d17873d3e607d398b85126cd4ed7aad89e4e9d89fc744cdab1af3189a996c2a" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/25/ce/308e5e5da57515dd7cab3ec37ea2d5b8ff50bef1fcc8e6d31456f9fae08e/statsmodels-0.14.6-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:fe76140ae7adc5ff0e60a3f0d56f4fffef484efa803c3efebf2fcd734d72ecb5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/05/30/affbabf3c27fb501ec7b5808230c619d4d1a4525c07301074eb4bda92fa9/statsmodels-0.14.6-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:26d4f0ed3b31f3c86f83a92f5c1f5cbe63fc992cd8915daf28ca49be14463a1c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/48/f5/3a73b51e6450c31652c53a8e12e24eac64e3824be816c0c2316e7dbdcb7d/statsmodels-0.14.6-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:d8c00a42863e4f4733ac9d078bbfad816249c01451740e6f5053ecc7db6d6368" }, - { url = "https://mirrors.aliyun.com/pypi/packages/81/68/dddd76117df2ef14c943c6bbb6618be5c9401280046f4ddfc9fb4596a1b8/statsmodels-0.14.6-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:19b58cf7474aa9e7e3b0771a66537148b2df9b5884fbf156096c0e6c1ff0469d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/56/4a/dce451c74c4050535fac1ec0c14b80706d8fc134c9da22db3c8a0ec62c33/statsmodels-0.14.6-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:81e7dcc5e9587f2567e52deaff5220b175bf2f648951549eae5fc9383b62bc37" }, - { url = "https://mirrors.aliyun.com/pypi/packages/60/15/3daba2df40be8b8a9a027d7f54c8dedf24f0d81b96e54b52293f5f7e3418/statsmodels-0.14.6-cp312-cp312-win_amd64.whl", hash = "sha256:b5eb07acd115aa6208b4058211138393a7e6c2cf12b6f213ede10f658f6a714f" }, { url = "https://mirrors.aliyun.com/pypi/packages/81/59/a5aad5b0cc266f5be013db8cde563ac5d2a025e7efc0c328d83b50c72992/statsmodels-0.14.6-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:47ee7af083623d2091954fa71c7549b8443168f41b7c5dce66510274c50fd73e" }, { url = "https://mirrors.aliyun.com/pypi/packages/53/dd/d8cfa7922fc6dc3c56fa6c59b348ea7de829a94cd73208c6f8202dd33f17/statsmodels-0.14.6-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:aa60d82e29fcd0a736e86feb63a11d2380322d77a9369a54be8b0965a3985f71" }, { url = "https://mirrors.aliyun.com/pypi/packages/ee/77/0ec96803eba444efd75dba32f2ef88765ae3e8f567d276805391ec2c98c6/statsmodels-0.14.6-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:89ee7d595f5939cc20bf946faedcb5137d975f03ae080f300ebb4398f16a5bd4" }, @@ -7915,15 +7518,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/5c/92/d0c83f63d3518e5f0b8a311937c31347349ec9a47b209ddc17f7566f58fc/stone-3.3.1-py3-none-any.whl", hash = "sha256:e15866fad249c11a963cce3bdbed37758f2e88c8ff4898616bc0caeb1e216047" }, ] -[[package]] -name = "strenum" -version = "0.4.15" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/85/ad/430fb60d90e1d112a62ff57bdd1f286ec73a2a0331272febfddd21f330e1/StrEnum-0.4.15.tar.gz", hash = "sha256:878fb5ab705442070e4dd1929bb5e2249511c0bcf2b0eeacf3bcd80875c82eff" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/81/69/297302c5f5f59c862faa31e6cb9a4cd74721cd1e052b38e464c5b402df8b/StrEnum-0.4.15-py3-none-any.whl", hash = "sha256:a30cda4af7cc6b5bf52c8055bc4bf4b2b6b14a93b574626da33df53cf7740659" }, -] - [[package]] name = "sympy" version = "1.14.0" @@ -8026,150 +7620,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/c5/db/daa85799b9af2aa50539b27eeb0d6a2a0ac35465f62683107847830dbe4d/tencentcloud_sdk_python-3.0.1478-py2.py3-none-any.whl", hash = "sha256:10ddee1c1348f49e2b54af606f978d4cb17fca656639e8d99b6527e6e4793833" }, ] -[[package]] -name = "tensorboard" -version = "2.18.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "absl-py" }, - { name = "grpcio" }, - { name = "markdown" }, - { name = "numpy" }, - { name = "packaging" }, - { name = "protobuf" }, - { name = "setuptools" }, - { name = "six" }, - { name = "tensorboard-data-server" }, - { name = "werkzeug" }, -] -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/b1/de/021c1d407befb505791764ad2cbd56ceaaa53a746baed01d2e2143f05f18/tensorboard-2.18.0-py3-none-any.whl", hash = "sha256:107ca4821745f73e2aefa02c50ff70a9b694f39f790b11e6f682f7d326745eab" }, -] - -[[package]] -name = "tensorboard-data-server" -version = "0.7.2" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/7a/13/e503968fefabd4c6b2650af21e110aa8466fe21432cd7c43a84577a89438/tensorboard_data_server-0.7.2-py3-none-any.whl", hash = "sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b7/85/dabeaf902892922777492e1d253bb7e1264cadce3cea932f7ff599e53fea/tensorboard_data_server-0.7.2-py3-none-macosx_10_9_x86_64.whl", hash = "sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60" }, - { url = "https://mirrors.aliyun.com/pypi/packages/73/c6/825dab04195756cf8ff2e12698f22513b3db2f64925bdd41671bfb33aaa5/tensorboard_data_server-0.7.2-py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530" }, -] - -[[package]] -name = "tensorflow-cpu" -version = "2.18.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -resolution-markers = [ - "(python_full_version >= '3.14' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.14' and sys_platform != 'darwin' and sys_platform != 'linux')", - "(python_full_version == '3.13.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.13.*' and sys_platform != 'darwin' and sys_platform != 'linux')", - "(python_full_version < '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", -] -dependencies = [ - { name = "absl-py", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "astunparse", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "gast", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "google-pasta", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "grpcio", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "h5py", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "keras", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "libclang", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "ml-dtypes", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "opt-einsum", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "setuptools", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "six", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "tensorboard", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "tensorflow-intel", marker = "sys_platform == 'win32'" }, - { name = "termcolor", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "wrapt", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, -] -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/ca/3f/2ed163140237aefa72c761d56af8ba3fa5cb0fe37a9f53b14ad8bcd7ef87/tensorflow_cpu-2.18.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:39bd421ad125e4163d6e2d41ab0e158b583fb5c6f9254522fb87635b0e70b891" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0e/7a/1c99bb2bb7d24238b748f9f0244a198ee15d23782bb56dbf4e7b93a29c6a/tensorflow_cpu-2.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:0b093b727c2f2a8cf4ee4f2c7352c8e958a2a1d27a452961b8d5f43a0798dcd2" }, -] - -[[package]] -name = "tensorflow-cpu" -version = "2.18.1" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -resolution-markers = [ - "python_full_version >= '3.14' and sys_platform == 'darwin'", - "python_full_version == '3.13.*' and sys_platform == 'darwin'", - "python_full_version < '3.13' and sys_platform == 'darwin'", - "python_full_version >= '3.14' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version == '3.13.*' and platform_machine == 'aarch64' and sys_platform == 'linux'", - "python_full_version < '3.13' and platform_machine == 'aarch64' and sys_platform == 'linux'", -] -dependencies = [ - { name = "absl-py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "astunparse", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "flatbuffers", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "gast", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "google-pasta", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "grpcio", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "h5py", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "keras", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "libclang", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "ml-dtypes", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "numpy", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "opt-einsum", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "packaging", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "protobuf", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "requests", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "setuptools", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "six", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "tensorboard", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "termcolor", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "typing-extensions", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, - { name = "wrapt", marker = "(platform_machine == 'aarch64' and sys_platform == 'linux') or sys_platform == 'darwin'" }, -] - -[[package]] -name = "tensorflow-intel" -version = "2.18.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "absl-py", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "astunparse", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "flatbuffers", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "gast", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "google-pasta", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "grpcio", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "h5py", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "keras", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "libclang", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "ml-dtypes", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "numpy", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "opt-einsum", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "packaging", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "protobuf", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "requests", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "setuptools", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "six", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "tensorboard", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "termcolor", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "typing-extensions", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, - { name = "wrapt", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" }, -] -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/ae/4e/44ce609139065035c56fe570fe7f0ee8d06180c99a424bac588472052c5d/tensorflow_intel-2.18.0-cp312-cp312-win_amd64.whl", hash = "sha256:a5818043f565cf74179b67eb52fc060587ccecb9540141c39d84fbcb37ecff8c" }, -] - -[[package]] -name = "termcolor" -version = "3.3.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/46/79/cf31d7a93a8fdc6aa0fbb665be84426a8c5a557d9240b6239e9e11e35fc5/termcolor-3.3.0.tar.gz", hash = "sha256:348871ca648ec6a9a983a13ab626c0acce02f515b9e1983332b17af7979521c5" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/33/d1/8bb87d21e9aeb323cc03034f5eaf2c8f69841e40e4853c2627edf8111ed3/termcolor-3.3.0-py3-none-any.whl", hash = "sha256:cf642efadaf0a8ebbbf4bc7a31cec2f9b5f21a9f726f4ccbb08192c9c26f43a5" }, -] - [[package]] name = "text-unidecode" version = "1.3" @@ -8192,6 +7642,44 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/11/3d/2653f4cf49660bb44eeac8270617cc4c0287d61716f249f55053f0af0724/tf_playwright_stealth-1.2.0-py3-none-any.whl", hash = "sha256:26ee47ee89fa0f43c606fe37c188ea3ccd36f96ea90c01d167b768df457e7886" }, ] +[[package]] +name = "thinc" +version = "8.3.13" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "blis" }, + { name = "catalogue" }, + { name = "confection" }, + { name = "cymem" }, + { name = "murmurhash" }, + { name = "numpy" }, + { name = "packaging" }, + { name = "preshed" }, + { name = "pydantic" }, + { name = "setuptools" }, + { name = "srsly" }, + { name = "wasabi" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/13/46/76df95f2c327f9a9cef30c1523bf285627897097163584dcf5f77b2ebce2/thinc-8.3.13.tar.gz", hash = "sha256:68e658549fc1eb3ff92aed5147fcbb9c15d6e9cc0e623b4d0998d16522ffb4f9" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/af/b9/7b46942176df459d1804a9e77b0976f7c56f3abf3ec7485d0e5f836a0382/thinc-8.3.13-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:c2811dfd8d46d8b5d3b39051b23e64006b2994a5143b1978b436938018792af8" }, + { url = "https://mirrors.aliyun.com/pypi/packages/a7/79/53085a72cd8f4fc4e6e313d05ea5aa98e870684f4a0fb318a9875fc0a964/thinc-8.3.13-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:5593e6300cb1ebe0c0e546e9c9fb49e7c2627a0aa688795cd4f995a8b820d2ec" }, + { url = "https://mirrors.aliyun.com/pypi/packages/9e/3e/d61b462b16da95ac6885f95bb395e672040ee594833e571a6edcffd234f5/thinc-8.3.13-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f697174d3fb474966ce50b430bbafa101a6d2f7ffb559dac4b5c59389ef72d22" }, + { url = "https://mirrors.aliyun.com/pypi/packages/78/4c/898cc654bb123734c71ec5a425c02ca34439517d01ce1c95a6563295580e/thinc-8.3.13-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e9c7c5c104737b414c8c4ec578e67d78b6c859afe25cbc0684402e721415bd7f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/cd/56/1abdbf0a4ad628e8a05d6516fe0745969649d805367a3dccad8ee872981b/thinc-8.3.13-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:7a99d0e242d1ccd23f9ae6bea7cd502f8626efa65c156b91d84581d0356696c3" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f1/22/b84dbdc6be5055bbdb2a7352e2c393f67e8593c137f1b83c82bf1e062b6e/thinc-8.3.13-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:e676edd21a747afbe3e6b9f3fca8b962e36d146ded03b070cb0c28e2dfbe9499" }, + { url = "https://mirrors.aliyun.com/pypi/packages/0f/a8/763cd7ba949334c9d2cddc92dadb68b344cb9546dc01b8d4a733dcaa16c1/thinc-8.3.13-cp313-cp313-win_amd64.whl", hash = "sha256:8ad40307f20e83f77af28ff5c6be0b86af7a8b251d1231c545508d2763157d8f" }, + { url = "https://mirrors.aliyun.com/pypi/packages/f5/15/a11f7bb3cbc97dfecf32a90552f5a8f8a5c99316a99c6c17bdabf5baf256/thinc-8.3.13-cp313-cp313-win_arm64.whl", hash = "sha256:723949cab11d1925c15447928513a718276316cec6e0de28337cca0a62be0521" }, + { url = "https://mirrors.aliyun.com/pypi/packages/80/40/f4937d113912c6d669ffe982356ab29dcb6c7fe3be926a15981dbbb6a91c/thinc-8.3.13-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:7badb0be4825535e6362c19e8a41872b65409e9da46d3453a391b843a0720865" }, + { url = "https://mirrors.aliyun.com/pypi/packages/d2/00/4d4ed1a11ba2920b85a03a0683b16d97dc5beb2e78078dbf0e13e43bcea7/thinc-8.3.13-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:565300b7e13de799e5abff00d445f537e9256cf7da4dcb0d0f005fc16748a29e" }, + { url = "https://mirrors.aliyun.com/pypi/packages/44/5d/dc33d6932be8721af2ef76b4a3a6e8020648630eabae61fb916d2a861d1d/thinc-8.3.13-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:c17cef1900a1aba7e1487493d16b8aa0a8633116f1b2a51c6649a4000697f17b" }, + { url = "https://mirrors.aliyun.com/pypi/packages/af/bc/a6d37d8dadc2c5b524f51192413481160c42c9dd6105e8d5551531623225/thinc-8.3.13-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:f4f26d1eec9b2a6a8f2e0298a5515d13eb06d70730d0d9e1040bb329e12bf3fb" }, + { url = "https://mirrors.aliyun.com/pypi/packages/7a/59/ce9c7067f1dfe5985875927de9cf7a79f9dae3e69487fd650dfba558029d/thinc-8.3.13-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:a61a31fd0ce3c2771cf4901ba6df70e774ffe32febf1024c5b43d63575cd58fe" }, + { url = "https://mirrors.aliyun.com/pypi/packages/4f/a8/f57819347fc4d8bef2204d15fcbb9d7dff2d6cdd5f83d5ed91456ddacc55/thinc-8.3.13-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:ba8119daf84a12259ae4d251d36426417bafa0b34108890b4b7e2b50966bd990" }, + { url = "https://mirrors.aliyun.com/pypi/packages/05/ef/a82214bb7c7c1e2d92b69e1a7654be90cfab180082c6108e45a98af2422c/thinc-8.3.13-cp314-cp314-win_amd64.whl", hash = "sha256:433e3826e018da489f1a8068e6de677f6eff3cc93991a599d90f12cd1bc26cdc" }, + { url = "https://mirrors.aliyun.com/pypi/packages/9f/ef/1648fda54e9689058335ff54f650a7a314db2a42e21af1b83949b2dc748e/thinc-8.3.13-cp314-cp314-win_arm64.whl", hash = "sha256:11754fada9ad5ba2e02d5f3f234f940e24015b82333db58372f4a6aedad9b43f" }, +] + [[package]] name = "threadpoolctl" version = "3.6.0" @@ -8227,13 +7715,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/7d/ab/4d017d0f76ec3171d469d80fc03dfbb4e48a4bcaddaa831b31d526f05edc/tiktoken-0.12.0.tar.gz", hash = "sha256:b18ba7ee2b093863978fcb14f74b3707cdc8d4d4d3836853ce7ec60772139931" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/a4/85/be65d39d6b647c79800fd9d29241d081d4eeb06271f383bb87200d74cf76/tiktoken-0.12.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b97f74aca0d78a1ff21b8cd9e9925714c15a9236d6ceacf5c7327c117e6e21e8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4a/42/6573e9129bc55c9bf7300b3a35bef2c6b9117018acca0dc760ac2d93dffe/tiktoken-0.12.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2b90f5ad190a4bb7c3eb30c5fa32e1e182ca1ca79f05e49b448438c3e225a49b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/66/c5/ed88504d2f4a5fd6856990b230b56d85a777feab84e6129af0822f5d0f70/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_aarch64.whl", hash = "sha256:65b26c7a780e2139e73acc193e5c63ac754021f160df919add909c1492c0fb37" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f4/90/3dae6cc5436137ebd38944d396b5849e167896fc2073da643a49f372dc4f/tiktoken-0.12.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:edde1ec917dfd21c1f2f8046b86348b0f54a2c0547f68149d8600859598769ad" }, - { url = "https://mirrors.aliyun.com/pypi/packages/a3/fe/26df24ce53ffde419a42f5f53d755b995c9318908288c17ec3f3448313a3/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:35a2f8ddd3824608b3d650a000c1ef71f730d0c56486845705a8248da00f9fe5" }, - { url = "https://mirrors.aliyun.com/pypi/packages/20/cc/b064cae1a0e9fac84b0d2c46b89f4e57051a5f41324e385d10225a984c24/tiktoken-0.12.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:83d16643edb7fa2c99eff2ab7733508aae1eebb03d5dfc46f5565862810f24e3" }, - { url = "https://mirrors.aliyun.com/pypi/packages/81/10/b8523105c590c5b8349f2587e2fdfe51a69544bd5a76295fc20f2374f470/tiktoken-0.12.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffc5288f34a8bc02e1ea7047b8d041104791d2ddbf42d1e5fa07822cbffe16bd" }, { url = "https://mirrors.aliyun.com/pypi/packages/00/61/441588ee21e6b5cdf59d6870f86beb9789e532ee9718c251b391b70c68d6/tiktoken-0.12.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:775c2c55de2310cc1bc9a3ad8826761cbdc87770e586fd7b6da7d4589e13dab3" }, { url = "https://mirrors.aliyun.com/pypi/packages/1f/05/dcf94486d5c5c8d34496abe271ac76c5b785507c8eae71b3708f1ad9b45a/tiktoken-0.12.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a01b12f69052fbe4b080a2cfb867c4de12c704b56178edf1d1d7b273561db160" }, { url = "https://mirrors.aliyun.com/pypi/packages/a0/70/5163fe5359b943f8db9946b62f19be2305de8c3d78a16f629d4165e2f40e/tiktoken-0.12.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:01d99484dc93b129cd0964f9d34eee953f2737301f18b3c7257bf368d7615baa" }, @@ -8329,8 +7810,8 @@ wheels = [ [[package]] name = "trio" -version = "0.24.0" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +version = "0.33.0" +source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "attrs" }, { name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" }, @@ -8339,9 +7820,9 @@ dependencies = [ { name = "sniffio" }, { name = "sortedcontainers" }, ] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/8a/f3/07c152213222c615fe2391b8e1fea0f5af83599219050a549c20fcbd9ba2/trio-0.24.0.tar.gz", hash = "sha256:ffa09a74a6bf81b84f8613909fb0beaee84757450183a7a2e0b47b455c0cac5d" } +sdist = { url = "https://files.pythonhosted.org/packages/52/b6/c744031c6f89b18b3f5f4f7338603ab381d740a7f45938c4607b2302481f/trio-0.33.0.tar.gz", hash = "sha256:a29b92b73f09d4b48ed249acd91073281a7f1063f09caba5dc70465b5c7aa970", size = 605109, upload-time = "2026-02-14T18:40:55.386Z" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/14/fb/9299cf74953f473a15accfdbe2c15218e766bae8c796f2567c83bae03e98/trio-0.24.0-py3-none-any.whl", hash = "sha256:c3bd3a4e3e3025cd9a2241eae75637c43fe0b9e88b4c97b9161a55b9e54cd72c" }, + { url = "https://files.pythonhosted.org/packages/1c/93/dab25dc87ac48da0fe0f6419e07d0bfd98799bed4e05e7b9e0f85a1a4b4b/trio-0.33.0-py3-none-any.whl", hash = "sha256:3bd5d87f781d9b0192d592aef28691f8951d6c2e41b7e1da4c25cde6c180ae9b", size = 510294, upload-time = "2026-02-14T18:40:53.313Z" }, ] [[package]] @@ -8544,6 +8025,18 @@ version = "0.2.5" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/9f/c1/dd817bf57e0274dacb10e0ac868cb6cd70876950cf361c41879c030a2b8b/warc3-wet-clueweb09-0.2.5.tar.gz", hash = "sha256:3054bfc07da525d5967df8ca3175f78fa3f78514c82643f8c81fbca96300b836" } +[[package]] +name = "wasabi" +version = "1.1.3" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/ac/f9/054e6e2f1071e963b5e746b48d1e3727470b2a490834d18ad92364929db3/wasabi-1.1.3.tar.gz", hash = "sha256:4bb3008f003809db0c3e28b4daf20906ea871a2bb43f9914197d540f4f2e0878" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/06/7c/34330a89da55610daa5f245ddce5aab81244321101614751e7537f125133/wasabi-1.1.3-py3-none-any.whl", hash = "sha256:f76e16e8f7e79f8c4c8be49b4024ac725713ab10cd7f19350ad18a8e3f71728c" }, +] + [[package]] name = "wcwidth" version = "0.6.0" @@ -8553,6 +8046,26 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/68/5a/199c59e0a824a3db2b89c5d2dade7ab5f9624dbf6448dc291b46d5ec94d3/wcwidth-0.6.0-py3-none-any.whl", hash = "sha256:1a3a1e510b553315f8e146c54764f4fb6264ffad731b3d78088cdb1478ffbdad" }, ] +[[package]] +name = "weasel" +version = "1.0.0" +source = { registry = "https://mirrors.aliyun.com/pypi/simple" } +dependencies = [ + { name = "cloudpathlib" }, + { name = "confection" }, + { name = "httpx" }, + { name = "packaging" }, + { name = "pydantic" }, + { name = "smart-open" }, + { name = "srsly" }, + { name = "typer" }, + { name = "wasabi" }, +] +sdist = { url = "https://mirrors.aliyun.com/pypi/packages/ce/e5/e272bb9a045105a1fdf4b798d8086f5932a178f4d738f17a74f5c9e0ae9a/weasel-1.0.0.tar.gz", hash = "sha256:7b129b44c90cc543b760532974ca1e4eb30dad2aa2026f57bdce66354ae610fc" } +wheels = [ + { url = "https://mirrors.aliyun.com/pypi/packages/0a/07/57ebf7a6798b016c064bd0ca81b4c6a99daa4dc377b898bc7b41eb6b5af0/weasel-1.0.0-py3-none-any.whl", hash = "sha256:89518acee027f49d743126c3502d35e6dd14f5768be5c37c9af47c171b6005cc" }, +] + [[package]] name = "webdav4" version = "0.10.0" @@ -8604,15 +8117,6 @@ version = "16.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/04/24/4b2031d72e840ce4c1ccb255f693b15c334757fc50023e4db9537080b8c4/websockets-16.0.tar.gz", hash = "sha256:5f6261a5e56e8d5c42a4497b364ea24d94d9563e8fbd44e78ac40879c60179b5" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/84/7b/bac442e6b96c9d25092695578dda82403c77936104b5682307bd4deb1ad4/websockets-16.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:71c989cbf3254fbd5e84d3bff31e4da39c43f884e64f2551d14bb3c186230f00" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b0/fe/136ccece61bd690d9c1f715baaeefd953bb2360134de73519d5df19d29ca/websockets-16.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:8b6e209ffee39ff1b6d0fa7bfef6de950c60dfb91b8fcead17da4ee539121a79" }, - { url = "https://mirrors.aliyun.com/pypi/packages/40/1e/9771421ac2286eaab95b8575b0cb701ae3663abf8b5e1f64f1fd90d0a673/websockets-16.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:86890e837d61574c92a97496d590968b23c2ef0aeb8a9bc9421d174cd378ae39" }, - { url = "https://mirrors.aliyun.com/pypi/packages/18/29/71729b4671f21e1eaa5d6573031ab810ad2936c8175f03f97f3ff164c802/websockets-16.0-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:9b5aca38b67492ef518a8ab76851862488a478602229112c4b0d58d63a7a4d5c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/97/bb/21c36b7dbbafc85d2d480cd65df02a1dc93bf76d97147605a8e27ff9409d/websockets-16.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:e0334872c0a37b606418ac52f6ab9cfd17317ac26365f7f65e203e2d0d0d359f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4a/34/9bf8df0c0cf88fa7bfe36678dc7b02970c9a7d5e065a3099292db87b1be2/websockets-16.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a0b31e0b424cc6b5a04b8838bbaec1688834b2383256688cf47eb97412531da1" }, - { url = "https://mirrors.aliyun.com/pypi/packages/47/88/4dd516068e1a3d6ab3c7c183288404cd424a9a02d585efbac226cb61ff2d/websockets-16.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:485c49116d0af10ac698623c513c1cc01c9446c058a4e61e3bf6c19dff7335a2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/91/d6/7d4553ad4bf1c0421e1ebd4b18de5d9098383b5caa1d937b63df8d04b565/websockets-16.0-cp312-cp312-win32.whl", hash = "sha256:eaded469f5e5b7294e2bdca0ab06becb6756ea86894a47806456089298813c89" }, - { url = "https://mirrors.aliyun.com/pypi/packages/c3/f0/f3a17365441ed1c27f850a80b2bc680a0fa9505d733fe152fdf5e98c1c0b/websockets-16.0-cp312-cp312-win_amd64.whl", hash = "sha256:5569417dc80977fc8c2d43a86f78e0a5a22fee17565d78621b6bb264a115d4ea" }, { url = "https://mirrors.aliyun.com/pypi/packages/cc/9c/baa8456050d1c1b08dd0ec7346026668cbc6f145ab4e314d707bb845bf0d/websockets-16.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:878b336ac47938b474c8f982ac2f7266a540adc3fa4ad74ae96fea9823a02cc9" }, { url = "https://mirrors.aliyun.com/pypi/packages/7e/0c/8811fc53e9bcff68fe7de2bcbe75116a8d959ac699a3200f4847a8925210/websockets-16.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:52a0fec0e6c8d9a784c2c78276a48a2bdf099e4ccc2a4cad53b27718dbfd0230" }, { url = "https://mirrors.aliyun.com/pypi/packages/aa/82/39a5f910cb99ec0b59e482971238c845af9220d3ab9fa76dd9162cda9d62/websockets-16.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:e6578ed5b6981005df1860a56e3617f14a6c307e6a71b4fff8c48fdc50f3ed2c" }, @@ -8655,18 +8159,6 @@ wheels = [ { url = "https://mirrors.aliyun.com/pypi/packages/7f/b2/0bba9bbb4596d2d2f285a16c2ab04118f6b957d8441566e1abb892e6a6b2/werkzeug-3.1.7-py3-none-any.whl", hash = "sha256:4b314d81163a3e1a169b6a0be2a000a0e204e8873c5de6586f453c55688d422f" }, ] -[[package]] -name = "wheel" -version = "0.46.3" -source = { registry = "https://mirrors.aliyun.com/pypi/simple" } -dependencies = [ - { name = "packaging" }, -] -sdist = { url = "https://mirrors.aliyun.com/pypi/packages/89/24/a2eb353a6edac9a0303977c4cb048134959dd2a51b48a269dfc9dde00c8a/wheel-0.46.3.tar.gz", hash = "sha256:e3e79874b07d776c40bd6033f8ddf76a7dad46a7b8aa1b2787a83083519a1803" } -wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/87/22/b76d483683216dde3d67cba61fb2444be8d5be289bf628c13fc0fd90e5f9/wheel-0.46.3-py3-none-any.whl", hash = "sha256:4b399d56c9d9338230118d705d9737a2a468ccca63d5e813e2a4fc7815d8bc4d" }, -] - [[package]] name = "wikipedia" version = "1.4.0" @@ -8695,16 +8187,6 @@ version = "1.17.3" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/95/8f/aeb76c5b46e273670962298c23e7ddde79916cb74db802131d49a85e4b7d/wrapt-1.17.3.tar.gz", hash = "sha256:f66eb08feaa410fe4eebd17f2a2c8e2e46d3476e9f8c783daa8e09e0faa666d0" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/9f/41/cad1aba93e752f1f9268c77270da3c469883d56e2798e7df6240dcb2287b/wrapt-1.17.3-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:ab232e7fdb44cdfbf55fc3afa31bcdb0d8980b9b95c38b6405df2acb672af0e0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/60/f8/096a7cc13097a1869fe44efe68dace40d2a16ecb853141394047f0780b96/wrapt-1.17.3-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:9baa544e6acc91130e926e8c802a17f3b16fbea0fd441b5a60f5cf2cc5c3deba" }, - { url = "https://mirrors.aliyun.com/pypi/packages/33/df/bdf864b8997aab4febb96a9ae5c124f700a5abd9b5e13d2a3214ec4be705/wrapt-1.17.3-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6b538e31eca1a7ea4605e44f81a48aa24c4632a277431a6ed3f328835901f4fd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9f/81/5d931d78d0eb732b95dc3ddaeeb71c8bb572fb01356e9133916cd729ecdd/wrapt-1.17.3-cp312-cp312-manylinux1_x86_64.manylinux_2_28_x86_64.manylinux_2_5_x86_64.whl", hash = "sha256:042ec3bb8f319c147b1301f2393bc19dba6e176b7da446853406d041c36c7828" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ca/38/2e1785df03b3d72d34fc6252d91d9d12dc27a5c89caef3335a1bbb8908ca/wrapt-1.17.3-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3af60380ba0b7b5aeb329bc4e402acd25bd877e98b3727b0135cb5c2efdaefe9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b3/8b/48cdb60fe0603e34e05cffda0b2a4adab81fd43718e11111a4b0100fd7c1/wrapt-1.17.3-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:0b02e424deef65c9f7326d8c19220a2c9040c51dc165cddb732f16198c168396" }, - { url = "https://mirrors.aliyun.com/pypi/packages/3c/51/d81abca783b58f40a154f1b2c56db1d2d9e0d04fa2d4224e357529f57a57/wrapt-1.17.3-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:74afa28374a3c3a11b3b5e5fca0ae03bef8450d6aa3ab3a1e2c30e3a75d023dc" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9e/b1/43b286ca1392a006d5336412d41663eeef1ad57485f3e52c767376ba7e5a/wrapt-1.17.3-cp312-cp312-win32.whl", hash = "sha256:4da9f45279fff3543c371d5ababc57a0384f70be244de7759c85a7f989cb4ebe" }, - { url = "https://mirrors.aliyun.com/pypi/packages/28/de/49493f962bd3c586ab4b88066e967aa2e0703d6ef2c43aa28cb83bf7b507/wrapt-1.17.3-cp312-cp312-win_amd64.whl", hash = "sha256:e71d5c6ebac14875668a1e90baf2ea0ef5b7ac7918355850c0908ae82bcb297c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f1/48/0f7102fe9cb1e8a5a77f80d4f0956d62d97034bbe88d33e94699f99d181d/wrapt-1.17.3-cp312-cp312-win_arm64.whl", hash = "sha256:604d076c55e2fdd4c1c03d06dc1a31b95130010517b5019db15365ec4a405fc6" }, { url = "https://mirrors.aliyun.com/pypi/packages/fc/f6/759ece88472157acb55fc195e5b116e06730f1b651b5b314c66291729193/wrapt-1.17.3-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:a47681378a0439215912ef542c45a783484d4dd82bac412b71e59cf9c0e1cea0" }, { url = "https://mirrors.aliyun.com/pypi/packages/4f/a9/49940b9dc6d47027dc850c116d79b4155f15c08547d04db0f07121499347/wrapt-1.17.3-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:54a30837587c6ee3cd1a4d1c2ec5d24e77984d44e2f34547e2323ddb4e22eb77" }, { url = "https://mirrors.aliyun.com/pypi/packages/45/35/6a08de0f2c96dcdd7fe464d7420ddb9a7655a6561150e5fc4da9356aeaab/wrapt-1.17.3-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:16ecf15d6af39246fe33e507105d67e4b81d8f8d2c6598ff7e3ca1b8a37213f7" }, @@ -8800,21 +8282,6 @@ version = "3.6.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/02/84/30869e01909fb37a6cc7e18688ee8bf1e42d57e7e0777636bd47524c43c7/xxhash-3.6.0.tar.gz", hash = "sha256:f0162a78b13a0d7617b2845b90c763339d1f1d82bb04a4b07f4ab535cc5e05d6" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/9a/07/d9412f3d7d462347e4511181dea65e47e0d0e16e26fbee2ea86a2aefb657/xxhash-3.6.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:01362c4331775398e7bb34e3ab403bc9ee9f7c497bc7dee6272114055277dd3c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/79/35/0429ee11d035fc33abe32dca1b2b69e8c18d236547b9a9b72c1929189b9a/xxhash-3.6.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b7b2df81a23f8cb99656378e72501b2cb41b1827c0f5a86f87d6b06b69f9f204" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b7/f2/57eb99aa0f7d98624c0932c5b9a170e1806406cdbcdb510546634a1359e0/xxhash-3.6.0-cp312-cp312-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:dc94790144e66b14f67b10ac8ed75b39ca47536bf8800eb7c24b50271ea0c490" }, - { url = "https://mirrors.aliyun.com/pypi/packages/4c/ed/6224ba353690d73af7a3f1c7cdb1fc1b002e38f783cb991ae338e1eb3d79/xxhash-3.6.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:93f107c673bccf0d592cdba077dedaf52fe7f42dcd7676eba1f6d6f0c3efffd2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/38/86/fb6b6130d8dd6b8942cc17ab4d90e223653a89aa32ad2776f8af7064ed13/xxhash-3.6.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:2aa5ee3444c25b69813663c9f8067dcfaa2e126dc55e8dddf40f4d1c25d7effa" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ee/dc/e84875682b0593e884ad73b2d40767b5790d417bde603cceb6878901d647/xxhash-3.6.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:f7f99123f0e1194fa59cc69ad46dbae2e07becec5df50a0509a808f90a0f03f0" }, - { url = "https://mirrors.aliyun.com/pypi/packages/11/4f/426f91b96701ec2f37bb2b8cec664eff4f658a11f3fa9d94f0a887ea6d2b/xxhash-3.6.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:49e03e6fe2cac4a1bc64952dd250cf0dbc5ef4ebb7b8d96bce82e2de163c82a2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/53/5a/ddbb83eee8e28b778eacfc5a85c969673e4023cdeedcfcef61f36731610b/xxhash-3.6.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:bd17fede52a17a4f9a7bc4472a5867cb0b160deeb431795c0e4abe158bc784e9" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1e/c2/ff69efd07c8c074ccdf0a4f36fcdd3d27363665bcdf4ba399abebe643465/xxhash-3.6.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:6fb5f5476bef678f69db04f2bd1efbed3030d2aba305b0fc1773645f187d6a4e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/58/ca/faa05ac19b3b622c7c9317ac3e23954187516298a091eb02c976d0d3dd45/xxhash-3.6.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:843b52f6d88071f87eba1631b684fcb4b2068cd2180a0224122fe4ef011a9374" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d4/7a/06aa7482345480cc0cb597f5c875b11a82c3953f534394f620b0be2f700c/xxhash-3.6.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:7d14a6cfaf03b1b6f5f9790f76880601ccc7896aff7ab9cd8978a939c1eb7e0d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/23/07/63ffb386cd47029aa2916b3d2f454e6cc5b9f5c5ada3790377d5430084e7/xxhash-3.6.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:418daf3db71e1413cfe211c2f9a528456936645c17f46b5204705581a45390ae" }, - { url = "https://mirrors.aliyun.com/pypi/packages/0f/93/14fde614cadb4ddf5e7cebf8918b7e8fac5ae7861c1875964f17e678205c/xxhash-3.6.0-cp312-cp312-win32.whl", hash = "sha256:50fc255f39428a27299c20e280d6193d8b63b8ef8028995323bf834a026b4fbb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/13/5d/0d125536cbe7565a83d06e43783389ecae0c0f2ed037b48ede185de477c0/xxhash-3.6.0-cp312-cp312-win_amd64.whl", hash = "sha256:c0f2ab8c715630565ab8991b536ecded9416d615538be8ecddce43ccf26cbc7c" }, - { url = "https://mirrors.aliyun.com/pypi/packages/54/85/6ec269b0952ec7e36ba019125982cf11d91256a778c7c3f98a4c5043d283/xxhash-3.6.0-cp312-cp312-win_arm64.whl", hash = "sha256:eae5c13f3bc455a3bbb68bdc513912dc7356de7e2280363ea235f71f54064829" }, { url = "https://mirrors.aliyun.com/pypi/packages/33/76/35d05267ac82f53ae9b0e554da7c5e281ee61f3cad44c743f0fcd354f211/xxhash-3.6.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:599e64ba7f67472481ceb6ee80fa3bd828fd61ba59fb11475572cc5ee52b89ec" }, { url = "https://mirrors.aliyun.com/pypi/packages/31/a8/3fbce1cd96534a95e35d5120637bf29b0d7f5d8fa2f6374e31b4156dd419/xxhash-3.6.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7d8b8aaa30fca4f16f0c84a5c8d7ddee0e25250ec2796c973775373257dde8f1" }, { url = "https://mirrors.aliyun.com/pypi/packages/0c/ea/d387530ca7ecfa183cb358027f1833297c6ac6098223fd14f9782cd0015c/xxhash-3.6.0-cp313-cp313-manylinux1_i686.manylinux_2_28_i686.manylinux_2_5_i686.whl", hash = "sha256:d597acf8506d6e7101a4a44a5e428977a51c0fadbbfd3c39650cca9253f6e5a6" }, @@ -8888,24 +8355,6 @@ dependencies = [ ] sdist = { url = "https://mirrors.aliyun.com/pypi/packages/23/6e/beb1beec874a72f23815c1434518bfc4ed2175065173fb138c3705f658d4/yarl-1.23.0.tar.gz", hash = "sha256:53b1ea6ca88ebd4420379c330aea57e258408dd0df9af0992e5de2078dc9f5d5" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/88/8a/94615bc31022f711add374097ad4144d569e95ff3c38d39215d07ac153a0/yarl-1.23.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:1932b6b8bba8d0160a9d1078aae5838a66039e8832d41d2992daa9a3a08f7860" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e3/6f/c6554045d59d64052698add01226bc867b52fe4a12373415d7991fdca95d/yarl-1.23.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:411225bae281f114067578891bc75534cfb3d92a3b4dfef7a6ca78ba354e6069" }, - { url = "https://mirrors.aliyun.com/pypi/packages/19/2a/725ecc166d53438bc88f76822ed4b1e3b10756e790bafd7b523fe97c322d/yarl-1.23.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:13a563739ae600a631c36ce096615fe307f131344588b0bc0daec108cdb47b25" }, - { url = "https://mirrors.aliyun.com/pypi/packages/99/30/58260ed98e6ff7f90ba84442c1ddd758c9170d70327394a6227b310cd60f/yarl-1.23.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9cbf44c5cb4a7633d078788e1b56387e3d3cf2b8139a3be38040b22d6c3221c8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/76/0a/8b08aac08b50682e65759f7f8dde98ae8168f72487e7357a5d684c581ef9/yarl-1.23.0-cp312-cp312-manylinux2014_armv7l.manylinux_2_17_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:53ad387048f6f09a8969631e4de3f1bf70c50e93545d64af4f751b2498755072" }, - { url = "https://mirrors.aliyun.com/pypi/packages/52/07/0b7179101fe5f8385ec6c6bb5d0cb9f76bd9fb4a769591ab6fb5cdbfc69a/yarl-1.23.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.manylinux_2_28_ppc64le.whl", hash = "sha256:4a59ba56f340334766f3a4442e0efd0af895fae9e2b204741ef885c446b3a1a8" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d3/8a/36d82869ab5ec829ca8574dfcb92b51286fcfb1e9c7a73659616362dc880/yarl-1.23.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:803a3c3ce4acc62eaf01eaca1208dcf0783025ef27572c3336502b9c232005e7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/66/3e/868e5c3364b6cee19ff3e1a122194fa4ce51def02c61023970442162859e/yarl-1.23.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a3d2bff8f37f8d0f96c7ec554d16945050d54462d6e95414babaa18bfafc7f51" }, - { url = "https://mirrors.aliyun.com/pypi/packages/cf/26/9c89acf82f08a52cb52d6d39454f8d18af15f9d386a23795389d1d423823/yarl-1.23.0-cp312-cp312-manylinux_2_31_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:c75eb09e8d55bceb4367e83496ff8ef2bc7ea6960efb38e978e8073ea59ecb67" }, - { url = "https://mirrors.aliyun.com/pypi/packages/6f/54/5b0db00d2cb056922356104468019c0a132e89c8d3ab67d8ede9f4483d2a/yarl-1.23.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877b0738624280e34c55680d6054a307aa94f7d52fa0e3034a9cc6e790871da7" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f6/40/10fa93811fd439341fad7e0718a86aca0de9548023bbb403668d6555acab/yarl-1.23.0-cp312-cp312-musllinux_1_2_armv7l.whl", hash = "sha256:b5405bb8f0e783a988172993cfc627e4d9d00432d6bbac65a923041edacf997d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/bc/d2/8ae2e6cd77d0805f4526e30ec43b6f9a3dfc542d401ac4990d178e4bf0cf/yarl-1.23.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:1c3a3598a832590c5a3ce56ab5576361b5688c12cb1d39429cf5dba30b510760" }, - { url = "https://mirrors.aliyun.com/pypi/packages/2f/0c/b3ceacf82c3fe21183ce35fa2acf5320af003d52bc1fcf5915077681142e/yarl-1.23.0-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:8419ebd326430d1cbb7efb5292330a2cf39114e82df5cc3d83c9a0d5ebeaf2f2" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9d/e0/12900edd28bdab91a69bd2554b85ad7b151f64e8b521fe16f9ad2f56477a/yarl-1.23.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:be61f6fff406ca40e3b1d84716fde398fc08bc63dd96d15f3a14230a0973ed86" }, - { url = "https://mirrors.aliyun.com/pypi/packages/15/61/74bb1182cf79c9bbe4eb6b1f14a57a22d7a0be5e9cedf8e2d5c2086474c3/yarl-1.23.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3ceb13c5c858d01321b5d9bb65e4cf37a92169ea470b70fec6f236b2c9dd7e34" }, - { url = "https://mirrors.aliyun.com/pypi/packages/69/7f/cd5ef733f2550de6241bd8bd8c3febc78158b9d75f197d9c7baa113436af/yarl-1.23.0-cp312-cp312-win32.whl", hash = "sha256:fffc45637bcd6538de8b85f51e3df3223e4ad89bccbfca0481c08c7fc8b7ed7d" }, - { url = "https://mirrors.aliyun.com/pypi/packages/f5/be/25216a49daeeb7af2bec0db22d5e7df08ed1d7c9f65d78b14f3b74fd72fc/yarl-1.23.0-cp312-cp312-win_amd64.whl", hash = "sha256:f69f57305656a4852f2a7203efc661d8c042e6cc67f7acd97d8667fb448a426e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/d2/35/aeab955d6c425b227d5b7247eafb24f2653fedc32f95373a001af5dfeb9e/yarl-1.23.0-cp312-cp312-win_arm64.whl", hash = "sha256:6e87a6e8735b44816e7db0b2fbc9686932df473c826b0d9743148432e10bb9b9" }, { url = "https://mirrors.aliyun.com/pypi/packages/9a/4b/a0a6e5d0ee8a2f3a373ddef8a4097d74ac901ac363eea1440464ccbe0898/yarl-1.23.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:16c6994ac35c3e74fb0ae93323bf8b9c2a9088d55946109489667c510a7d010e" }, { url = "https://mirrors.aliyun.com/pypi/packages/67/b6/8925d68af039b835ae876db5838e82e76ec87b9782ecc97e192b809c4831/yarl-1.23.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:4a42e651629dafb64fd5b0286a3580613702b5809ad3f24934ea87595804f2c5" }, { url = "https://mirrors.aliyun.com/pypi/packages/ae/50/06d511cc4b8e0360d3c94af051a768e84b755c5eb031b12adaaab6dec6e5/yarl-1.23.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:7c6b9461a2a8b47c65eef63bb1c76a4f1c119618ffa99ea79bc5bb1e46c5821b" }, @@ -9034,8 +8483,6 @@ version = "0.1.10" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/35/3e/dd482d5bf99d1dabcce0a20a479859cb7a6bd8a365b07b41ebf46b3c0f3d/zlib_state-0.1.10.tar.gz", hash = "sha256:c29b6b93cea1b80025fbc96fa91ceed8b5e7b54ef08f16d6e4c7f8fb56aad777" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/9e/c9/318a8fa73d41b94810816815e38372d75a8c83c02c9d10dd796443b74ccd/zlib_state-0.1.10-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:6d4f3196f84a4d504f4c04147ec7fd9132651883830f6f07be3702d82731f99e" }, - { url = "https://mirrors.aliyun.com/pypi/packages/38/d8/89a7e7fbea33b20dcdefa122afde7e79a9fdbe75cf5b48e13a110a2c8c8e/zlib_state-0.1.10-cp312-cp312-win_amd64.whl", hash = "sha256:8465b3ddb7fc11e30a49f38615426e369dd1ac5d3d780d89e759e731dfc7bbf4" }, { url = "https://mirrors.aliyun.com/pypi/packages/70/0c/2b0803cb9f30bddbc9eda87d251d958d21cfdde826bc1deb1e19ca0ff320/zlib_state-0.1.10-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:dfecba070cdeeab073573ac721459727d60e0b8ef7b38dac3c965459781b0eeb" }, { url = "https://mirrors.aliyun.com/pypi/packages/b1/d2/74ff59bb480801eae2731523f98be198eec135a9d37e27791b635f2c9124/zlib_state-0.1.10-cp313-cp313-win_amd64.whl", hash = "sha256:72e354f09c942055677ba59d76ca8c311a8129dfc98c3b44db33302843090204" }, { url = "https://mirrors.aliyun.com/pypi/packages/e1/b2/83cfa28037f152d623c1cf716013e5938513d414e8ac3c0312e1b839928f/zlib_state-0.1.10-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:c86d39c50e046547e23d2f0170556444f1f385c251ce0d5cc00c9d7ed6c0ef1e" }, @@ -9048,23 +8495,6 @@ version = "0.25.0" source = { registry = "https://mirrors.aliyun.com/pypi/simple" } sdist = { url = "https://mirrors.aliyun.com/pypi/packages/fd/aa/3e0508d5a5dd96529cdc5a97011299056e14c6505b678fd58938792794b1/zstandard-0.25.0.tar.gz", hash = "sha256:7713e1179d162cf5c7906da876ec2ccb9c3a9dcbdffef0cc7f70c3667a205f0b" } wheels = [ - { url = "https://mirrors.aliyun.com/pypi/packages/82/fc/f26eb6ef91ae723a03e16eddb198abcfce2bc5a42e224d44cc8b6765e57e/zstandard-0.25.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:7b3c3a3ab9daa3eed242d6ecceead93aebbb8f5f84318d82cee643e019c4b73b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/aa/1c/d920d64b22f8dd028a8b90e2d756e431a5d86194caa78e3819c7bf53b4b3/zstandard-0.25.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:913cbd31a400febff93b564a23e17c3ed2d56c064006f54efec210d586171c00" }, - { url = "https://mirrors.aliyun.com/pypi/packages/53/6c/288c3f0bd9fcfe9ca41e2c2fbfd17b2097f6af57b62a81161941f09afa76/zstandard-0.25.0-cp312-cp312-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:011d388c76b11a0c165374ce660ce2c8efa8e5d87f34996aa80f9c0816698b64" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1e/15/efef5a2f204a64bdb5571e6161d49f7ef0fffdbca953a615efbec045f60f/zstandard-0.25.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:6dffecc361d079bb48d7caef5d673c88c8988d3d33fb74ab95b7ee6da42652ea" }, - { url = "https://mirrors.aliyun.com/pypi/packages/b7/37/a6ce629ffdb43959e92e87ebdaeebb5ac81c944b6a75c9c47e300f85abdf/zstandard-0.25.0-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:7149623bba7fdf7e7f24312953bcf73cae103db8cae49f8154dd1eadc8a29ecb" }, - { url = "https://mirrors.aliyun.com/pypi/packages/e3/79/2bf870b3abeb5c070fe2d670a5a8d1057a8270f125ef7676d29ea900f496/zstandard-0.25.0-cp312-cp312-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:6a573a35693e03cf1d67799fd01b50ff578515a8aeadd4595d2a7fa9f3ec002a" }, - { url = "https://mirrors.aliyun.com/pypi/packages/53/60/7be26e610767316c028a2cbedb9a3beabdbe33e2182c373f71a1c0b88f36/zstandard-0.25.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5a56ba0db2d244117ed744dfa8f6f5b366e14148e00de44723413b2f3938a902" }, - { url = "https://mirrors.aliyun.com/pypi/packages/85/c7/3483ad9ff0662623f3648479b0380d2de5510abf00990468c286c6b04017/zstandard-0.25.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:10ef2a79ab8e2974e2075fb984e5b9806c64134810fac21576f0668e7ea19f8f" }, - { url = "https://mirrors.aliyun.com/pypi/packages/08/b3/206883dd25b8d1591a1caa44b54c2aad84badccf2f1de9e2d60a446f9a25/zstandard-0.25.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:aaf21ba8fb76d102b696781bddaa0954b782536446083ae3fdaa6f16b25a1c4b" }, - { url = "https://mirrors.aliyun.com/pypi/packages/9d/31/76c0779101453e6c117b0ff22565865c54f48f8bd807df2b00c2c404b8e0/zstandard-0.25.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:1869da9571d5e94a85a5e8d57e4e8807b175c9e4a6294e3b66fa4efb074d90f6" }, - { url = "https://mirrors.aliyun.com/pypi/packages/18/e1/97680c664a1bf9a247a280a053d98e251424af51f1b196c6d52f117c9720/zstandard-0.25.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:809c5bcb2c67cd0ed81e9229d227d4ca28f82d0f778fc5fea624a9def3963f91" }, - { url = "https://mirrors.aliyun.com/pypi/packages/1e/73/316e4010de585ac798e154e88fd81bb16afc5c5cb1a72eeb16dd37e8024a/zstandard-0.25.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:f27662e4f7dbf9f9c12391cb37b4c4c3cb90ffbd3b1fb9284dadbbb8935fa708" }, - { url = "https://mirrors.aliyun.com/pypi/packages/5b/60/dd0f8cfa8129c5a0ce3ea6b7f70be5b33d2618013a161e1ff26c2b39787c/zstandard-0.25.0-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:99c0c846e6e61718715a3c9437ccc625de26593fea60189567f0118dc9db7512" }, - { url = "https://mirrors.aliyun.com/pypi/packages/fc/5f/75aafd4b9d11b5407b641b8e41a57864097663699f23e9ad4dbb91dc6bfe/zstandard-0.25.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:474d2596a2dbc241a556e965fb76002c1ce655445e4e3bf38e5477d413165ffa" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ff/8d/0309daffea4fcac7981021dbf21cdb2e3427a9e76bafbcdbdf5392ff99a4/zstandard-0.25.0-cp312-cp312-win32.whl", hash = "sha256:23ebc8f17a03133b4426bcc04aabd68f8236eb78c3760f12783385171b0fd8bd" }, - { url = "https://mirrors.aliyun.com/pypi/packages/79/3b/fa54d9015f945330510cb5d0b0501e8253c127cca7ebe8ba46a965df18c5/zstandard-0.25.0-cp312-cp312-win_amd64.whl", hash = "sha256:ffef5a74088f1e09947aecf91011136665152e0b4b359c42be3373897fb39b01" }, - { url = "https://mirrors.aliyun.com/pypi/packages/ea/6b/8b51697e5319b1f9ac71087b0af9a40d8a6288ff8025c36486e0c12abcc4/zstandard-0.25.0-cp312-cp312-win_arm64.whl", hash = "sha256:181eb40e0b6a29b3cd2849f825e0fa34397f649170673d385f3598ae17cca2e9" }, { url = "https://mirrors.aliyun.com/pypi/packages/35/0b/8df9c4ad06af91d39e94fa96cc010a24ac4ef1378d3efab9223cc8593d40/zstandard-0.25.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:ec996f12524f88e151c339688c3897194821d7f03081ab35d31d1e12ec975e94" }, { url = "https://mirrors.aliyun.com/pypi/packages/3f/06/9ae96a3e5dcfd119377ba33d4c42a7d89da1efabd5cb3e366b156c45ff4d/zstandard-0.25.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a1a4ae2dec3993a32247995bdfe367fc3266da832d82f8438c8570f989753de1" }, { url = "https://mirrors.aliyun.com/pypi/packages/d9/14/933d27204c2bd404229c69f445862454dcc101cd69ef8c6068f15aaec12c/zstandard-0.25.0-cp313-cp313-manylinux2010_i686.manylinux2014_i686.manylinux_2_12_i686.manylinux_2_17_i686.whl", hash = "sha256:e96594a5537722fdfb79951672a2a63aec5ebfb823e7560586f7484819f2a08f" }, diff --git a/web/package-lock.json b/web/package-lock.json index bfb0aee4f27..38407cfbe33 100644 --- a/web/package-lock.json +++ b/web/package-lock.json @@ -100,6 +100,7 @@ "recharts": "^2.12.4", "rehype-katex": "^7.0.1", "rehype-raw": "^7.0.0", + "remark-breaks": "^4.0.0", "remark-gfm": "^4.0.0", "remark-math": "^6.0.0", "sonner": "^1.7.4", @@ -18439,6 +18440,20 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/mdast-util-newline-to-break": { + "version": "2.0.0", + "resolved": "https://registry.npmmirror.com/mdast-util-newline-to-break/-/mdast-util-newline-to-break-2.0.0.tgz", + "integrity": "sha512-MbgeFca0hLYIEx/2zGsszCSEJJ1JSCdiY5xQxRcLDDGa8EPvlLPupJ4DSajbMPAnC0je8jfb9TiUATnxxrHUog==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-find-and-replace": "^3.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/mdast-util-phrasing": { "version": "4.1.0", "resolved": "https://registry.npmmirror.com/mdast-util-phrasing/-/mdast-util-phrasing-4.1.0.tgz", @@ -22539,6 +22554,21 @@ "url": "https://opencollective.com/unified" } }, + "node_modules/remark-breaks": { + "version": "4.0.0", + "resolved": "https://registry.npmmirror.com/remark-breaks/-/remark-breaks-4.0.0.tgz", + "integrity": "sha512-IjEjJOkH4FuJvHZVIW0QCDWxcG96kCq7An/KVH2NfJe6rKZU2AsHeB3OEjPNRxi4QC34Xdx7I2KGYn6IpT7gxQ==", + "license": "MIT", + "dependencies": { + "@types/mdast": "^4.0.0", + "mdast-util-newline-to-break": "^2.0.0", + "unified": "^11.0.0" + }, + "funding": { + "type": "opencollective", + "url": "https://opencollective.com/unified" + } + }, "node_modules/remark-gfm": { "version": "4.0.1", "resolved": "https://registry.npmmirror.com/remark-gfm/-/remark-gfm-4.0.1.tgz", diff --git a/web/package.json b/web/package.json index 4e0485c6d98..6dbed295b41 100644 --- a/web/package.json +++ b/web/package.json @@ -122,6 +122,7 @@ "recharts": "^2.12.4", "rehype-katex": "^7.0.1", "rehype-raw": "^7.0.0", + "remark-breaks": "^4.0.0", "remark-gfm": "^4.0.0", "remark-math": "^6.0.0", "sonner": "^1.7.4", diff --git a/web/src/assets/svg/data-source/rest-api.svg b/web/src/assets/svg/data-source/rest-api.svg new file mode 100644 index 00000000000..f7d3e6d213a --- /dev/null +++ b/web/src/assets/svg/data-source/rest-api.svg @@ -0,0 +1,13 @@ + + + + + + + + + + + + + diff --git a/web/src/components/chunk-method-dialog/index.tsx b/web/src/components/chunk-method-dialog/index.tsx index aa6c2398354..21650d7e6d5 100644 --- a/web/src/components/chunk-method-dialog/index.tsx +++ b/web/src/components/chunk-method-dialog/index.tsx @@ -17,7 +17,7 @@ import { DocumentParserType, ParseType } from '@/constants/knowledge'; import { useFetchKnowledgeBaseConfiguration } from '@/hooks/use-knowledge-request'; import { IModalProps } from '@/interfaces/common'; import { IParserConfig } from '@/interfaces/database/document'; -import { IChangeParserConfigRequestBody } from '@/interfaces/request/document'; +import { IChangeParserRequestBody } from '@/interfaces/request/document'; import { MetadataType } from '@/pages/dataset/components/metedata/constant'; import { AutoMetadata, @@ -28,7 +28,6 @@ import { } from '@/pages/dataset/dataset-setting/configuration/common-item'; import { zodResolver } from '@hookform/resolvers/zod'; import omit from 'lodash/omit'; -import {} from 'module'; import { useEffect, useMemo } from 'react'; import { useForm, useWatch } from 'react-hook-form'; import { useTranslation } from 'react-i18next'; @@ -56,10 +55,7 @@ import { const FormId = 'ChunkMethodDialogForm'; -interface IProps extends IModalProps<{ - parserId: string; - parserConfig: IChangeParserConfigRequestBody; -}> { +interface IProps extends IModalProps { loading: boolean; parserId: string; pipelineId?: string; @@ -126,16 +122,19 @@ export function ChunkMethodDialog({ mineru_formula_enable: z.boolean().optional(), mineru_table_enable: z.boolean().optional(), mineru_lang: z.string().optional(), - // raptor: z - // .object({ - // use_raptor: z.boolean().optional(), - // prompt: z.string().optional().optional(), - // max_token: z.coerce.number().optional(), - // threshold: z.coerce.number().optional(), - // max_cluster: z.coerce.number().optional(), - // random_seed: z.coerce.number().optional(), - // }) - // .optional(), + raptor: z + .object({ + use_raptor: z.boolean().optional(), + prompt: z.string().optional(), + max_token: z.coerce.number().optional(), + threshold: z.coerce.number().optional(), + max_cluster: z.coerce.number().optional(), + random_seed: z.coerce.number().optional(), + scope: z.string().optional(), + clustering_method: z.enum(['gmm', 'ahc']).optional(), + tree_builder: z.enum(['raptor', 'psi']).optional(), + }) + .optional(), // graphrag: z.object({ // use_graphrag: z.boolean().optional(), // }), diff --git a/web/src/components/chunk-method-dialog/use-default-parser-values.ts b/web/src/components/chunk-method-dialog/use-default-parser-values.ts index 47af38771b9..84f7c9e3c3d 100644 --- a/web/src/components/chunk-method-dialog/use-default-parser-values.ts +++ b/web/src/components/chunk-method-dialog/use-default-parser-values.ts @@ -23,14 +23,17 @@ export function useDefaultParserValues() { mineru_formula_enable: true, mineru_table_enable: true, mineru_lang: 'English', - // raptor: { - // use_raptor: false, - // prompt: t('knowledgeConfiguration.promptText'), - // max_token: 256, - // threshold: 0.1, - // max_cluster: 64, - // random_seed: 0, - // }, + raptor: { + use_raptor: false, + prompt: t('knowledgeConfiguration.promptText'), + max_token: 256, + threshold: 0.1, + max_cluster: 64, + random_seed: 0, + scope: 'file', + clustering_method: 'gmm', + tree_builder: 'raptor', + }, // graphrag: { // use_graphrag: false, // }, diff --git a/web/src/components/document-preview/doc-preview.tsx b/web/src/components/document-preview/doc-preview.tsx index 147b457c6fe..67d956d9175 100644 --- a/web/src/components/document-preview/doc-preview.tsx +++ b/web/src/components/document-preview/doc-preview.tsx @@ -118,7 +118,7 @@ export const DocPreviewer: React.FC = ({ return (

diff --git a/web/src/components/document-preview/md/index.tsx b/web/src/components/document-preview/md/index.tsx index bdc30f91bc1..13f1af3c2f4 100644 --- a/web/src/components/document-preview/md/index.tsx +++ b/web/src/components/document-preview/md/index.tsx @@ -1,10 +1,10 @@ import { Authorization } from '@/constants/authorization'; +import { MarkdownRemarkPluginsLite } from '@/constants/markdown-remark-plugins'; import { cn } from '@/lib/utils'; import FileError from '@/pages/document-viewer/file-error'; import { getAuthorization } from '@/utils/authorization-util'; import React, { useEffect, useState } from 'react'; import ReactMarkdown from 'react-markdown'; -import remarkGfm from 'remark-gfm'; interface MdProps { // filePath: string; @@ -34,7 +34,9 @@ export const Md: React.FC = ({ url, className }) => { style={{ padding: 4, overflow: 'scroll' }} className={cn(className, 'markdown-body h-[calc(100vh - 200px)]')} > - {content} + + {content} +
); }; diff --git a/web/src/components/document-preview/ppt-preview.tsx b/web/src/components/document-preview/ppt-preview.tsx index 4895ad9fe33..d95e05c46d7 100644 --- a/web/src/components/document-preview/ppt-preview.tsx +++ b/web/src/components/document-preview/ppt-preview.tsx @@ -41,7 +41,7 @@ export const PptPreviewer: React.FC = ({ }); pptxPrviewer.preview(arrayBuffer); } - } catch (err) { + } catch { message.error('ppt parse failed'); } }; diff --git a/web/src/components/dynamic-form.tsx b/web/src/components/dynamic-form.tsx index 5c9fff5eaf4..0920e2422ef 100644 --- a/web/src/components/dynamic-form.tsx +++ b/web/src/components/dynamic-form.tsx @@ -111,10 +111,12 @@ interface DynamicFormProps { // Form ref interface export interface DynamicFormRef { submit: () => void; + isDirty: () => boolean; getValues: (name?: string) => any; reset: (values?: any) => void; trigger: UseFormTrigger; watch: (field: string, callback: (value: any) => void) => () => void; + watchDirty: (callback: (isDirty: boolean, values: any) => void) => () => void; updateFieldType: (fieldName: string, newType: FormFieldType) => void; onFieldUpdate: ( fieldName: string, @@ -347,7 +349,6 @@ export const RenderField = ({ field: FormFieldConfig; labelClassName?: string; }) => { - const form = useFormContext(); if (field.render) { if (field.type === FormFieldType.Custom && field.hideLabel) { return
{field.render({})}
; @@ -810,6 +811,7 @@ const DynamicForm = { onSubmit(filteredValues); })(); }, + isDirty: () => form.formState.isDirty, getValues: form.getValues, reset: (values?: T) => { if (values) { @@ -829,6 +831,12 @@ const DynamicForm = { }); return unsubscribe; }, + watchDirty: (callback: (isDirty: boolean, values: any) => void) => { + const { unsubscribe } = form.watch((values: any) => { + callback(form.formState.isDirty, values); + }); + return unsubscribe; + }, onFieldUpdate: ( fieldName: string, diff --git a/web/src/components/embed-dialog/index.tsx b/web/src/components/embed-dialog/index.tsx index d2656b2ae07..dbb45df2471 100644 --- a/web/src/components/embed-dialog/index.tsx +++ b/web/src/components/embed-dialog/index.tsx @@ -1,6 +1,6 @@ import CopyToClipboard from '@/components/copy-to-clipboard'; import { SelectWithSearch } from '@/components/originui/select-with-search'; -import { Button } from '@/components/ui/button'; +import { Button, ButtonLoading } from '@/components/ui/button'; import { Dialog, DialogContent, @@ -17,6 +17,7 @@ import { } from '@/components/ui/form'; import { Label } from '@/components/ui/label'; import { RadioGroup, RadioGroupItem } from '@/components/ui/radio-group'; +import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs'; import { SharedFrom } from '@/constants/chat'; import { LanguageAbbreviation, @@ -48,23 +49,78 @@ const FormSchema = z.object({ locale: z.string(), embedType: z.enum(['fullscreen', 'widget']), enableStreaming: z.boolean(), + muteWidget: z.boolean(), theme: z.enum([ThemeEnum.Light, ThemeEnum.Dark]), userId: z.string().optional(), + widgetTitle: z.string(), + widgetSubtitle: z.string(), + widgetFooterText: z.string(), + widgetFooterLink: z.string(), + widgetAccentColor: z.string(), + widgetBackgroundColor: z.string(), + widgetTextColor: z.string(), + widgetHeaderTextColor: z.string(), + widgetFooterTextColor: z.string(), }); +export type WidgetSettings = Pick< + z.infer, + | 'enableStreaming' + | 'muteWidget' + | 'widgetTitle' + | 'widgetSubtitle' + | 'widgetFooterText' + | 'widgetFooterLink' + | 'widgetAccentColor' + | 'widgetBackgroundColor' + | 'widgetTextColor' + | 'widgetHeaderTextColor' + | 'widgetFooterTextColor' +>; + +export const defaultWidgetSettings: WidgetSettings = { + enableStreaming: false, + muteWidget: false, + widgetTitle: '', + widgetSubtitle: '', + widgetFooterText: '', + widgetFooterLink: '', + widgetAccentColor: '#2563eb', + widgetBackgroundColor: '#ffffff', + widgetTextColor: '#111827', + widgetHeaderTextColor: '#ffffff', + widgetFooterTextColor: '#111827', +}; + type IProps = IModalProps & { token: string; from: SharedFrom; beta: string; isAgent: boolean; + initialWidgetSettings?: Partial; + onSaveWidgetSettings?: (settings: WidgetSettings) => Promise; + savingWidgetSettings?: boolean; +}; + +const normalizeHexColor = (value: string | undefined, fallback: string) => { + const normalizedValue = value?.trim() ?? ''; + return /^#([0-9a-f]{3}|[0-9a-f]{6})$/i.test(normalizedValue) + ? normalizedValue + : fallback; }; +/** + * Builds the embed code preview and customization UI for shared chat and agent widgets. + */ function EmbedDialog({ hideModal, token = '', from, beta = '', isAgent, + initialWidgetSettings, + onSaveWidgetSettings, + savingWidgetSettings, visible, }: IProps) { const { t } = useTranslation(); @@ -77,8 +133,9 @@ function EmbedDialog({ published: false, locale: '', embedType: 'fullscreen' as const, - enableStreaming: false, theme: ThemeEnum.Light, + ...defaultWidgetSettings, + ...initialWidgetSettings, }, }); @@ -98,8 +155,18 @@ function EmbedDialog({ locale, embedType, enableStreaming, + muteWidget, theme, userId, + widgetTitle, + widgetSubtitle, + widgetFooterText, + widgetFooterLink, + widgetAccentColor, + widgetBackgroundColor, + widgetTextColor, + widgetHeaderTextColor, + widgetFooterTextColor, } = values; const baseRoute = embedType === 'widget' @@ -125,6 +192,39 @@ function EmbedDialog({ if (embedType === 'widget') { src.searchParams.append('mode', 'master'); src.searchParams.append('streaming', String(enableStreaming)); + src.searchParams.append('muted', String(muteWidget)); + if (!isEmpty(trim(widgetTitle))) { + src.searchParams.append('widget_title', widgetTitle ?? ''); + } + if (!isEmpty(trim(widgetSubtitle))) { + src.searchParams.append('widget_subtitle', widgetSubtitle ?? ''); + } + if (!isEmpty(trim(widgetFooterText))) { + src.searchParams.append('widget_footer', widgetFooterText ?? ''); + } + if (!isEmpty(trim(widgetFooterLink))) { + src.searchParams.append('widget_footer_link', widgetFooterLink ?? ''); + } + src.searchParams.append( + 'widget_accent_color', + normalizeHexColor(widgetAccentColor, '#2563eb'), + ); + src.searchParams.append( + 'widget_background_color', + normalizeHexColor(widgetBackgroundColor, '#ffffff'), + ); + src.searchParams.append( + 'widget_text_color', + normalizeHexColor(widgetTextColor, '#111827'), + ); + src.searchParams.append( + 'widget_header_text_color', + normalizeHexColor(widgetHeaderTextColor, '#ffffff'), + ); + src.searchParams.append( + 'widget_footer_text_color', + normalizeHexColor(widgetFooterTextColor, '#111827'), + ); } if (theme && embedType === 'fullscreen') { src.searchParams.append('theme', theme); @@ -179,9 +279,29 @@ window.addEventListener('message',e=>{ window.open(iframeSrc, '_blank'); }, [generateIframeSrc]); + const handleSaveWidgetSettings = useCallback(async () => { + if (!onSaveWidgetSettings) { + return; + } + + await onSaveWidgetSettings({ + enableStreaming: values.enableStreaming, + muteWidget: values.muteWidget, + widgetTitle: values.widgetTitle, + widgetSubtitle: values.widgetSubtitle, + widgetFooterText: values.widgetFooterText, + widgetFooterLink: values.widgetFooterLink, + widgetAccentColor: values.widgetAccentColor, + widgetBackgroundColor: values.widgetBackgroundColor, + widgetTextColor: values.widgetTextColor, + widgetHeaderTextColor: values.widgetHeaderTextColor, + widgetFooterTextColor: values.widgetFooterTextColor, + }); + }, [onSaveWidgetSettings, values]); + return ( - + {t('common.embedIntoSite')} @@ -189,103 +309,274 @@ window.addEventListener('message',e=>{
- ( - - {t('chat.embedType')} - - -
- - -
-
- - -
-
-
- -
- )} - /> - {values.embedType === 'fullscreen' && ( - ( - - {t('chat.theme')} - - -
- - -
-
- - -
-
-
- -
+ + + Embed Setup + Widget Customization + + + ( + + {t('chat.embedType')} + + +
+ + +
+
+ + +
+
+
+ +
+ )} + /> + {values.embedType === 'fullscreen' && ( + ( + + {t('chat.theme')} + + +
+ + +
+
+ + +
+
+
+ +
+ )} + /> )} - /> - )} - - {isAgent && ( - - )} - {values.embedType === 'widget' && ( - - )} - - - - {isAgent && ( - - - - )} + + {isAgent && ( + + )} + {values.embedType === 'widget' && ( + + )} + {values.embedType === 'widget' && ( + + )} + + + + {isAgent && ( + + + + )} +
+ +
+ These settings apply to the floating widget embed. +
+
+ + + + + + + + + + + + + ( + + Widget accent color + +
+ + +
+
+ +
+ )} + /> + ( + + Background color + +
+ + +
+
+ +
+ )} + /> + ( + + Text color + +
+ + +
+
+ +
+ )} + /> + ( + + Header text color + +
+ + +
+
+ +
+ )} + /> + ( + + Footer text color + +
+ + +
+
+ +
+ )} + /> +
+
+
{t('search.embedCode')} -
+
+ @@ -293,14 +584,26 @@ window.addEventListener('message',e=>{
- +
+ {isAgent && onSaveWidgetSettings && ( + + {t('flow.save')} widget settings + + )} + +
{t(isAgent ? 'flow' : 'chat', { keyPrefix: 'header' })} ID diff --git a/web/src/components/floating-chat-widget-markdown.tsx b/web/src/components/floating-chat-widget-markdown.tsx index 3a4e4942c66..51912d72afb 100644 --- a/web/src/components/floating-chat-widget-markdown.tsx +++ b/web/src/components/floating-chat-widget-markdown.tsx @@ -36,8 +36,7 @@ import { } from 'react-syntax-highlighter/dist/esm/styles/prism'; import rehypeKatex from 'rehype-katex'; import rehypeRaw from 'rehype-raw'; -import remarkGfm from 'remark-gfm'; -import remarkMath from 'remark-math'; +import { MarkdownRemarkPlugins } from '@/constants/markdown-remark-plugins'; import { visitParents } from 'unist-util-visit-parents'; import styles from './floating-chat-widget-markdown.module.less'; import { useIsDarkTheme } from './theme-provider'; @@ -292,13 +291,15 @@ const FloatingChatWidgetMarkdown = ({
( -

{children}

- ), + p: (props: any) => { + const { children, node, ...rest } = props; + void node; + return

{children}

; + }, 'custom-typography': ({ children }: { children: string }) => renderReference(children), code(props: any) { diff --git a/web/src/components/floating-chat-widget.tsx b/web/src/components/floating-chat-widget.tsx index 46fb49482a4..9b5c375cff3 100644 --- a/web/src/components/floating-chat-widget.tsx +++ b/web/src/components/floating-chat-widget.tsx @@ -1,3 +1,4 @@ +import CopyToClipboard from '@/components/copy-to-clipboard'; import PdfSheet from '@/components/pdf-drawer'; import { useClickDrawer } from '@/components/pdf-drawer/hooks'; import { MessageType, SharedFrom } from '@/constants/chat'; @@ -14,6 +15,75 @@ import { } from '../pages/next-chats/hooks/use-send-shared-message'; import FloatingChatWidgetMarkdown from './floating-chat-widget-markdown'; +/** + * Normalizes a hex color input and falls back to a safe default when invalid. + */ +const normalizeHexColor = (value: string | null, fallback: string) => { + return value && /^#([0-9a-f]{3}|[0-9a-f]{6})$/i.test(value) + ? value + : fallback; +}; + +/** + * Darkens a hex color to derive hover and gradient variants for the widget chrome. + */ +const darkenHexColor = (hexColor: string, amount = 0.12) => { + const normalizedHex = hexColor.replace('#', ''); + const expandedHex = + normalizedHex.length === 3 + ? normalizedHex + .split('') + .map((char) => `${char}${char}`) + .join('') + : normalizedHex; + const channels = expandedHex.match(/.{2}/g); + + if (!channels) { + return hexColor; + } + + return `#${channels + .map((channel) => { + const value = parseInt(channel, 16); + const adjustedValue = Math.max( + 0, + Math.min(255, Math.round(value * (1 - amount))), + ); + return adjustedValue.toString(16).padStart(2, '0'); + }) + .join('')}`; +}; + +/** + * Accepts a footer link from the widget query string and returns a safe HTTP(S) URL. + */ +const normalizeWidgetFooterLink = (value: string | null) => { + const normalizedValue = value?.trim(); + + if (!normalizedValue) { + return undefined; + } + + const candidate = /^[a-z][a-z\d+.-]*:/i.test(normalizedValue) + ? normalizedValue + : `https://${normalizedValue}`; + + try { + const url = new URL(candidate); + + if (url.protocol === 'http:' || url.protocol === 'https:') { + return url.toString(); + } + } catch { + return undefined; + } + + return undefined; +}; + +/** + * Renders the embeddable floating chat widget and applies URL-driven widget settings. + */ const FloatingChatWidget = () => { const { t } = useTranslation(); const [isOpen, setIsOpen] = useState(false); @@ -36,6 +106,34 @@ const FloatingChatWidget = () => { const urlParams = new URLSearchParams(window.location.search); const mode = urlParams.get('mode') || 'full'; // 'button', 'window', or 'full' const enableStreaming = urlParams.get('streaming') === 'true'; // Only enable if explicitly set to true + const isMuted = urlParams.get('muted') === 'true'; + const widgetTitle = urlParams.get('widget_title')?.trim(); + const widgetSubtitle = urlParams.get('widget_subtitle')?.trim(); + const widgetFooter = urlParams.get('widget_footer')?.trim(); + const widgetFooterLink = normalizeWidgetFooterLink( + urlParams.get('widget_footer_link'), + ); + const widgetAccentColor = normalizeHexColor( + urlParams.get('widget_accent_color'), + '#2563eb', + ); + const widgetAccentColorStrong = darkenHexColor(widgetAccentColor); + const widgetBackgroundColor = normalizeHexColor( + urlParams.get('widget_background_color'), + '#ffffff', + ); + const widgetTextColor = normalizeHexColor( + urlParams.get('widget_text_color'), + '#111827', + ); + const widgetHeaderTextColor = normalizeHexColor( + urlParams.get('widget_header_text_color'), + '#ffffff', + ); + const widgetFooterTextColor = normalizeHexColor( + urlParams.get('widget_footer_text_color'), + '#111827', + ); const { handlePressEnter, @@ -58,6 +156,49 @@ const FloatingChatWidget = () => { )(); const title = data.title; + const displayTitle = widgetTitle || title || t('chat.chatSupport'); + const displaySubtitle = widgetSubtitle || t('chat.replyInstantly'); + const displayFooter = widgetFooter || ''; + const renderFooter = () => { + if (!displayFooter) { + return null; + } + + return ( +
+ {widgetFooterLink ? ( + + {displayFooter} + + ) : ( + displayFooter + )} +
+ ); + }; + const bodyContainerStyle: React.CSSProperties = { + borderRadius: '0 0 16px 16px', + backgroundColor: widgetBackgroundColor, + color: widgetTextColor, + }; + const inputStyle: React.CSSProperties = { + minHeight: '44px', + maxHeight: '120px', + color: widgetTextColor, + backgroundColor: widgetBackgroundColor, + }; const { visible, hideModal, documentId, selectedChunk, clickDocumentButton } = useClickDrawer(); @@ -69,6 +210,10 @@ const FloatingChatWidget = () => { // Play sound when opening const playNotificationSound = useCallback(() => { + if (isMuted) { + return; + } + try { const audioContext = new ( window.AudioContext || (window as any).webkitAudioContext @@ -94,10 +239,14 @@ const FloatingChatWidget = () => { console.warn(error); // Silent fail if audio not supported } - }, []); + }, [isMuted]); // Play sound for AI responses (Intercom-style) const playResponseSound = useCallback(() => { + if (isMuted) { + return; + } + try { const audioContext = new ( window.AudioContext || (window as any).webkitAudioContext @@ -124,7 +273,7 @@ const FloatingChatWidget = () => { // Silent fail if audio not supported } - }, []); + }, [isMuted]); // Set loaded state and locale useEffect(() => { @@ -338,9 +487,10 @@ const FloatingChatWidget = () => { '*', ); }} - className={`w-14 h-14 bg-blue-600 hover:bg-blue-700 text-white rounded-full transition-all duration-300 flex items-center justify-center group ${ + className={`w-14 h-14 text-white rounded-full transition-all duration-300 flex items-center justify-center group ${ isOpen ? 'scale-95' : 'scale-100 hover:scale-105' }`} + style={{ backgroundColor: widgetAccentColor }} >
{