diff --git a/rivetkit-typescript/packages/rivetkit-native/index.d.ts b/rivetkit-typescript/packages/rivetkit-native/index.d.ts index 3f618fb1f3..9bc723b82d 100644 --- a/rivetkit-typescript/packages/rivetkit-native/index.d.ts +++ b/rivetkit-typescript/packages/rivetkit-native/index.d.ts @@ -48,6 +48,10 @@ export interface JsKvEntry { export interface HibernatingRequestEntry { gatewayId: Buffer requestId: Buffer + envoyMessageIndex: number + rivetMessageIndex: number + path: string + headers?: Record } /** * Start the native envoy client synchronously. diff --git a/rivetkit-typescript/packages/rivetkit-native/src/bridge_actor.rs b/rivetkit-typescript/packages/rivetkit-native/src/bridge_actor.rs index 1735910a16..d5c957e799 100644 --- a/rivetkit-typescript/packages/rivetkit-native/src/bridge_actor.rs +++ b/rivetkit-typescript/packages/rivetkit-native/src/bridge_actor.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; use std::sync::Arc; +use std::sync::Mutex; use napi::threadsafe_function::ThreadsafeFunctionCallMode; use rivet_envoy_client::config::{ @@ -8,7 +9,7 @@ use rivet_envoy_client::config::{ }; use rivet_envoy_client::handle::EnvoyHandle; use rivet_envoy_protocol as protocol; -use tokio::sync::{Mutex, oneshot}; +use tokio::sync::oneshot; use crate::types; @@ -22,7 +23,7 @@ pub type EventCallback = napi::threadsafe_function::ThreadsafeFunction< pub type ResponseMap = Arc>>>; /// Map of open WebSocket senders, keyed by concatenated gateway_id + request_id (8 bytes). -pub type WsSenderMap = Arc>>; +pub type WsSenderMap = Arc>>; fn make_ws_key(gateway_id: &protocol::GatewayId, request_id: &protocol::RequestId) -> [u8; 8] { let mut key = [0u8; 8]; @@ -85,7 +86,7 @@ impl EnvoyCallbacks for BridgeCallbacks { let (tx, rx) = oneshot::channel(); { - let mut map = response_map.lock().await; + let mut map = response_map.lock().expect("response_map poisoned"); map.insert(response_id, tx); } @@ -123,7 +124,7 @@ impl EnvoyCallbacks for BridgeCallbacks { let (tx, rx) = oneshot::channel(); { - let mut map = response_map.lock().await; + let mut map = response_map.lock().expect("response_map poisoned"); map.insert(response_id, tx); } @@ -177,7 +178,7 @@ impl EnvoyCallbacks for BridgeCallbacks { let (tx, rx) = oneshot::channel(); { - let mut map = response_map.lock().await; + let mut map = response_map.lock().expect("response_map poisoned"); map.insert(response_id, tx); } @@ -301,12 +302,39 @@ impl EnvoyCallbacks for BridgeCallbacks { fn can_hibernate( &self, - _actor_id: &str, - _gateway_id: &protocol::GatewayId, - _request_id: &protocol::RequestId, - _request: &HttpRequest, + actor_id: &str, + gateway_id: &protocol::GatewayId, + request_id: &protocol::RequestId, + request: &HttpRequest, ) -> bool { - false + let response_id = uuid::Uuid::new_v4().to_string(); + let envelope = serde_json::json!({ + "kind": "websocket_can_hibernate", + "actorId": actor_id, + "gatewayId": gateway_id, + "requestId": request_id, + "method": request.method, + "path": request.path, + "headers": request.headers, + "responseId": response_id, + }); + + let (tx, rx) = oneshot::channel(); + { + let mut map = self.response_map.lock().expect("response_map poisoned"); + map.insert(response_id, tx); + } + + self.event_cb + .call(envelope, ThreadsafeFunctionCallMode::Blocking); + + match tokio::task::block_in_place(|| rx.blocking_recv()) { + Ok(response) => response + .get("canHibernate") + .and_then(|value| value.as_bool()) + .unwrap_or(false), + Err(_) => false, + } } } diff --git a/rivetkit-typescript/packages/rivetkit-native/src/envoy_handle.rs b/rivetkit-typescript/packages/rivetkit-native/src/envoy_handle.rs index c5159ab285..0d2a5cda38 100644 --- a/rivetkit-typescript/packages/rivetkit-native/src/envoy_handle.rs +++ b/rivetkit-typescript/packages/rivetkit-native/src/envoy_handle.rs @@ -277,10 +277,10 @@ impl JsEnvoyHandle { rivet_envoy_client::tunnel::HibernatingWebSocketMetadata { gateway_id, request_id, - envoy_message_index: 0, - rivet_message_index: 0, - path: String::new(), - headers: HashMap::new(), + envoy_message_index: r.envoy_message_index, + rivet_message_index: r.rivet_message_index, + path: r.path, + headers: r.headers.unwrap_or_else(HashMap::new), } }) .collect(); @@ -371,7 +371,10 @@ impl JsEnvoyHandle { response_id: String, data: serde_json::Value, ) -> napi::Result<()> { - let mut map = self.response_map.lock().await; + let mut map = self + .response_map + .lock() + .map_err(|_| napi::Error::from_reason("response_map poisoned"))?; if let Some(tx) = map.remove(&response_id) { let _ = tx.send(data); } diff --git a/rivetkit-typescript/packages/rivetkit-native/src/lib.rs b/rivetkit-typescript/packages/rivetkit-native/src/lib.rs index 331eab715f..f9ed9a2b80 100644 --- a/rivetkit-typescript/packages/rivetkit-native/src/lib.rs +++ b/rivetkit-typescript/packages/rivetkit-native/src/lib.rs @@ -51,7 +51,7 @@ pub fn start_envoy_sync_js( .map_err(|e| napi::Error::from_reason(format!("failed to create tokio runtime: {}", e)))?; let runtime = Arc::new(runtime); - let response_map: ResponseMap = Arc::new(tokio::sync::Mutex::new(HashMap::new())); + let response_map: ResponseMap = Arc::new(std::sync::Mutex::new(HashMap::new())); let ws_sender_map: WsSenderMap = Arc::new(tokio::sync::Mutex::new(HashMap::new())); // Create threadsafe callback for bridging events to JS diff --git a/rivetkit-typescript/packages/rivetkit-native/src/types.rs b/rivetkit-typescript/packages/rivetkit-native/src/types.rs index 3d03582faf..4040469dd3 100644 --- a/rivetkit-typescript/packages/rivetkit-native/src/types.rs +++ b/rivetkit-typescript/packages/rivetkit-native/src/types.rs @@ -35,6 +35,10 @@ pub struct JsKvEntry { pub struct HibernatingRequestEntry { pub gateway_id: Buffer, pub request_id: Buffer, + pub envoy_message_index: u16, + pub rivet_message_index: u16, + pub path: String, + pub headers: Option>, } /// Encode a protocol MessageId into a 10-byte buffer. diff --git a/rivetkit-typescript/packages/rivetkit-native/wrapper.js b/rivetkit-typescript/packages/rivetkit-native/wrapper.js index 7c0613edd0..af95a172f1 100644 --- a/rivetkit-typescript/packages/rivetkit-native/wrapper.js +++ b/rivetkit-typescript/packages/rivetkit-native/wrapper.js @@ -98,6 +98,10 @@ function wrapHandle(jsHandle) { const requests = (metaEntries || []).map((e) => ({ gatewayId: Buffer.from(e.gatewayId), requestId: Buffer.from(e.requestId), + envoyMessageIndex: e.envoyMessageIndex ?? 0, + rivetMessageIndex: e.rivetMessageIndex ?? 0, + path: e.path ?? "", + headers: e.headers ?? {}, })); jsHandle.restoreHibernatingRequests(actorId, requests); }, @@ -308,6 +312,46 @@ function handleEvent(event, config, wrappedHandle) { ); break; } + case "websocket_can_hibernate": { + const gatewayId = Buffer.from(event.gatewayId); + const requestId = Buffer.from(event.requestId); + const headers = new Headers(event.headers || {}); + headers.set("Upgrade", "websocket"); + headers.set("Connection", "Upgrade"); + const request = new Request(`http://actor${event.path}`, { + method: event.method, + headers, + }); + + Promise.resolve( + config.hibernatableWebSocket + ? config.hibernatableWebSocket.canHibernate( + event.actorId, + gatewayId, + requestId, + request, + ) + : false, + ).then( + async (canHibernate) => { + if (handle._raw) { + await handle._raw.respondCallback(event.responseId, { + canHibernate: Boolean(canHibernate), + }); + } + }, + async (err) => { + console.error("canHibernate error:", err); + if (handle._raw) { + await handle._raw.respondCallback(event.responseId, { + canHibernate: false, + error: String(err), + }); + } + }, + ); + break; + } case "websocket_open": { if (config.websocket) { const messageId = Buffer.from(event.messageId);