diff --git a/.changeset/evil-carrots-grab.md b/.changeset/evil-carrots-grab.md new file mode 100644 index 0000000..684882f --- /dev/null +++ b/.changeset/evil-carrots-grab.md @@ -0,0 +1,5 @@ +--- +"partyserver": patch +--- + +remove redundant initialize code as setName takes care of it, along with the nested blockConcurrencyWhile call diff --git a/packages/partyserver/src/index.ts b/packages/partyserver/src/index.ts index b82313b..9a17c49 100644 --- a/packages/partyserver/src/index.ts +++ b/packages/partyserver/src/index.ts @@ -483,12 +483,6 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam await this.setName(connection.server); // TODO: ^ this shouldn't be async - if (this.#status !== "started") { - // This means the server "woke up" after hibernation - // so we need to hydrate it again - await this.#initialize(); - } - return this.onMessage(connection, message); } @@ -508,11 +502,6 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam await this.setName(connection.server); // TODO: ^ this shouldn't be async - if (this.#status !== "started") { - // This means the server "woke up" after hibernation - // so we need to hydrate it again - await this.#initialize(); - } return this.onClose(connection, code, reason, wasClean); } @@ -527,11 +516,6 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam await this.setName(connection.server); // TODO: ^ this shouldn't be async - if (this.#status !== "started") { - // This means the server "woke up" after hibernation - // so we need to hydrate it again - await this.#initialize(); - } return this.onError(connection, error); } @@ -612,9 +596,7 @@ Did you try connecting directly to this Durable Object? Try using getServerByNam this.#_name = name; if (this.#status !== "started") { - await this.ctx.blockConcurrencyWhile(async () => { - await this.#initialize(); - }); + await this.#initialize(); } } diff --git a/packages/partyserver/src/tests/index.test.ts b/packages/partyserver/src/tests/index.test.ts index 127ba26..358dff7 100644 --- a/packages/partyserver/src/tests/index.test.ts +++ b/packages/partyserver/src/tests/index.test.ts @@ -1,6 +1,7 @@ import { createExecutionContext, - env + env, + runDurableObjectAlarm // waitOnExecutionContext } from "cloudflare:test"; import { describe, expect, it } from "vitest"; @@ -301,6 +302,143 @@ describe("Server", () => { // describe("in-memory"); }); +describe("Hibernating Server (setName handles initialization)", () => { + it("calls onStart before processing connections", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/parties/hibernating-on-start-server/h-test1", + { + headers: { Upgrade: "websocket" } + } + ); + const response = await worker.fetch(request, env, ctx); + const ws = response.webSocket!; + ws.accept(); + + const { promise, resolve, reject } = Promise.withResolvers(); + ws.addEventListener("message", (message) => { + try { + // counter should be 1 because onStart completed before onConnect + expect(message.data).toEqual("1"); + resolve(); + } catch (e) { + reject(e); + } finally { + ws.close(); + } + }); + + return promise; + }); + + it("calls onStart only once with concurrent connections and requests", async () => { + const ctx = createExecutionContext(); + + async function makeConnection() { + const request = new Request( + "http://example.com/parties/hibernating-on-start-server/h-test2", + { + headers: { Upgrade: "websocket" } + } + ); + const response = await worker.fetch(request, env, ctx); + const ws = response.webSocket!; + ws.accept(); + const { promise, resolve, reject } = Promise.withResolvers(); + ws.addEventListener("message", (message) => { + try { + expect(message.data).toEqual("1"); + resolve(); + } catch (e) { + reject(e); + } finally { + ws.close(); + } + }); + return promise; + } + + async function makeRequest() { + const request = new Request( + "http://example.com/parties/hibernating-on-start-server/h-test2" + ); + const response = await worker.fetch(request, env, ctx); + expect(await response.text()).toEqual("1"); + } + + await Promise.all([makeConnection(), makeConnection(), makeRequest()]); + }); + + it("handles websocket messages after initialization", async () => { + const ctx = createExecutionContext(); + const request = new Request( + "http://example.com/parties/hibernating-on-start-server/h-test3", + { + headers: { Upgrade: "websocket" } + } + ); + const response = await worker.fetch(request, env, ctx); + const ws = response.webSocket!; + ws.accept(); + + // Wait for the onConnect message + const connectMessage = await new Promise((resolve) => { + ws.addEventListener("message", (e) => resolve(e.data as string), { + once: true + }); + }); + expect(connectMessage).toEqual("1"); + + // Send a message and verify the server is still initialized + ws.send("hello"); + const echoMessage = await new Promise((resolve) => { + ws.addEventListener("message", (e) => resolve(e.data as string), { + once: true + }); + }); + expect(echoMessage).toEqual("counter:1"); + + ws.close(); + }); +}); + +describe("Alarm (initialize without redundant blockConcurrencyWhile)", () => { + it("properly initializes on alarm and calls onAlarm", async () => { + // Use a single stub for the entire test so runDurableObjectAlarm + // sees the same DO instance that has the alarm scheduled. + const id = env.AlarmServer.idFromName("alarm-test1"); + const stub = env.AlarmServer.get(id); + + // Initialize the DO and schedule an alarm in one request + const res = await stub.fetch( + new Request( + "http://example.com/parties/alarm-server/alarm-test1?setAlarm=1", + { + headers: { "x-partykit-room": "alarm-test1" } + } + ) + ); + expect(await res.text()).toEqual("alarm set"); + + // Trigger the alarm + const ran = await runDurableObjectAlarm(stub); + expect(ran).toBe(true); + + // Verify state: onStart called once, alarm was triggered once + const stateRes = await stub.fetch( + new Request("http://example.com/", { + headers: { "x-partykit-room": "alarm-test1" } + }) + ); + const state = (await stateRes.json()) as { + counter: number; + alarmCount: number; + }; + expect(state.counter).toEqual(1); + expect(state.alarmCount).toEqual(1); + }); +}); + describe("CORS", () => { it("returns CORS headers on OPTIONS preflight for matched routes", async () => { const ctx = createExecutionContext(); diff --git a/packages/partyserver/src/tests/worker.ts b/packages/partyserver/src/tests/worker.ts index a84f2c3..c9c9012 100644 --- a/packages/partyserver/src/tests/worker.ts +++ b/packages/partyserver/src/tests/worker.ts @@ -1,6 +1,6 @@ import { routePartykitRequest, Server } from "../index"; -import type { Connection, ConnectionContext } from "../index"; +import type { Connection, ConnectionContext, WSMessage } from "../index"; function assert(condition: unknown, message: string): asserts condition { if (!condition) { @@ -11,6 +11,8 @@ function assert(condition: unknown, message: string): asserts condition { export type Env = { Stateful: DurableObjectNamespace; OnStartServer: DurableObjectNamespace; + HibernatingOnStartServer: DurableObjectNamespace; + AlarmServer: DurableObjectNamespace; Mixed: DurableObjectNamespace; ConfigurableState: DurableObjectNamespace; ConfigurableStateInMemory: DurableObjectNamespace; @@ -67,6 +69,75 @@ export class OnStartServer extends Server { } } +/** + * Like OnStartServer but with hibernate: true. + * Tests that setName properly initializes the server in the + * hibernating websocket handler path (webSocketMessage, webSocketClose, etc.) + */ +export class HibernatingOnStartServer extends Server { + static options = { + hibernate: true + }; + + counter = 0; + + async onStart() { + assert(this.name, "name is not available inside onStart"); + await new Promise((resolve) => { + setTimeout(() => { + this.counter++; + resolve(); + }, 300); + }); + } + + onConnect(connection: Connection) { + connection.send(this.counter.toString()); + } + + onMessage(connection: Connection, _message: WSMessage) { + connection.send(`counter:${this.counter}`); + } + + onRequest(): Response { + return new Response(this.counter.toString()); + } +} + +/** + * Tests that alarm() properly initializes the server + * without the redundant blockConcurrencyWhile wrapper. + */ +export class AlarmServer extends Server { + static options = { + hibernate: true + }; + + counter = 0; + alarmCount = 0; + + async onStart() { + this.counter++; + } + + onAlarm() { + this.alarmCount++; + } + + async onRequest(request: Request): Promise { + const url = new URL(request.url); + if (url.searchParams.get("setAlarm")) { + // Schedule alarm far in the future so it won't auto-fire + await this.ctx.storage.setAlarm(Date.now() + 60_000); + return new Response("alarm set"); + } + return Response.json({ + counter: this.counter, + alarmCount: this.alarmCount + }); + } +} + export class Mixed extends Server { static options = { hibernate: true diff --git a/packages/partyserver/src/tests/wrangler.jsonc b/packages/partyserver/src/tests/wrangler.jsonc index d3ca216..7376d1d 100644 --- a/packages/partyserver/src/tests/wrangler.jsonc +++ b/packages/partyserver/src/tests/wrangler.jsonc @@ -42,6 +42,14 @@ { "name": "CustomCorsServer", "class_name": "CustomCorsServer" + }, + { + "name": "HibernatingOnStartServer", + "class_name": "HibernatingOnStartServer" + }, + { + "name": "AlarmServer", + "class_name": "AlarmServer" } ] }, @@ -61,6 +69,10 @@ { "tag": "v3", "new_classes": ["CorsServer", "CustomCorsServer"] + }, + { + "tag": "v4", + "new_classes": ["HibernatingOnStartServer", "AlarmServer"] } ] }