diff --git a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py index 3a09fc942f..91940f1d22 100644 --- a/src/google/adk/plugins/bigquery_agent_analytics_plugin.py +++ b/src/google/adk/plugins/bigquery_agent_analytics_plugin.py @@ -1746,8 +1746,10 @@ def _get_events_schema() -> list[bigquery.SchemaField]: "A JSON object containing arbitrary key-value pairs for" " additional event metadata. Includes enrichment fields like" " 'root_agent_name' (turn orchestration), 'model' (request" - " model), 'model_version' (response version), and" - " 'usage_metadata' (detailed token counts)." + " model), 'model_version' (response version)," + " 'usage_metadata' (detailed token counts), and" + " 'finish_reason' (LLM_RESPONSE termination reason, e.g." + " 'STOP', 'MAX_TOKENS', 'SAFETY', 'MALFORMED_FUNCTION_CALL')." ), ), bigquery.SchemaField( @@ -1847,6 +1849,7 @@ def _get_events_schema() -> list[bigquery.SchemaField]: "JSON_VALUE(attributes, '$.model_version') AS model_version", "JSON_QUERY(attributes, '$.usage_metadata') AS usage_metadata", "JSON_QUERY(attributes, '$.cache_metadata') AS cache_metadata", + "JSON_VALUE(attributes, '$.finish_reason') AS finish_reason", ], "LLM_ERROR": [ "CAST(JSON_VALUE(latency_ms, '$.total_ms') AS INT64) AS total_ms", @@ -1958,6 +1961,7 @@ class EventData: model_version: Optional[str] = None usage_metadata: Any = None cache_metadata: Any = None + finish_reason: Optional[str] = None status: str = "OK" error_message: Optional[str] = None extra_attributes: dict[str, Any] = field(default_factory=dict) @@ -2821,6 +2825,9 @@ def _enrich_attributes( else: attrs["cache_metadata"] = event_data.cache_metadata + if event_data.finish_reason is not None: + attrs["finish_reason"] = event_data.finish_reason + if self.config.log_session_metadata: try: session = callback_context._invocation_context.session @@ -3417,6 +3424,11 @@ async def after_model_callback( # Otherwise log_event will fetch current stack (which is parent). span_id = popped_span_id or span_id + raw_finish_reason = getattr(llm_response, "finish_reason", None) + finish_reason = ( + raw_finish_reason.name if raw_finish_reason is not None else None + ) + await self._log_event( "LLM_RESPONSE", callback_context, @@ -3428,6 +3440,7 @@ async def after_model_callback( model_version=llm_response.model_version, usage_metadata=llm_response.usage_metadata, cache_metadata=getattr(llm_response, "cache_metadata", None), + finish_reason=finish_reason, span_id_override=span_id if is_popped else None, parent_span_id_override=(parent_span_id if is_popped else None), ), diff --git a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py index 5719adf2b4..85bd599c4a 100644 --- a/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py +++ b/tests/unittests/plugins/test_bigquery_agent_analytics_plugin.py @@ -1530,6 +1530,60 @@ async def test_after_model_callback_text_response( # In this test we didn't pass it in kwargs in the updated call above, so it might be missing unless we add it back to kwargs. # The original test passed it as kwarg. + @pytest.mark.asyncio + async def test_after_model_callback_projects_finish_reason( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + """finish_reason from LlmResponse is projected to attributes JSON.""" + llm_response = llm_response_lib.LlmResponse( + content=types.Content(parts=[types.Part(text="Truncated")]), + usage_metadata=types.UsageMetadata( + prompt_token_count=10, total_token_count=15 + ), + finish_reason=types.FinishReason.MAX_TOKENS, + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context) + await bq_plugin_inst.after_model_callback( + callback_context=callback_context, + llm_response=llm_response, + ) + await asyncio.sleep(0.01) + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + _assert_common_fields(log_entry, "LLM_RESPONSE") + attrs = json.loads(log_entry["attributes"]) + assert attrs["finish_reason"] == "MAX_TOKENS" + + @pytest.mark.asyncio + async def test_after_model_callback_omits_finish_reason_when_absent( + self, + bq_plugin_inst, + mock_write_client, + callback_context, + dummy_arrow_schema, + ): + """When LlmResponse.finish_reason is None, the key is not emitted.""" + llm_response = llm_response_lib.LlmResponse( + content=types.Content(parts=[types.Part(text="ok")]), + ) + bigquery_agent_analytics_plugin.TraceManager.push_span(callback_context) + await bq_plugin_inst.after_model_callback( + callback_context=callback_context, + llm_response=llm_response, + ) + await asyncio.sleep(0.01) + log_entry = await _get_captured_event_dict_async( + mock_write_client, dummy_arrow_schema + ) + _assert_common_fields(log_entry, "LLM_RESPONSE") + attrs = json.loads(log_entry["attributes"]) + assert "finish_reason" not in attrs + @pytest.mark.asyncio async def test_after_model_callback_tool_call( self, @@ -3755,6 +3809,30 @@ def test_usage_metadata_truncated(self): "output_tokens": 50, } + def test_finish_reason_included_when_set(self): + """Should include finish_reason in attributes when set on EventData.""" + plugin = self._make_plugin() + ed = bigquery_agent_analytics_plugin.EventData(finish_reason="MAX_TOKENS") + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_root_agent_name", + return_value="agent", + ): + attrs = plugin._enrich_attributes(ed, self._make_callback_context()) + assert attrs["finish_reason"] == "MAX_TOKENS" + + def test_finish_reason_omitted_when_none(self): + """Should not add finish_reason key when EventData.finish_reason is None.""" + plugin = self._make_plugin() + ed = bigquery_agent_analytics_plugin.EventData() + with mock.patch.object( + bigquery_agent_analytics_plugin.TraceManager, + "get_root_agent_name", + return_value="agent", + ): + attrs = plugin._enrich_attributes(ed, self._make_callback_context()) + assert "finish_reason" not in attrs + class TestMultiSubagentToolLogging: """Tests that tool events from different subagents are attributed correctly.