Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions src/google/adk/plugins/bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1746,8 +1746,10 @@ def _get_events_schema() -> list[bigquery.SchemaField]:
"A JSON object containing arbitrary key-value pairs for"
" additional event metadata. Includes enrichment fields like"
" 'root_agent_name' (turn orchestration), 'model' (request"
" model), 'model_version' (response version), and"
" 'usage_metadata' (detailed token counts)."
" model), 'model_version' (response version),"
" 'usage_metadata' (detailed token counts), and"
" 'finish_reason' (LLM_RESPONSE termination reason, e.g."
" 'STOP', 'MAX_TOKENS', 'SAFETY', 'MALFORMED_FUNCTION_CALL')."
),
),
bigquery.SchemaField(
Expand Down Expand Up @@ -1847,6 +1849,7 @@ def _get_events_schema() -> list[bigquery.SchemaField]:
"JSON_VALUE(attributes, '$.model_version') AS model_version",
"JSON_QUERY(attributes, '$.usage_metadata') AS usage_metadata",
"JSON_QUERY(attributes, '$.cache_metadata') AS cache_metadata",
"JSON_VALUE(attributes, '$.finish_reason') AS finish_reason",
],
"LLM_ERROR": [
"CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms",
Expand Down Expand Up @@ -1958,6 +1961,7 @@ class EventData:
model_version: Optional[str] = None
usage_metadata: Any = None
cache_metadata: Any = None
finish_reason: Optional[str] = None
status: str = "OK"
error_message: Optional[str] = None
extra_attributes: dict[str, Any] = field(default_factory=dict)
Expand Down Expand Up @@ -2821,6 +2825,9 @@ def _enrich_attributes(
else:
attrs["cache_metadata"] = event_data.cache_metadata

if event_data.finish_reason is not None:
attrs["finish_reason"] = event_data.finish_reason

if self.config.log_session_metadata:
try:
session = callback_context._invocation_context.session
Expand Down Expand Up @@ -3417,6 +3424,11 @@ async def after_model_callback(
# Otherwise log_event will fetch current stack (which is parent).
span_id = popped_span_id or span_id

raw_finish_reason = getattr(llm_response, "finish_reason", None)
finish_reason = (
raw_finish_reason.name if raw_finish_reason is not None else None
)

await self._log_event(
"LLM_RESPONSE",
callback_context,
Expand All @@ -3428,6 +3440,7 @@ async def after_model_callback(
model_version=llm_response.model_version,
usage_metadata=llm_response.usage_metadata,
cache_metadata=getattr(llm_response, "cache_metadata", None),
finish_reason=finish_reason,
span_id_override=span_id if is_popped else None,
parent_span_id_override=(parent_span_id if is_popped else None),
),
Expand Down
78 changes: 78 additions & 0 deletions tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1530,6 +1530,60 @@ async def test_after_model_callback_text_response(
# In this test we didn't pass it in kwargs in the updated call above, so it might be missing unless we add it back to kwargs.
# The original test passed it as kwarg.

@pytest.mark.asyncio
async def test_after_model_callback_projects_finish_reason(
self,
bq_plugin_inst,
mock_write_client,
callback_context,
dummy_arrow_schema,
):
"""finish_reason from LlmResponse is projected to attributes JSON."""
llm_response = llm_response_lib.LlmResponse(
content=types.Content(parts=[types.Part(text="Truncated")]),
usage_metadata=types.UsageMetadata(
prompt_token_count=10, total_token_count=15
),
finish_reason=types.FinishReason.MAX_TOKENS,
)
bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context)
await bq_plugin_inst.after_model_callback(
callback_context=callback_context,
llm_response=llm_response,
)
await asyncio.sleep(0.01)
log_entry = await _get_captured_event_dict_async(
mock_write_client, dummy_arrow_schema
)
_assert_common_fields(log_entry, "LLM_RESPONSE")
attrs = json.loads(log_entry["attributes"])
assert attrs["finish_reason"] == "MAX_TOKENS"

@pytest.mark.asyncio
async def test_after_model_callback_omits_finish_reason_when_absent(
self,
bq_plugin_inst,
mock_write_client,
callback_context,
dummy_arrow_schema,
):
"""When LlmResponse.finish_reason is None, the key is not emitted."""
llm_response = llm_response_lib.LlmResponse(
content=types.Content(parts=[types.Part(text="ok")]),
)
bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context)
await bq_plugin_inst.after_model_callback(
callback_context=callback_context,
llm_response=llm_response,
)
await asyncio.sleep(0.01)
log_entry = await _get_captured_event_dict_async(
mock_write_client, dummy_arrow_schema
)
_assert_common_fields(log_entry, "LLM_RESPONSE")
attrs = json.loads(log_entry["attributes"])
assert "finish_reason" not in attrs

@pytest.mark.asyncio
async def test_after_model_callback_tool_call(
self,
Expand Down Expand Up @@ -3755,6 +3809,30 @@ def test_usage_metadata_truncated(self):
"output_tokens": 50,
}

def test_finish_reason_included_when_set(self):
"""Should include finish_reason in attributes when set on EventData."""
plugin = self._make_plugin()
ed = bigquery_agent_analytics_plugin.EventData(finish_reason="MAX_TOKENS")
with mock.patch.object(
bigquery_agent_analytics_plugin.TraceManager,
"get_root_agent_name",
return_value="agent",
):
attrs = plugin._enrich_attributes(ed, self._make_callback_context())
assert attrs["finish_reason"] == "MAX_TOKENS"

def test_finish_reason_omitted_when_none(self):
"""Should not add finish_reason key when EventData.finish_reason is None."""
plugin = self._make_plugin()
ed = bigquery_agent_analytics_plugin.EventData()
with mock.patch.object(
bigquery_agent_analytics_plugin.TraceManager,
"get_root_agent_name",
return_value="agent",
):
attrs = plugin._enrich_attributes(ed, self._make_callback_context())
assert "finish_reason" not in attrs


class TestMultiSubagentToolLogging:
"""Tests that tool events from different subagents are attributed correctly.
Expand Down
Loading