Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/byok_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ Inline RAG additionally supports:
> `score_multiplier` does not apply to OKP results. To control the amount of retrieved
> context, set the `BYOK_RAG_MAX_CHUNKS` and `OKP_RAG_MAX_CHUNKS` constants in `src/constants.py`
> (defaults: 10 and 5 respectively). For Tool RAG, use `TOOL_RAG_MAX_CHUNKS` (default: 10).
> The `RAG_CONTENT_LIMIT` constant (default: 10) caps the final merged inline RAG
> chunks (BYOK + OKP) delivered to the LLM. Tool RAG is controlled independently
> by `TOOL_RAG_MAX_CHUNKS`.

---

Expand Down
7 changes: 4 additions & 3 deletions docs/rag_guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -385,9 +385,10 @@ the number of retrieved chunks, set the constants in `src/constants.py`:

| Constant | Default | Description |
|----------|---------|-------------|
| `OKP_RAG_MAX_CHUNKS` | 5 | Max chunks retrieved from OKP (Inline RAG) |
| `BYOK_RAG_MAX_CHUNKS` | 10 | Max chunks retrieved from BYOK stores (Inline RAG) |
| `TOOL_RAG_MAX_CHUNKS` | 10 | Max chunks retrieved via Tool RAG (`file_search`) |
| `RAG_CONTENT_LIMIT` | 10 | Hard upper bound on the final merged inline RAG chunks (BYOK + OKP) delivered to the LLM |
| `OKP_RAG_MAX_CHUNKS` | 5 | Fetch hint for OKP (Inline RAG); controls how many chunks enter the reranking pool |
| `BYOK_RAG_MAX_CHUNKS` | 10 | Fetch hint for BYOK stores (Inline RAG); controls how many chunks enter the reranking pool |
| `TOOL_RAG_MAX_CHUNKS` | 10 | Max chunks retrieved via Tool RAG (`file_search`); independent from `RAG_CONTENT_LIMIT` |

**Limitations:**

Expand Down
3 changes: 3 additions & 0 deletions src/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@
USER_QUOTA_LIMITER: Final[str] = "user_limiter"
CLUSTER_QUOTA_LIMITER: Final[str] = "cluster_limiter"

# Hard cap on total RAG chunks delivered to the LLM across all sources
RAG_CONTENT_LIMIT: Final[int] = 10

