Skip to content

Commit 7c84b2b

Browse files
authored
fix(pegboard-runner): clear terminal tunnel routes (#4621)
# Description Please include a summary of the changes and the related issue. Please also include relevant motivation and context. ## Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] This change requires a documentation update ## How Has This Been Tested? Please describe the tests that you ran to verify your changes. ## Checklist: - [ ] My code follows the style guidelines of this project - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have made corresponding changes to the documentation - [ ] My changes generate no new warnings - [ ] I have added tests that prove my fix is effective or that my feature works - [ ] New and existing unit tests pass locally with my changes
1 parent 49c97f4 commit 7c84b2b

2 files changed

Lines changed: 186 additions & 10 deletions

File tree

engine/packages/pegboard-runner/src/ws_to_tunnel_task.rs

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,9 @@ async fn handle_tunnel_message_mk2(
860860
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
861861
msg: protocol::mk2::ToServerTunnelMessage,
862862
) -> Result<()> {
863+
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
864+
let clear_route = should_clear_tunnel_route_mk2(&msg.message_kind);
865+
863866
// Extract inner data length before consuming msg
864867
let inner_data_len = tunnel_message_inner_data_len_mk2(&msg.message_kind);
865868

@@ -868,10 +871,7 @@ async fn handle_tunnel_message_mk2(
868871
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
869872
}
870873

871-
if !authorized_tunnel_routes
872-
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
873-
.await
874-
{
874+
if !authorized_tunnel_routes.contains_async(&route).await {
875875
return Err(
876876
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
877877
);
@@ -899,6 +899,10 @@ async fn handle_tunnel_message_mk2(
899899
)
900900
})?;
901901

902+
if clear_route {
903+
authorized_tunnel_routes.remove_async(&route).await;
904+
}
905+
902906
Ok(())
903907
}
904908

@@ -909,6 +913,9 @@ async fn handle_tunnel_message_mk1(
909913
authorized_tunnel_routes: &HashMap<(protocol::mk2::GatewayId, protocol::mk2::RequestId), ()>,
910914
msg: protocol::ToServerTunnelMessage,
911915
) -> Result<()> {
916+
let route = (msg.message_id.gateway_id, msg.message_id.request_id);
917+
let clear_route = should_clear_tunnel_route_mk1(&msg.message_kind);
918+
912919
// Ignore DeprecatedTunnelAck messages (used only for backwards compatibility)
913920
if matches!(
914921
msg.message_kind,
@@ -925,10 +932,7 @@ async fn handle_tunnel_message_mk1(
925932
return Err(errors::WsError::InvalidPacket("payload too large".to_string()).build());
926933
}
927934

928-
if !authorized_tunnel_routes
929-
.contains_async(&(msg.message_id.gateway_id, msg.message_id.request_id))
930-
.await
931-
{
935+
if !authorized_tunnel_routes.contains_async(&route).await {
932936
return Err(
933937
errors::WsError::InvalidPacket("unauthorized tunnel message".to_string()).build(),
934938
);
@@ -950,9 +954,35 @@ async fn handle_tunnel_message_mk1(
950954
)
951955
})?;
952956

957+
if clear_route {
958+
authorized_tunnel_routes.remove_async(&route).await;
959+
}
960+
953961
Ok(())
954962
}
955963

964+
fn should_clear_tunnel_route_mk2(msg_kind: &protocol::mk2::ToServerTunnelMessageKind) -> bool {
965+
match msg_kind {
966+
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart(response) => {
967+
!response.stream
968+
}
969+
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish,
970+
protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort
971+
| protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true,
972+
_ => false,
973+
}
974+
}
975+
976+
fn should_clear_tunnel_route_mk1(msg_kind: &protocol::ToServerTunnelMessageKind) -> bool {
977+
match msg_kind {
978+
protocol::ToServerTunnelMessageKind::ToServerResponseStart(response) => !response.stream,
979+
protocol::ToServerTunnelMessageKind::ToServerResponseChunk(chunk) => chunk.finish,
980+
protocol::ToServerTunnelMessageKind::ToServerResponseAbort
981+
| protocol::ToServerTunnelMessageKind::ToServerWebSocketClose(_) => true,
982+
_ => false,
983+
}
984+
}
985+
956986
/// Returns the length of the inner data payload for a tunnel message kind.
957987
fn tunnel_message_inner_data_len_mk2(kind: &protocol::mk2::ToServerTunnelMessageKind) -> usize {
958988
use protocol::mk2::ToServerTunnelMessageKind;

engine/packages/pegboard-runner/tests/support/ws_to_tunnel_task.rs

Lines changed: 148 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,74 @@ fn response_abort_message_mk2(
2525
}
2626
}
2727

28+
fn response_start_message_mk2(
29+
gateway_id: protocol::mk2::GatewayId,
30+
request_id: protocol::mk2::RequestId,
31+
) -> protocol::mk2::ToServerTunnelMessage {
32+
response_start_message_mk2_with_stream(gateway_id, request_id, false)
33+
}
34+
35+
fn response_start_message_mk2_with_stream(
36+
gateway_id: protocol::mk2::GatewayId,
37+
request_id: protocol::mk2::RequestId,
38+
stream: bool,
39+
) -> protocol::mk2::ToServerTunnelMessage {
40+
protocol::mk2::ToServerTunnelMessage {
41+
message_id: protocol::mk2::MessageId {
42+
gateway_id,
43+
request_id,
44+
message_index: 0,
45+
},
46+
message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerResponseStart(
47+
protocol::mk2::ToServerResponseStart {
48+
status: 200,
49+
headers: Default::default(),
50+
body: None,
51+
stream,
52+
},
53+
),
54+
}
55+
}
56+
57+
fn response_chunk_message_mk2(
58+
gateway_id: protocol::mk2::GatewayId,
59+
request_id: protocol::mk2::RequestId,
60+
finish: bool,
61+
) -> protocol::mk2::ToServerTunnelMessage {
62+
protocol::mk2::ToServerTunnelMessage {
63+
message_id: protocol::mk2::MessageId {
64+
gateway_id,
65+
request_id,
66+
message_index: 0,
67+
},
68+
message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerResponseChunk(
69+
protocol::mk2::ToServerResponseChunk {
70+
body: b"chunk".to_vec(),
71+
finish,
72+
},
73+
),
74+
}
75+
}
76+
77+
fn websocket_message_mk2(
78+
gateway_id: protocol::mk2::GatewayId,
79+
request_id: protocol::mk2::RequestId,
80+
) -> protocol::mk2::ToServerTunnelMessage {
81+
protocol::mk2::ToServerTunnelMessage {
82+
message_id: protocol::mk2::MessageId {
83+
gateway_id,
84+
request_id,
85+
message_index: 0,
86+
},
87+
message_kind: protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(
88+
protocol::mk2::ToServerWebSocketMessage {
89+
data: b"ping".to_vec(),
90+
binary: false,
91+
},
92+
),
93+
}
94+
}
95+
2896
fn response_abort_message_mk1(
2997
gateway_id: protocol::mk2::GatewayId,
3098
request_id: protocol::mk2::RequestId,
@@ -39,6 +107,74 @@ fn response_abort_message_mk1(
39107
}
40108
}
41109

110+
fn websocket_message_mk1(
111+
gateway_id: protocol::mk2::GatewayId,
112+
request_id: protocol::mk2::RequestId,
113+
) -> protocol::ToServerTunnelMessage {
114+
protocol::ToServerTunnelMessage {
115+
message_id: protocol::MessageId {
116+
gateway_id,
117+
request_id,
118+
message_index: 0,
119+
},
120+
message_kind: protocol::ToServerTunnelMessageKind::ToServerWebSocketMessage(
121+
protocol::ToServerWebSocketMessage {
122+
data: b"ping".to_vec(),
123+
binary: false,
124+
},
125+
),
126+
}
127+
}
128+
129+
fn response_start_message_mk1(
130+
gateway_id: protocol::mk2::GatewayId,
131+
request_id: protocol::mk2::RequestId,
132+
) -> protocol::ToServerTunnelMessage {
133+
response_start_message_mk1_with_stream(gateway_id, request_id, false)
134+
}
135+
136+
fn response_start_message_mk1_with_stream(
137+
gateway_id: protocol::mk2::GatewayId,
138+
request_id: protocol::mk2::RequestId,
139+
stream: bool,
140+
) -> protocol::ToServerTunnelMessage {
141+
protocol::ToServerTunnelMessage {
142+
message_id: protocol::MessageId {
143+
gateway_id,
144+
request_id,
145+
message_index: 0,
146+
},
147+
message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseStart(
148+
protocol::ToServerResponseStart {
149+
status: 200,
150+
headers: Default::default(),
151+
body: None,
152+
stream,
153+
},
154+
),
155+
}
156+
}
157+
158+
fn response_chunk_message_mk1(
159+
gateway_id: protocol::mk2::GatewayId,
160+
request_id: protocol::mk2::RequestId,
161+
finish: bool,
162+
) -> protocol::ToServerTunnelMessage {
163+
protocol::ToServerTunnelMessage {
164+
message_id: protocol::MessageId {
165+
gateway_id,
166+
request_id,
167+
message_index: 0,
168+
},
169+
message_kind: protocol::ToServerTunnelMessageKind::ToServerResponseChunk(
170+
protocol::ToServerResponseChunk {
171+
body: b"chunk".to_vec(),
172+
finish,
173+
},
174+
),
175+
}
176+
}
177+
42178
#[tokio::test]
43179
async fn rejects_unissued_mk2_tunnel_message_pairs() {
44180
let pubsub = memory_pubsub("pegboard-runner-ws-to-tunnel-test-reject-mk2");
@@ -82,7 +218,7 @@ async fn republishes_issued_mk2_tunnel_message_pairs() {
82218
&pubsub,
83219
1024,
84220
&authorized_tunnel_routes,
85-
response_abort_message_mk2(gateway_id, request_id),
221+
websocket_message_mk2(gateway_id, request_id),
86222
)
87223
.await
88224
.unwrap();
@@ -92,6 +228,11 @@ async fn republishes_issued_mk2_tunnel_message_pairs() {
92228
.unwrap()
93229
.unwrap();
94230
assert!(matches!(msg, NextOutput::Message(_)));
231+
assert!(
232+
authorized_tunnel_routes
233+
.contains_async(&(gateway_id, request_id))
234+
.await
235+
);
95236
}
96237

97238
#[tokio::test]
@@ -137,7 +278,7 @@ async fn republishes_issued_mk1_tunnel_message_pairs() {
137278
&pubsub,
138279
1024,
139280
&authorized_tunnel_routes,
140-
response_abort_message_mk1(gateway_id, request_id),
281+
websocket_message_mk1(gateway_id, request_id),
141282
)
142283
.await
143284
.unwrap();
@@ -147,4 +288,9 @@ async fn republishes_issued_mk1_tunnel_message_pairs() {
147288
.unwrap()
148289
.unwrap();
149290
assert!(matches!(msg, NextOutput::Message(_)));
291+
assert!(
292+
authorized_tunnel_routes
293+
.contains_async(&(gateway_id, request_id))
294+
.await
295+
);
150296
}

0 commit comments

Comments
 (0)