diff --git a/.changeset/selfish-bananas-clean.md b/.changeset/selfish-bananas-clean.md new file mode 100644 index 0000000..3508149 --- /dev/null +++ b/.changeset/selfish-bananas-clean.md @@ -0,0 +1,5 @@ +--- +"partyserver": patch +--- + +Check for hibernated websocket connections diff --git a/packages/partyserver/src/connection.ts b/packages/partyserver/src/connection.ts index 8cadaa6..b834f54 100644 --- a/packages/partyserver/src/connection.ts +++ b/packages/partyserver/src/connection.ts @@ -38,6 +38,39 @@ type ConnectionAttachments = { __user?: unknown; }; +function tryGetPartyServerMeta( + ws: WebSocket +): ConnectionAttachments["__pk"] | null { + try { + // Avoid AttachmentCache.get() here: hibernated sockets accepted outside + // PartyServer can have an attachment without a __pk namespace. + const attachment = WebSocket.prototype.deserializeAttachment.call( + ws + ) as unknown; + if (!attachment || typeof attachment !== "object") { + return null; + } + if (!("__pk" in attachment)) { + return null; + } + const pk = (attachment as ConnectionAttachments).__pk as unknown; + if (!pk || typeof pk !== "object") { + return null; + } + const { id, server } = pk as { id?: unknown; server?: unknown }; + if (typeof id !== "string" || typeof server !== "string") { + return null; + } + return pk as ConnectionAttachments["__pk"]; + } catch { + return null; + } +} + +export function isPartyServerWebSocket(ws: WebSocket): boolean { + return tryGetPartyServerMeta(ws) !== null; +} + /** * Cache websocket attachments to avoid having to rehydrate them on every property access. */ @@ -180,6 +213,12 @@ class HibernatingConnectionIterator implements IterableIterator< while ((socket = sockets[this.index++])) { // only yield open sockets to match non-hibernating behaviour if (socket.readyState === WebSocket.READY_STATE_OPEN) { + // Durable Objects hibernation APIs allow storing arbitrary sockets via + // `state.acceptWebSocket()`. Those sockets won't have PartyServer's + // `__pk` attachment namespace and must be ignored. + if (!isPartyServerWebSocket(socket)) { + continue; + } const value = createLazyConnection(socket) as Connection; return { done: false, value }; } @@ -263,15 +302,25 @@ export class HibernatingConnectionManager implements ConnectionManager { constructor(private controller: DurableObjectState) {} getCount() { - return Number(this.controller.getWebSockets().length); + // Only count sockets managed by PartyServer. Other hibernated sockets may + // exist on the same Durable Object via `state.acceptWebSocket()`. + let count = 0; + for (const ws of this.controller.getWebSockets()) { + if (isPartyServerWebSocket(ws)) count++; + } + return count; } getConnection(id: string) { // TODO: Should we cache the connections? const sockets = this.controller.getWebSockets(id); - if (sockets.length === 0) return undefined; - if (sockets.length === 1) - return createLazyConnection(sockets[0]) as Connection; + const matching = sockets.filter((ws) => { + return tryGetPartyServerMeta(ws)?.id === id; + }); + + if (matching.length === 0) return undefined; + if (matching.length === 1) + return createLazyConnection(matching[0]) as Connection; throw new Error( `More than one connection found for id ${id}. Did you mean to use getConnections(tag) instead?` diff --git a/packages/partyserver/src/index.ts b/packages/partyserver/src/index.ts index a009a62..bb5f1a8 100644 --- a/packages/partyserver/src/index.ts +++ b/packages/partyserver/src/index.ts @@ -8,7 +8,8 @@ import { nanoid } from "nanoid"; import { createLazyConnection, HibernatingConnectionManager, - InMemoryConnectionManager + InMemoryConnectionManager, + isPartyServerWebSocket } from "./connection"; import type { ConnectionManager } from "./connection"; @@ -422,6 +423,13 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam return; } + // Ignore websockets accepted outside PartyServer (e.g. via + // `state.acceptWebSocket()` in user code). These sockets won't have the + // `__pk` attachment namespace required to rehydrate a Connection. + if (!isPartyServerWebSocket(ws)) { + return; + } + const connection = createLazyConnection(ws); // rehydrate the server name if it's woken up @@ -449,6 +457,10 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam return; } + if (!isPartyServerWebSocket(ws)) { + return; + } + const connection = createLazyConnection(ws); // rehydrate the server name if it's woken up @@ -470,6 +482,10 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam return; } + if (!isPartyServerWebSocket(ws)) { + return; + } + const connection = createLazyConnection(ws); // rehydrate the server name if it's woken up diff --git a/packages/partyserver/src/tests/index.test.ts b/packages/partyserver/src/tests/index.test.ts index 44991d9..4311be8 100644 --- a/packages/partyserver/src/tests/index.test.ts +++ b/packages/partyserver/src/tests/index.test.ts @@ -149,6 +149,47 @@ describe("Server", () => { expect(response.headers.get("Location")).toBe("https://example3.com"); }); + it("ignores foreign hibernated websockets when broadcasting", async () => { + const ctx = createExecutionContext(); + + // Create a websocket that is accepted via the DO hibernation API directly + // (no PartyServer `__pk` attachment). + const foreignReq = new Request( + "http://example.com/parties/mixed/room/foreign", + { + headers: { Upgrade: "websocket" } + } + ); + const foreignRes = await worker.fetch(foreignReq, env, ctx); + const foreignWs = foreignRes.webSocket!; + foreignWs.accept(); + + // Now connect via PartyServer. onConnect() will call broadcast(), which must + // not crash due to the foreign socket. + const req = new Request("http://example.com/parties/mixed/room", { + headers: { Upgrade: "websocket" } + }); + const res = await worker.fetch(req, env, ctx); + const ws = res.webSocket!; + ws.accept(); + + const { promise, resolve, reject } = Promise.withResolvers(); + ws.addEventListener("message", (message) => { + try { + // We should receive at least one message from the server. + expect(["hello", "connected"]).toContain(message.data); + resolve(); + } catch (e) { + reject(e); + } finally { + ws.close(); + foreignWs.close(); + } + }); + + return promise; + }); + // it("can be connected with a query parameter"); // it("can be connected with a header"); diff --git a/packages/partyserver/src/tests/worker.ts b/packages/partyserver/src/tests/worker.ts index 46468f9..d0c12b4 100644 --- a/packages/partyserver/src/tests/worker.ts +++ b/packages/partyserver/src/tests/worker.ts @@ -11,6 +11,7 @@ function assert(condition: unknown, message: string): asserts condition { export type Env = { Stateful: DurableObjectNamespace; OnStartServer: DurableObjectNamespace; + Mixed: DurableObjectNamespace; }; export class Stateful extends Server { @@ -61,6 +62,37 @@ export class OnStartServer extends Server { } } +export class Mixed extends Server { + static options = { + hibernate: true + }; + + async fetch(request: Request): Promise { + const url = new URL(request.url); + if (url.pathname.endsWith("/foreign")) { + const room = request.headers.get("x-partykit-room"); + if (room) { + await this.setName(room); + } + + const pair = new WebSocketPair(); + const [client, server] = Object.values(pair); + // Accept a hibernated websocket that PartyServer does not manage. This is + // equivalent to user code calling `this.ctx.acceptWebSocket()` directly. + this.ctx.acceptWebSocket(server, ["foreign"]); + return new Response(null, { status: 101, webSocket: client }); + } + + return super.fetch(request); + } + + onConnect(connection: Connection): void { + // Trigger a broadcast while a foreign hibernated socket exists. + this.broadcast("hello"); + connection.send("connected"); + } +} + export default { async fetch(request: Request, env: Env, _ctx: ExecutionContext) { return ( diff --git a/packages/partyserver/src/tests/wrangler.jsonc b/packages/partyserver/src/tests/wrangler.jsonc index 02ca823..3ed9315 100644 --- a/packages/partyserver/src/tests/wrangler.jsonc +++ b/packages/partyserver/src/tests/wrangler.jsonc @@ -18,13 +18,17 @@ { "name": "OnStartServer", "class_name": "OnStartServer" + }, + { + "name": "Mixed", + "class_name": "Mixed" } ] }, "migrations": [ { "tag": "v1", // Should be unique for each entry - "new_classes": ["Stateful", "OnStartServer"] + "new_classes": ["Stateful", "OnStartServer", "Mixed"] } ] } diff --git a/packages/partysocket/src/tests/integration.test.ts b/packages/partysocket/src/tests/integration.test.ts index 7fba7ba..682508c 100644 --- a/packages/partysocket/src/tests/integration.test.ts +++ b/packages/partysocket/src/tests/integration.test.ts @@ -163,7 +163,7 @@ describe("Integration - Full Lifecycle", () => { for (let i = 0; i < messageCount; i++) { expect(receivedMessages[i]).toBe(`message-${i}`); } - } catch (e) { + } catch (_e) { // If we still have Blobs, messages aren't fully processed yet return; } diff --git a/packages/partysocket/src/tests/react-hooks.test.tsx b/packages/partysocket/src/tests/react-hooks.test.tsx index b1598a5..ecf1798 100644 --- a/packages/partysocket/src/tests/react-hooks.test.tsx +++ b/packages/partysocket/src/tests/react-hooks.test.tsx @@ -9,7 +9,7 @@ import { WebSocketServer } from "ws"; import usePartySocket, { useWebSocket } from "../react"; const PORT = 50128; -const URL = `ws://localhost:${PORT}`; +// const URL = `ws://localhost:${PORT}`; describe("usePartySocket", () => { let wss: WebSocketServer; @@ -313,8 +313,12 @@ describe("usePartySocket", () => { test("attaches onOpen event handler", async () => { const onOpen = vitest.fn(); - wss.once("connection", (ws) => { - // Connection established + // Set up connection handler before rendering + const connectionPromise = new Promise((resolve) => { + wss.once("connection", (_ws: any) => { + // Connection established + resolve(); + }); }); const { result } = renderHook(() => @@ -325,13 +329,20 @@ describe("usePartySocket", () => { }) ); + // Wait for connection to be established on server side + await connectionPromise; + + // Wait for connection to be established on client side await waitFor( () => { - expect(onOpen).toHaveBeenCalled(); + expect(result.current.readyState).toBe(WebSocket.OPEN); }, { timeout: 3000 } ); + // Verify onOpen was called + expect(onOpen).toHaveBeenCalled(); + result.current.close(); }); @@ -339,9 +350,13 @@ describe("usePartySocket", () => { const onMessage = vitest.fn(); const testMessage = "hello from server"; - wss.once("connection", (ws) => { - ws.send(testMessage); - }); + const connectionHandler = (ws: any) => { + // Send message after a small delay to ensure connection is fully established + setTimeout(() => { + ws.send(testMessage); + }, 50); + }; + wss.on("connection", connectionHandler); const { result } = renderHook(() => usePartySocket({ @@ -351,6 +366,7 @@ describe("usePartySocket", () => { }) ); + // Wait for message to be received await waitFor( () => { expect(onMessage).toHaveBeenCalled(); @@ -360,6 +376,7 @@ describe("usePartySocket", () => { { timeout: 3000 } ); + wss.off("connection", connectionHandler); result.current.close(); }); @@ -367,7 +384,8 @@ describe("usePartySocket", () => { const onClose = vitest.fn(); wss.once("connection", (ws) => { - setTimeout(() => ws.close(), 50); + // Wait for connection to be fully established before closing + setTimeout(() => ws.close(), 100); }); const { result } = renderHook(() => @@ -378,14 +396,21 @@ describe("usePartySocket", () => { }) ); + // Wait for connection to be established first await waitFor( () => { - expect(onClose).toHaveBeenCalled(); + expect(result.current.readyState).toBe(WebSocket.OPEN); }, { timeout: 3000 } ); - result.current.close(); + // Then wait for close event + await waitFor( + () => { + expect(onClose).toHaveBeenCalled(); + }, + { timeout: 3000 } + ); }); test("attaches onError event handler", async () => { @@ -414,10 +439,12 @@ describe("usePartySocket", () => { const onMessage1 = vitest.fn(); const onMessage2 = vitest.fn(); - wss.once("connection", (ws) => { - setTimeout(() => ws.send("message1"), 50); - setTimeout(() => ws.send("message2"), 100); - }); + const connectionHandler = (ws: any) => { + // Send messages with delays to ensure connection is established + setTimeout(() => ws.send("message1"), 100); + setTimeout(() => ws.send("message2"), 200); + }; + wss.on("connection", connectionHandler); const { result, rerender } = renderHook( ({ onMessage }) => @@ -453,6 +480,7 @@ describe("usePartySocket", () => { { timeout: 3000 } ); + wss.off("connection", connectionHandler); result.current.close(); }); @@ -485,7 +513,7 @@ describe("usePartySocket", () => { }); test("connects automatically when startClosed is false", async () => { - wss.once("connection", (ws) => { + wss.once("connection", (_ws) => { // Connection established }); diff --git a/packages/partysocket/src/tests/react-ssr.test.tsx b/packages/partysocket/src/tests/react-ssr.test.tsx index 15786cf..c79ab66 100644 --- a/packages/partysocket/src/tests/react-ssr.test.tsx +++ b/packages/partysocket/src/tests/react-ssr.test.tsx @@ -3,7 +3,7 @@ */ import { renderToString } from "react-dom/server"; import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; -import { WebSocketServer } from "ws"; +import { type WebSocket as NodeWebSocket, WebSocketServer } from "ws"; import { usePartySocket, useWebSocket } from "../react"; const PORT = 50135; @@ -25,9 +25,16 @@ describe("SSR/Node.js Environment - usePartySocket", () => { }); afterEach(() => { - wss.close(); - global.window = originalWindow; - global.document = originalDocument; + return new Promise((resolve) => { + wss.clients.forEach((client: NodeWebSocket) => { + client.terminate(); + }); + wss.close(() => { + global.window = originalWindow; + global.document = originalDocument; + resolve(); + }); + }); }); it("should use default host when window is not available", () => { @@ -151,7 +158,7 @@ describe("SSR/Node.js Environment - usePartySocket", () => { it("should handle query params in SSR", () => { function TestComponent() { - const socket = usePartySocket({ + const _socket = usePartySocket({ host: "example.com", room: "test-room", query: { token: "abc123" }, @@ -167,7 +174,7 @@ describe("SSR/Node.js Environment - usePartySocket", () => { it("should handle async query params in SSR", () => { function TestComponent() { - const socket = usePartySocket({ + const _socket = usePartySocket({ host: "example.com", room: "test-room", query: async () => ({ token: "abc123" }), @@ -219,9 +226,16 @@ describe("SSR/Node.js Environment - useWebSocket", () => { }); afterEach(() => { - wss.close(); - global.window = originalWindow; - global.document = originalDocument; + return new Promise((resolve) => { + wss.clients.forEach((client: NodeWebSocket) => { + client.terminate(); + }); + wss.close(() => { + global.window = originalWindow; + global.document = originalDocument; + resolve(); + }); + }); }); it("should render with string URL in SSR", () => { @@ -273,7 +287,7 @@ describe("SSR/Node.js Environment - useWebSocket", () => { it("should handle protocols array in SSR", () => { function TestComponent() { - const socket = useWebSocket( + const _socket = useWebSocket( `ws://localhost:${PORT + 1}`, ["protocol1", "protocol2"], { @@ -290,7 +304,7 @@ describe("SSR/Node.js Environment - useWebSocket", () => { it("should handle protocol function in SSR", () => { function TestComponent() { - const socket = useWebSocket( + const _socket = useWebSocket( `ws://localhost:${PORT + 1}`, () => "protocol1", { @@ -307,7 +321,7 @@ describe("SSR/Node.js Environment - useWebSocket", () => { it("should handle async protocol in SSR", () => { function TestComponent() { - const socket = useWebSocket( + const _socket = useWebSocket( `ws://localhost:${PORT + 1}`, async () => "protocol1", { diff --git a/packages/partysocket/src/tests/reconnecting-node.test.ts b/packages/partysocket/src/tests/reconnecting-node.test.ts index b27c3eb..9272654 100644 --- a/packages/partysocket/src/tests/reconnecting-node.test.ts +++ b/packages/partysocket/src/tests/reconnecting-node.test.ts @@ -56,7 +56,7 @@ afterEach(() => { afterAll(() => { return new Promise((resolve) => { - wss.clients.forEach((client) => { + wss.clients.forEach((client: NodeWebSocket) => { client.terminate(); }); wss.close(() => { diff --git a/packages/partysocket/src/tests/reconnecting.test.ts b/packages/partysocket/src/tests/reconnecting.test.ts index cfa0c8e..723ce7e 100644 --- a/packages/partysocket/src/tests/reconnecting.test.ts +++ b/packages/partysocket/src/tests/reconnecting.test.ts @@ -41,7 +41,7 @@ afterEach(() => { afterAll(() => { return new Promise((resolve) => { - wss.clients.forEach((client) => { + wss.clients.forEach((client: NodeWebSocket) => { client.terminate(); }); wss.close(() => { @@ -475,13 +475,13 @@ testDone("start closed", (done, fail) => { const anyMessageText = "hello"; const anyProtocol = "foobar"; - wss.once("connection", (ws) => { - void ws.once("message", (msg) => { + wss.once("connection", (ws: NodeWebSocket) => { + void ws.once("message", (msg: Buffer) => { ws.send(msg); }); }); - wss.once("error", (e) => { + wss.once("error", (e: Error) => { fail(e); }); @@ -526,13 +526,13 @@ testDone("connect, send, receive, close", (done, fail) => { const anyMessageText = "hello"; const anyProtocol = "foobar"; - wss.once("connection", (ws) => { - void ws.once("message", (msg) => { + wss.once("connection", (ws: NodeWebSocket) => { + void ws.once("message", (msg: Buffer) => { ws.send(msg); }); }); - wss.on("error", (e) => { + wss.on("error", (e: Error) => { fail(e); }); @@ -739,8 +739,8 @@ testDone( expect(ws.bufferedAmount).toBe(messages.reduce((a, m) => a + m.length, 0)); let i = 0; - wss.once("connection", (client) => { - client.on("message", (data) => { + wss.once("connection", (client: NodeWebSocket) => { + client.on("message", (data: Buffer) => { // eslint-disable-next-line @typescript-eslint/no-base-to-string if (data.toString() === "ok") { expect(i).toBe(messages.length);