Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 52 additions & 8 deletions index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import {
getLiveAccountSync,
getLiveAccountSyncDebounceMs,
getLiveAccountSyncPollMs,
getResponseContinuation,
getSessionAffinity,
getSessionAffinityTtlMs,
getSessionAffinityMaxEntries,
Expand Down Expand Up @@ -977,7 +978,10 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
const ensureSessionAffinity = (
pluginConfig: ReturnType<typeof loadPluginConfig>,
): void => {
if (!getSessionAffinity(pluginConfig)) {
if (
!getSessionAffinity(pluginConfig) &&
!getResponseContinuation(pluginConfig)
) {
sessionAffinityStore = null;
sessionAffinityConfigKey = null;
return;
Expand Down Expand Up @@ -1336,8 +1340,7 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
const originalBody = await parseRequestBodyFromInit(baseInit?.body);
const isStreaming = originalBody.stream === true;
const parsedBody =
Object.keys(originalBody).length > 0 ? originalBody : undefined;

Object.keys(originalBody).length > 0 ? { ...originalBody } : undefined;
const transformation = await transformRequestForCodex(
baseInit,
url,
Expand All @@ -1356,13 +1359,33 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
let model = transformedBody?.model;
let modelFamily = model ? getModelFamily(model) : "gpt-5.1";
let quotaKey = model ? `${modelFamily}:${model}` : modelFamily;
const responseContinuationEnabled =
getResponseContinuation(pluginConfig);
const threadIdCandidate =
(process.env.CODEX_THREAD_ID ?? promptCacheKey ?? "")
.toString()
.trim() || undefined;
const sessionAffinityKey = threadIdCandidate ?? promptCacheKey ?? null;
const effectivePromptCacheKey =
(sessionAffinityKey ?? promptCacheKey ?? "").toString().trim() || undefined;
const shouldUseResponseContinuation =
Boolean(transformedBody) &&
responseContinuationEnabled &&
!transformedBody?.previous_response_id;
if (shouldUseResponseContinuation && transformedBody) {
const lastResponseId =
sessionAffinityStore?.getLastResponseId(sessionAffinityKey);
if (lastResponseId) {
transformedBody = {
...transformedBody,
previous_response_id: lastResponseId,
};
requestInit = {
...requestInit,
body: JSON.stringify(transformedBody),
};
}
}
const preferredSessionAccountIndex = sessionAffinityStore?.getPreferredAccountIndex(
sessionAffinityKey,
);
Expand Down Expand Up @@ -2397,7 +2420,10 @@ accountAttemptLoop: while (attempted.size < Math.max(1, accountCount)) {
successAccountForResponse = fallbackAccount;
successEntitlementAccountKey = fallbackEntitlementAccountKey;
runtimeMetrics.streamFailoverRecoveries += 1;
if (fallbackAccount.index !== account.index) {
if (
fallbackAccount.index !== account.index &&
!responseContinuationEnabled
) {
runtimeMetrics.streamFailoverCrossAccountRecoveries += 1;
runtimeMetrics.accountRotations += 1;
sessionAffinityStore?.remember(
Expand Down Expand Up @@ -2447,7 +2473,20 @@ accountAttemptLoop: while (attempted.size < Math.max(1, accountCount)) {
},
);
}
let storedResponseIdForSuccess = false;
const successResponse = await handleSuccessResponse(responseForSuccess, isStreaming, {
onResponseId: (responseId) => {
if (!responseContinuationEnabled) return;
sessionAffinityStore?.remember(
sessionAffinityKey,
successAccountForResponse.index,
);
sessionAffinityStore?.rememberLastResponseId(
sessionAffinityKey,
responseId,
);
storedResponseIdForSuccess = true;
},
streamStallTimeoutMs,
});

Expand Down Expand Up @@ -2512,10 +2551,15 @@ accountAttemptLoop: while (attempted.size < Math.max(1, accountCount)) {
capabilityModelKey,
);
entitlementCache.clear(successAccountKey, capabilityModelKey);
sessionAffinityStore?.remember(
sessionAffinityKey,
successAccountForResponse.index,
);
if (
!responseContinuationEnabled ||
(!isStreaming && !storedResponseIdForSuccess)
) {
sessionAffinityStore?.remember(
sessionAffinityKey,
successAccountForResponse.index,
);
}
runtimeMetrics.successfulRequests++;
runtimeMetrics.lastError = null;
if (lastCodexCliActiveSyncIndex !== successAccountForResponse.index) {
Expand Down
19 changes: 19 additions & 0 deletions lib/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ export const DEFAULT_PLUGIN_CONFIG: PluginConfig = {
sessionAffinity: true,
sessionAffinityTtlMs: 20 * 60_000,
sessionAffinityMaxEntries: 512,
responseContinuation: false,
proactiveRefreshGuardian: true,
proactiveRefreshIntervalMs: 60_000,
proactiveRefreshBufferMs: 5 * 60_000,
Expand Down Expand Up @@ -917,6 +918,24 @@ export function getSessionAffinityMaxEntries(pluginConfig: PluginConfig): number
);
}

/**
* Controls whether the plugin should automatically continue Responses API turns
* with the last known `previous_response_id` for the active session key.
*
* Reads the `responseContinuation` value from `pluginConfig` and allows an
* environment override via `CODEX_AUTH_RESPONSE_CONTINUATION`.
*
* @param pluginConfig - The plugin configuration to consult for the setting
* @returns `true` if automatic response continuation is enabled, `false` otherwise
*/
export function getResponseContinuation(pluginConfig: PluginConfig): boolean {
return resolveBooleanSetting(
"CODEX_AUTH_RESPONSE_CONTINUATION",
pluginConfig.responseContinuation,
false,
);
}

/**
* Controls whether the proactive refresh guardian is enabled.
*
Expand Down
17 changes: 10 additions & 7 deletions lib/request/fetch-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ import { queuedRefresh } from "../refresh-queue.js";
import { logRequest, logError, logWarn } from "../logger.js";
import { getCodexInstructions, getModelFamily } from "../prompts/codex.js";
import { transformRequestBody, normalizeModel } from "./request-transformer.js";
import { convertSseToJson, ensureContentType } from "./response-handler.js";
import {
attachResponseIdCapture,
convertSseToJson,
ensureContentType,
} from "./response-handler.js";
import type { UserConfig, RequestBody } from "../types.js";
import { registerCleanup } from "../shutdown.js";
import { CodexAuthError } from "../errors.js";
Expand Down Expand Up @@ -841,7 +845,10 @@ export async function handleErrorResponse(
export async function handleSuccessResponse(
response: Response,
isStreaming: boolean,
options?: { streamStallTimeoutMs?: number },
options?: {
onResponseId?: (responseId: string) => void;
streamStallTimeoutMs?: number;
},
): Promise<Response> {
// Check for deprecation headers (RFC 8594)
const deprecation = response.headers.get("Deprecation");
Expand All @@ -858,11 +865,7 @@ export async function handleSuccessResponse(
}

// For streaming requests (streamText), return stream as-is
return new Response(response.body, {
status: response.status,
statusText: response.statusText,
headers: responseHeaders,
});
return attachResponseIdCapture(response, responseHeaders, options?.onResponseId);
}

async function safeReadBody(response: Response): Promise<string> {
Expand Down
142 changes: 130 additions & 12 deletions lib/request/response-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,64 @@ const log = createLogger("response-handler");
const MAX_SSE_SIZE = 10 * 1024 * 1024; // 10MB limit to prevent memory exhaustion
const DEFAULT_STREAM_STALL_TIMEOUT_MS = 45_000;

function extractResponseId(response: unknown): string | null {
if (!response || typeof response !== "object") return null;
const candidate = (response as { id?: unknown }).id;
return typeof candidate === "string" && candidate.trim().length > 0
? candidate.trim()
: null;
}

function notifyResponseId(
onResponseId: ((responseId: string) => void) | undefined,
response: unknown,
): void {
const responseId = extractResponseId(response);
if (!responseId || !onResponseId) return;
try {
onResponseId(responseId);
} catch (error) {
log.warn("Failed to persist response id from upstream event", {
error: String(error),
responseId,
});
}
}

type CapturedResponseEvent =
| { kind: "error" }
| { kind: "response"; response: unknown }
| null;

function maybeCaptureResponseEvent(
data: SSEEventData,
onResponseId?: (responseId: string) => void,
): CapturedResponseEvent {
if (data.type === "error") {
log.error("SSE error event received", { error: data });
return { kind: "error" };
}

if (data.type === "response.done" || data.type === "response.completed") {
notifyResponseId(onResponseId, data.response);
if (data.response !== undefined && data.response !== null) {
return { kind: "response", response: data.response };
}
}

return null;
}

/**

* Parse SSE stream to extract final response
* @param sseText - Complete SSE stream text
* @returns Final response object or null if not found
*/
function parseSseStream(sseText: string): unknown | null {
function parseSseStream(
sseText: string,
onResponseId?: (responseId: string) => void,
): unknown | null {
const lines = sseText.split(/\r?\n/);

for (const line of lines) {
Expand All @@ -24,15 +75,9 @@ function parseSseStream(sseText: string): unknown | null {
if (!payload || payload === '[DONE]') continue;
try {
const data = JSON.parse(payload) as SSEEventData;

if (data.type === 'error') {
log.error("SSE error event received", { error: data });
return null;
}

if (data.type === 'response.done' || data.type === 'response.completed') {
return data.response;
}
const capturedEvent = maybeCaptureResponseEvent(data, onResponseId);
if (capturedEvent?.kind === "error") return null;
if (capturedEvent?.kind === "response") return capturedEvent.response;
} catch {
// Skip malformed JSON
}
Expand All @@ -51,7 +96,10 @@ function parseSseStream(sseText: string): unknown | null {
export async function convertSseToJson(
response: Response,
headers: Headers,
options?: { streamStallTimeoutMs?: number },
options?: {
onResponseId?: (responseId: string) => void;
streamStallTimeoutMs?: number;
},
): Promise<Response> {
if (!response.body) {
throw new Error(`[${PLUGIN_NAME}] Response has no body`);
Expand Down Expand Up @@ -80,7 +128,7 @@ export async function convertSseToJson(
}

// Parse SSE events to extract the final response
const finalResponse = parseSseStream(fullText);
const finalResponse = parseSseStream(fullText, options?.onResponseId);

if (!finalResponse) {
log.warn("Could not find final response in SSE stream");
Expand Down Expand Up @@ -119,6 +167,56 @@ export async function convertSseToJson(

}

function createResponseIdCapturingStream(
body: ReadableStream<Uint8Array>,
onResponseId: (responseId: string) => void,
): ReadableStream<Uint8Array> {
const decoder = new TextDecoder();
let bufferedText = "";
let sawErrorEvent = false;

const processBufferedLines = (flush = false): void => {
if (sawErrorEvent) return;
const lines = bufferedText.split(/\r?\n/);
if (!flush) {
bufferedText = lines.pop() ?? "";
} else {
bufferedText = "";
}

for (const rawLine of lines) {
const trimmedLine = rawLine.trim();
if (!trimmedLine.startsWith("data: ")) continue;
const payload = trimmedLine.slice(6).trim();
if (!payload || payload === "[DONE]") continue;
try {
const data = JSON.parse(payload) as SSEEventData;
const capturedEvent = maybeCaptureResponseEvent(data, onResponseId);
if (capturedEvent?.kind === "error") {
sawErrorEvent = true;
break;
}
} catch {
// Ignore malformed SSE lines and keep forwarding the raw stream.
}
}
};

return body.pipeThrough(
new TransformStream<Uint8Array, Uint8Array>({
transform(chunk, controller) {
bufferedText += decoder.decode(chunk, { stream: true });
processBufferedLines();
controller.enqueue(chunk);
},
flush() {
bufferedText += decoder.decode();
processBufferedLines(true);
},
}),
);
}

/**
* Ensure response has content-type header
* @param headers - Response headers
Expand Down Expand Up @@ -186,3 +284,23 @@ export function isEmptyResponse(body: unknown): boolean {

return false;
}

export function attachResponseIdCapture(
response: Response,
headers: Headers,
onResponseId?: (responseId: string) => void,
): Response {
if (!response.body || !onResponseId) {
return new Response(response.body, {
status: response.status,
statusText: response.statusText,
headers,
});
}

return new Response(createResponseIdCapturingStream(response.body, onResponseId), {
status: response.status,
statusText: response.statusText,
headers,
});
}
1 change: 1 addition & 0 deletions lib/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export const PluginConfigSchema = z.object({
sessionAffinity: z.boolean().optional(),
sessionAffinityTtlMs: z.number().min(1_000).optional(),
sessionAffinityMaxEntries: z.number().min(8).optional(),
responseContinuation: z.boolean().optional(),
proactiveRefreshGuardian: z.boolean().optional(),
proactiveRefreshIntervalMs: z.number().min(5_000).optional(),
proactiveRefreshBufferMs: z.number().min(30_000).optional(),
Expand Down
Loading