diff --git a/apps/cli/src/agent/__tests__/json-event-emitter-result.test.ts b/apps/cli/src/agent/__tests__/json-event-emitter-result.test.ts index 2be7adcbb53..ffec8b64763 100644 --- a/apps/cli/src/agent/__tests__/json-event-emitter-result.test.ts +++ b/apps/cli/src/agent/__tests__/json-event-emitter-result.test.ts @@ -32,6 +32,13 @@ function emitMessage(emitter: JsonEventEmitter, message: ClineMessage): void { ) } +function emitMessageUpdate(emitter: JsonEventEmitter, message: ClineMessage): void { + ;(emitter as unknown as { handleMessage: (msg: ClineMessage, isUpdate: boolean) => void }).handleMessage( + message, + true, + ) +} + function emitTaskCompleted(emitter: JsonEventEmitter, event: TaskCompletedEvent): void { ;(emitter as unknown as { handleTaskCompleted: (taskCompleted: TaskCompletedEvent) => void }).handleTaskCompleted( event, @@ -63,6 +70,178 @@ function createCompletedStateInfo(message: ClineMessage): AgentStateInfo { } describe("JsonEventEmitter result emission", () => { + it("reports context usage when context window is configured", () => { + const { stdout, lines } = createMockStdout() + const emitter = new JsonEventEmitter({ mode: "stream-json", stdout, contextWindow: 200 }) + + emitMessage(emitter, { + ts: 5, + type: "say", + say: "api_req_started", + partial: false, + text: JSON.stringify({ + cost: 0.001, + tokensIn: 40, + tokensOut: 20, + }), + } as ClineMessage) + + emitMessage(emitter, { + ts: 6, + type: "say", + say: "api_req_started", + partial: false, + text: JSON.stringify({ + cost: 0.002, + tokensIn: 30, + tokensOut: 10, + }), + } as ClineMessage) + + const completionMessage = createAskCompletionMessage(7, "done") + emitTaskCompleted(emitter, { + success: true, + stateInfo: createCompletedStateInfo(completionMessage), + message: completionMessage, + }) + + const result = lines().find((line) => line.type === "result") + const cost = result?.cost as Record + + expect(cost?.contextWindow).toBe(200) + expect(cost?.contextTokens).toBe(40) + expect(cost?.contextUsagePercent).toBe(20) + }) + + it("reports token usage and context usage when api_req_started has no cost field", () => { + const { stdout, lines } = createMockStdout() + const emitter = new JsonEventEmitter({ mode: "stream-json", stdout, contextWindow: 1000 }) + + emitMessage(emitter, { + ts: 8, + type: "say", + say: "api_req_started", + partial: false, + text: JSON.stringify({ + tokensIn: 120, + tokensOut: 30, + cacheWrites: 10, + cacheReads: 5, + }), + } as ClineMessage) + + const completionMessage = createAskCompletionMessage(9, "done") + emitTaskCompleted(emitter, { + success: true, + stateInfo: createCompletedStateInfo(completionMessage), + message: completionMessage, + }) + + const result = lines().find((line) => line.type === "result") + const cost = result?.cost as Record + + expect(cost?.inputTokens).toBe(120) + expect(cost?.outputTokens).toBe(30) + expect(cost?.cacheWrites).toBe(10) + expect(cost?.cacheReads).toBe(5) + expect(cost).not.toHaveProperty("totalCost") + expect(cost?.contextTokens).toBe(150) + expect(cost?.contextWindow).toBe(1000) + expect(cost?.contextUsagePercent).toBe(15) + }) + + it("aggregates token usage and cost across api requests in a completion turn", () => { + const { stdout, lines } = createMockStdout() + const emitter = new JsonEventEmitter({ mode: "stream-json", stdout }) + + emitMessage(emitter, { + ts: 10, + type: "say", + say: "api_req_started", + partial: false, + text: JSON.stringify({ + cost: 0.01, + tokensIn: 100, + tokensOut: 50, + cacheWrites: 20, + cacheReads: 10, + }), + } as ClineMessage) + + emitMessage(emitter, { + ts: 11, + type: "say", + say: "api_req_started", + partial: false, + text: JSON.stringify({ + cost: 0.02, + tokensIn: 25, + tokensOut: 10, + cacheWrites: 5, + cacheReads: 2, + }), + } as ClineMessage) + + const completionMessage = createAskCompletionMessage(12, "done") + emitTaskCompleted(emitter, { + success: true, + stateInfo: createCompletedStateInfo(completionMessage), + message: completionMessage, + }) + + const result = lines().find((line) => line.type === "result") + expect(result).toBeDefined() + expect(result?.cost).toMatchObject({ + totalCost: 0.03, + inputTokens: 125, + outputTokens: 60, + cacheWrites: 25, + cacheReads: 12, + }) + }) + + it("captures cost from updated api_req_started messages with the same message id", () => { + const { stdout, lines } = createMockStdout() + const emitter = new JsonEventEmitter({ mode: "stream-json", stdout }) + + // Placeholder message without final usage. + emitMessage(emitter, { + ts: 20, + type: "say", + say: "api_req_started", + partial: false, + text: JSON.stringify({ apiProtocol: "openai" }), + } as ClineMessage) + + // Later update of the same message with finalized usage/cost. + emitMessageUpdate(emitter, { + ts: 20, + type: "say", + say: "api_req_started", + partial: false, + text: JSON.stringify({ + apiProtocol: "openai", + cost: 0.004, + tokensIn: 40, + tokensOut: 10, + }), + } as ClineMessage) + + const completionMessage = createAskCompletionMessage(21, "done") + emitTaskCompleted(emitter, { + success: true, + stateInfo: createCompletedStateInfo(completionMessage), + message: completionMessage, + }) + + const result = lines().find((line) => line.type === "result") + expect(result?.cost).toMatchObject({ + totalCost: 0.004, + inputTokens: 40, + outputTokens: 10, + }) + }) + it("prefers current completion message content over stale cached completion text", () => { const { stdout, lines } = createMockStdout() const emitter = new JsonEventEmitter({ mode: "stream-json", stdout }) @@ -125,5 +304,6 @@ describe("JsonEventEmitter result emission", () => { expect(output).toHaveLength(2) expect(output[0]?.content).toBe("FIRST") expect(output[1]).not.toHaveProperty("content") + expect(output[1]).not.toHaveProperty("cost") }) }) diff --git a/apps/cli/src/agent/json-event-emitter.ts b/apps/cli/src/agent/json-event-emitter.ts index f46c39be506..abc030e7040 100644 --- a/apps/cli/src/agent/json-event-emitter.ts +++ b/apps/cli/src/agent/json-event-emitter.ts @@ -38,6 +38,8 @@ export interface JsonEventEmitterOptions { protocol?: string /** Supported stdin protocol capabilities emitted in system:init */ capabilities?: string[] + /** Optional context window for reporting context usage in result.cost */ + contextWindow?: number } /** @@ -55,21 +57,31 @@ function parseToolInfo(text: string | undefined): { name: string; input: Record< } /** - * Parse API request cost information from api_req_started message text. + * Parse API request usage information from api_req_started message text. */ function parseApiReqCost(text: string | undefined): JsonEventCost | undefined { if (!text) return undefined try { const parsed = JSON.parse(text) - return parsed.cost !== undefined - ? { - totalCost: parsed.cost, - inputTokens: parsed.tokensIn, - outputTokens: parsed.tokensOut, - cacheWrites: parsed.cacheWrites, - cacheReads: parsed.cacheReads, - } - : undefined + const usage: JsonEventCost = {} + + if (typeof parsed.cost === "number") { + usage.totalCost = parsed.cost + } + if (typeof parsed.tokensIn === "number") { + usage.inputTokens = parsed.tokensIn + } + if (typeof parsed.tokensOut === "number") { + usage.outputTokens = parsed.tokensOut + } + if (typeof parsed.cacheWrites === "number") { + usage.cacheWrites = parsed.cacheWrites + } + if (typeof parsed.cacheReads === "number") { + usage.cacheReads = parsed.cacheReads + } + + return Object.keys(usage).length > 0 ? usage : undefined } catch { return undefined } @@ -97,12 +109,14 @@ export class JsonEventEmitter { private events: JsonEvent[] = [] private unsubscribers: (() => void)[] = [] private pendingWrites = new Set>() - private lastCost: JsonEventCost | undefined private requestIdProvider: () => string | undefined private schemaVersion: number private protocol: string private capabilities: string[] + private contextWindow: number | undefined private seenMessageIds = new Set() + // Track finalized api_req_started usage for the current completion turn. + private currentTurnApiReqCosts = new Map() // Track previous content for delta computation private previousContent = new Map() // Track previous tool-use content for structured (non-append-only) delta computation. @@ -127,6 +141,19 @@ export class JsonEventEmitter { "stdin:ping", "stdin:shutdown", ] + this.contextWindow = + typeof options.contextWindow === "number" && + Number.isFinite(options.contextWindow) && + options.contextWindow > 0 + ? options.contextWindow + : undefined + } + + setContextWindow(contextWindow: number | undefined): void { + this.contextWindow = + typeof contextWindow === "number" && Number.isFinite(contextWindow) && contextWindow > 0 + ? contextWindow + : undefined } /** @@ -326,6 +353,84 @@ export class JsonEventEmitter { return event } + private updateCurrentTurnApiCost(msgId: number, cost: JsonEventCost): void { + this.currentTurnApiReqCosts.set(msgId, cost) + } + + private getCurrentTurnCost(): JsonEventCost | undefined { + if (this.currentTurnApiReqCosts.size === 0) { + return undefined + } + + let totalCost = 0 + let inputTokens = 0 + let outputTokens = 0 + let cacheWrites = 0 + let cacheReads = 0 + + let hasTotalCost = false + let hasInputTokens = false + let hasOutputTokens = false + let hasCacheWrites = false + let hasCacheReads = false + + for (const value of this.currentTurnApiReqCosts.values()) { + if (typeof value.totalCost === "number") { + totalCost += value.totalCost + hasTotalCost = true + } + if (typeof value.inputTokens === "number") { + inputTokens += value.inputTokens + hasInputTokens = true + } + if (typeof value.outputTokens === "number") { + outputTokens += value.outputTokens + hasOutputTokens = true + } + if (typeof value.cacheWrites === "number") { + cacheWrites += value.cacheWrites + hasCacheWrites = true + } + if (typeof value.cacheReads === "number") { + cacheReads += value.cacheReads + hasCacheReads = true + } + } + + const aggregated: JsonEventCost = {} + if (hasTotalCost) aggregated.totalCost = totalCost + if (hasInputTokens) aggregated.inputTokens = inputTokens + if (hasOutputTokens) aggregated.outputTokens = outputTokens + if (hasCacheWrites) aggregated.cacheWrites = cacheWrites + if (hasCacheReads) aggregated.cacheReads = cacheReads + + if (this.contextWindow !== undefined) { + let latestMessageId = Number.NEGATIVE_INFINITY + let latestRequestCost: JsonEventCost | undefined + + for (const [messageId, usage] of this.currentTurnApiReqCosts.entries()) { + if (messageId > latestMessageId) { + latestMessageId = messageId + latestRequestCost = usage + } + } + + if (latestRequestCost) { + const contextTokensIn = + typeof latestRequestCost.inputTokens === "number" ? latestRequestCost.inputTokens : 0 + const contextTokensOut = + typeof latestRequestCost.outputTokens === "number" ? latestRequestCost.outputTokens : 0 + const contextTokens = contextTokensIn + contextTokensOut + + aggregated.contextTokens = contextTokens + aggregated.contextWindow = this.contextWindow + aggregated.contextUsagePercent = (contextTokens / this.contextWindow) * 100 + } + } + + return Object.keys(aggregated).length > 0 ? aggregated : undefined + } + /** * Handle a ClineMessage and emit the appropriate JSON event. */ @@ -339,6 +444,14 @@ export class JsonEventEmitter { // Skip duplicate complete messages if (isDone && this.seenMessageIds.has(msg.ts)) { + // api_req_started messages are updated in-place (same ts) with final usage/cost. + // We still need to process those updates for result metrics, even if we skip re-emitting. + if (msg.type === "say" && msg.say === "api_req_started") { + const cost = parseApiReqCost(msg.text) + if (cost) { + this.updateCurrentTurnApiCost(msg.ts, cost) + } + } return } @@ -409,7 +522,7 @@ export class JsonEventEmitter { case "api_req_started": { const cost = parseApiReqCost(msg.text) if (cost) { - this.lastCost = cost + this.updateCurrentTurnApiCost(msg.ts, cost) } break } @@ -602,6 +715,7 @@ export class JsonEventEmitter { // Prefer the completion payload from the current event. If it is empty, // fall back to the most recent tracked completion text, then assistant text. const resultContent = event.message?.text || this.completionResultContent || this.lastAssistantText + const resultCost = this.getCurrentTurnCost() this.emitEvent({ type: "result", @@ -609,16 +723,17 @@ export class JsonEventEmitter { content: resultContent, done: true, success: event.success, - cost: this.lastCost, + cost: resultCost, }) // Prevent stale completion content from leaking into later turns. this.completionResultContent = undefined this.lastAssistantText = undefined + this.currentTurnApiReqCosts.clear() // For "json" mode, output the final accumulated result if (this.mode === "json") { - this.outputFinalResult(event.success, resultContent) + this.outputFinalResult(event.success, resultContent, resultCost) } } @@ -659,12 +774,12 @@ export class JsonEventEmitter { /** * Output the final accumulated result (for "json" mode). */ - private outputFinalResult(success: boolean, content?: string): void { + private outputFinalResult(success: boolean, content?: string, cost?: JsonEventCost): void { const output: JsonFinalOutput = { type: "result", success, content, - cost: this.lastCost, + cost, events: this.events.filter((e) => e.type !== "result"), // Exclude the result event itself } @@ -707,10 +822,10 @@ export class JsonEventEmitter { */ clear(): void { this.events = [] - this.lastCost = undefined this.seenMessageIds.clear() this.previousContent.clear() this.previousToolUseContent.clear() + this.currentTurnApiReqCosts.clear() this.completionResultContent = undefined this.lastAssistantText = undefined this.expectPromptEchoAsUser = true diff --git a/apps/cli/src/commands/cli/run.ts b/apps/cli/src/commands/cli/run.ts index f2c0e039129..af97ffbf69d 100644 --- a/apps/cli/src/commands/cli/run.ts +++ b/apps/cli/src/commands/cli/run.ts @@ -6,6 +6,7 @@ import { createElement } from "react" import pWaitFor from "p-wait-for" import { setLogger } from "@roo-code/vscode-shim" +import type { ModelRecord } from "@roo-code/types" import { FlagOptions, @@ -51,8 +52,8 @@ function normalizeError(error: unknown): Error { return error instanceof Error ? error : new Error(String(error)) } -async function warmRooModels(host: ExtensionHost): Promise { - await new Promise((resolve, reject) => { +async function fetchRooModels(host: ExtensionHost): Promise { + return new Promise((resolve, reject) => { let settled = false const cleanup = () => { @@ -92,7 +93,7 @@ async function warmRooModels(host: ExtensionHost): Promise { return } - finish(() => resolve()) + finish(() => resolve(isRecord(values?.models) ? (values.models as ModelRecord) : {})) } const timeoutId = setTimeout(() => { @@ -104,6 +105,14 @@ async function warmRooModels(host: ExtensionHost): Promise { }) } +function resolveRooModelContextWindow(models: ModelRecord, modelId: string): number | undefined { + const contextWindow = models[modelId]?.contextWindow + if (typeof contextWindow !== "number" || !Number.isFinite(contextWindow) || contextWindow <= 0) { + return undefined + } + return contextWindow +} + export async function run(promptArg: string | undefined, flagOptions: FlagOptions) { setLogger({ info: () => {}, @@ -541,9 +550,18 @@ export async function run(promptArg: string | undefined, flagOptions: FlagOption try { await host.activate() + let rooContextWindow: number | undefined + if (extensionHostOptions.provider === "roo") { try { - await warmRooModels(host) + const rooModels = await fetchRooModels(host) + rooContextWindow = resolveRooModelContextWindow(rooModels, extensionHostOptions.model) + + if (flagOptions.debug && rooContextWindow === undefined) { + console.error( + `[CLI] Warning: Roo model context window unavailable for model: ${extensionHostOptions.model}`, + ) + } } catch (warmupError) { if (flagOptions.debug) { const message = warmupError instanceof Error ? warmupError.message : String(warmupError) @@ -552,6 +570,10 @@ export async function run(promptArg: string | undefined, flagOptions: FlagOption } } + if (rooContextWindow !== undefined) { + jsonEmitter?.setContextWindow(rooContextWindow) + } + if (jsonEmitter) { jsonEmitter.attachToClient(host.client) } diff --git a/packages/types/src/__tests__/cli.test.ts b/packages/types/src/__tests__/cli.test.ts index 483e633b4c9..11300e93643 100644 --- a/packages/types/src/__tests__/cli.test.ts +++ b/packages/types/src/__tests__/cli.test.ts @@ -76,5 +76,28 @@ describe("CLI types", () => { expect(result.success).toBe(true) }) + + it("preserves context usage fields in result cost", () => { + const result = rooCliFinalOutputSchema.safeParse({ + type: "result", + success: true, + cost: { + totalCost: 0.01, + inputTokens: 100, + outputTokens: 50, + contextTokens: 150, + contextWindow: 200_000, + contextUsagePercent: 0.075, + }, + events: [], + }) + + expect(result.success).toBe(true) + if (result.success) { + expect(result.data.cost?.contextTokens).toBe(150) + expect(result.data.cost?.contextWindow).toBe(200_000) + expect(result.data.cost?.contextUsagePercent).toBe(0.075) + } + }) }) }) diff --git a/packages/types/src/cli.ts b/packages/types/src/cli.ts index 738db4a103e..042c3e7739f 100644 --- a/packages/types/src/cli.ts +++ b/packages/types/src/cli.ts @@ -125,6 +125,9 @@ export const rooCliCostSchema = z.object({ outputTokens: z.number().optional(), cacheWrites: z.number().optional(), cacheReads: z.number().optional(), + contextTokens: z.number().optional(), + contextWindow: z.number().optional(), + contextUsagePercent: z.number().optional(), }) export type RooCliCost = z.infer