diff --git a/.changeset/cors-support.md b/.changeset/cors-support.md new file mode 100644 index 0000000..2a2573e --- /dev/null +++ b/.changeset/cors-support.md @@ -0,0 +1,7 @@ +--- +"partyserver": patch +--- + +Add CORS support to `routePartykitRequest`. + +Pass `cors: true` for permissive defaults or `cors: { ...headers }` for custom CORS headers. Preflight (OPTIONS) requests are handled automatically for matched routes, and CORS headers are appended to all non-WebSocket responses — including responses returned by `onBeforeRequest`. diff --git a/package-lock.json b/package-lock.json index 89621ad..09ada71 100644 --- a/package-lock.json +++ b/package-lock.json @@ -12235,7 +12235,7 @@ "devDependencies": { "@cloudflare/workers-types": "^4.20251218.0", "hono": "^4.11.1", - "partyserver": "^0.1.2" + "partyserver": "^0.1.3" }, "peerDependencies": { "@cloudflare/workers-types": "^4.20240729.0", @@ -12260,7 +12260,7 @@ "license": "ISC", "dependencies": { "nanoid": "^5.1.6", - "partysocket": "^1.1.11" + "partysocket": "^1.1.12" } }, "packages/partyhard": { @@ -12327,8 +12327,8 @@ "license": "ISC", "devDependencies": { "@cloudflare/workers-types": "^4.20251218.0", - "partyserver": "^0.1.2", - "partysocket": "^1.1.11" + "partyserver": "^0.1.3", + "partysocket": "^1.1.12" }, "peerDependencies": { "@cloudflare/workers-types": "^4.20240729.0", @@ -12345,7 +12345,7 @@ }, "devDependencies": { "@cloudflare/workers-types": "^4.20251218.0", - "partyserver": "^0.1.2" + "partyserver": "^0.1.3" }, "peerDependencies": { "@cloudflare/workers-types": "^4.20240729.0", @@ -12367,7 +12367,7 @@ "license": "ISC", "dependencies": { "cron-parser": "^5.4.0", - "partyserver": "^0.1.2" + "partyserver": "^0.1.3" } }, "packages/y-partyserver": { @@ -12382,7 +12382,7 @@ "devDependencies": { "@cloudflare/workers-types": "^4.20251218.0", "@types/lodash.debounce": "^4.0.9", - "partyserver": "^0.1.2", + "partyserver": "^0.1.3", "ws": "^8.18.3", "yjs": "^13.6.28" }, diff --git a/packages/hono-party/package.json b/packages/hono-party/package.json index 9364a9a..4302b1d 100644 --- a/packages/hono-party/package.json +++ b/packages/hono-party/package.json @@ -37,6 +37,6 @@ "devDependencies": { "@cloudflare/workers-types": "^4.20251218.0", "hono": "^4.11.1", - "partyserver": "^0.1.2" + "partyserver": "^0.1.3" } } diff --git a/packages/partyfn/package.json b/packages/partyfn/package.json index 875a83d..e54087c 100644 --- a/packages/partyfn/package.json +++ b/packages/partyfn/package.json @@ -19,7 +19,7 @@ ], "dependencies": { "nanoid": "^5.1.6", - "partysocket": "^1.1.11" + "partysocket": "^1.1.12" }, "scripts": { "build": "tsx scripts/build.ts" diff --git a/packages/partyserver/src/index.ts b/packages/partyserver/src/index.ts index 9c36442..b82313b 100644 --- a/packages/partyserver/src/index.ts +++ b/packages/partyserver/src/index.ts @@ -102,6 +102,30 @@ export interface PartyServerOptions< jurisdiction?: DurableObjectJurisdiction; locationHint?: DurableObjectLocationHint; props?: Props; + /** + * Whether to enable CORS for matched routes. + * + * When `true`, uses default permissive CORS headers: + * - Access-Control-Allow-Origin: * + * - Access-Control-Allow-Methods: GET, POST, HEAD, OPTIONS + * - Access-Control-Allow-Headers: * + * - Access-Control-Max-Age: 86400 + * + * For credentialed requests, pass explicit headers with a specific origin: + * ```ts + * cors: { + * "Access-Control-Allow-Origin": "https://myapp.com", + * "Access-Control-Allow-Credentials": "true", + * "Access-Control-Allow-Methods": "GET, POST, HEAD, OPTIONS", + * "Access-Control-Allow-Headers": "Content-Type, Authorization" + * } + * ``` + * + * When set to a `HeadersInit` value, uses those as the CORS headers instead. + * CORS preflight (OPTIONS) requests are handled automatically for matched routes. + * Non-WebSocket responses on matched routes will also have the CORS headers appended. + */ + cors?: boolean | HeadersInit; onBeforeConnect?: ( req: Request, lobby: { @@ -122,8 +146,31 @@ export interface PartyServerOptions< | Promise; } /** - * A utility function for PartyKit style routing. + * Resolve CORS options into a concrete headers object (or null if CORS is disabled). */ +function resolveCorsHeaders( + cors: boolean | HeadersInit | undefined +): Record | null { + if (cors === true) { + return { + "Access-Control-Allow-Origin": "*", + "Access-Control-Allow-Methods": "GET, POST, HEAD, OPTIONS", + "Access-Control-Allow-Headers": "*", + "Access-Control-Max-Age": "86400" + }; + } + if (cors && typeof cors === "object") { + // Convert any HeadersInit shape to a plain record + const h = new Headers(cors as HeadersInit); + const record: Record = {}; + h.forEach((value, key) => { + record[key] = value; + }); + return record; + } + return null; +} + export async function routePartykitRequest< Env extends Cloudflare.Env = Cloudflare.Env, T extends Server = Server, @@ -188,6 +235,26 @@ Did you forget to add a durable object binding to the class ${namespace[0].toUpp return new Response("Invalid request", { status: 400 }); } + // Resolve CORS headers for this matched route + const corsHeaders = resolveCorsHeaders(options?.cors); + const isWebSocket = + req.headers.get("Upgrade")?.toLowerCase() === "websocket"; + + // Helper: append CORS headers to a response (skipped for WebSocket upgrades) + function withCorsHeaders(response: Response): Response { + if (!corsHeaders || isWebSocket) return response; + const newResponse = new Response(response.body, response); + for (const [key, value] of Object.entries(corsHeaders)) { + newResponse.headers.set(key, value); + } + return newResponse; + } + + // Handle CORS preflight requests for matched routes + if (req.method === "OPTIONS" && corsHeaders) { + return new Response(null, { headers: corsHeaders }); + } + let doNamespace = map[namespace]; if (options?.jurisdiction) { doNamespace = doNamespace.jurisdiction(options.jurisdiction); @@ -210,7 +277,7 @@ Did you forget to add a durable object binding to the class ${namespace[0].toUpp req.headers.set("x-partykit-props", JSON.stringify(options?.props)); } - if (req.headers.get("Upgrade")?.toLowerCase() === "websocket") { + if (isWebSocket) { if (options?.onBeforeConnect) { const reqOrRes = await options.onBeforeConnect(req, { party: namespace, @@ -231,12 +298,12 @@ Did you forget to add a durable object binding to the class ${namespace[0].toUpp if (reqOrRes instanceof Request) { req = reqOrRes; } else if (reqOrRes instanceof Response) { - return reqOrRes; + return withCorsHeaders(reqOrRes); } } } - return stub.fetch(req); + return withCorsHeaders(await stub.fetch(req)); } else { return null; } @@ -336,17 +403,6 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam return Response.json({ ok: true }); } - // Handle keep-alive WebSocket endpoint (internal use for waitUntil) - if (url.pathname === "/cdn-cgi/partyserver/keep-alive/") { - if (request.headers.get("Upgrade")?.toLowerCase() === "websocket") { - const { 0: client, 1: server } = new WebSocketPair(); - // Always use hibernation API for keep-alive (efficient, internal-only) - this.ctx.acceptWebSocket(server, ["partyserver-keepalive"]); - return new Response(null, { status: 101, webSocket: client }); - } - return new Response("WebSocket required", { status: 426 }); - } - if (request.headers.get("Upgrade")?.toLowerCase() !== "websocket") { return await this.onRequest(request); } else { @@ -414,15 +470,6 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam } async webSocketMessage(ws: WebSocket, message: WSMessage): Promise { - // Handle keep-alive pings first (internal waitUntil mechanism) - const tags = this.ctx.getTags(ws); - if (tags.includes("partyserver-keepalive")) { - if (message === "ping") { - ws.send("pong"); - } - return; - } - // Ignore websockets accepted outside PartyServer (e.g. via // `state.acceptWebSocket()` in user code). These sockets won't have the // `__pk` attachment namespace required to rehydrate a Connection. @@ -451,12 +498,6 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam reason: string, wasClean: boolean ): Promise { - // Ignore keep-alive socket closes (internal waitUntil mechanism) - const tags = this.ctx.getTags(ws); - if (tags.includes("partyserver-keepalive")) { - return; - } - if (!isPartyServerWebSocket(ws)) { return; } @@ -476,12 +517,6 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam } async webSocketError(ws: WebSocket, error: unknown): Promise { - // Ignore keep-alive socket errors (internal waitUntil mechanism) - const tags = this.ctx.getTags(ws); - if (tags.includes("partyserver-keepalive")) { - return; - } - if (!isPartyServerWebSocket(ws)) { return; } @@ -630,114 +665,6 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam return []; } - /** - * Execute a long-running async function while keeping the Durable Object alive. - * - * Durable Objects normally terminate 70-140s after the last network request. - * This method keeps the DO alive by establishing a WebSocket connection to itself - * and sending periodic ping messages. - * - * @experimental This API is experimental and may change in future versions. - * - * @param fn - The async function to execute - * @param timeoutMs - Maximum time to keep the DO alive (default: 30 minutes) - * @returns The result of the async function - * - * @remarks - * Requires the `enable_ctx_exports` compatibility flag in wrangler.jsonc: - * ```json - * { - * "compatibility_flags": ["enable_ctx_exports"] - * } - * ``` - * - * @example - * ```typescript - * const result = await this.experimental_waitUntil(async () => { - * // Long-running operation - * await processLargeDataset(); - * return { success: true }; - * }, 60 * 60 * 1000); // 1 hour timeout - * ``` - */ - async experimental_waitUntil( - fn: () => Promise, - timeoutMs: number = 30 * 60 * 1000 // 30 minutes default - ): Promise { - // Get namespace from ctx.exports (requires enable_ctx_exports compatibility flag) - const exports = ( - this.ctx as DurableObjectState & { exports?: Record } - ).exports; - if (!exports) { - throw new Error( - "waitUntil requires the 'enable_ctx_exports' compatibility flag. " + - 'Add it to your wrangler.jsonc: { "compatibility_flags": ["enable_ctx_exports"] }' - ); - } - - const namespace = exports[this.#ParentClass.name] as - | DurableObjectNamespace - | undefined; - if (!namespace) { - throw new Error( - `Could not find namespace for ${this.#ParentClass.name} in ctx.exports. ` + - "Make sure the class name matches your Durable Object binding." - ); - } - - const stub = namespace.get(this.ctx.id); - - // Connect to self via WebSocket for keep-alive - const response = await stub.fetch( - "http://dummy-example.cloudflare.com/cdn-cgi/partyserver/keep-alive/", - { - headers: { - Upgrade: "websocket", - "x-partykit-room": this.name - } - } - ); - - const ws = response.webSocket; - if (!ws) { - throw new Error("Failed to establish keep-alive WebSocket connection"); - } - ws.accept(); - - // Set up ping interval (every 10 seconds) - const pingInterval = setInterval(() => { - try { - ws.send("ping"); - } catch { - // WebSocket may have closed, ignore - } - }, 10_000); - - // Create a timeout promise that rejects after timeoutMs - let timeoutId: ReturnType; - const timeoutPromise = new Promise((_, reject) => { - timeoutId = setTimeout(() => { - reject( - new Error(`experimental_waitUntil timed out after ${timeoutMs}ms`) - ); - }, timeoutMs); - }); - - try { - // Race the function against the timeout - const result = await Promise.race([fn(), timeoutPromise]); - return result; - } finally { - clearTimeout(timeoutId!); - clearInterval(pingInterval); - try { - ws.close(1000, "Complete"); - } catch { - // Ignore close errors - } - } - } - #_props?: Props; // Implemented by the user diff --git a/packages/partyserver/src/tests/index.test.ts b/packages/partyserver/src/tests/index.test.ts index 6d133da..127ba26 100644 --- a/packages/partyserver/src/tests/index.test.ts +++ b/packages/partyserver/src/tests/index.test.ts @@ -300,3 +300,130 @@ describe("Server", () => { // describe("hibernated"); // describe("in-memory"); }); + +describe("CORS", () => { + it("returns CORS headers on OPTIONS preflight for matched routes", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/cors-parties/cors-server/room1", + { method: "OPTIONS" } + ); + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(200); + expect(response.headers.get("Access-Control-Allow-Origin")).toBe("*"); + expect(response.headers.get("Access-Control-Allow-Methods")).toBe( + "GET, POST, HEAD, OPTIONS" + ); + expect(response.headers.get("Access-Control-Allow-Headers")).toBe("*"); + expect(response.headers.get("Access-Control-Max-Age")).toBe("86400"); + // Credentials header should NOT be in defaults (contradicts wildcard origin) + expect(response.headers.get("Access-Control-Allow-Credentials")).toBeNull(); + }); + + it("does not handle OPTIONS for unmatched routes (returns 404 from fallback)", async () => { + const ctx = createExecutionContext(); + const request = new Request("http://example.com/other-path", { + method: "OPTIONS" + }); + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(404); + expect(response.headers.get("Access-Control-Allow-Origin")).toBeNull(); + }); + + it("does not handle OPTIONS for routes without cors enabled", async () => { + const ctx = createExecutionContext(); + // The default /parties/ prefix has no cors option + const request = new Request("http://example.com/parties/stateful/room1", { + method: "OPTIONS" + }); + const response = await worker.fetch(request, env, ctx); + // Without cors, OPTIONS goes to the DO like any other request + expect(response.headers.get("Access-Control-Allow-Origin")).toBeNull(); + }); + + it("appends CORS headers to regular (non-WebSocket) responses", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/cors-parties/cors-server/room1" + ); + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(200); + expect(await response.json()).toEqual({ cors: true }); + expect(response.headers.get("Access-Control-Allow-Origin")).toBe("*"); + expect(response.headers.get("Access-Control-Allow-Methods")).toBe( + "GET, POST, HEAD, OPTIONS" + ); + }); + + it("does not append CORS headers to WebSocket upgrade responses", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/cors-parties/cors-server/room1", + { + headers: { Upgrade: "websocket" } + } + ); + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(101); + expect(response.headers.get("Access-Control-Allow-Origin")).toBeNull(); + response.webSocket?.accept(); + response.webSocket?.close(); + }); + + it("supports custom HeadersInit CORS headers", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/custom-cors-parties/custom-cors-server/room1" + ); + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(200); + expect(await response.json()).toEqual({ customCors: true }); + expect(response.headers.get("Access-Control-Allow-Origin")).toBe( + "https://example.com" + ); + expect(response.headers.get("Access-Control-Allow-Methods")).toBe( + "GET, POST" + ); + // Should not have the default headers that weren't specified + expect(response.headers.get("Access-Control-Max-Age")).toBeNull(); + }); + + it("supports custom HeadersInit CORS headers on OPTIONS preflight", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/custom-cors-parties/custom-cors-server/room1", + { method: "OPTIONS" } + ); + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(200); + expect(response.headers.get("Access-Control-Allow-Origin")).toBe( + "https://example.com" + ); + expect(response.headers.get("Access-Control-Allow-Methods")).toBe( + "GET, POST" + ); + }); + + it("does not add CORS headers when cors option is not set", async () => { + const ctx = createExecutionContext(); + const request = new Request("http://example.com/parties/stateful/room1"); + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(200); + expect(response.headers.get("Access-Control-Allow-Origin")).toBeNull(); + }); + + it("appends CORS headers to responses returned by onBeforeRequest", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/cors-parties/cors-server/blocked" + ); + const response = await worker.fetch(request, env, ctx); + expect(response.status).toBe(403); + expect(await response.text()).toBe("Forbidden"); + // CORS headers must be present so the browser can read the error + expect(response.headers.get("Access-Control-Allow-Origin")).toBe("*"); + expect(response.headers.get("Access-Control-Allow-Methods")).toBe( + "GET, POST, HEAD, OPTIONS" + ); + }); +}); diff --git a/packages/partyserver/src/tests/worker.ts b/packages/partyserver/src/tests/worker.ts index def870d..a84f2c3 100644 --- a/packages/partyserver/src/tests/worker.ts +++ b/packages/partyserver/src/tests/worker.ts @@ -15,6 +15,8 @@ export type Env = { ConfigurableState: DurableObjectNamespace; ConfigurableStateInMemory: DurableObjectNamespace; StateRoundTrip: DurableObjectNamespace; + CorsServer: DurableObjectNamespace; + CustomCorsServer: DurableObjectNamespace; }; export class Stateful extends Server { @@ -189,8 +191,50 @@ export class ConfigurableStateInMemory extends Server { } } +export class CorsServer extends Server { + onRequest(): Response | Promise { + return Response.json({ cors: true }); + } +} + +export class CustomCorsServer extends Server { + onRequest(): Response | Promise { + return Response.json({ customCors: true }); + } +} + export default { async fetch(request: Request, env: Env, _ctx: ExecutionContext) { + const url = new URL(request.url); + + // Route requests under /cors-parties/ with cors: true + if (url.pathname.startsWith("/cors-parties/")) { + return ( + (await routePartykitRequest(request, env, { + prefix: "cors-parties", + cors: true, + onBeforeRequest: async (_req, { name }) => { + if (name === "blocked") { + return new Response("Forbidden", { status: 403 }); + } + } + })) || new Response("Not Found", { status: 404 }) + ); + } + + // Route requests under /custom-cors-parties/ with custom CORS headers + if (url.pathname.startsWith("/custom-cors-parties/")) { + return ( + (await routePartykitRequest(request, env, { + prefix: "custom-cors-parties", + cors: { + "Access-Control-Allow-Origin": "https://example.com", + "Access-Control-Allow-Methods": "GET, POST" + } + })) || new Response("Not Found", { status: 404 }) + ); + } + return ( (await routePartykitRequest(request, env, { onBeforeConnect: async (_request, { party, name }) => { diff --git a/packages/partyserver/src/tests/wrangler.jsonc b/packages/partyserver/src/tests/wrangler.jsonc index fe591aa..d3ca216 100644 --- a/packages/partyserver/src/tests/wrangler.jsonc +++ b/packages/partyserver/src/tests/wrangler.jsonc @@ -34,6 +34,14 @@ { "name": "StateRoundTrip", "class_name": "StateRoundTrip" + }, + { + "name": "CorsServer", + "class_name": "CorsServer" + }, + { + "name": "CustomCorsServer", + "class_name": "CustomCorsServer" } ] }, @@ -49,6 +57,10 @@ "ConfigurableStateInMemory", "StateRoundTrip" ] + }, + { + "tag": "v3", + "new_classes": ["CorsServer", "CustomCorsServer"] } ] } diff --git a/packages/partysub/package.json b/packages/partysub/package.json index df17dcf..9e84aa5 100644 --- a/packages/partysub/package.json +++ b/packages/partysub/package.json @@ -46,7 +46,7 @@ }, "devDependencies": { "@cloudflare/workers-types": "^4.20251218.0", - "partyserver": "^0.1.2", - "partysocket": "^1.1.11" + "partyserver": "^0.1.3", + "partysocket": "^1.1.12" } } diff --git a/packages/partysync/package.json b/packages/partysync/package.json index 116aa90..9eaa1e8 100644 --- a/packages/partysync/package.json +++ b/packages/partysync/package.json @@ -54,6 +54,6 @@ }, "devDependencies": { "@cloudflare/workers-types": "^4.20251218.0", - "partyserver": "^0.1.2" + "partyserver": "^0.1.3" } } diff --git a/packages/partywhen/package.json b/packages/partywhen/package.json index d50c21e..31371c3 100644 --- a/packages/partywhen/package.json +++ b/packages/partywhen/package.json @@ -29,6 +29,6 @@ "description": "A library for scheduling and running tasks in Cloudflare Workers", "dependencies": { "cron-parser": "^5.4.0", - "partyserver": "^0.1.2" + "partyserver": "^0.1.3" } } diff --git a/packages/y-partyserver/package.json b/packages/y-partyserver/package.json index 2c50e6a..87fb36c 100644 --- a/packages/y-partyserver/package.json +++ b/packages/y-partyserver/package.json @@ -53,7 +53,7 @@ "devDependencies": { "@cloudflare/workers-types": "^4.20251218.0", "@types/lodash.debounce": "^4.0.9", - "partyserver": "^0.1.2", + "partyserver": "^0.1.3", "ws": "^8.18.3", "yjs": "^13.6.28" }