Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rivetkit-typescript/packages/rivetkit-native/index.d.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* tslint:disable */

Check failure on line 1 in rivetkit-typescript/packages/rivetkit-native/index.d.ts

View workflow job for this annotation

GitHub Actions / RivetKit / Quality Check

format

Formatter would have printed the following content:
/* eslint-disable */

/* auto-generated by NAPI-RS */
Expand Down Expand Up @@ -48,6 +48,10 @@
export interface HibernatingRequestEntry {
gatewayId: Buffer
requestId: Buffer
envoyMessageIndex: number
rivetMessageIndex: number
path: string
headers?: Record<string, string>
}
/**
* Start the native envoy client synchronously.
Expand Down
48 changes: 38 additions & 10 deletions rivetkit-typescript/packages/rivetkit-native/src/bridge_actor.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;

use napi::threadsafe_function::ThreadsafeFunctionCallMode;
use rivet_envoy_client::config::{
Expand All @@ -8,7 +9,7 @@ use rivet_envoy_client::config::{
};
use rivet_envoy_client::handle::EnvoyHandle;
use rivet_envoy_protocol as protocol;
use tokio::sync::{Mutex, oneshot};
use tokio::sync::oneshot;

use crate::types;

Expand All @@ -22,7 +23,7 @@ pub type EventCallback = napi::threadsafe_function::ThreadsafeFunction<
pub type ResponseMap = Arc<Mutex<HashMap<String, oneshot::Sender<serde_json::Value>>>>;

/// Map of open WebSocket senders, keyed by concatenated gateway_id + request_id (8 bytes).
pub type WsSenderMap = Arc<Mutex<HashMap<[u8; 8], WebSocketSender>>>;
pub type WsSenderMap = Arc<tokio::sync::Mutex<HashMap<[u8; 8], WebSocketSender>>>;

fn make_ws_key(gateway_id: &protocol::GatewayId, request_id: &protocol::RequestId) -> [u8; 8] {
let mut key = [0u8; 8];
Expand Down Expand Up @@ -85,7 +86,7 @@ impl EnvoyCallbacks for BridgeCallbacks {

let (tx, rx) = oneshot::channel();
{
let mut map = response_map.lock().await;
let mut map = response_map.lock().expect("response_map poisoned");
map.insert(response_id, tx);
}

Expand Down Expand Up @@ -123,7 +124,7 @@ impl EnvoyCallbacks for BridgeCallbacks {

let (tx, rx) = oneshot::channel();
{
let mut map = response_map.lock().await;
let mut map = response_map.lock().expect("response_map poisoned");
map.insert(response_id, tx);
}

Expand Down Expand Up @@ -177,7 +178,7 @@ impl EnvoyCallbacks for BridgeCallbacks {

let (tx, rx) = oneshot::channel();
{
let mut map = response_map.lock().await;
let mut map = response_map.lock().expect("response_map poisoned");
map.insert(response_id, tx);
}

Expand Down Expand Up @@ -301,12 +302,39 @@ impl EnvoyCallbacks for BridgeCallbacks {

fn can_hibernate(
&self,
_actor_id: &str,
_gateway_id: &protocol::GatewayId,
_request_id: &protocol::RequestId,
_request: &HttpRequest,
actor_id: &str,
gateway_id: &protocol::GatewayId,
request_id: &protocol::RequestId,
request: &HttpRequest,
) -> bool {
false
let response_id = uuid::Uuid::new_v4().to_string();
let envelope = serde_json::json!({
"kind": "websocket_can_hibernate",
"actorId": actor_id,
"gatewayId": gateway_id,
"requestId": request_id,
"method": request.method,
"path": request.path,
"headers": request.headers,
"responseId": response_id,
});

let (tx, rx) = oneshot::channel();
{
let mut map = self.response_map.lock().expect("response_map poisoned");
map.insert(response_id, tx);
}

self.event_cb
.call(envelope, ThreadsafeFunctionCallMode::Blocking);

match tokio::task::block_in_place(|| rx.blocking_recv()) {
Ok(response) => response
.get("canHibernate")
.and_then(|value| value.as_bool())
.unwrap_or(false),
Err(_) => false,
}
}
}

Expand Down
13 changes: 8 additions & 5 deletions rivetkit-typescript/packages/rivetkit-native/src/envoy_handle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,10 +277,10 @@ impl JsEnvoyHandle {
rivet_envoy_client::tunnel::HibernatingWebSocketMetadata {
gateway_id,
request_id,
envoy_message_index: 0,
rivet_message_index: 0,
path: String::new(),
headers: HashMap::new(),
envoy_message_index: r.envoy_message_index,
rivet_message_index: r.rivet_message_index,
path: r.path,
headers: r.headers.unwrap_or_else(HashMap::new),
}
})
.collect();
Expand Down Expand Up @@ -371,7 +371,10 @@ impl JsEnvoyHandle {
response_id: String,
data: serde_json::Value,
) -> napi::Result<()> {
let mut map = self.response_map.lock().await;
let mut map = self
.response_map
.lock()
.map_err(|_| napi::Error::from_reason("response_map poisoned"))?;
if let Some(tx) = map.remove(&response_id) {
let _ = tx.send(data);
}
Expand Down
2 changes: 1 addition & 1 deletion rivetkit-typescript/packages/rivetkit-native/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub fn start_envoy_sync_js(
.map_err(|e| napi::Error::from_reason(format!("failed to create tokio runtime: {}", e)))?;
let runtime = Arc::new(runtime);

let response_map: ResponseMap = Arc::new(tokio::sync::Mutex::new(HashMap::new()));
let response_map: ResponseMap = Arc::new(std::sync::Mutex::new(HashMap::new()));
let ws_sender_map: WsSenderMap = Arc::new(tokio::sync::Mutex::new(HashMap::new()));

// Create threadsafe callback for bridging events to JS
Expand Down
4 changes: 4 additions & 0 deletions rivetkit-typescript/packages/rivetkit-native/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@ pub struct JsKvEntry {
pub struct HibernatingRequestEntry {
pub gateway_id: Buffer,
pub request_id: Buffer,
pub envoy_message_index: u16,
pub rivet_message_index: u16,
pub path: String,
pub headers: Option<std::collections::HashMap<String, String>>,
}

/// Encode a protocol MessageId into a 10-byte buffer.
Expand Down
44 changes: 44 additions & 0 deletions rivetkit-typescript/packages/rivetkit-native/wrapper.js
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ function wrapHandle(jsHandle) {
const requests = (metaEntries || []).map((e) => ({
gatewayId: Buffer.from(e.gatewayId),
requestId: Buffer.from(e.requestId),
envoyMessageIndex: e.envoyMessageIndex ?? 0,
rivetMessageIndex: e.rivetMessageIndex ?? 0,
path: e.path ?? "",
headers: e.headers ?? {},
}));
jsHandle.restoreHibernatingRequests(actorId, requests);
},
Expand Down Expand Up @@ -308,6 +312,46 @@ function handleEvent(event, config, wrappedHandle) {
);
break;
}
case "websocket_can_hibernate": {
const gatewayId = Buffer.from(event.gatewayId);
const requestId = Buffer.from(event.requestId);
const headers = new Headers(event.headers || {});
headers.set("Upgrade", "websocket");
headers.set("Connection", "Upgrade");
const request = new Request(`http://actor${event.path}`, {
method: event.method,
headers,
});

Promise.resolve(
config.hibernatableWebSocket
? config.hibernatableWebSocket.canHibernate(
event.actorId,
gatewayId,
requestId,
request,
)
: false,
).then(
async (canHibernate) => {
if (handle._raw) {
await handle._raw.respondCallback(event.responseId, {
canHibernate: Boolean(canHibernate),
});
}
},
async (err) => {
console.error("canHibernate error:", err);
if (handle._raw) {
await handle._raw.respondCallback(event.responseId, {
canHibernate: false,
error: String(err),
});
}
},
);
break;
}
case "websocket_open": {
if (config.websocket) {
const messageId = Buffer.from(event.messageId);
Expand Down
Loading