Skip to content

Commit 9695978

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Trigger traceCallLlm to set call_llm attributes before span ends
PiperOrigin-RevId: 879435432
1 parent 3c8f488 commit 9695978

3 files changed

Lines changed: 71 additions & 58 deletions

File tree

core/src/main/java/com/google/adk/flows/llmflows/BaseLlmFlow.java

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -188,20 +188,22 @@ private Flowable<LlmResponse> callLlm(
188188
context, llmRequestBuilder, eventForCallbackUsage, exception)
189189
.switchIfEmpty(Single.error(exception))
190190
.toFlowable())
191-
.doOnNext(
192-
llmResp ->
193-
Tracing.traceCallLlm(
194-
context,
195-
eventForCallbackUsage.id(),
196-
llmRequestBuilder.build(),
197-
llmResp))
198191
.doOnError(
199192
error -> {
200193
Span span = Span.current();
201194
span.setStatus(StatusCode.ERROR, error.getMessage());
202195
span.recordException(error);
203196
})
204-
.compose(Tracing.trace("call_llm"))
197+
.compose(
198+
Tracing.<LlmResponse>trace("call_llm")
199+
.onSuccess(
200+
(span, llmResp) ->
201+
Tracing.traceCallLlm(
202+
span,
203+
context,
204+
eventForCallbackUsage.id(),
205+
llmRequestBuilder.build(),
206+
llmResp)))
205207
.concatMap(
206208
llmResp ->
207209
handleAfterModelCallback(context, llmResp, eventForCallbackUsage)

core/src/main/java/com/google/adk/telemetry/Tracing.java

Lines changed: 60 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import java.util.Map;
5555
import java.util.Objects;
5656
import java.util.Optional;
57+
import java.util.function.BiConsumer;
5758
import java.util.function.Consumer;
5859
import java.util.function.Supplier;
5960
import org.reactivestreams.Publisher;
@@ -292,62 +293,49 @@ private static Map<String, Object> buildLlmRequestForTrace(LlmRequest llmRequest
292293
* @param llmResponse The LLM response object.
293294
*/
294295
public static void traceCallLlm(
296+
Span span,
295297
InvocationContext invocationContext,
296298
String eventId,
297299
LlmRequest llmRequest,
298300
LlmResponse llmResponse) {
299-
getValidCurrentSpan("traceCallLlm")
300-
.ifPresent(
301-
span -> {
302-
span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent");
303-
llmRequest
304-
.model()
305-
.ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName));
301+
span.setAttribute(GEN_AI_SYSTEM, "gcp.vertex.agent");
302+
llmRequest.model().ifPresent(modelName -> span.setAttribute(GEN_AI_REQUEST_MODEL, modelName));
306303

307-
setInvocationAttributes(span, invocationContext, eventId);
304+
setInvocationAttributes(span, invocationContext, eventId);
308305

309-
setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest));
310-
setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse);
306+
setJsonAttribute(span, ADK_LLM_REQUEST, buildLlmRequestForTrace(llmRequest));
307+
setJsonAttribute(span, ADK_LLM_RESPONSE, llmResponse);
311308

312-
llmRequest
313-
.config()
314-
.ifPresent(
315-
config -> {
316-
config
317-
.topP()
318-
.ifPresent(
319-
topP ->
320-
span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue()));
321-
config
322-
.maxOutputTokens()
323-
.ifPresent(
324-
maxTokens ->
325-
span.setAttribute(
326-
GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue()));
327-
});
328-
llmResponse
329-
.usageMetadata()
309+
llmRequest
310+
.config()
311+
.ifPresent(
312+
config -> {
313+
config
314+
.topP()
315+
.ifPresent(topP -> span.setAttribute(GEN_AI_REQUEST_TOP_P, topP.doubleValue()));
316+
config
317+
.maxOutputTokens()
330318
.ifPresent(
331-
usage -> {
332-
usage
333-
.promptTokenCount()
334-
.ifPresent(
335-
tokens ->
336-
span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens));
337-
usage
338-
.candidatesTokenCount()
339-
.ifPresent(
340-
tokens ->
341-
span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens));
342-
});
343-
llmResponse
344-
.finishReason()
345-
.map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT))
319+
maxTokens ->
320+
span.setAttribute(GEN_AI_REQUEST_MAX_TOKENS, maxTokens.longValue()));
321+
});
322+
llmResponse
323+
.usageMetadata()
324+
.ifPresent(
325+
usage -> {
326+
usage
327+
.promptTokenCount()
328+
.ifPresent(tokens -> span.setAttribute(GEN_AI_USAGE_INPUT_TOKENS, (long) tokens));
329+
usage
330+
.candidatesTokenCount()
346331
.ifPresent(
347-
reason ->
348-
span.setAttribute(
349-
GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason)));
332+
tokens -> span.setAttribute(GEN_AI_USAGE_OUTPUT_TOKENS, (long) tokens));
350333
});
334+
llmResponse
335+
.finishReason()
336+
.map(reason -> reason.knownEnum().name().toLowerCase(Locale.ROOT))
337+
.ifPresent(
338+
reason -> span.setAttribute(GEN_AI_RESPONSE_FINISH_REASONS, ImmutableList.of(reason)));
351339
}
352340

