Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ import {
getLiveAccountSync,
getLiveAccountSyncDebounceMs,
getLiveAccountSyncPollMs,
getResponseContinuation,
getSessionAffinity,
getSessionAffinityTtlMs,
getSessionAffinityMaxEntries,
Expand Down Expand Up @@ -1336,7 +1337,27 @@ export const OpenAIOAuthPlugin: Plugin = async ({ client }: PluginInput) => {
const originalBody = await parseRequestBodyFromInit(baseInit?.body);
const isStreaming = originalBody.stream === true;
const parsedBody =
Object.keys(originalBody).length > 0 ? originalBody : undefined;
Object.keys(originalBody).length > 0 ? { ...originalBody } : undefined;
const requestPromptCacheKey =
typeof parsedBody?.prompt_cache_key === "string"
? parsedBody.prompt_cache_key.trim()
: "";
const requestThreadId =
(process.env.CODEX_THREAD_ID ?? requestPromptCacheKey ?? "")
.toString()
.trim() || undefined;
const continuationSessionKey = requestThreadId ?? requestPromptCacheKey ?? null;
const shouldUseResponseContinuation =
Boolean(parsedBody) &&
getResponseContinuation(pluginConfig) &&
!parsedBody?.previous_response_id;
if (shouldUseResponseContinuation) {
const lastResponseId =
sessionAffinityStore?.getLastResponseId(continuationSessionKey);
Comment thread
greptile-apps[bot] marked this conversation as resolved.
Outdated
if (lastResponseId && parsedBody) {
parsedBody.previous_response_id = lastResponseId;
}
Comment thread
ndycode marked this conversation as resolved.
Outdated
}

const transformation = await transformRequestForCodex(
baseInit,
Expand Down Expand Up @@ -2447,7 +2468,15 @@ accountAttemptLoop: while (attempted.size < Math.max(1, accountCount)) {
},
);
}
let capturedResponseId: string | null = null;
const successResponse = await handleSuccessResponse(responseForSuccess, isStreaming, {
onResponseId: (responseId) => {
capturedResponseId = responseId;
sessionAffinityStore?.rememberLastResponseId(
sessionAffinityKey,
responseId,
);
Comment thread
ndycode marked this conversation as resolved.
},
Comment thread
ndycode marked this conversation as resolved.
Outdated
streamStallTimeoutMs,
});

Expand Down Expand Up @@ -2516,6 +2545,12 @@ accountAttemptLoop: while (attempted.size < Math.max(1, accountCount)) {
sessionAffinityKey,
successAccountForResponse.index,
);
if (capturedResponseId) {
sessionAffinityStore?.rememberLastResponseId(
sessionAffinityKey,
capturedResponseId,
);
}
Comment thread
ndycode marked this conversation as resolved.
Outdated
runtimeMetrics.successfulRequests++;
runtimeMetrics.lastError = null;
if (lastCodexCliActiveSyncIndex !== successAccountForResponse.index) {
Expand Down
19 changes: 19 additions & 0 deletions lib/config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ export const DEFAULT_PLUGIN_CONFIG: PluginConfig = {
sessionAffinity: true,
sessionAffinityTtlMs: 20 * 60_000,
sessionAffinityMaxEntries: 512,
responseContinuation: false,
proactiveRefreshGuardian: true,
proactiveRefreshIntervalMs: 60_000,
proactiveRefreshBufferMs: 5 * 60_000,
Expand Down Expand Up @@ -917,6 +918,24 @@ export function getSessionAffinityMaxEntries(pluginConfig: PluginConfig): number
);
}

/**
* Controls whether the plugin should automatically continue Responses API turns
* with the last known `previous_response_id` for the active session key.
*
* Reads the `responseContinuation` value from `pluginConfig` and allows an
* environment override via `CODEX_AUTH_RESPONSE_CONTINUATION`.
*
* @param pluginConfig - The plugin configuration to consult for the setting
* @returns `true` if automatic response continuation is enabled, `false` otherwise
*/
export function getResponseContinuation(pluginConfig: PluginConfig): boolean {
return resolveBooleanSetting(
"CODEX_AUTH_RESPONSE_CONTINUATION",
pluginConfig.responseContinuation,
false,
);
}

/**
* Controls whether the proactive refresh guardian is enabled.
*
Expand Down
17 changes: 10 additions & 7 deletions lib/request/fetch-helpers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ import { queuedRefresh } from "../refresh-queue.js";
import { logRequest, logError, logWarn } from "../logger.js";
import { getCodexInstructions, getModelFamily } from "../prompts/codex.js";
import { transformRequestBody, normalizeModel } from "./request-transformer.js";
import { convertSseToJson, ensureContentType } from "./response-handler.js";
import {
attachResponseIdCapture,
convertSseToJson,
ensureContentType,
} from "./response-handler.js";
import type { UserConfig, RequestBody } from "../types.js";
import { registerCleanup } from "../shutdown.js";
import { CodexAuthError } from "../errors.js";
Expand Down Expand Up @@ -841,7 +845,10 @@ export async function handleErrorResponse(
export async function handleSuccessResponse(
response: Response,
isStreaming: boolean,
options?: { streamStallTimeoutMs?: number },
options?: {
onResponseId?: (responseId: string) => void;
streamStallTimeoutMs?: number;
},
): Promise<Response> {
// Check for deprecation headers (RFC 8594)
const deprecation = response.headers.get("Deprecation");
Expand All @@ -858,11 +865,7 @@ export async function handleSuccessResponse(
}

// For streaming requests (streamText), return stream as-is
return new Response(response.body, {
status: response.status,
statusText: response.statusText,
headers: responseHeaders,
});
return attachResponseIdCapture(response, responseHeaders, options?.onResponseId);
}

async function safeReadBody(response: Response): Promise<string> {
Expand Down
128 changes: 116 additions & 12 deletions lib/request/response-handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,57 @@ const log = createLogger("response-handler");
const MAX_SSE_SIZE = 10 * 1024 * 1024; // 10MB limit to prevent memory exhaustion
const DEFAULT_STREAM_STALL_TIMEOUT_MS = 45_000;

function extractResponseId(response: unknown): string | null {
if (!response || typeof response !== "object") return null;
const candidate = (response as { id?: unknown }).id;
return typeof candidate === "string" && candidate.trim().length > 0
? candidate.trim()
: null;
}

function notifyResponseId(
onResponseId: ((responseId: string) => void) | undefined,
response: unknown,
): void {
const responseId = extractResponseId(response);
if (!responseId || !onResponseId) return;
try {
onResponseId(responseId);
} catch (error) {
log.warn("Failed to persist response id from upstream event", {
error: String(error),
responseId,
});
}
}

function maybeCaptureResponseEvent(
data: SSEEventData,
onResponseId?: (responseId: string) => void,
): unknown | null {
if (data.type === "error") {
log.error("SSE error event received", { error: data });
return null;
}

if (data.type === "response.done" || data.type === "response.completed") {
notifyResponseId(onResponseId, data.response);
return data.response ?? null;
}

return null;
}
Comment thread
greptile-apps[bot] marked this conversation as resolved.

/**

* Parse SSE stream to extract final response
* @param sseText - Complete SSE stream text
* @returns Final response object or null if not found
*/
function parseSseStream(sseText: string): unknown | null {
function parseSseStream(
sseText: string,
onResponseId?: (responseId: string) => void,
): unknown | null {
const lines = sseText.split(/\r?\n/);

for (const line of lines) {
Expand All @@ -24,15 +68,8 @@ function parseSseStream(sseText: string): unknown | null {
if (!payload || payload === '[DONE]') continue;
try {
const data = JSON.parse(payload) as SSEEventData;

if (data.type === 'error') {
log.error("SSE error event received", { error: data });
return null;
}

if (data.type === 'response.done' || data.type === 'response.completed') {
return data.response;
}
const finalResponse = maybeCaptureResponseEvent(data, onResponseId);
if (finalResponse) return finalResponse;
} catch {
// Skip malformed JSON
}
Expand All @@ -51,7 +88,10 @@ function parseSseStream(sseText: string): unknown | null {
export async function convertSseToJson(
response: Response,
headers: Headers,
options?: { streamStallTimeoutMs?: number },
options?: {
onResponseId?: (responseId: string) => void;
streamStallTimeoutMs?: number;
},
): Promise<Response> {
if (!response.body) {
throw new Error(`[${PLUGIN_NAME}] Response has no body`);
Expand Down Expand Up @@ -80,7 +120,7 @@ export async function convertSseToJson(
}

// Parse SSE events to extract the final response
const finalResponse = parseSseStream(fullText);
const finalResponse = parseSseStream(fullText, options?.onResponseId);

if (!finalResponse) {
log.warn("Could not find final response in SSE stream");
Expand Down Expand Up @@ -119,6 +159,50 @@ export async function convertSseToJson(

}

function createResponseIdCapturingStream(
body: ReadableStream<Uint8Array>,
onResponseId: (responseId: string) => void,
): ReadableStream<Uint8Array> {
const decoder = new TextDecoder();
let bufferedText = "";

const processBufferedLines = (flush = false): void => {
const lines = bufferedText.split(/\r?\n/);
if (!flush) {
bufferedText = lines.pop() ?? "";
} else {
bufferedText = "";
}

for (const rawLine of lines) {
const trimmedLine = rawLine.trim();
if (!trimmedLine.startsWith("data: ")) continue;
const payload = trimmedLine.slice(6).trim();
if (!payload || payload === "[DONE]") continue;
try {
const data = JSON.parse(payload) as SSEEventData;
maybeCaptureResponseEvent(data, onResponseId);
} catch {
// Ignore malformed SSE lines and keep forwarding the raw stream.
}
}
};

return body.pipeThrough(
new TransformStream<Uint8Array, Uint8Array>({
transform(chunk, controller) {
bufferedText += decoder.decode(chunk, { stream: true });
processBufferedLines();
controller.enqueue(chunk);
},
flush() {
bufferedText += decoder.decode();
processBufferedLines(true);
},
}),
);
}

/**
* Ensure response has content-type header
* @param headers - Response headers
Expand Down Expand Up @@ -186,3 +270,23 @@ export function isEmptyResponse(body: unknown): boolean {

return false;
}

export function attachResponseIdCapture(
response: Response,
headers: Headers,
onResponseId?: (responseId: string) => void,
): Response {
if (!response.body || !onResponseId) {
return new Response(response.body, {
status: response.status,
statusText: response.statusText,
headers,
});
}

return new Response(createResponseIdCapturingStream(response.body, onResponseId), {
status: response.status,
statusText: response.statusText,
headers,
});
}
1 change: 1 addition & 0 deletions lib/schemas.ts
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ export const PluginConfigSchema = z.object({
sessionAffinity: z.boolean().optional(),
sessionAffinityTtlMs: z.number().min(1_000).optional(),
sessionAffinityMaxEntries: z.number().min(8).optional(),
responseContinuation: z.boolean().optional(),
proactiveRefreshGuardian: z.boolean().optional(),
proactiveRefreshIntervalMs: z.number().min(5_000).optional(),
proactiveRefreshBufferMs: z.number().min(30_000).optional(),
Expand Down
44 changes: 44 additions & 0 deletions lib/session-affinity.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export interface SessionAffinityOptions {
interface SessionAffinityEntry {
accountIndex: number;
expiresAt: number;
lastResponseId?: string;
updatedAt: number;
}

Expand Down Expand Up @@ -65,6 +66,8 @@ export class SessionAffinityStore {
if (!key) return;
if (!Number.isFinite(accountIndex) || accountIndex < 0) return;

const existingEntry = this.entries.get(key);

if (this.entries.size >= this.maxEntries && !this.entries.has(key)) {
const oldest = this.findOldestKey();
if (oldest) this.entries.delete(oldest);
Expand All @@ -73,6 +76,47 @@ export class SessionAffinityStore {
this.entries.set(key, {
accountIndex,
expiresAt: now + this.ttlMs,
lastResponseId: existingEntry?.lastResponseId,
updatedAt: now,
});
}

getLastResponseId(sessionKey: string | null | undefined, now = Date.now()): string | null {
const key = normalizeSessionKey(sessionKey);
if (!key) return null;

const entry = this.entries.get(key);
if (!entry) return null;
if (entry.expiresAt <= now) {
this.entries.delete(key);
return null;
}

const lastResponseId =
typeof entry.lastResponseId === "string" ? entry.lastResponseId.trim() : "";
return lastResponseId || null;
}

rememberLastResponseId(
sessionKey: string | null | undefined,
responseId: string | null | undefined,
now = Date.now(),
): void {
const key = normalizeSessionKey(sessionKey);
const normalizedResponseId = typeof responseId === "string" ? responseId.trim() : "";
if (!key || !normalizedResponseId) return;

const entry = this.entries.get(key);
if (!entry) return;
if (entry.expiresAt <= now) {
this.entries.delete(key);
return;
}

this.entries.set(key, {
...entry,
expiresAt: now + this.ttlMs,
lastResponseId: normalizedResponseId,
updatedAt: now,
});
}
Expand Down
Loading