# RAG as a tool constants
DEFAULT_RAG_TOOL: Final[str] = "file_search"
TOOL_RAG_MAX_CHUNKS: Final[int] = 10 # retrieved from RAG as a tool
Expand Down
20 changes: 9 additions & 11 deletions src/utils/vector_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,9 @@ async def build_rag_context( # pylint: disable=too-many-locals,too-many-branche
) -> RAGContext:
"""Build RAG context by fetching and merging chunks from all enabled sources.

Fetches 2 * BYOK_RAG_MAX_CHUNKS from each of BYOK and Solr, merges and keeps
top 2 * BYOK_RAG_MAX_CHUNKS by score, reranks with a cross-encoder, then
keeps the top BYOK_RAG_MAX_CHUNKS for context. Enabled sources can be BYOK
Each source fetches using its per-source limit to build the reranking pool.
Results are merged, sorted by score, reranked with a cross-encoder if
enabled, then capped at RAG_CONTENT_LIMIT. Enabled sources can be BYOK
and/or Solr OKP.

Args:
Expand All @@ -626,34 +626,32 @@ async def build_rag_context( # pylint: disable=too-many-locals,too-many-branche
if moderation_decision == "blocked":
return RAGContext()

pool_size = 2 * constants.BYOK_RAG_MAX_CHUNKS
top_k = constants.BYOK_RAG_MAX_CHUNKS
top_k = constants.RAG_CONTENT_LIMIT

# Fetch 2*BYOK_RAG_MAX_CHUNKS from each source in parallel
# Fetch from each source using per-source limits for the reranking pool
byok_chunks_task = _fetch_byok_rag(
client, query, vector_store_ids, max_chunks=pool_size
client, query, vector_store_ids, max_chunks=constants.BYOK_RAG_MAX_CHUNKS
)
solr_chunks_task = _fetch_solr_rag(client, query, solr)

(byok_chunks, byok_documents), (solr_chunks, solr_documents) = await asyncio.gather(
byok_chunks_task, solr_chunks_task
)

# Merge: combine and sort by score, keep top 2*BYOK_RAG_MAX_CHUNKS
# Merge: combine and sort by score
merged = byok_chunks + solr_chunks
merged.sort(
key=lambda c: c.score if c.score is not None else float("-inf"), reverse=True
)
merged = merged[:pool_size]

# Rerank full pool with cross-encoder if enabled; boost BYOK then take top_k
# Rerank full pool with cross-encoder if enabled; then take top_k
if configuration.reranker.enabled:
logger.info(
"Reranker enabled: processing %d chunks with model '%s'",
len(merged),
configuration.reranker.model,
)
reranked = await rerank_chunks_with_cross_encoder(query, merged, pool_size)
reranked = await rerank_chunks_with_cross_encoder(query, merged, len(merged))
context_chunks = apply_byok_rerank_boost(reranked)[:top_k]
logger.info(
"Reranker completed: returned %d top chunks after BYOK boost",
Expand Down
110 changes: 93 additions & 17 deletions tests/integration/endpoints/test_query_byok_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,25 +1066,25 @@ async def _side_effect(**kwargs: Any) -> Any:


# ==============================================================================
# BYOK_RAG_MAX_CHUNKS Capping Tests
# RAG_CONTENT_LIMIT Capping Tests
# ==============================================================================


@pytest.mark.asyncio
async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable=too-many-locals
async def test_query_rag_content_limit_caps_retrieved_results( # pylint: disable=too-many-locals
test_config: AppConfig,
mocker: MockerFixture,
test_request: Request,
test_auth: AuthTuple,
) -> None:
"""Test that BYOK_RAG_MAX_CHUNKS caps the number of returned chunks.
"""Test that RAG_CONTENT_LIMIT caps the number of returned chunks.

A single source returns more chunks than BYOK_RAG_MAX_CHUNKS allows.
The response should contain at most BYOK_RAG_MAX_CHUNKS chunks and
A single source returns more chunks than RAG_CONTENT_LIMIT allows.
The response should contain at most RAG_CONTENT_LIMIT chunks and
they should be the highest-scored ones.

Verifies:
- Number of RAG chunks does not exceed BYOK_RAG_MAX_CHUNKS
- Number of RAG chunks does not exceed RAG_CONTENT_LIMIT
- Returned chunks are the top-scoring ones
"""
entry = mocker.MagicMock()
Expand All @@ -1101,8 +1101,8 @@ async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable=
mock_holder_class = mocker.patch("app.endpoints.query.AsyncLlamaStackClientHolder")
mock_client = _build_base_mock_client(mocker)

# Generate more chunks than BYOK_RAG_MAX_CHUNKS
num_chunks = constants.BYOK_RAG_MAX_CHUNKS + 1
# Generate more chunks than RAG_CONTENT_LIMIT
num_chunks = constants.RAG_CONTENT_LIMIT + 1
chunks_data = [
(f"Chunk content {i}", f"chunk-{i}", round(0.50 + i * 0.03, 2))
for i in range(num_chunks)
Expand Down Expand Up @@ -1141,7 +1141,7 @@ async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable=
)

assert response.rag_chunks is not None
assert len(response.rag_chunks) == constants.BYOK_RAG_MAX_CHUNKS
assert len(response.rag_chunks) == constants.RAG_CONTENT_LIMIT

