diff --git a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs index 550e9ff216..1c4b1a7c80 100644 --- a/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs @@ -860,6 +860,9 @@ async fn handle_tunnel_message_mk2( authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>, msg: protocol::mk2::ToServerTunnelMessage, ) -> Result<()> { + let route = (msg.message_id.gateway_id, msg.message_id.request_id); + let clear_route = should_clear_tunnel_route_mk2(&msg.message_kind); + // Extract inner data length before consuming msg let inner_data_len = tunnel_message_inner_data_len_mk2(&msg.message_kind); @@ -868,10 +871,7 @@ async fn handle_tunnel_message_mk2( return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build()); } - if !authorized_tunnel_routes - .contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id)) - .await - { + if !authorized_tunnel_routes.contains_async(&route).await { return Err( errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(), ); @@ -899,6 +899,10 @@ async fn handle_tunnel_message_mk2( ) })?; + if clear_route { + authorized_tunnel_routes.remove_async(&route).await; + } + Ok(()) } @@ -909,6 +913,9 @@ async fn handle_tunnel_message_mk1( authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>, msg: protocol::ToServerTunnelMessage, ) -> Result<()> { + let route = (msg.message_id.gateway_id, msg.message_id.request_id); + let clear_route = should_clear_tunnel_route_mk1(&msg.message_kind); + // Ignore DeprecatedTunnelAck messages (used only for backwards compatibility) if matches!( msg.message_kind, @@ -925,10 +932,7 @@ async fn handle_tunnel_message_mk1( return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build()); } - if !authorized_tunnel_routes - .contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id)) - .await - { + if !authorized_tunnel_routes.contains_async(&route).await { return Err( errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(), ); @@ -950,9 +954,35 @@ async fn handle_tunnel_message_mk1( ) })?; + if clear_route { + authorized_tunnel_routes.remove_async(&route).await; + } + Ok(()) } +fn should_clear_tunnel_route_mk2(msg_kind: &protocol::mk2::ToServerTunnelMessageKind) -> bool { + match msg_kind { + protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart(response) => { + !response.stream + } + protocol::mk2::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish, + protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort + | protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true, + _ => false, + } +} + +fn should_clear_tunnel_route_mk1(msg_kind: &protocol::ToServerTunnelMessageKind) -> bool { + match msg_kind { + protocol::ToServerTunnelMessageKind::ToServerResponseStart(response) => !response.stream, + protocol::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish, + protocol::ToServerTunnelMessageKind::ToServerResponseAbort + | protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true, + _ => false, + } +} + /// Returns the length of the inner data payload for a tunnel message kind. fn tunnel_message_inner_data_len_mk2(kind: &protocol::mk2::ToServerTunnelMessageKind) -> usize { use protocol::mk2::ToServerTunnelMessageKind; diff --git a/engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs b/engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs index 53b4278136..fbfcf2dc66 100644 --- a/engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs +++ b/engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs @@ -25,6 +25,74 @@ fn response_abort_message_mk2( } } +fn response_start_message_mk2( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, +) -> protocol::mk2::ToServerTunnelMessage { + response_start_message_mk2_with_stream(gateway_id, request_id, false) +} + +fn response_start_message_mk2_with_stream( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, + stream: bool, +) -> protocol::mk2::ToServerTunnelMessage { + protocol::mk2::ToServerTunnelMessage { + message_id: protocol::mk2::MessageId { + gateway_id, + request_id, + message_index: 0, + }, + message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart( + protocol::mk2::ToServerResponseStart { + status: 200, + headers: Default::default(), + body: None, + stream, + }, + ), + } +} + +fn response_chunk_message_mk2( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, + finish: bool, +) -> protocol::mk2::ToServerTunnelMessage { + protocol::mk2::ToServerTunnelMessage { + message_id: protocol::mk2::MessageId { + gateway_id, + request_id, + message_index: 0, + }, + message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerResponseChunk( + protocol::mk2::ToServerResponseChunk { + body: b"chunk".to_vec(), + finish, + }, + ), + } +} + +fn websocket_message_mk2( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, +) -> protocol::mk2::ToServerTunnelMessage { + protocol::mk2::ToServerTunnelMessage { + message_id: protocol::mk2::MessageId { + gateway_id, + request_id, + message_index: 0, + }, + message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage( + protocol::mk2::ToServerWebSocketMessage { + data: b"ping".to_vec(), + binary: false, + }, + ), + } +} + fn response_abort_message_mk1( gateway_id: protocol::mk2::GatewayId, request_id: protocol::mk2::RequestId, @@ -39,6 +107,74 @@ fn response_abort_message_mk1( } } +fn websocket_message_mk1( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, +) -> protocol::ToServerTunnelMessage { + protocol::ToServerTunnelMessage { + message_id: protocol::MessageId { + gateway_id, + request_id, + message_index: 0, + }, + message_kind: protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage( + protocol::ToServerWebSocketMessage { + data: b"ping".to_vec(), + binary: false, + }, + ), + } +} + +fn response_start_message_mk1( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, +) -> protocol::ToServerTunnelMessage { + response_start_message_mk1_with_stream(gateway_id, request_id, false) +} + +fn response_start_message_mk1_with_stream( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, + stream: bool, +) -> protocol::ToServerTunnelMessage { + protocol::ToServerTunnelMessage { + message_id: protocol::MessageId { + gateway_id, + request_id, + message_index: 0, + }, + message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseStart( + protocol::ToServerResponseStart { + status: 200, + headers: Default::default(), + body: None, + stream, + }, + ), + } +} + +fn response_chunk_message_mk1( + gateway_id: protocol::mk2::GatewayId, + request_id: protocol::mk2::RequestId, + finish: bool, +) -> protocol::ToServerTunnelMessage { + protocol::ToServerTunnelMessage { + message_id: protocol::MessageId { + gateway_id, + request_id, + message_index: 0, + }, + message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseChunk( + protocol::ToServerResponseChunk { + body: b"chunk".to_vec(), + finish, + }, + ), + } +} + #[tokio::test] async fn rejects_unissued_mk2_tunnel_message_pairs() { let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-reject-mk2"); @@ -82,7 +218,7 @@ async fn republishes_issued_mk2_tunnel_message_pairs() { &pubsub, 1024, &authorized_tunnel_routes, - response_abort_message_mk2(gateway_id, request_id), + websocket_message_mk2(gateway_id, request_id), ) .await .unwrap(); @@ -92,6 +228,11 @@ async fn republishes_issued_mk2_tunnel_message_pairs() { .unwrap() .unwrap(); assert!(matches!(msg, NextOutput::Message(_))); + assert!( + authorized_tunnel_routes + .contains_async(&(gateway_id, request_id)) + .await + ); } #[tokio::test] @@ -137,7 +278,7 @@ async fn republishes_issued_mk1_tunnel_message_pairs() { &pubsub, 1024, &authorized_tunnel_routes, - response_abort_message_mk1(gateway_id, request_id), + websocket_message_mk1(gateway_id, request_id), ) .await .unwrap(); @@ -147,4 +288,9 @@ async fn republishes_issued_mk1_tunnel_message_pairs() { .unwrap() .unwrap(); assert!(matches!(msg, NextOutput::Message(_))); + assert!( + authorized_tunnel_routes + .contains_async(&(gateway_id, request_id)) + .await + ); }