353341
/**
@@ -472,6 +460,7 @@ public static final class TracerProvider<T>
472460
private final String spanName;
473461
private Context explicitParentContext;
474462
private final List<Consumer<Span>> spanConfigurers = new ArrayList<>();
463+
private BiConsumer<Span, T> onSuccessConsumer;
475464

476465
private TracerProvider(String spanName) {
477466
this.spanName = spanName;
@@ -491,6 +480,16 @@ public TracerProvider<T> setParent(Context parentContext) {
491480
return this;
492481
}
493482

483+
/**
484+
* Registers a callback to be executed with the span and the result item when the stream emits a
485+
* success value.
486+
*/
487+
@CanIgnoreReturnValue
488+
public TracerProvider<T> onSuccess(BiConsumer<Span, T> consumer) {
489+
this.onSuccessConsumer = consumer;
490+
return this;
491+
}
492+
494493
private Context getParentContext() {
495494
return explicitParentContext != null ? explicitParentContext : Context.current();
496495
}
@@ -521,7 +520,11 @@ public Publisher<T> apply(Flowable<T> upstream) {
521520
return Flowable.defer(
522521
() -> {
523522
TracingLifecycle lifecycle = new TracingLifecycle();
524-
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
523+
Flowable<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
524+
if (onSuccessConsumer != null) {
525+
pipeline = pipeline.doOnNext(t -> onSuccessConsumer.accept(lifecycle.span, t));
526+
}
527+
return pipeline.doFinally(lifecycle::end);
525528
});
526529
}
527530

@@ -530,7 +533,11 @@ public SingleSource<T> apply(Single<T> upstream) {
530533
return Single.defer(
531534
() -> {
532535
TracingLifecycle lifecycle = new TracingLifecycle();
533-
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
536+
Single<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
537+
if (onSuccessConsumer != null) {
538+
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
539+
}
540+
return pipeline.doFinally(lifecycle::end);
534541
});
535542
}
536543

@@ -539,7 +546,11 @@ public MaybeSource<T> apply(Maybe<T> upstream) {
539546
return Maybe.defer(
540547
() -> {
541548
TracingLifecycle lifecycle = new TracingLifecycle();
542-
return upstream.doOnSubscribe(s -> lifecycle.start()).doFinally(lifecycle::end);
549+
Maybe<T> pipeline = upstream.doOnSubscribe(s -> lifecycle.start());
550+
if (onSuccessConsumer != null) {
551+
pipeline = pipeline.doOnSuccess(t -> onSuccessConsumer.accept(lifecycle.span, t));
552+
}
553+
return pipeline.doFinally(lifecycle::end);
543554
});
544555
}
545556

core/src/test/java/com/google/adk/telemetry/ContextPropagationTest.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ public void testTraceCallLlm() {
503503
.totalTokenCount(30)
504504
.build())
505505
.build();
506-
Tracing.traceCallLlm(buildInvocationContext(), "event-1", llmRequest, llmResponse);
506+
Tracing.traceCallLlm(span, buildInvocationContext(), "event-1", llmRequest, llmResponse);
507507
} finally {
508508
span.end();
509509
}

0 commit comments

Comments
 (0)