# Check that the score is computed properly
for chunk in response.rag_chunks:
Expand All @@ -1161,20 +1161,20 @@ async def test_query_byok_max_chunks_caps_retrieved_results( # pylint: disable=


@pytest.mark.asyncio
async def test_query_byok_max_chunks_caps_across_multiple_sources( # pylint: disable=too-many-locals
async def test_query_rag_content_limit_caps_across_multiple_sources( # pylint: disable=too-many-locals
test_config: AppConfig,
mocker: MockerFixture,
test_request: Request,
test_auth: AuthTuple,
) -> None:
"""Test that BYOK_RAG_MAX_CHUNKS caps chunks across multiple sources.
"""Test that RAG_CONTENT_LIMIT caps chunks across multiple sources.

Two sources each return several chunks. The combined result should
not exceed BYOK_RAG_MAX_CHUNKS and should contain the globally
highest-scored chunks regardless of source.
Two sources each return several chunks. The combined result should not
exceed RAG_CONTENT_LIMIT and should contain the globally highest-scored
chunks regardless of source.

Verifies:
- Total chunks across sources are capped at BYOK_RAG_MAX_CHUNKS
- Total chunks across sources are capped at RAG_CONTENT_LIMIT
- Top-scoring chunks from both sources are included
"""
entry_a = mocker.MagicMock()
Expand All @@ -1194,7 +1194,7 @@ async def test_query_byok_max_chunks_caps_across_multiple_sources( # pylint: di
mock_client = _build_base_mock_client(mocker)

# Overlapping score bands so top-k must pick from both sources
n = constants.BYOK_RAG_MAX_CHUNKS
n = constants.RAG_CONTENT_LIMIT
resp_a = _make_vector_io_response(
mocker,
[
Expand Down Expand Up @@ -1246,7 +1246,7 @@ async def _side_effect(**kwargs: Any) -> Any:
)

assert response.rag_chunks is not None
assert len(response.rag_chunks) == constants.BYOK_RAG_MAX_CHUNKS
assert len(response.rag_chunks) == constants.RAG_CONTENT_LIMIT

# Check that the score is computed properly
for chunk in response.rag_chunks:
Expand All @@ -1266,3 +1266,79 @@ async def _side_effect(**kwargs: Any) -> Any:
chunk_contents = {chunk.content for chunk in response.rag_chunks}
assert "Source A chunk 0" not in chunk_contents
assert "Source B chunk 0" not in chunk_contents


@pytest.mark.asyncio
async def test_query_rag_content_limit_caps_inline_rag( # pylint: disable=too-many-locals
test_config: AppConfig,
mocker: MockerFixture,
test_request: Request,
test_auth: AuthTuple,
) -> None:
"""Test that RAG_CONTENT_LIMIT caps inline RAG below BYOK_RAG_MAX_CHUNKS.

Sets RAG_CONTENT_LIMIT to 3 (below BYOK_RAG_MAX_CHUNKS=10) and feeds
10 chunks. The result should be capped at 3.

Verifies:
- Number of inline RAG chunks equals the lowered RAG_CONTENT_LIMIT
- Returned chunks are the top-scoring ones
"""
mocker.patch("utils.vector_search.constants.RAG_CONTENT_LIMIT", 3)

entry = mocker.MagicMock()
entry.rag_id = "big-source"
entry.vector_db_id = "vs-big-source"
entry.score_multiplier = 1.0

test_config.configuration.byok_rag = [entry]
test_config.configuration.rag.inline = ["big-source"]
test_config.configuration.reranker.enabled = False

mock_holder_class = mocker.patch("app.endpoints.query.AsyncLlamaStackClientHolder")
mock_client = _build_base_mock_client(mocker)

num_chunks = constants.BYOK_RAG_MAX_CHUNKS
chunks_data = [
(f"Chunk content {i}", f"chunk-{i}", round(0.50 + i * 0.03, 2))
for i in range(num_chunks)
]
mock_client.vector_io.query = mocker.AsyncMock(
return_value=_make_vector_io_response(mocker, chunks_data)
)

mock_vs_resp = mocker.MagicMock()
mock_vs_resp.data = []
mock_client.vector_stores.list.return_value = mock_vs_resp

mock_holder_class.return_value.get_client.return_value = mock_client

query_request = QueryRequest(
query="test query",
conversation_id=None,
provider=None,
model=None,
system_prompt=None,
attachments=None,
no_tools=False,
generate_topic_summary=None,
media_type=None,
vector_store_ids=None,
shield_ids=None,
solr=None,
)

response = await query_endpoint_handler(
request=test_request,
query_request=query_request,
auth=test_auth,
mcp_headers={},
)

assert response.rag_chunks is not None
assert len(response.rag_chunks) == 3

scores: list[float] = [
chunk.score for chunk in response.rag_chunks if chunk.score is not None
]
assert scores == sorted(scores, reverse=True)
Loading
Loading