From 56ddd4f4660be4a7e9d7b16d26040f70d5418d61 Mon Sep 17 00:00:00 2001 From: are-ces <195810094+are-ces@users.noreply.github.com> Date: Tue, 19 May 2026 09:36:15 +0200 Subject: [PATCH 1/2] LCORE-1333: Add BYOK RAG integration tests for /responses endpoint Add test_responses_byok_integration.py with integration tests covering inline RAG, tool RAG, combined RAG, score multiplier, chunk capping, and RAG_CONTENT_LIMIT enforcement for the /responses endpoint. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../test_responses_byok_integration.py | 704 ++++++++++++++++++ 1 file changed, 704 insertions(+) create mode 100644 tests/integration/endpoints/test_responses_byok_integration.py diff --git a/tests/integration/endpoints/test_responses_byok_integration.py b/tests/integration/endpoints/test_responses_byok_integration.py new file mode 100644 index 000000000..06b97341a --- /dev/null +++ b/tests/integration/endpoints/test_responses_byok_integration.py @@ -0,0 +1,704 @@ +"""Integration tests for the /responses endpoint BYOK RAG functionality.""" + +from typing import Any + +import pytest +from fastapi import Request +from pytest_mock import MockerFixture + +import constants +from app.endpoints.responses import responses_endpoint_handler +from authentication.interface import AuthTuple +from configuration import AppConfig +from models.api.requests import ResponsesRequest +from models.api.responses.successful import ResponsesResponse +from models.common.responses.responses_context import ResponsesContext +from tests.integration.endpoints.test_query_byok_integration import ( + _build_base_mock_client, + _make_byok_vector_io_response, + _make_vector_io_response, +) + +# --------------------------------------------------------------------------- +# Constants +# --------------------------------------------------------------------------- + +MOCK_AUTH: AuthTuple = ( + "00000000-0000-0000-0000-000", + "lightspeed-user", + True, + "", +) + +_RESPONSE_DUMP: dict[str, Any] = { + "id": "resp-1", + "object": "response", + "created_at": 1700000000, + "status": "completed", + "model": "test-provider/test-model", + "output": [ + { + "type": "message", + "id": "msg-1", + "role": "assistant", + "status": "completed", + "content": [ + { + "type": "output_text", + "text": "OpenShift is a Kubernetes distribution.", + "annotations": [], + } + ], + } + ], + "usage": { + "input_tokens": 50, + "output_tokens": 20, + "total_tokens": 70, + "input_tokens_details": {"cached_tokens": 0}, + "output_tokens_details": {"reasoning_tokens": 0}, + }, +} + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_responses_mock_client(mocker: MockerFixture) -> Any: + """Build a mock client suitable for the /responses endpoint.""" + mock_client = _build_base_mock_client(mocker) + mock_client.responses.create.return_value.model_dump.return_value = ( + _RESPONSE_DUMP.copy() + ) + return mock_client + + +def _patch_all_client_holders(mocker: MockerFixture, mock_client: Any) -> None: + """Patch AsyncLlamaStackClientHolder in all modules used by the responses endpoint.""" + for module in ( + "app.endpoints.responses", + "utils.endpoints", + "utils.responses", + ): + holder = mocker.patch(f"{module}.AsyncLlamaStackClientHolder") + holder.return_value.get_client.return_value = mock_client + + original_cls = ResponsesContext + + def _skip_validation(**kwargs: Any) -> ResponsesContext: + return original_cls.model_construct(**kwargs) + + mocker.patch( + "app.endpoints.responses.ResponsesContext", side_effect=_skip_validation + ) + + +# ============================================================================== +# Inline BYOK RAG Tests +# ============================================================================== + + +@pytest.mark.asyncio +async def test_responses_byok_inline_rag_injects_context( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, +) -> None: + """Test that inline BYOK RAG fetches chunks and injects context into the input. + + Verifies: + - vector_io.query is called for BYOK inline RAG + - RAG context is injected into the responses.create input + - Response is a valid ResponsesResponse + """ + entry = mocker.MagicMock() + entry.rag_id = "test-knowledge" + entry.vector_db_id = "vs-byok-knowledge" + entry.score_multiplier = 1.0 + test_config.configuration.byok_rag = [entry] + test_config.configuration.rag.inline = ["test-knowledge"] + + mock_client = _build_responses_mock_client(mocker) + _patch_all_client_holders(mocker, mock_client) + + mock_client.vector_io.query = mocker.AsyncMock( + return_value=_make_byok_vector_io_response(mocker) + ) + + mock_vs_resp = mocker.MagicMock() + mock_vs_resp.data = [] + mock_client.vector_stores.list.return_value = mock_vs_resp + + responses_request = ResponsesRequest(input="What is OpenShift?", stream=False) + + response = await responses_endpoint_handler( + request=test_request, + responses_request=responses_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + assert isinstance(response, ResponsesResponse) + + # Verify vector_io.query was called for inline RAG + mock_client.vector_io.query.assert_called() + call_kwargs = mock_client.vector_io.query.call_args.kwargs + assert call_kwargs["query"] == "What is OpenShift?" + + # Verify RAG context was injected into responses.create input + create_call = mock_client.responses.create.call_args_list[0] + input_text = create_call.kwargs.get("input", "") + assert "file_search found" in input_text + assert "OpenShift is a Kubernetes distribution" in input_text + + +@pytest.mark.asyncio +async def test_responses_byok_inline_rag_error_is_handled_gracefully( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, +) -> None: + """Test that BYOK RAG search failures are handled gracefully. + + Verifies: + - When vector_io.query raises an exception, the endpoint still succeeds + - The error is silently handled (BYOK search errors are non-fatal) + """ + entry = mocker.MagicMock() + entry.rag_id = "test-knowledge" + entry.vector_db_id = "vs-byok-knowledge" + entry.score_multiplier = 1.0 + test_config.configuration.byok_rag = [entry] + test_config.configuration.rag.inline = ["test-knowledge"] + + mock_client = _build_responses_mock_client(mocker) + _patch_all_client_holders(mocker, mock_client) + + mock_client.vector_io.query = mocker.AsyncMock( + side_effect=Exception("Connection refused") + ) + + mock_vs_resp = mocker.MagicMock() + mock_vs_resp.data = [] + mock_client.vector_stores.list.return_value = mock_vs_resp + + responses_request = ResponsesRequest(input="What is OpenShift?", stream=False) + + response = await responses_endpoint_handler( + request=test_request, + responses_request=responses_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + # Endpoint should succeed despite BYOK RAG failure + assert isinstance(response, ResponsesResponse) + + +# ============================================================================== +# Tool-based BYOK RAG Tests +# ============================================================================== + + +@pytest.mark.asyncio +async def test_responses_byok_tool_rag_returns_tool_calls( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, +) -> None: + """Test that BYOK tool RAG configures file_search tool in responses.create call. + + Verifies: + - file_search tool is present in the tools passed to responses.create + - The tool includes the configured vector store ID + """ + byok_entry = mocker.MagicMock() + byok_entry.rag_id = "test-knowledge" + byok_entry.vector_db_id = "vs-byok-knowledge" + byok_entry.score_multiplier = 1.0 + byok_entry.model_dump.return_value = { + "rag_id": "test-knowledge", + "rag_type": "inline::faiss", + "embedding_model": "sentence-transformers/all-mpnet-base-v2", + "embedding_dimension": 768, + "vector_db_id": "vs-byok-knowledge", + "db_path": "/tmp/test-db", + "score_multiplier": 1.0, + } + + test_config.configuration.byok_rag = [byok_entry] + test_config.configuration.rag.inline = [] + test_config.configuration.rag.tool = ["test-knowledge"] + + mock_client = _build_responses_mock_client(mocker) + _patch_all_client_holders(mocker, mock_client) + + mock_vector_store = mocker.MagicMock() + mock_vector_store.id = "vs-byok-knowledge" + mock_list_result = mocker.MagicMock() + mock_list_result.data = [mock_vector_store] + mock_client.vector_stores.list.return_value = mock_list_result + + responses_request = ResponsesRequest(input="What is OpenShift?", stream=False) + + await responses_endpoint_handler( + request=test_request, + responses_request=responses_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + assert mock_client.responses.create.called + call_kwargs = mock_client.responses.create.call_args_list[0] + + tools = call_kwargs.kwargs.get("tools", []) + file_search_tools = [ + t + for t in tools + if (t.get("type") if isinstance(t, dict) else getattr(t, "type", None)) + == "file_search" + ] + assert len(file_search_tools) == 1 + + +# ============================================================================== +# Combined Inline + Tool RAG Tests +# ============================================================================== + + +@pytest.mark.asyncio +async def test_responses_byok_combined_inline_and_tool_rag( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, +) -> None: + """Test that inline and tool-based BYOK RAG are both active when configured. + + Verifies: + - Inline RAG context is injected into the input + - file_search tool is present in the tools passed to responses.create + """ + byok_entry = mocker.MagicMock() + byok_entry.rag_id = "test-knowledge" + byok_entry.vector_db_id = "vs-byok-knowledge" + byok_entry.score_multiplier = 1.0 + byok_entry.model_dump.return_value = { + "rag_id": "test-knowledge", + "rag_type": "inline::faiss", + "embedding_model": "sentence-transformers/all-mpnet-base-v2", + "embedding_dimension": 768, + "vector_db_id": "vs-byok-knowledge", + "db_path": "/tmp/test-db", + "score_multiplier": 1.0, + } + test_config.configuration.byok_rag = [byok_entry] + test_config.configuration.rag.inline = ["test-knowledge"] + test_config.configuration.rag.tool = ["test-knowledge"] + + mock_client = _build_responses_mock_client(mocker) + _patch_all_client_holders(mocker, mock_client) + + # Inline RAG returns chunks via vector_io + mock_client.vector_io.query = mocker.AsyncMock( + return_value=_make_byok_vector_io_response(mocker) + ) + + # Tool RAG vector stores + mock_vector_store = mocker.MagicMock() + mock_vector_store.id = "vs-byok-knowledge" + mock_list_result = mocker.MagicMock() + mock_list_result.data = [mock_vector_store] + mock_client.vector_stores.list.return_value = mock_list_result + + responses_request = ResponsesRequest(input="What is OpenShift?", stream=False) + + response = await responses_endpoint_handler( + request=test_request, + responses_request=responses_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + assert isinstance(response, ResponsesResponse) + + # Verify inline RAG context was injected + create_call = mock_client.responses.create.call_args_list[0] + input_text = create_call.kwargs.get("input", "") + assert "file_search found" in input_text + + # Verify tool RAG file_search tool is present + tools = create_call.kwargs.get("tools", []) + file_search_tools = [ + t + for t in tools + if (t.get("type") if isinstance(t, dict) else getattr(t, "type", None)) + == "file_search" + ] + assert len(file_search_tools) == 1 + + +# ============================================================================== +# Inline RAG rag_id Resolution Tests +# ============================================================================== + + +@pytest.mark.asyncio +async def test_responses_byok_inline_rag_only_configured_rag_id_is_queried( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, +) -> None: + """Test that only the rag_id listed in rag.inline triggers retrieval. + + Two BYOK sources are registered (source-a and source-b) but only + source-a is listed in rag.inline. Only the vector_db_id for + source-a should be queried. + + Verifies: + - vector_io.query is called exactly once (for the configured source) + - The call targets the correct vector_db_id + """ + entry_a = mocker.MagicMock() + entry_a.rag_id = "source-a" + entry_a.vector_db_id = "vs-source-a" + entry_a.score_multiplier = 1.0 + + entry_b = mocker.MagicMock() + entry_b.rag_id = "source-b" + entry_b.vector_db_id = "vs-source-b" + entry_b.score_multiplier = 1.0 + + test_config.configuration.byok_rag = [entry_a, entry_b] + test_config.configuration.rag.inline = ["source-a"] + + mock_client = _build_responses_mock_client(mocker) + _patch_all_client_holders(mocker, mock_client) + + mock_client.vector_io.query = mocker.AsyncMock( + return_value=_make_byok_vector_io_response(mocker) + ) + + mock_vs_resp = mocker.MagicMock() + mock_vs_resp.data = [] + mock_client.vector_stores.list.return_value = mock_vs_resp + + responses_request = ResponsesRequest(input="What is OpenShift?", stream=False) + + response = await responses_endpoint_handler( + request=test_request, + responses_request=responses_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + assert isinstance(response, ResponsesResponse) + + assert mock_client.vector_io.query.call_count == 1 + call_kwargs = mock_client.vector_io.query.call_args.kwargs + assert call_kwargs["vector_store_id"] == "vs-source-a" + + +# ============================================================================== +# Score Multiplier Priority Tests +# ============================================================================== + + +@pytest.mark.asyncio +async def test_responses_byok_score_multiplier_shifts_chunk_priority( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, +) -> None: + """Test that score_multiplier can shift chunk priority across sources. + + Doc A (source-a) has high base similarity (0.90) with multiplier 1.0. + Doc B (source-b) has low base similarity (0.40) with multiplier 5.0. + After weighting: Doc A = 0.90, Doc B = 2.00. + Doc B should appear above Doc A in the final context. + + Verifies: + - The chunk with the higher weighted score appears first in the context + - score_multiplier correctly influences ranking + """ + entry_a = mocker.MagicMock() + entry_a.rag_id = "source-a" + entry_a.vector_db_id = "vs-source-a" + entry_a.score_multiplier = 1.0 + + entry_b = mocker.MagicMock() + entry_b.rag_id = "source-b" + entry_b.vector_db_id = "vs-source-b" + entry_b.score_multiplier = 5.0 + + test_config.configuration.byok_rag = [entry_a, entry_b] + test_config.configuration.rag.inline = ["source-a", "source-b"] + + mock_client = _build_responses_mock_client(mocker) + _patch_all_client_holders(mocker, mock_client) + + # Source A: high base similarity + resp_a = _make_vector_io_response( + mocker, + [ + ("Doc A content - high similarity", "doc-a", 0.90), + ], + ) + # Source B: low base similarity + resp_b = _make_vector_io_response( + mocker, + [ + ("Doc B content - low similarity", "doc-b", 0.40), + ], + ) + + # Return different results per vector store + async def _side_effect(**kwargs: Any) -> Any: + if kwargs["vector_store_id"] == "vs-source-a": + return resp_a + return resp_b + + mock_client.vector_io.query = mocker.AsyncMock(side_effect=_side_effect) + + mock_vs_resp = mocker.MagicMock() + mock_vs_resp.data = [] + mock_client.vector_stores.list.return_value = mock_vs_resp + + responses_request = ResponsesRequest(input="test query", stream=False) + + response = await responses_endpoint_handler( + request=test_request, + responses_request=responses_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + assert isinstance(response, ResponsesResponse) + + # Doc B (weighted 2.0) should rank above Doc A (weighted 0.9) in the context + create_call = mock_client.responses.create.call_args_list[0] + input_text = create_call.kwargs.get("input", "") + assert "file_search found 2 chunks:" in input_text + + # Doc B (higher weighted score) should appear before Doc A in the context + pos_b = input_text.find("Doc B content - low similarity") + pos_a = input_text.find("Doc A content - high similarity") + assert pos_b != -1 and pos_a != -1 + assert ( + pos_b < pos_a + ), "Doc B should appear before Doc A due to higher weighted score" + + +# ============================================================================== +# RAG_CONTENT_LIMIT Capping Tests +# ============================================================================== + + +@pytest.mark.asyncio +async def test_responses_rag_content_limit_caps_retrieved_results( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, +) -> None: + """Test that RAG_CONTENT_LIMIT caps the number of returned chunks. + + A single source returns more chunks than RAG_CONTENT_LIMIT allows. + The context sent to the LLM should contain at most RAG_CONTENT_LIMIT chunks. + + Verifies: + - Context chunk count does not exceed RAG_CONTENT_LIMIT + - Returned chunks are the top-scoring ones + """ + entry = mocker.MagicMock() + entry.rag_id = "big-source" + entry.vector_db_id = "vs-big-source" + entry.score_multiplier = 1.0 + + test_config.configuration.byok_rag = [entry] + test_config.configuration.rag.inline = ["big-source"] + test_config.configuration.reranker.enabled = False + + mock_client = _build_responses_mock_client(mocker) + _patch_all_client_holders(mocker, mock_client) + + # Generate more chunks than RAG_CONTENT_LIMIT + num_chunks = constants.RAG_CONTENT_LIMIT + 1 + chunks_data = [ + (f"Chunk content {i}", f"chunk-{i}", round(0.50 + i * 0.03, 2)) + for i in range(num_chunks) + ] + mock_client.vector_io.query = mocker.AsyncMock( + return_value=_make_vector_io_response(mocker, chunks_data) + ) + + mock_vs_resp = mocker.MagicMock() + mock_vs_resp.data = [] + mock_client.vector_stores.list.return_value = mock_vs_resp + + responses_request = ResponsesRequest(input="test query", stream=False) + + response = await responses_endpoint_handler( + request=test_request, + responses_request=responses_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + assert isinstance(response, ResponsesResponse) + + create_call = mock_client.responses.create.call_args_list[0] + input_text = create_call.kwargs.get("input", "") + expected_header = f"file_search found {constants.RAG_CONTENT_LIMIT} chunks:" + assert expected_header in input_text + + # The highest-scored chunk should be present + assert f"Chunk content {num_chunks - 1}" in input_text + # The lowest-scored chunk should be excluded + assert "Chunk content 0" not in input_text + + +@pytest.mark.asyncio +async def test_responses_rag_content_limit_caps_across_multiple_sources( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, +) -> None: + """Test that RAG_CONTENT_LIMIT caps chunks across multiple sources. + + Two sources each return several chunks. The combined result should not + exceed RAG_CONTENT_LIMIT and should contain the globally highest-scored + chunks regardless of source. + + Verifies: + - Total chunks across sources are capped at RAG_CONTENT_LIMIT + - Top-scoring chunks from both sources are included + """ + entry_a = mocker.MagicMock() + entry_a.rag_id = "source-a" + entry_a.vector_db_id = "vs-source-a" + entry_a.score_multiplier = 1.0 + + entry_b = mocker.MagicMock() + entry_b.rag_id = "source-b" + entry_b.vector_db_id = "vs-source-b" + entry_b.score_multiplier = 1.0 + + test_config.configuration.byok_rag = [entry_a, entry_b] + test_config.configuration.rag.inline = ["source-a", "source-b"] + + mock_client = _build_responses_mock_client(mocker) + _patch_all_client_holders(mocker, mock_client) + + # Overlapping score bands so top-k must pick from both sources + n = constants.RAG_CONTENT_LIMIT + resp_a = _make_vector_io_response( + mocker, + [ + (f"Source A chunk {i}", f"a-chunk-{i}", round(0.70 + i * 0.05, 2)) + for i in range(n) + ], + ) + resp_b = _make_vector_io_response( + mocker, + [ + (f"Source B chunk {i}", f"b-chunk-{i}", round(0.72 + i * 0.05, 2)) + for i in range(n) + ], + ) + + async def _side_effect(**kwargs: Any) -> Any: + if kwargs["vector_store_id"] == "vs-source-a": + return resp_a + return resp_b + + mock_client.vector_io.query = mocker.AsyncMock(side_effect=_side_effect) + + mock_vs_resp = mocker.MagicMock() + mock_vs_resp.data = [] + mock_client.vector_stores.list.return_value = mock_vs_resp + + responses_request = ResponsesRequest(input="test query", stream=False) + + response = await responses_endpoint_handler( + request=test_request, + responses_request=responses_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + assert isinstance(response, ResponsesResponse) + + create_call = mock_client.responses.create.call_args_list[0] + input_text = create_call.kwargs.get("input", "") + expected_header = f"file_search found {constants.RAG_CONTENT_LIMIT} chunks:" + assert expected_header in input_text + + # Both sources should survive the cap (high-scoring chunks from each) + assert "Source A chunk" in input_text + assert "Source B chunk" in input_text + + # Lowest-scoring chunks from each source should be dropped + assert "Source A chunk 0" not in input_text + assert "Source B chunk 0" not in input_text + + +@pytest.mark.asyncio +async def test_responses_rag_content_limit_caps_inline_rag( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, +) -> None: + """Test that RAG_CONTENT_LIMIT caps inline RAG below BYOK_RAG_MAX_CHUNKS. + + Sets RAG_CONTENT_LIMIT to 3 (below BYOK_RAG_MAX_CHUNKS=10) and feeds + 10 chunks. The context sent to the LLM should contain at most 3 chunks. + + Verifies: + - Context chunk count equals the lowered RAG_CONTENT_LIMIT + - Only the highest-scored chunks appear in the context + """ + mocker.patch("utils.vector_search.constants.RAG_CONTENT_LIMIT", 3) + + entry = mocker.MagicMock() + entry.rag_id = "big-source" + entry.vector_db_id = "vs-big-source" + entry.score_multiplier = 1.0 + + test_config.configuration.byok_rag = [entry] + test_config.configuration.rag.inline = ["big-source"] + test_config.configuration.reranker.enabled = False + + mock_client = _build_responses_mock_client(mocker) + _patch_all_client_holders(mocker, mock_client) + + num_chunks = constants.BYOK_RAG_MAX_CHUNKS + chunks_data = [ + (f"Chunk content {i}", f"chunk-{i}", round(0.50 + i * 0.03, 2)) + for i in range(num_chunks) + ] + mock_client.vector_io.query = mocker.AsyncMock( + return_value=_make_vector_io_response(mocker, chunks_data) + ) + + mock_vs_resp = mocker.MagicMock() + mock_vs_resp.data = [] + mock_client.vector_stores.list.return_value = mock_vs_resp + + responses_request = ResponsesRequest(input="test query", stream=False) + + response = await responses_endpoint_handler( + request=test_request, + responses_request=responses_request, + auth=MOCK_AUTH, + mcp_headers={}, + ) + + assert isinstance(response, ResponsesResponse) + + create_call = mock_client.responses.create.call_args_list[0] + input_text = create_call.kwargs.get("input", "") + expected_header = "file_search found 3 chunks:" + assert expected_header in input_text + + assert f"Chunk content {num_chunks - 1}" in input_text + assert "Chunk content 0" not in input_text From a18cf4ca70ebc60198d3553761648dd6ba7f7ce3 Mon Sep 17 00:00:00 2001 From: are-ces <195810094+are-ces@users.noreply.github.com> Date: Tue, 19 May 2026 09:36:23 +0200 Subject: [PATCH 2/2] LCORE-1333: Add RAG_CONTENT_LIMIT to cap inline RAG chunks Add RAG_CONTENT_LIMIT constant (default: 10) that caps the final merged BYOK + OKP output from build_rag_context. Per-source constants remain as fetch hints for the reranking pool. Tool RAG is unaffected. Co-Authored-By: Claude Opus 4.6 (1M context) --- docs/byok_guide.md | 3 + docs/rag_guide.md | 7 +- src/constants.py | 3 + src/utils/vector_search.py | 20 ++-- .../endpoints/test_query_byok_integration.py | 110 +++++++++++++++--- .../test_streaming_query_byok_integration.py | 103 +++++++++++++--- 6 files changed, 198 insertions(+), 48 deletions(-) diff --git a/docs/byok_guide.md b/docs/byok_guide.md index 5213a8d15..918a434cf 100644 --- a/docs/byok_guide.md +++ b/docs/byok_guide.md @@ -85,6 +85,9 @@ Inline RAG additionally supports: > `score_multiplier` does not apply to OKP results. To control the amount of retrieved > context, set the `BYOK_RAG_MAX_CHUNKS` and `OKP_RAG_MAX_CHUNKS` constants in `src/constants.py` > (defaults: 10 and 5 respectively). For Tool RAG, use `TOOL_RAG_MAX_CHUNKS` (default: 10). +> The `RAG_CONTENT_LIMIT` constant (default: 10) caps the final merged inline RAG +> chunks (BYOK + OKP) delivered to the LLM. Tool RAG is controlled independently +> by `TOOL_RAG_MAX_CHUNKS`. --- diff --git a/docs/rag_guide.md b/docs/rag_guide.md index 598272ea7..d33e24189 100644 --- a/docs/rag_guide.md +++ b/docs/rag_guide.md @@ -385,9 +385,10 @@ the number of retrieved chunks, set the constants in `src/constants.py`: | Constant | Default | Description | |----------|---------|-------------| -| `OKP_RAG_MAX_CHUNKS` | 5 | Max chunks retrieved from OKP (Inline RAG) | -| `BYOK_RAG_MAX_CHUNKS` | 10 | Max chunks retrieved from BYOK stores (Inline RAG) | -| `TOOL_RAG_MAX_CHUNKS` | 10 | Max chunks retrieved via Tool RAG (`file_search`) | +| `RAG_CONTENT_LIMIT` | 10 | Hard upper bound on the final merged inline RAG chunks (BYOK + OKP) delivered to the LLM | +| `OKP_RAG_MAX_CHUNKS` | 5 | Fetch hint for OKP (Inline RAG); controls how many chunks enter the reranking pool | +| `BYOK_RAG_MAX_CHUNKS` | 10 | Fetch hint for BYOK stores (Inline RAG); controls how many chunks enter the reranking pool | +| `TOOL_RAG_MAX_CHUNKS` | 10 | Max chunks retrieved via Tool RAG (`file_search`); independent from `RAG_CONTENT_LIMIT` | **Limitations:** diff --git a/src/constants.py b/src/constants.py index 88dd2aee5..46ef11cce 100644 --- a/src/constants.py +++ b/src/constants.py @@ -188,6 +188,9 @@ USER_QUOTA_LIMITER: Final[str] = "user_limiter" CLUSTER_QUOTA_LIMITER: Final[str] = "cluster_limiter" +# Hard cap on total RAG chunks delivered to the LLM across all sources +RAG_CONTENT_LIMIT: Final[int] = 10 + # RAG as a tool constants DEFAULT_RAG_TOOL: Final[str] = "file_search" TOOL_RAG_MAX_CHUNKS: Final[int] = 10 # retrieved from RAG as a tool diff --git a/src/utils/vector_search.py b/src/utils/vector_search.py index 3cd7d0b9a..6defd011a 100644 --- a/src/utils/vector_search.py +++ b/src/utils/vector_search.py @@ -609,9 +609,9 @@ async def build_rag_context( # pylint: disable=too-many-locals,too-many-branche ) -> RAGContext: """Build RAG context by fetching and merging chunks from all enabled sources. - Fetches 2 * BYOK_RAG_MAX_CHUNKS from each of BYOK and Solr, merges and keeps - top 2 * BYOK_RAG_MAX_CHUNKS by score, reranks with a cross-encoder, then - keeps the top BYOK_RAG_MAX_CHUNKS for context. Enabled sources can be BYOK + Each source fetches using its per-source limit to build the reranking pool. + Results are merged, sorted by score, reranked with a cross-encoder if + enabled, then capped at RAG_CONTENT_LIMIT. Enabled sources can be BYOK and/or Solr OKP. Args: @@ -626,12 +626,11 @@ async def build_rag_context( # pylint: disable=too-many-locals,too-many-branche if moderation_decision == "blocked": return RAGContext() - pool_size = 2 * constants.BYOK_RAG_MAX_CHUNKS - top_k = constants.BYOK_RAG_MAX_CHUNKS + top_k = constants.RAG_CONTENT_LIMIT - # Fetch 2*BYOK_RAG_MAX_CHUNKS from each source in parallel + # Fetch from each source using per-source limits for the reranking pool byok_chunks_task = _fetch_byok_rag( - client, query, vector_store_ids, max_chunks=pool_size + client, query, vector_store_ids, max_chunks=constants.BYOK_RAG_MAX_CHUNKS ) solr_chunks_task = _fetch_solr_rag(client, query, solr) @@ -639,21 +638,20 @@ async def build_rag_context( # pylint: disable=too-many-locals,too-many-branche byok_chunks_task, solr_chunks_task ) - # Merge: combine and sort by score, keep top 2*BYOK_RAG_MAX_CHUNKS + # Merge: combine and sort by score merged = byok_chunks + solr_chunks merged.sort( key=lambda c: c.score if c.score is not None else float("-inf"), reverse=True ) - merged = merged[:pool_size] - # Rerank full pool with cross-encoder if enabled; boost BYOK then take top_k + # Rerank full pool with cross-encoder if enabled; then take top_k if configuration.reranker.enabled: logger.info( "Reranker enabled: processing %d chunks with model '%s'", len(merged), configuration.reranker.model, ) - reranked = await rerank_chunks_with_cross_encoder(query, merged, pool_size) + reranked = await rerank_chunks_with_cross_encoder(query, merged, len(merged)) context_chunks = apply_byok_rerank_boost(reranked)[:top_k] logger.info( "Reranker completed: returned %d top chunks after BYOK boost", diff --git a/tests/integration/endpoints/test_query_byok_integration.py b/tests/integration/endpoints/test_query_byok_integration.py index bdf080489..589501eff 100644 --- a/tests/integration/endpoints/test_query_byok_integration.py +++ b/tests/integration/endpoints/test_query_byok_integration.py @@ -1066,25 +1066,25 @@ async def _side_effect(**kwargs: Any) -> Any: # ============================================================================== -# BYOK_RAG_MAX_CHUNKS Capping Tests +# RAG_CONTENT_LIMIT Capping Tests # ============================================================================== @pytest.mark.asyncio -async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable=too-many-locals +async def test_query_rag_content_limit_caps_retrieved_results( # pylint: disable=too-many-locals test_config: AppConfig, mocker: MockerFixture, test_request: Request, test_auth: AuthTuple, ) -> None: - """Test that BYOK_RAG_MAX_CHUNKS caps the number of returned chunks. + """Test that RAG_CONTENT_LIMIT caps the number of returned chunks. - A single source returns more chunks than BYOK_RAG_MAX_CHUNKS allows. - The response should contain at most BYOK_RAG_MAX_CHUNKS chunks and + A single source returns more chunks than RAG_CONTENT_LIMIT allows. + The response should contain at most RAG_CONTENT_LIMIT chunks and they should be the highest-scored ones. Verifies: - - Number of RAG chunks does not exceed BYOK_RAG_MAX_CHUNKS + - Number of RAG chunks does not exceed RAG_CONTENT_LIMIT - Returned chunks are the top-scoring ones """ entry = mocker.MagicMock() @@ -1101,8 +1101,8 @@ async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable= mock_holder_class = mocker.patch("app.endpoints.query.AsyncLlamaStackClientHolder") mock_client = _build_base_mock_client(mocker) - # Generate more chunks than BYOK_RAG_MAX_CHUNKS - num_chunks = constants.BYOK_RAG_MAX_CHUNKS + 1 + # Generate more chunks than RAG_CONTENT_LIMIT + num_chunks = constants.RAG_CONTENT_LIMIT + 1 chunks_data = [ (f"Chunk content {i}", f"chunk-{i}", round(0.50 + i * 0.03, 2)) for i in range(num_chunks) @@ -1141,7 +1141,7 @@ async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable= ) assert response.rag_chunks is not None - assert len(response.rag_chunks) == constants.BYOK_RAG_MAX_CHUNKS + assert len(response.rag_chunks) == constants.RAG_CONTENT_LIMIT # Check that the score is computed properly for chunk in response.rag_chunks: @@ -1161,20 +1161,20 @@ async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable= @pytest.mark.asyncio -async def test_query_byok_max_chunks_caps_across_multiple_sources( # pylint: disable=too-many-locals +async def test_query_rag_content_limit_caps_across_multiple_sources( # pylint: disable=too-many-locals test_config: AppConfig, mocker: MockerFixture, test_request: Request, test_auth: AuthTuple, ) -> None: - """Test that BYOK_RAG_MAX_CHUNKS caps chunks across multiple sources. + """Test that RAG_CONTENT_LIMIT caps chunks across multiple sources. - Two sources each return several chunks. The combined result should - not exceed BYOK_RAG_MAX_CHUNKS and should contain the globally - highest-scored chunks regardless of source. + Two sources each return several chunks. The combined result should not + exceed RAG_CONTENT_LIMIT and should contain the globally highest-scored + chunks regardless of source. Verifies: - - Total chunks across sources are capped at BYOK_RAG_MAX_CHUNKS + - Total chunks across sources are capped at RAG_CONTENT_LIMIT - Top-scoring chunks from both sources are included """ entry_a = mocker.MagicMock() @@ -1194,7 +1194,7 @@ async def test_query_byok_max_chunks_caps_across_multiple_sources( # pylint: di mock_client = _build_base_mock_client(mocker) # Overlapping score bands so top-k must pick from both sources - n = constants.BYOK_RAG_MAX_CHUNKS + n = constants.RAG_CONTENT_LIMIT resp_a = _make_vector_io_response( mocker, [ @@ -1246,7 +1246,7 @@ async def _side_effect(**kwargs: Any) -> Any: ) assert response.rag_chunks is not None - assert len(response.rag_chunks) == constants.BYOK_RAG_MAX_CHUNKS + assert len(response.rag_chunks) == constants.RAG_CONTENT_LIMIT # Check that the score is computed properly for chunk in response.rag_chunks: @@ -1266,3 +1266,79 @@ async def _side_effect(**kwargs: Any) -> Any: chunk_contents = {chunk.content for chunk in response.rag_chunks} assert "Source A chunk 0" not in chunk_contents assert "Source B chunk 0" not in chunk_contents + + +@pytest.mark.asyncio +async def test_query_rag_content_limit_caps_inline_rag( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, + test_auth: AuthTuple, +) -> None: + """Test that RAG_CONTENT_LIMIT caps inline RAG below BYOK_RAG_MAX_CHUNKS. + + Sets RAG_CONTENT_LIMIT to 3 (below BYOK_RAG_MAX_CHUNKS=10) and feeds + 10 chunks. The result should be capped at 3. + + Verifies: + - Number of inline RAG chunks equals the lowered RAG_CONTENT_LIMIT + - Returned chunks are the top-scoring ones + """ + mocker.patch("utils.vector_search.constants.RAG_CONTENT_LIMIT", 3) + + entry = mocker.MagicMock() + entry.rag_id = "big-source" + entry.vector_db_id = "vs-big-source" + entry.score_multiplier = 1.0 + + test_config.configuration.byok_rag = [entry] + test_config.configuration.rag.inline = ["big-source"] + test_config.configuration.reranker.enabled = False + + mock_holder_class = mocker.patch("app.endpoints.query.AsyncLlamaStackClientHolder") + mock_client = _build_base_mock_client(mocker) + + num_chunks = constants.BYOK_RAG_MAX_CHUNKS + chunks_data = [ + (f"Chunk content {i}", f"chunk-{i}", round(0.50 + i * 0.03, 2)) + for i in range(num_chunks) + ] + mock_client.vector_io.query = mocker.AsyncMock( + return_value=_make_vector_io_response(mocker, chunks_data) + ) + + mock_vs_resp = mocker.MagicMock() + mock_vs_resp.data = [] + mock_client.vector_stores.list.return_value = mock_vs_resp + + mock_holder_class.return_value.get_client.return_value = mock_client + + query_request = QueryRequest( + query="test query", + conversation_id=None, + provider=None, + model=None, + system_prompt=None, + attachments=None, + no_tools=False, + generate_topic_summary=None, + media_type=None, + vector_store_ids=None, + shield_ids=None, + solr=None, + ) + + response = await query_endpoint_handler( + request=test_request, + query_request=query_request, + auth=test_auth, + mcp_headers={}, + ) + + assert response.rag_chunks is not None + assert len(response.rag_chunks) == 3 + + scores: list[float] = [ + chunk.score for chunk in response.rag_chunks if chunk.score is not None + ] + assert scores == sorted(scores, reverse=True) diff --git a/tests/integration/endpoints/test_streaming_query_byok_integration.py b/tests/integration/endpoints/test_streaming_query_byok_integration.py index 3cb5b878e..29c1737b1 100644 --- a/tests/integration/endpoints/test_streaming_query_byok_integration.py +++ b/tests/integration/endpoints/test_streaming_query_byok_integration.py @@ -908,24 +908,24 @@ async def _side_effect(**kwargs: Any) -> Any: # ============================================================================== -# BYOK_RAG_MAX_CHUNKS Capping Streaming Tests +# RAG_CONTENT_LIMIT Capping Streaming Tests # ============================================================================== @pytest.mark.asyncio -async def test_streaming_query_byok_max_chunks_caps_context( # pylint: disable=too-many-locals +async def test_streaming_query_rag_content_limit_caps_context( # pylint: disable=too-many-locals test_config: AppConfig, mocker: MockerFixture, test_request: Request, test_auth: AuthTuple, ) -> None: - """Test that BYOK_RAG_MAX_CHUNKS caps chunks in streaming query context. + """Test that RAG_CONTENT_LIMIT caps chunks in streaming query context. - A source returns more chunks than BYOK_RAG_MAX_CHUNKS. The injected - context should contain at most BYOK_RAG_MAX_CHUNKS chunk entries. + A source returns more chunks than RAG_CONTENT_LIMIT. The injected context + should contain at most RAG_CONTENT_LIMIT chunk entries. Verifies: - - Context chunk count does not exceed BYOK_RAG_MAX_CHUNKS + - Context chunk count does not exceed RAG_CONTENT_LIMIT - Only the highest-scored chunks appear in the context """ entry = mocker.MagicMock() @@ -941,8 +941,8 @@ async def test_streaming_query_byok_max_chunks_caps_context( # pylint: disable= ) mock_client = _build_base_streaming_mock_client(mocker) - # Generate more chunks than BYOK_RAG_MAX_CHUNKS - num_chunks = constants.BYOK_RAG_MAX_CHUNKS + 5 + # Generate more chunks than RAG_CONTENT_LIMIT + num_chunks = constants.RAG_CONTENT_LIMIT + 5 chunks_data = [ (f"Chunk content {i}", f"chunk-{i}", round(0.50 + i * 0.03, 2)) for i in range(num_chunks) @@ -973,7 +973,7 @@ async def test_streaming_query_byok_max_chunks_caps_context( # pylint: disable= # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. create_call = mock_client.responses.create.call_args_list[0] input_text = create_call.kwargs["input"] - expected_header = f"file_search found {constants.BYOK_RAG_MAX_CHUNKS} chunks:" + expected_header = f"file_search found {constants.RAG_CONTENT_LIMIT} chunks:" assert expected_header in input_text # The lowest-scoring chunk should NOT be in the context @@ -983,20 +983,20 @@ async def test_streaming_query_byok_max_chunks_caps_context( # pylint: disable= @pytest.mark.asyncio -async def test_streaming_query_byok_max_chunks_caps_across_multiple_sources( # pylint: disable=too-many-locals +async def test_streaming_query_rag_content_limit_caps_across_multiple_sources( # pylint: disable=too-many-locals test_config: AppConfig, mocker: MockerFixture, test_request: Request, test_auth: AuthTuple, ) -> None: - """Test that BYOK_RAG_MAX_CHUNKS caps chunks across multiple sources in streaming. + """Test that RAG_CONTENT_LIMIT caps chunks across multiple sources in streaming. - Two sources each return several chunks. The combined context should - not exceed BYOK_RAG_MAX_CHUNKS and should contain the globally - highest-scored chunks regardless of source. + Two sources each return several chunks. The combined context should not + exceed RAG_CONTENT_LIMIT and should contain the globally highest-scored + chunks regardless of source. Verifies: - - Total chunks across sources are capped at BYOK_RAG_MAX_CHUNKS + - Total chunks across sources are capped at RAG_CONTENT_LIMIT - Only the highest-scored chunks appear in the context """ entry_a = mocker.MagicMock() @@ -1018,7 +1018,7 @@ async def test_streaming_query_byok_max_chunks_caps_across_multiple_sources( # mock_client = _build_base_streaming_mock_client(mocker) # Overlapping score bands so top-k must pick from both sources - n = constants.BYOK_RAG_MAX_CHUNKS + n = constants.RAG_CONTENT_LIMIT resp_a = _make_vector_io_response( mocker, [ @@ -1062,7 +1062,7 @@ async def _side_effect(**kwargs: Any) -> Any: # .kwargs holds its keyword arguments, e.g. "input" is the full prompt text sent to the model. create_call = mock_client.responses.create.call_args_list[0] input_text = create_call.kwargs["input"] - expected_header = f"file_search found {constants.BYOK_RAG_MAX_CHUNKS} chunks:" + expected_header = f"file_search found {constants.RAG_CONTENT_LIMIT} chunks:" assert expected_header in input_text # Both sources must appear in the context (overlapping scores guarantee this) @@ -1072,3 +1072,72 @@ async def _side_effect(**kwargs: Any) -> Any: # Lowest-scoring chunks from each source must be dropped assert "Source A chunk 0" not in input_text assert "Source B chunk 0" not in input_text + + +@pytest.mark.asyncio +async def test_streaming_query_rag_content_limit_caps_inline_rag( # pylint: disable=too-many-locals + test_config: AppConfig, + mocker: MockerFixture, + test_request: Request, + test_auth: AuthTuple, +) -> None: + """Test that RAG_CONTENT_LIMIT caps inline RAG below BYOK_RAG_MAX_CHUNKS in streaming. + + Sets RAG_CONTENT_LIMIT to 3 (below BYOK_RAG_MAX_CHUNKS=10) and feeds + 10 chunks. The context should contain at most 3 chunk entries. + + Verifies: + - Context chunk count equals the lowered RAG_CONTENT_LIMIT + - Only the highest-scored chunks appear in the context + """ + mocker.patch("utils.vector_search.constants.RAG_CONTENT_LIMIT", 3) + + entry = mocker.MagicMock() + entry.rag_id = "big-source" + entry.vector_db_id = "vs-big-source" + entry.score_multiplier = 1.0 + + test_config.configuration.byok_rag = [entry] + test_config.configuration.rag.inline = ["big-source"] + test_config.configuration.reranker.enabled = False + + mock_holder_class = mocker.patch( + "app.endpoints.streaming_query.AsyncLlamaStackClientHolder" + ) + mock_client = _build_base_streaming_mock_client(mocker) + + num_chunks = constants.BYOK_RAG_MAX_CHUNKS + chunks_data = [ + (f"Chunk content {i}", f"chunk-{i}", round(0.50 + i * 0.03, 2)) + for i in range(num_chunks) + ] + mock_client.vector_io.query = mocker.AsyncMock( + return_value=_make_vector_io_response(mocker, chunks_data) + ) + + mock_vs_resp = mocker.MagicMock() + mock_vs_resp.data = [] + mock_client.vector_stores.list.return_value = mock_vs_resp + + mock_holder_class.return_value.get_client.return_value = mock_client + + query_request = QueryRequest(query="test query") + + response = await streaming_query_endpoint_handler( + request=test_request, + query_request=query_request, + auth=test_auth, + mcp_headers={}, + ) + + assert isinstance(response, StreamingResponse) + + create_call = mock_client.responses.create.call_args_list[0] + input_text = create_call.kwargs["input"] + expected_header = "file_search found 3 chunks:" + assert expected_header in input_text + + # The highest-scoring chunk should be in the context + assert f"Chunk content {num_chunks - 1}" in input_text + # Low-scoring chunks should be excluded + assert "Chunk content 0" not in input_text