From de87a8b26022466c30c7db8c4cdf447a3bb5d8ca Mon Sep 17 00:00:00 2001 From: Nathan Flurry Date: Tue, 14 Apr 2026 15:16:12 -0700 Subject: [PATCH] fix(rivetkit): restore hibernatable sockets and hydrate serverless starts --- .../tests/actor-conn-hibernation.ts | 20 +- .../driver-test-suite/tests/actor-sleep.ts | 371 ++++---- .../tests/gateway-routing.ts | 16 +- .../src/drivers/engine/actor-driver.ts | 825 +++++++++++++++++- 4 files changed, 979 insertions(+), 253 deletions(-) diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-conn-hibernation.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-conn-hibernation.ts index 9a1bd358fa..5c22952a05 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-conn-hibernation.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-conn-hibernation.ts @@ -3,12 +3,18 @@ import { HIBERNATION_SLEEP_TIMEOUT } from "../../../fixtures/driver-test-suite/h import type { DriverTestConfig } from "../mod"; import { setupDriverTest, waitFor } from "../utils"; +async function waitForHibernatableRegistration( + driverTestConfig: DriverTestConfig, +): Promise { + await waitFor(driverTestConfig, 100); +} + export function runActorConnHibernationTests( driverTestConfig: DriverTestConfig, ) { - describe.skipIf(driverTestConfig.skip?.hibernation)( - "Connection Hibernation", - () => { + describe + .skipIf(driverTestConfig.skip?.hibernation) + .sequential("Connection Hibernation", () => { test("basic conn hibernation", async (c) => { const { client } = await setupDriverTest(c, driverTestConfig); @@ -20,6 +26,7 @@ export function runActorConnHibernationTests( // Initial RPC call const ping1 = await hibernatingActor.ping(); expect(ping1).toBe("pong"); + await waitForHibernatableRegistration(driverTestConfig); // Trigger sleep await hibernatingActor.triggerSleep(); @@ -64,6 +71,7 @@ export function runActorConnHibernationTests( await hibernatingActor.getActorCounts(); expect(initialActorCounts.wakeCount).toBe(1); expect(initialActorCounts.sleepCount).toBe(0); + await waitForHibernatableRegistration(driverTestConfig); // Trigger sleep await hibernatingActor.triggerSleep(); @@ -113,6 +121,7 @@ export function runActorConnHibernationTests( }); for (let i = 0; i < 2; i++) { + await waitForHibernatableRegistration(driverTestConfig); await hibernatingActor.triggerSleep(); await waitFor( driverTestConfig, @@ -140,6 +149,7 @@ export function runActorConnHibernationTests( // Initial RPC call await conn1.ping(); + await waitForHibernatableRegistration(driverTestConfig); // Get connection ID const connectionIds = await conn1.getConnectionIds(); @@ -196,6 +206,7 @@ export function runActorConnHibernationTests( await vi.waitFor(async () => { expect(connection.isConnected).toBe(true); }); + await waitForHibernatableRegistration(driverTestConfig); const sleepingPromise = new Promise((resolve) => { connection.once("sleeping", () => { @@ -241,6 +252,5 @@ export function runActorConnHibernationTests( await connection.dispose(); } }); - }, - ); + }); } diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-sleep.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-sleep.ts index 452ba90520..1d48f53cce 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-sleep.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/actor-sleep.ts @@ -8,6 +8,16 @@ import { import type { DriverTestConfig } from "../mod"; import { setupDriverTest, waitFor } from "../utils"; +const SLEEP_TEST_TIMEOUT = 90_000; +const SLEEP_CYCLE_WAIT_MS = SLEEP_TIMEOUT * 2 + 250; +const RAW_WS_SLEEP_CYCLE_WAIT_MS = + RAW_WS_HANDLER_SLEEP_TIMEOUT + RAW_WS_HANDLER_DELAY + 250; + +type SleepSnapshot = { + startCount: number; + sleepCount: number; +}; + async function waitForRawWebSocketMessage(ws: WebSocket) { return await new Promise((resolve, reject) => { const onMessage = (event: MessageEvent) => { @@ -62,6 +72,52 @@ async function closeRawWebSocket(ws: WebSocket) { }); } +async function waitForSleepCycle( + driverTestConfig: DriverTestConfig, + ms: number = SLEEP_CYCLE_WAIT_MS, +) { + await waitFor(driverTestConfig, ms); +} + +async function readAfterSleepCycle( + driverTestConfig: DriverTestConfig, + read: () => Promise, + options?: { + maxAttempts?: number; + minSleepCount?: number; + minStartCount?: number; + waitMs?: number; + }, +): Promise { + const maxAttempts = options?.maxAttempts ?? 3; + const minSleepCount = options?.minSleepCount ?? 1; + const minStartCount = options?.minStartCount ?? minSleepCount + 1; + const waitMs = options?.waitMs ?? SLEEP_CYCLE_WAIT_MS; + let lastError: unknown; + let lastSnapshot: T | undefined; + + for (let attempt = 0; attempt < maxAttempts; attempt += 1) { + await waitForSleepCycle(driverTestConfig, waitMs); + + try { + const snapshot = await read(); + lastSnapshot = snapshot; + if ( + snapshot.sleepCount >= minSleepCount && + snapshot.startCount >= minStartCount + ) { + return snapshot; + } + } catch (error) { + lastError = error; + } + } + + throw new Error( + `timed out waiting for actor sleep cycle: lastSnapshot=${JSON.stringify(lastSnapshot)} lastError=${String(lastError)}`, + ); +} + // TODO: These tests are broken with fake timers because `_sleep` requires // background async promises that have a race condition with calling // `getCounts` @@ -74,7 +130,7 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { ? describe.skip : describe.sequential; - describeSleepTests("Actor Sleep Tests", () => { + describeSleepTests("Actor Sleep Tests", { timeout: SLEEP_TEST_TIMEOUT }, () => { test("actor sleep persists state", async (c) => { const { client } = await setupDriverTest(c, driverTestConfig); @@ -144,15 +200,12 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { expect(startCount).toBe(1); } - // Wait for sleep - await waitFor(driverTestConfig, SLEEP_TIMEOUT + 250); - - // Get sleep count after restore - { - const { startCount, sleepCount } = await sleepActor.getCounts(); - expect(sleepCount).toBe(1); - expect(startCount).toBe(2); - } + const { startCount, sleepCount } = await readAfterSleepCycle( + driverTestConfig, + () => sleepActor.getCounts(), + ); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); test("actor automatically sleeps after timeout with connect", async (c) => { @@ -171,17 +224,13 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { // Disconnect to allow actor to sleep await sleepActor.dispose(); - // Wait for sleep - await waitFor(driverTestConfig, SLEEP_TIMEOUT + 250); - - // Reconnect to get sleep count after restore const sleepActor2 = client.sleep.getOrCreate(); - { - const { startCount, sleepCount } = - await sleepActor2.getCounts(); - expect(sleepCount).toBe(1); - expect(startCount).toBe(2); - } + const { startCount, sleepCount } = await readAfterSleepCycle( + driverTestConfig, + () => sleepActor2.getCounts(), + ); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); test("waitUntil can broadcast before sleep disconnect", async (c) => { @@ -207,16 +256,14 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { await sleepActor.dispose(); - await waitFor(driverTestConfig, 250); - const sleepActor2 = client.sleepWithWaitUntilMessage.getOrCreate(); - { - const { startCount, sleepCount, waitUntilMessageCount } = - await sleepActor2.getCounts(); - expect(waitUntilMessageCount).toBe(1); - expect(sleepCount).toBe(1); - expect(startCount).toBe(2); - } + const { startCount, sleepCount, waitUntilMessageCount } = + await readAfterSleepCycle(driverTestConfig, () => + sleepActor2.getCounts(), + ); + expect(waitUntilMessageCount).toBe(1); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); test("waitUntil works in onWake", async (c) => { @@ -233,15 +280,13 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { // Trigger sleep so the waitUntil promise drains before persisting await sleepActor.triggerSleep(); - await waitFor(driverTestConfig, 250); - // After sleep and wake, verify the waitUntil promise completed - { - const status = await sleepActor.getStatus(); - expect(status.sleepCount).toBe(1); - expect(status.startCount).toBe(2); - expect(status.waitUntilCompleted).toBe(true); - } + const status = await readAfterSleepCycle(driverTestConfig, () => + sleepActor.getStatus(), + ); + expect(status.sleepCount).toBeGreaterThanOrEqual(1); + expect(status.startCount).toBe(status.sleepCount + 1); + expect(status.waitUntilCompleted).toBe(true); }); test("rpc calls keep actor awake", async (c) => { @@ -277,15 +322,12 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { expect(startCount).toBe(1); // Still the same instance } - // Now wait for full timeout without any RPC calls - await waitFor(driverTestConfig, SLEEP_TIMEOUT + 250); - - // Actor should have slept and restarted - { - const { startCount, sleepCount } = await sleepActor.getCounts(); - expect(sleepCount).toBe(1); // Slept once - expect(startCount).toBe(2); // New instance after sleep - } + const { startCount, sleepCount } = await readAfterSleepCycle( + driverTestConfig, + () => sleepActor.getCounts(), + ); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); test("alarms keep actor awake", async (c) => { @@ -331,15 +373,13 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { // Set an alarm to keep the actor awake await sleepActor.setAlarm(SLEEP_TIMEOUT + 250); - // Wait until after SLEEPT_IMEOUT to validate the actor did not sleep - await waitFor(driverTestConfig, SLEEP_TIMEOUT + 200); - - // Actor should not have slept - { - const { startCount, sleepCount } = await sleepActor.getCounts(); - expect(sleepCount).toBe(1); - expect(startCount).toBe(2); - } + const { startCount, sleepCount } = await readAfterSleepCycle( + driverTestConfig, + () => sleepActor.getCounts(), + { waitMs: SLEEP_TIMEOUT + 500 }, + ); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); test("long running rpcs keep actor awake", async (c) => { @@ -373,17 +413,13 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { } await sleepActor.dispose(); - // Now wait for the sleep timeout - await waitFor(driverTestConfig, SLEEP_TIMEOUT + 250); - - // Actor should have slept after the timeout const sleepActor2 = client.sleepWithLongRpc.getOrCreate(); - { - const { startCount, sleepCount } = - await sleepActor2.getCounts(); - expect(sleepCount).toBe(1); // Slept once - expect(startCount).toBe(2); // New instance after sleep - } + const { startCount, sleepCount } = await readAfterSleepCycle( + driverTestConfig, + () => sleepActor2.getCounts(), + ); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); test("active raw websockets keep actor awake", async (c) => { @@ -442,15 +478,12 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { // Close WebSocket ws.close(); - // Wait for sleep timeout after WebSocket closed - await waitFor(driverTestConfig, SLEEP_TIMEOUT + 250); - - // Actor should have slept after WebSocket closed - { - const { startCount, sleepCount } = await sleepActor.getCounts(); - expect(sleepCount).toBe(1); // Slept once - expect(startCount).toBe(2); // New instance after sleep - } + const { startCount, sleepCount } = await readAfterSleepCycle( + driverTestConfig, + () => sleepActor.getCounts(), + ); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); test("active raw fetch requests keep actor awake", async (c) => { @@ -487,15 +520,12 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { expect(requestCount).toBe(1); } - // Wait for sleep timeout - await waitFor(driverTestConfig, SLEEP_TIMEOUT + 250); - - // Actor should have slept after timeout - { - const { startCount, sleepCount } = await sleepActor.getCounts(); - expect(sleepCount).toBe(1); // Slept once - expect(startCount).toBe(2); // New instance after sleep - } + const { startCount, sleepCount } = await readAfterSleepCycle( + driverTestConfig, + () => sleepActor.getCounts(), + ); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); test("noSleep option disables sleeping", async (c) => { @@ -556,14 +586,13 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { } expect(await sleepActor.setPreventSleep(false)).toBe(false); - await waitFor(driverTestConfig, SLEEP_TIMEOUT + 250); - { - const status = await sleepActor.getStatus(); - expect(status.sleepCount).toBe(1); - expect(status.startCount).toBe(2); - expect(status.preventSleep).toBe(false); - } + const status = await readAfterSleepCycle(driverTestConfig, () => + sleepActor.getStatus(), + ); + expect(status.sleepCount).toBeGreaterThanOrEqual(1); + expect(status.startCount).toBe(status.sleepCount + 1); + expect(status.preventSleep).toBe(false); }); test("preventSleep delays shutdown until cleared", async (c) => { @@ -577,16 +606,16 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { await sleepActor.setDelayPreventSleepDuringShutdown(true), ).toBe(true); await sleepActor.triggerSleep(); - await waitFor(driverTestConfig, PREVENT_SLEEP_TIMEOUT + 150); - { - const status = await sleepActor.getStatus(); - expect(status.sleepCount).toBe(1); - expect(status.startCount).toBe(2); - expect(status.preventSleep).toBe(false); - expect(status.delayPreventSleepDuringShutdown).toBe(true); - expect(status.preventSleepClearedDuringShutdown).toBe(true); - } + const status = await readAfterSleepCycle(driverTestConfig, () => + sleepActor.getStatus(), + { waitMs: PREVENT_SLEEP_TIMEOUT + 500 }, + ); + expect(status.sleepCount).toBeGreaterThanOrEqual(1); + expect(status.startCount).toBe(status.sleepCount + 1); + expect(status.preventSleep).toBe(false); + expect(status.delayPreventSleepDuringShutdown).toBe(true); + expect(status.preventSleepClearedDuringShutdown).toBe(true); }); test("preventSleep can be restored during onWake", async (c) => { @@ -597,12 +626,13 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { expect(await sleepActor.setPreventSleepOnWake(true)).toBe(true); await sleepActor.triggerSleep(); - await waitFor(driverTestConfig, 250); { - const status = await sleepActor.getStatus(); - expect(status.sleepCount).toBe(1); - expect(status.startCount).toBe(2); + const status = await readAfterSleepCycle(driverTestConfig, () => + sleepActor.getStatus(), + ); + expect(status.sleepCount).toBeGreaterThanOrEqual(1); + expect(status.startCount).toBe(status.sleepCount + 1); expect(status.preventSleep).toBe(true); expect(status.preventSleepOnWake).toBe(true); } @@ -620,12 +650,13 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { expect(await sleepActor.setPreventSleepOnWake(false)).toBe(false); expect(await sleepActor.setPreventSleep(false)).toBe(false); - await waitFor(driverTestConfig, SLEEP_TIMEOUT + 250); - { - const status = await sleepActor.getStatus(); - expect(status.sleepCount).toBe(2); - expect(status.startCount).toBe(3); + const status = await readAfterSleepCycle(driverTestConfig, () => + sleepActor.getStatus(), + { minSleepCount: 2, minStartCount: 3 }, + ); + expect(status.sleepCount).toBeGreaterThanOrEqual(2); + expect(status.startCount).toBe(status.sleepCount + 1); expect(status.preventSleep).toBe(false); expect(status.preventSleepOnWake).toBe(false); } @@ -643,25 +674,14 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { expect(message.type).toBe("message-started"); await closeRawWebSocket(ws); - await waitFor(driverTestConfig, RAW_WS_HANDLER_SLEEP_TIMEOUT + 75); { - const status = await actor.getStatus(); - expect(status.startCount).toBe(1); - expect(status.sleepCount).toBe(0); - expect(status.messageStarted).toBe(1); - expect(status.messageFinished).toBe(0); - } - - await waitFor( - driverTestConfig, - RAW_WS_HANDLER_DELAY + RAW_WS_HANDLER_SLEEP_TIMEOUT + 150, - ); - - { - const status = await actor.getStatus(); - expect(status.startCount).toBe(2); - expect(status.sleepCount).toBe(1); + const status = await readAfterSleepCycle(driverTestConfig, () => + actor.getStatus(), + { waitMs: RAW_WS_SLEEP_CYCLE_WAIT_MS }, + ); + expect(status.startCount).toBe(status.sleepCount + 1); + expect(status.sleepCount).toBeGreaterThanOrEqual(1); expect(status.messageStarted).toBe(1); expect(status.messageFinished).toBe(1); } @@ -678,25 +698,14 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { expect(message.type).toBe("message-started"); await closeRawWebSocket(ws); - await waitFor(driverTestConfig, RAW_WS_HANDLER_SLEEP_TIMEOUT + 75); - - { - const status = await actor.getStatus(); - expect(status.startCount).toBe(1); - expect(status.sleepCount).toBe(0); - expect(status.messageStarted).toBe(1); - expect(status.messageFinished).toBe(0); - } - - await waitFor( - driverTestConfig, - RAW_WS_HANDLER_DELAY + RAW_WS_HANDLER_SLEEP_TIMEOUT + 150, - ); { - const status = await actor.getStatus(); - expect(status.startCount).toBe(2); - expect(status.sleepCount).toBe(1); + const status = await readAfterSleepCycle(driverTestConfig, () => + actor.getStatus(), + { waitMs: RAW_WS_SLEEP_CYCLE_WAIT_MS }, + ); + expect(status.startCount).toBe(status.sleepCount + 1); + expect(status.sleepCount).toBeGreaterThanOrEqual(1); expect(status.messageStarted).toBe(1); expect(status.messageFinished).toBe(1); } @@ -709,25 +718,14 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { const ws = await connectRawWebSocket(actor); await closeRawWebSocket(ws); - await waitFor(driverTestConfig, RAW_WS_HANDLER_SLEEP_TIMEOUT + 75); { - const status = await actor.getStatus(); - expect(status.startCount).toBe(1); - expect(status.sleepCount).toBe(0); - expect(status.closeStarted).toBe(1); - expect(status.closeFinished).toBe(0); - } - - await waitFor( - driverTestConfig, - RAW_WS_HANDLER_DELAY + RAW_WS_HANDLER_SLEEP_TIMEOUT + 150, - ); - - { - const status = await actor.getStatus(); - expect(status.startCount).toBe(2); - expect(status.sleepCount).toBe(1); + const status = await readAfterSleepCycle(driverTestConfig, () => + actor.getStatus(), + { waitMs: RAW_WS_SLEEP_CYCLE_WAIT_MS }, + ); + expect(status.startCount).toBe(status.sleepCount + 1); + expect(status.sleepCount).toBeGreaterThanOrEqual(1); expect(status.closeStarted).toBe(1); expect(status.closeFinished).toBe(1); } @@ -740,25 +738,14 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { const ws = await connectRawWebSocket(actor); await closeRawWebSocket(ws); - await waitFor(driverTestConfig, RAW_WS_HANDLER_SLEEP_TIMEOUT + 75); - - { - const status = await actor.getStatus(); - expect(status.startCount).toBe(1); - expect(status.sleepCount).toBe(0); - expect(status.closeStarted).toBe(1); - expect(status.closeFinished).toBe(0); - } - - await waitFor( - driverTestConfig, - RAW_WS_HANDLER_DELAY + RAW_WS_HANDLER_SLEEP_TIMEOUT + 150, - ); { - const status = await actor.getStatus(); - expect(status.startCount).toBe(2); - expect(status.sleepCount).toBe(1); + const status = await readAfterSleepCycle(driverTestConfig, () => + actor.getStatus(), + { waitMs: RAW_WS_SLEEP_CYCLE_WAIT_MS }, + ); + expect(status.startCount).toBe(status.sleepCount + 1); + expect(status.sleepCount).toBeGreaterThanOrEqual(1); expect(status.closeStarted).toBe(1); expect(status.closeFinished).toBe(1); } @@ -814,16 +801,12 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { // Close the WebSocket from client side ws.close(); - // Wait for sleep to fully complete - await waitFor(driverTestConfig, 500); - - // Verify sleep happened - { - const { startCount, sleepCount } = - await sleepActor.getCounts(); - expect(sleepCount).toBe(1); - expect(startCount).toBe(2); - } + const { startCount, sleepCount } = await readAfterSleepCycle( + driverTestConfig, + () => sleepActor.getCounts(), + ); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); test("onSleep sends delayed message to raw websocket", async (c) => { @@ -877,16 +860,12 @@ export function runActorSleepTests(driverTestConfig: DriverTestConfig) { // Close the WebSocket from client side ws.close(); - // Wait for sleep to fully complete - await waitFor(driverTestConfig, 500); - - // Verify sleep happened - { - const { startCount, sleepCount } = - await sleepActor.getCounts(); - expect(sleepCount).toBe(1); - expect(startCount).toBe(2); - } + const { startCount, sleepCount } = await readAfterSleepCycle( + driverTestConfig, + () => sleepActor.getCounts(), + ); + expect(sleepCount).toBeGreaterThanOrEqual(1); + expect(startCount).toBe(sleepCount + 1); }); }); } diff --git a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/gateway-routing.ts b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/gateway-routing.ts index 8e86290f57..2092542a99 100644 --- a/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/gateway-routing.ts +++ b/rivetkit-typescript/packages/rivetkit/src/driver-test-suite/tests/gateway-routing.ts @@ -25,7 +25,7 @@ export function runGatewayRoutingTests(driverTestConfig: DriverTestConfig) { // Make a direct request using header-based routing const response = await fetch( - `${endpoint}/api/hello`, + `${endpoint}/request/api/hello`, { headers: { "x-rivet-target": "actor", @@ -49,7 +49,7 @@ export function runGatewayRoutingTests(driverTestConfig: DriverTestConfig) { ); const response = await fetch( - `${endpoint}/api/hello`, + `${endpoint}/request/api/hello`, { headers: { "x-rivet-target": "actor", @@ -86,7 +86,7 @@ export function runGatewayRoutingTests(driverTestConfig: DriverTestConfig) { // Build a manual query-routed URL const queryUrl = new URL( - `${endpoint}/gateway/rawHttpActor/api/hello`, + `${endpoint}/gateway/rawHttpActor/request/api/hello`, ); queryUrl.searchParams.set("rvt-namespace", namespace); queryUrl.searchParams.set("rvt-method", "getOrCreate"); @@ -121,7 +121,7 @@ export function runGatewayRoutingTests(driverTestConfig: DriverTestConfig) { // Build a get-only query URL const queryUrl = new URL( - `${endpoint}/gateway/rawHttpActor/api/hello`, + `${endpoint}/gateway/rawHttpActor/request/api/hello`, ); queryUrl.searchParams.set("rvt-namespace", namespace); queryUrl.searchParams.set("rvt-method", "get"); @@ -154,7 +154,7 @@ export function runGatewayRoutingTests(driverTestConfig: DriverTestConfig) { const runner = parsedUrl.searchParams.get("rvt-runner")!; const queryUrl = new URL( - `${endpoint}/gateway/rawHttpActor/api/hello`, + `${endpoint}/gateway/rawHttpActor/request/api/hello`, ); queryUrl.searchParams.set("rvt-namespace", namespace); queryUrl.searchParams.set("rvt-method", "getOrCreate"); @@ -176,7 +176,7 @@ export function runGatewayRoutingTests(driverTestConfig: DriverTestConfig) { ); // Manually build URL with duplicate rvt-namespace - const url = `${endpoint}/gateway/rawHttpActor/api/hello?rvt-namespace=a&rvt-namespace=b&rvt-method=get&rvt-key=dup`; + const url = `${endpoint}/gateway/rawHttpActor/request/api/hello?rvt-namespace=a&rvt-namespace=b&rvt-method=get&rvt-key=dup`; const response = await fetch(url); expect(response.ok).toBe(false); @@ -207,7 +207,7 @@ export function runGatewayRoutingTests(driverTestConfig: DriverTestConfig) { // Build URL with rvt-* params and an actor query param const queryUrl = new URL( - `${endpoint}/gateway/rawHttpRequestPropertiesActor/test-path`, + `${endpoint}/gateway/rawHttpRequestPropertiesActor/request/test-path`, ); queryUrl.searchParams.set("rvt-namespace", namespace); queryUrl.searchParams.set("rvt-method", "getOrCreate"); @@ -252,7 +252,7 @@ export function runGatewayRoutingTests(driverTestConfig: DriverTestConfig) { const runner = parsedUrl.searchParams.get("rvt-runner")!; const queryUrl = new URL( - `${endpoint}/gateway/rawHttpActor/api/hello`, + `${endpoint}/gateway/rawHttpActor/request/api/hello`, ); queryUrl.searchParams.set("rvt-namespace", namespace); queryUrl.searchParams.set("rvt-method", "getOrCreate"); diff --git a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts index b805ab15c8..71d60d7e49 100644 --- a/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts +++ b/rivetkit-typescript/packages/rivetkit/src/drivers/engine/actor-driver.ts @@ -25,6 +25,7 @@ import { createPreloadMap, } from "@/actor/instance/preload-map"; import { deserializeActorKey } from "@/actor/keys"; +import { convertConnFromBarePersistedConn } from "@/actor/conn/persisted"; import type { Encoding } from "@/actor/protocol/serde"; import { type ActorRouter, createActorRouter } from "@/actor/router"; import { @@ -64,6 +65,7 @@ import { DynamicActorInstance } from "@/dynamic/instance"; import { DynamicActorIsolateRuntime } from "@/dynamic/isolate-runtime"; import { isDynamicActorDefinition } from "@/dynamic/internal"; import { buildActorNames, type RegistryConfig } from "@/registry/config"; +import { CONN_VERSIONED } from "@/schemas/actor-persist/versioned"; import { getEndpoint } from "@/engine-client/api-utils"; import { type LongTimeoutHandle, @@ -110,6 +112,34 @@ interface HibernatableWebSocketAckState { ackWaiters: Map void>>; } +interface HibernatableConnectBinding { + actorId: string; + websocket: UniversalWebSocket; + request: Request; + requestPath: string; + requestHeaders: Record; + encoding: Encoding; + connParams: unknown; + gatewayId: ArrayBuffer; + requestId: ArrayBuffer; + remoteAckHookToken?: string; + detach?: () => void; +} + +interface HibernatableRunnerWebSocketBinding { + actorId: string; + websocket: UniversalWebSocket; + requestPath: string; + requestHeaders: Record; + encoding: Encoding; + connParams: unknown; + gatewayId: ArrayBuffer; + requestId: ArrayBuffer; + remoteAckHookToken?: string; + proxyToActorWs?: UniversalWebSocket; + detach?: () => void; +} + export type DriverContext = {}; export class EngineActorDriver implements ActorDriver { @@ -123,6 +153,14 @@ export class EngineActorDriver implements ActorDriver { string, HibernatableWebSocketAckState >(); + #hibernatableConnectBindings = new Map< + string, + HibernatableConnectBinding + >(); + #hibernatableRunnerWebSocketBindings = new Map< + string, + HibernatableRunnerWebSocketBinding + >(); #hwsMessageIndex = new Map< string, { @@ -295,6 +333,54 @@ export class EngineActorDriver implements ActorDriver { ); } + #detachHibernatableConnectBinding( + gatewayId: ArrayBuffer, + requestId: ArrayBuffer, + ): void { + const key = this.#hibernatableWebSocketAckKey(gatewayId, requestId); + const binding = this.#hibernatableConnectBindings.get(key); + if (!binding) { + return; + } + binding.detach?.(); + binding.detach = undefined; + } + + #deleteHibernatableConnectBinding( + gatewayId: ArrayBuffer, + requestId: ArrayBuffer, + ): void { + const key = this.#hibernatableWebSocketAckKey(gatewayId, requestId); + const binding = this.#hibernatableConnectBindings.get(key); + binding?.detach?.(); + this.#hibernatableConnectBindings.delete(key); + } + + #detachHibernatableRunnerWebSocketBinding( + gatewayId: ArrayBuffer, + requestId: ArrayBuffer, + ): void { + const key = this.#hibernatableWebSocketAckKey(gatewayId, requestId); + const binding = + this.#hibernatableRunnerWebSocketBindings.get(key); + if (!binding) { + return; + } + binding.detach?.(); + binding.detach = undefined; + } + + #deleteHibernatableRunnerWebSocketBinding( + gatewayId: ArrayBuffer, + requestId: ArrayBuffer, + ): void { + const key = this.#hibernatableWebSocketAckKey(gatewayId, requestId); + const binding = + this.#hibernatableRunnerWebSocketBindings.get(key); + binding?.detach?.(); + this.#hibernatableRunnerWebSocketBindings.delete(key); + } + #recordInboundHibernatableWebSocketMessage( gatewayId: ArrayBuffer, requestId: ArrayBuffer, @@ -742,6 +828,17 @@ export class EngineActorDriver implements ActorDriver { remainingActors: this.#actors.size, waitMs: ENVOY_STOP_WAIT_MS, }); + for (const actorId of this.#actors.keys()) { + logger().warn({ + msg: "force stopping actor during driver shutdown", + actorId, + }); + this.#envoy.stopActor( + actorId, + undefined, + "driver shutdown after sleep timeout", + ); + } } else { logger().debug({ msg: "all actors stopped before envoy drain", @@ -779,13 +876,547 @@ export class EngineActorDriver implements ActorDriver { }); } - this.#dynamicRuntimes.clear(); + await this.#disposeAllDynamicRuntimes("driver shutdown"); } async waitForReady(): Promise { await this.#envoy.started(); } + async #hydrateServerlessStartPayload( + payload: ArrayBuffer, + ): Promise { + if ( + typeof protocol.decodeToEnvoy !== "function" || + typeof protocol.encodeToEnvoy !== "function" + ) { + throw new Error( + "missing envoy protocol codec in rivetkit-native wrapper", + ); + } + + const bytes = new Uint8Array(payload); + if (bytes.byteLength < 2) { + throw new Error("serverless start payload too short"); + } + + const versionPrefix = bytes.slice(0, 2); + const decoded = protocol.decodeToEnvoy(bytes.slice(2)); + if (decoded.tag !== "ToEnvoyCommands") { + return payload; + } + + let changed = false; + const commands = await Promise.all( + decoded.val.map(async (commandWrapper) => { + if (commandWrapper.inner.tag !== "CommandStartActor") { + return commandWrapper; + } + + if ( + commandWrapper.inner.val.hibernatingRequests.length > 0 + ) { + return commandWrapper; + } + + const actorId = commandWrapper.checkpoint.actorId; + const metaEntries = await this.#hwsLoadPersistedMetadata( + actorId, + commandWrapper.inner.val.preloadedKv, + ); + if (metaEntries.length === 0) { + return commandWrapper; + } + + changed = true; + logger().debug({ + msg: "hydrating hibernating requests into serverless start payload", + actorId, + requestCount: metaEntries.length, + }); + + return { + ...commandWrapper, + inner: { + tag: "CommandStartActor" as const, + val: { + ...commandWrapper.inner.val, + hibernatingRequests: metaEntries.map( + ({ gatewayId, requestId }) => ({ + gatewayId, + requestId, + }), + ), + }, + }, + }; + }), + ); + + if (!changed) { + return payload; + } + + const encoded = protocol.encodeToEnvoy({ + tag: "ToEnvoyCommands", + val: commands, + }); + const hydrated = new Uint8Array(versionPrefix.length + encoded.length); + hydrated.set(versionPrefix, 0); + hydrated.set(encoded, versionPrefix.length); + return hydrated.buffer; + } + + async #bindHibernatableConnectSocket( + binding: HibernatableConnectBinding, + isRestoringHibernatable: boolean, + ): Promise { + this.#detachHibernatableConnectBinding( + binding.gatewayId, + binding.requestId, + ); + this.#hibernatableConnectBindings.set( + this.#hibernatableWebSocketAckKey( + binding.gatewayId, + binding.requestId, + ), + binding, + ); + + if (this.#isDynamicActor(binding.actorId)) { + await this.#bindDynamicHibernatableConnectSocket( + binding, + isRestoringHibernatable, + ); + return; + } + + const wsHandler = await routeWebSocket( + binding.request, + binding.requestPath, + binding.requestHeaders, + this.#config, + this, + binding.actorId, + binding.encoding, + binding.connParams, + binding.gatewayId, + binding.requestId, + true, + isRestoringHibernatable, + ); + + (binding.websocket as WSContextInit).raw = binding.websocket; + const wsContext = new WSContext(binding.websocket); + + const onOpen = (event: Event) => { + wsHandler.onOpen(event, wsContext); + }; + const onMessage = (event: RivetMessageEvent) => { + if ( + this.#maybeRespondToHibernatableAckStateProbe( + binding.websocket, + event.data, + binding.gatewayId, + binding.requestId, + ) + ) { + return; + } + + wsHandler.onMessage(event, wsContext); + + const actor = this.#actors.get(binding.actorId)?.actor; + if (!actor || !isStaticActorInstance(actor) || !wsHandler.conn) { + return; + } + + const conn = actor.connectionManager.findHibernatableConn( + binding.gatewayId, + binding.requestId, + ); + if (!conn) { + return; + } + + if (typeof event.rivetMessageIndex === "number") { + this.#recordInboundHibernatableWebSocketMessage( + binding.gatewayId, + binding.requestId, + event.rivetMessageIndex, + ); + } + actor.handleInboundHibernatableWebSocketMessage( + conn, + event.data, + event.rivetMessageIndex, + ); + }; + const onClose = (event: CloseEvent) => { + wsHandler.onClose(event, wsContext); + this.#deleteHibernatableWebSocketAckState( + binding.gatewayId, + binding.requestId, + ); + unregisterRemoteHibernatableWebSocketAckHooks( + binding.remoteAckHookToken, + this.#config.test.enabled, + ); + this.#deleteHibernatableConnectBinding( + binding.gatewayId, + binding.requestId, + ); + }; + const onError = (event: Event) => { + wsHandler.onError(event, wsContext); + }; + + binding.websocket.addEventListener("message", onMessage); + binding.websocket.addEventListener("close", onClose); + binding.websocket.addEventListener("error", onError); + if (isRestoringHibernatable) { + wsHandler.onRestore?.(wsContext); + } else { + binding.websocket.addEventListener("open", onOpen); + } + + binding.detach = () => { + binding.websocket.removeEventListener("message", onMessage); + binding.websocket.removeEventListener("close", onClose); + binding.websocket.removeEventListener("error", onError); + if (!isRestoringHibernatable) { + binding.websocket.removeEventListener("open", onOpen); + } + }; + } + + async #bindDynamicHibernatableConnectSocket( + binding: HibernatableConnectBinding, + isRestoringHibernatable: boolean, + ): Promise { + const runtime = this.#requireDynamicRuntime(binding.actorId); + const proxyToActorWs = await runtime.openWebSocket( + binding.requestPath, + binding.encoding, + binding.connParams, + { + headers: binding.requestHeaders, + gatewayId: binding.gatewayId, + requestId: binding.requestId, + isHibernatable: true, + isRestoringHibernatable, + }, + ); + + const onProxyMessage = (event: RivetMessageEvent) => { + if (binding.websocket.readyState !== binding.websocket.OPEN) { + return; + } + binding.websocket.send(event.data as any); + }; + const onProxyClose = (event: CloseEvent) => { + if ( + isRestoringHibernatable && + event.reason === "dynamic.runtime.disposed" + ) { + return; + } + if (binding.websocket.readyState !== binding.websocket.CLOSED) { + binding.websocket.close(event.code, event.reason); + } + }; + const onProxyError = () => { + if (binding.websocket.readyState !== binding.websocket.CLOSED) { + binding.websocket.close(1011, "dynamic.websocket_error"); + } + }; + const onMessage = (event: RivetMessageEvent) => { + if ( + this.#maybeRespondToHibernatableAckStateProbe( + binding.websocket, + event.data, + binding.gatewayId, + binding.requestId, + ) + ) { + return; + } + + if (typeof event.rivetMessageIndex === "number") { + this.#recordInboundHibernatableWebSocketMessage( + binding.gatewayId, + binding.requestId, + event.rivetMessageIndex, + ); + } + + void runtime + .forwardIncomingWebSocketMessage( + proxyToActorWs, + event.data as any, + event.rivetMessageIndex, + ) + .catch((error) => { + logger().error({ + msg: "failed forwarding websocket message to dynamic actor", + actorId: binding.actorId, + error: stringifyError(error), + }); + binding.websocket.close(1011, "dynamic.websocket_forward_failed"); + }); + }; + const onClose = (event: CloseEvent) => { + if (proxyToActorWs.readyState !== proxyToActorWs.CLOSED) { + proxyToActorWs.close(event.code, event.reason); + } + this.#deleteHibernatableWebSocketAckState( + binding.gatewayId, + binding.requestId, + ); + unregisterRemoteHibernatableWebSocketAckHooks( + binding.remoteAckHookToken, + this.#config.test.enabled, + ); + this.#deleteHibernatableConnectBinding( + binding.gatewayId, + binding.requestId, + ); + }; + const onError = () => { + if (proxyToActorWs.readyState !== proxyToActorWs.CLOSED) { + proxyToActorWs.close(1011, "dynamic.gateway_error"); + } + }; + + proxyToActorWs.addEventListener("message", onProxyMessage); + proxyToActorWs.addEventListener("close", onProxyClose); + proxyToActorWs.addEventListener("error", onProxyError); + binding.websocket.addEventListener("message", onMessage); + binding.websocket.addEventListener("close", onClose); + binding.websocket.addEventListener("error", onError); + + binding.detach = () => { + proxyToActorWs.removeEventListener("message", onProxyMessage); + proxyToActorWs.removeEventListener("close", onProxyClose); + proxyToActorWs.removeEventListener("error", onProxyError); + binding.websocket.removeEventListener("message", onMessage); + binding.websocket.removeEventListener("close", onClose); + binding.websocket.removeEventListener("error", onError); + if (proxyToActorWs.readyState !== proxyToActorWs.CLOSED) { + proxyToActorWs.close(1011, "dynamic.rebind"); + } + }; + } + + async #bindDynamicHibernatableRunnerWebSocket( + binding: HibernatableRunnerWebSocketBinding, + isRestoringHibernatable: boolean, + ): Promise { + this.#detachHibernatableRunnerWebSocketBinding( + binding.gatewayId, + binding.requestId, + ); + this.#hibernatableRunnerWebSocketBindings.set( + this.#hibernatableWebSocketAckKey( + binding.gatewayId, + binding.requestId, + ), + binding, + ); + + const runtime = this.#requireDynamicRuntime(binding.actorId); + const proxyToActorWs = await runtime.openWebSocket( + binding.requestPath, + binding.encoding, + binding.connParams, + { + headers: binding.requestHeaders, + gatewayId: binding.gatewayId, + requestId: binding.requestId, + isHibernatable: true, + isRestoringHibernatable, + }, + ); + binding.proxyToActorWs = proxyToActorWs; + + const onProxyMessage = (event: RivetMessageEvent) => { + if (binding.websocket.readyState !== binding.websocket.OPEN) { + return; + } + binding.websocket.send(event.data as any); + }; + const onProxyClose = (event: CloseEvent) => { + if (event.reason === "dynamic.runtime.disposed") { + return; + } + if (binding.websocket.readyState !== binding.websocket.CLOSED) { + binding.websocket.close(event.code, event.reason); + } + }; + const onProxyError = () => { + if (binding.websocket.readyState !== binding.websocket.CLOSED) { + binding.websocket.close(1011, "dynamic.websocket_error"); + } + }; + const onMessage = (event: RivetMessageEvent) => { + if ( + this.#maybeRespondToHibernatableAckStateProbe( + binding.websocket, + event.data, + binding.gatewayId, + binding.requestId, + ) + ) { + return; + } + + if (typeof event.rivetMessageIndex === "number") { + this.#recordInboundHibernatableWebSocketMessage( + binding.gatewayId, + binding.requestId, + event.rivetMessageIndex, + ); + } + + const currentRuntime = this.#dynamicRuntimes.get(binding.actorId); + const currentProxyToActorWs = binding.proxyToActorWs; + if (!currentRuntime || !currentProxyToActorWs) { + logger().error({ + msg: "dynamic runtime websocket binding is missing after restore", + actorId: binding.actorId, + }); + binding.websocket.close(1011, "dynamic.websocket_forward_failed"); + return; + } + + void currentRuntime + .forwardIncomingWebSocketMessage( + currentProxyToActorWs, + event.data as any, + event.rivetMessageIndex, + ) + .catch((error) => { + logger().error({ + msg: "failed forwarding websocket message to dynamic actor", + actorId: binding.actorId, + error: stringifyError(error), + }); + binding.websocket.close(1011, "dynamic.websocket_forward_failed"); + }); + }; + const onClose = (event: CloseEvent) => { + const currentProxyToActorWs = binding.proxyToActorWs; + if ( + currentProxyToActorWs && + currentProxyToActorWs.readyState !== + currentProxyToActorWs.CLOSED + ) { + currentProxyToActorWs.close(event.code, event.reason); + } + this.#deleteHibernatableWebSocketAckState( + binding.gatewayId, + binding.requestId, + ); + unregisterRemoteHibernatableWebSocketAckHooks( + binding.remoteAckHookToken, + this.#config.test.enabled, + ); + this.#deleteHibernatableRunnerWebSocketBinding( + binding.gatewayId, + binding.requestId, + ); + }; + const onError = () => { + const currentProxyToActorWs = binding.proxyToActorWs; + if ( + currentProxyToActorWs && + currentProxyToActorWs.readyState !== + currentProxyToActorWs.CLOSED + ) { + currentProxyToActorWs.close(1011, "dynamic.gateway_error"); + } + }; + + proxyToActorWs.addEventListener("message", onProxyMessage); + proxyToActorWs.addEventListener("close", onProxyClose); + proxyToActorWs.addEventListener("error", onProxyError); + binding.websocket.addEventListener("message", onMessage); + binding.websocket.addEventListener("close", onClose); + binding.websocket.addEventListener("error", onError); + + binding.detach = () => { + proxyToActorWs.removeEventListener("message", onProxyMessage); + proxyToActorWs.removeEventListener("close", onProxyClose); + proxyToActorWs.removeEventListener("error", onProxyError); + binding.websocket.removeEventListener("message", onMessage); + binding.websocket.removeEventListener("close", onClose); + binding.websocket.removeEventListener("error", onError); + }; + } + + async #rebindDynamicHibernatableRunnerWebSockets( + actorId: string, + ): Promise { + const bindings = Array.from( + this.#hibernatableRunnerWebSocketBindings.values(), + ).filter((binding) => binding.actorId === actorId); + for (const binding of bindings) { + await this.#bindDynamicHibernatableRunnerWebSocket( + binding, + true, + ); + } + } + + async #rebindHibernatableConnectSockets(actorId: string): Promise { + const bindings = Array.from( + this.#hibernatableConnectBindings.values(), + ).filter((binding) => binding.actorId === actorId); + + for (const binding of bindings) { + await this.#bindHibernatableConnectSocket(binding, true); + } + } + + async #hwsLoadPersistedMetadata( + actorId: string, + preloadedKv: protocol.PreloadedKv | null, + ): Promise { + const preloadMap = this.#buildStartupPreloadMap(preloadedKv).preloadMap; + const preloadedConnEntries = preloadMap?.listPrefix(KEYS.CONN_PREFIX); + const connEntries = + preloadedConnEntries ?? + (await this.#envoy.kvListPrefix(actorId, KEYS.CONN_PREFIX)); + + const metaEntries: HibernatingWebSocketMetadata[] = []; + for (const [_key, value] of connEntries) { + try { + const bareData = + CONN_VERSIONED.deserializeWithEmbeddedVersion(value); + const conn = convertConnFromBarePersistedConn< + unknown, + unknown + >(bareData); + metaEntries.push({ + gatewayId: conn.gatewayId, + requestId: conn.requestId, + rivetMessageIndex: conn.serverMessageIndex, + envoyMessageIndex: conn.clientMessageIndex, + path: conn.requestPath, + headers: conn.requestHeaders, + }); + } catch (error) { + logger().warn({ + msg: "failed to decode persisted hibernating websocket metadata", + actorId, + error: stringifyError(error), + }); + } + } + + return metaEntries; + } + async serverlessHandleStart(c: HonoContext): Promise { let payload = await c.req.arrayBuffer(); @@ -805,6 +1436,7 @@ export class EngineActorDriver implements ActorDriver { return; } + payload = await this.#hydrateServerlessStartPayload(payload); await this.#envoy.startServerlessActor(payload); // Send ping every second to keep the connection alive @@ -1016,20 +1648,32 @@ export class EngineActorDriver implements ActorDriver { handler.actorStartPromise?.resolve(); handler.actorStartPromise = undefined; - const rawMetaEntries = - await dynamicActor.getHibernatingWebSockets(); - const metaEntries = rawMetaEntries.map((entry) => ({ - gatewayId: entry.gatewayId, - requestId: entry.requestId, - rivetMessageIndex: entry.serverMessageIndex, - envoyMessageIndex: entry.clientMessageIndex, - path: entry.path, - headers: entry.headers, - })); - await this.#envoy.restoreHibernatingRequests( - actorId, - metaEntries, - ); + try { + await this.#rebindHibernatableConnectSockets(actorId); + await this.#rebindDynamicHibernatableRunnerWebSockets( + actorId, + ); + const rawMetaEntries = + await dynamicActor.getHibernatingWebSockets(); + const metaEntries = rawMetaEntries.map((entry) => ({ + gatewayId: entry.gatewayId, + requestId: entry.requestId, + rivetMessageIndex: entry.serverMessageIndex, + envoyMessageIndex: entry.clientMessageIndex, + path: entry.path, + headers: entry.headers, + })); + await this.#envoy.restoreHibernatingRequests( + actorId, + metaEntries, + ); + } catch (error) { + logger().warn({ + msg: "failed to restore dynamic hibernating requests after actor start", + actorId, + err: stringifyError(error), + }); + } } else if (isStaticActorDefinition(definition)) { const instantiateStart = performance.now(); const staticActor = @@ -1194,7 +1838,7 @@ export class EngineActorDriver implements ActorDriver { }); } } - this.#dynamicRuntimes.delete(actorId); + await this.#disposeDynamicRuntime(actorId, "actor stop"); if (handler.alarmTimeout) { handler.alarmTimeout.abort(); @@ -1206,6 +1850,37 @@ export class EngineActorDriver implements ActorDriver { logger().debug({ msg: "engine actor stopped", actorId, reason }); } + async #disposeDynamicRuntime( + actorId: string, + reason: string, + ): Promise { + const runtime = this.#dynamicRuntimes.get(actorId); + if (!runtime) { + return; + } + + try { + await runtime.dispose(); + } catch (error) { + logger().warn({ + msg: "failed to dispose dynamic runtime", + actorId, + reason, + error: stringifyError(error), + }); + } finally { + this.#dynamicRuntimes.delete(actorId); + } + } + + async #disposeAllDynamicRuntimes(reason: string): Promise { + await Promise.all( + Array.from(this.#dynamicRuntimes.keys(), (actorId) => + this.#disposeDynamicRuntime(actorId, reason), + ), + ); + } + // MARK: - Envoy Networking async #envoyFetch( _envoy: EnvoyHandle, @@ -1308,6 +1983,32 @@ export class EngineActorDriver implements ActorDriver { REMOTE_ACK_HOOK_QUERY_PARAM, ) ?? undefined; + const requestPathWithoutQuery = requestPath.split("?")[0]; + + if (isHibernatable && requestPathWithoutQuery === PATH_CONNECT) { + this.#registerHibernatableWebSocketAckTestHooks( + websocket, + gatewayIdBuf, + requestIdBuf, + remoteAckHookToken, + ); + await this.#bindHibernatableConnectSocket( + { + actorId, + websocket, + request, + requestPath, + requestHeaders, + encoding, + connParams, + gatewayId: gatewayIdBuf, + requestId: requestIdBuf, + remoteAckHookToken, + }, + isRestoringHibernatable, + ); + return; + } if (this.#isDynamicActor(actorId)) { await this.#runnerDynamicWebSocket( @@ -1397,11 +2098,25 @@ export class EngineActorDriver implements ActorDriver { return; } - if (actor?.isStopping) { + const currentActor = this.#actors.get(actorId)?.actor; + const actorForDispatch = + currentActor && + isStaticActorInstance(currentActor) + ? currentActor + : actor; + const connForDispatch = + isHibernatable && actorForDispatch + ? actorForDispatch.connectionManager.findHibernatableConn( + gatewayIdBuf, + requestIdBuf, + ) ?? conn + : conn; + + if (actorForDispatch?.isStopping) { logger().debug({ msg: "ignoring ws message, actor is stopping", - connId: conn?.id, - actorId: actor?.id, + connId: connForDispatch?.id, + actorId: actorForDispatch?.id, messageIndex: event.rivetMessageIndex, }); return; @@ -1420,17 +2135,17 @@ export class EngineActorDriver implements ActorDriver { // Runtime-owned hibernatable websocket bookkeeping lives on the // actor instance so static and dynamic paths share the same logic. - if (conn && actor && isStaticActorInstance(actor)) { - actor.handleInboundHibernatableWebSocketMessage( - conn, + if (connForDispatch && actorForDispatch) { + actorForDispatch.handleInboundHibernatableWebSocketMessage( + connForDispatch, event.data, event.rivetMessageIndex, ); } }; - if (isRawWebSocketPath && actor) { - void actor.internalKeepAwake(run); + if (isRawWebSocketPath && actorForDispatch) { + void actorForDispatch.internalKeepAwake(run); } else { void run(); } @@ -1513,6 +2228,45 @@ export class EngineActorDriver implements ActorDriver { REMOTE_ACK_HOOK_QUERY_PARAM, ) ?? undefined; + if (isHibernatable) { + this.#registerHibernatableWebSocketAckTestHooks( + websocket, + gatewayIdBuf, + requestIdBuf, + remoteAckHookToken, + ); + try { + await this.#bindDynamicHibernatableRunnerWebSocket( + { + actorId, + websocket, + requestPath, + requestHeaders, + encoding, + connParams, + gatewayId: gatewayIdBuf, + requestId: requestIdBuf, + remoteAckHookToken, + }, + isRestoringHibernatable, + ); + } catch (error) { + const { group, code } = deconstructError( + error, + logger(), + {}, + false, + ); + logger().error({ + msg: "failed to bind dynamic hibernatable websocket", + actorId, + error: stringifyError(error), + }); + websocket.close(1011, `${group}.${code}`); + } + return; + } + try { runtime = this.#requireDynamicRuntime(actorId); } catch (error) { @@ -1555,15 +2309,6 @@ export class EngineActorDriver implements ActorDriver { return; } - if (isHibernatable) { - this.#registerHibernatableWebSocketAckTestHooks( - websocket, - gatewayIdBuf, - requestIdBuf, - remoteAckHookToken, - ); - } - proxyToActorWs.addEventListener( "message", (event: RivetMessageEvent) => { @@ -1575,7 +2320,7 @@ export class EngineActorDriver implements ActorDriver { ); proxyToActorWs.addEventListener("close", (event) => { - if (isHibernatable && event.reason === "dynamic.runtime.disposed") { + if (event.reason === "dynamic.runtime.disposed") { logger().debug({ msg: "ignoring dynamic runtime dispose close for hibernatable websocket", actorId, @@ -1636,16 +2381,6 @@ export class EngineActorDriver implements ActorDriver { }); websocket.addEventListener("close", (event) => { - if (isHibernatable) { - this.#deleteHibernatableWebSocketAckState( - gatewayIdBuf, - requestIdBuf, - ); - unregisterRemoteHibernatableWebSocketAckHooks( - remoteAckHookToken, - this.#config.test.enabled, - ); - } if (proxyToActorWs.readyState !== proxyToActorWs.CLOSED) { proxyToActorWs.close(event.code, event.reason); } @@ -1787,6 +2522,8 @@ export class EngineActorDriver implements ActorDriver { handler.actorStartPromise?.resolve(); handler.actorStartPromise = undefined; + await this.#rebindHibernatableConnectSockets(actor.id); + // Restore hibernating requests const metaEntries = await this.#hwsLoadAll(actor.id); await this.#envoy.restoreHibernatingRequests(actor.id, metaEntries);