Skip to content

Commit 67191e0

Browse files
committed
fix(pegboard-gateway): enforce tunnel message state
1 parent 8088fbf commit 67191e0

2 files changed

Lines changed: 142 additions & 4 deletions

File tree

engine/packages/pegboard-gateway/src/lib.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use tokio_tungstenite::tungstenite::{
2727
};
2828
use universaldb::utils::IsolationLevel::*;
2929

30-
use crate::shared_state::{InFlightRequestHandle, SharedState};
30+
use crate::shared_state::{InFlightRequestHandle, InFlightRequestState, SharedState};
3131

3232
mod keepalive_task;
3333
mod metrics;
@@ -178,7 +178,12 @@ impl PegboardGateway {
178178
..
179179
} = self
180180
.shared_state
181-
.start_in_flight_request(tunnel_subject, runner_protocol_version, request_id)
181+
.start_in_flight_request(
182+
tunnel_subject,
183+
runner_protocol_version,
184+
request_id,
185+
InFlightRequestState::AwaitingHttpResponseStart,
186+
)
182187
.await;
183188

184189
// Start request
@@ -304,7 +309,16 @@ impl PegboardGateway {
304309
new,
305310
} = self
306311
.shared_state
307-
.start_in_flight_request(tunnel_subject.clone(), runner_protocol_version, request_id)
312+
.start_in_flight_request(
313+
tunnel_subject.clone(),
314+
runner_protocol_version,
315+
request_id,
316+
if after_hibernation {
317+
InFlightRequestState::ActiveWebSocket
318+
} else {
319+
InFlightRequestState::AwaitingWebSocketOpen
320+
},
321+
)
308322
.await;
309323

310324
ensure!(

engine/packages/pegboard-gateway/src/shared_state.rs

Lines changed: 125 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,60 @@ pub struct InFlightRequestHandle {
2626
pub new: bool,
2727
}
2828

29+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30+
pub enum InFlightRequestState {
31+
AwaitingHttpResponseStart,
32+
AwaitingWebSocketOpen,
33+
ActiveWebSocket,
34+
Closed,
35+
}
36+
37+
impl InFlightRequestState {
38+
fn accept_message(&mut self, message_kind: &protocol::mk2::ToServerTunnelMessageKind) -> bool {
39+
use protocol::mk2::ToServerTunnelMessageKind;
40+
41+
match (self, message_kind) {
42+
(
43+
state @ InFlightRequestState::AwaitingHttpResponseStart,
44+
ToServerTunnelMessageKind::ToServerResponseStart(_)
45+
| ToServerTunnelMessageKind::ToServerResponseAbort,
46+
) => {
47+
*state = InFlightRequestState::Closed;
48+
true
49+
}
50+
(
51+
state @ InFlightRequestState::AwaitingWebSocketOpen,
52+
ToServerTunnelMessageKind::ToServerWebSocketOpen(_),
53+
) => {
54+
*state = InFlightRequestState::ActiveWebSocket;
55+
true
56+
}
57+
(
58+
state @ InFlightRequestState::AwaitingWebSocketOpen,
59+
ToServerTunnelMessageKind::ToServerWebSocketClose(_),
60+
)
61+
| (
62+
state @ InFlightRequestState::ActiveWebSocket,
63+
ToServerTunnelMessageKind::ToServerWebSocketClose(_),
64+
) => {
65+
*state = InFlightRequestState::Closed;
66+
true
67+
}
68+
(
69+
InFlightRequestState::ActiveWebSocket,
70+
ToServerTunnelMessageKind::ToServerWebSocketMessage(_)
71+
| ToServerTunnelMessageKind::ToServerWebSocketMessageAck(_),
72+
) => true,
73+
_ => false,
74+
}
75+
}
76+
}
77+
2978
struct InFlightRequest {
3079
/// UPS subject to send messages to for this request.
3180
receiver_subject: String,
3281
protocol_version: u16,
82+
state: InFlightRequestState,
3383
/// Sender for incoming messages to this request.
3484
msg_tx: mpsc::Sender<protocol::mk2::ToServerTunnelMessageKind>,
3585
/// Used to check if the request handler has been dropped.
@@ -134,6 +184,7 @@ impl SharedState {
134184
receiver_subject: String,
135185
protocol_version: u16,
136186
request_id: protocol::mk2::RequestId,
187+
state: InFlightRequestState,
137188
) -> InFlightRequestHandle {
138189
let (msg_tx, msg_rx) = mpsc::channel(128);
139190
let (drop_tx, drop_rx) = watch::channel(None);
@@ -143,6 +194,7 @@ impl SharedState {
143194
entry.insert_entry(InFlightRequest {
144195
receiver_subject,
145196
protocol_version,
197+
state,
146198
msg_tx,
147199
drop_tx,
148200
opened: false,
@@ -159,6 +211,7 @@ impl SharedState {
159211
entry.receiver_subject = receiver_subject;
160212
entry.msg_tx = msg_tx;
161213
entry.drop_tx = drop_tx;
214+
entry.state = state;
162215
entry.opened = false;
163216
entry.last_pong = util::timestamp::now();
164217

@@ -355,7 +408,7 @@ impl SharedState {
355408
Ok(protocol::mk2::ToGateway::ToServerTunnelMessage(msg)) => {
356409
let message_id = msg.message_id;
357410

358-
let Some(in_flight) = self
411+
let Some(mut in_flight) = self
359412
.in_flight_requests
360413
.get_async(&message_id.request_id)
361414
.await
@@ -369,6 +422,18 @@ impl SharedState {
369422
continue;
370423
};
371424

425+
if !in_flight.state.accept_message(&msg.message_kind) {
426+
tracing::warn!(
427+
gateway_id=%protocol::util::id_to_string(&message_id.gateway_id),
428+
request_id=%protocol::util::id_to_string(&message_id.request_id),
429+
message_index=message_id.message_index,
430+
state=?in_flight.state,
431+
message_kind=?msg.message_kind,
432+
"dropping invalid tunnel message for request state"
433+
);
434+
continue;
435+
}
436+
372437
// Send message to the request handler to emulate the real network action
373438
let inner_size = match &msg.message_kind {
374439
protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(
@@ -619,6 +684,65 @@ fn wrapping_gt(a: u16, b: u16) -> bool {
619684
a != b && a.wrapping_sub(b) < u16::MAX / 2
620685
}
621686

687+
#[cfg(test)]
688+
mod tests {
689+
use super::InFlightRequestState;
690+
use rivet_runner_protocol as protocol;
691+
692+
#[test]
693+
fn http_requests_only_accept_http_terminal_messages() {
694+
let mut state = InFlightRequestState::AwaitingHttpResponseStart;
695+
assert!(
696+
state.accept_message(&protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort,)
697+
);
698+
assert_eq!(state, InFlightRequestState::Closed);
699+
700+
let mut state = InFlightRequestState::AwaitingHttpResponseStart;
701+
assert!(!state.accept_message(
702+
&protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(
703+
protocol::mk2::ToServerWebSocketMessage {
704+
data: Vec::new(),
705+
binary: false,
706+
},
707+
),
708+
));
709+
assert_eq!(state, InFlightRequestState::AwaitingHttpResponseStart);
710+
}
711+
712+
#[test]
713+
fn websockets_must_open_before_streaming() {
714+
let mut state = InFlightRequestState::AwaitingWebSocketOpen;
715+
assert!(!state.accept_message(
716+
&protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketMessage(
717+
protocol::mk2::ToServerWebSocketMessage {
718+
data: Vec::new(),
719+
binary: false,
720+
},
721+
),
722+
));
723+
assert_eq!(state, InFlightRequestState::AwaitingWebSocketOpen);
724+
725+
assert!(state.accept_message(
726+
&protocol::mk2::ToServerTunnelMessageKind::ToServerWebSocketOpen(
727+
protocol::mk2::ToServerWebSocketOpen {
728+
can_hibernate: false,
729+
},
730+
),
731+
));
732+
assert_eq!(state, InFlightRequestState::ActiveWebSocket);
733+
}
734+
735+
#[test]
736+
fn active_websockets_reject_http_messages() {
737+
let mut state = InFlightRequestState::ActiveWebSocket;
738+
assert!(
739+
!state
740+
.accept_message(&protocol::mk2::ToServerTunnelMessageKind::ToServerResponseAbort,)
741+
);
742+
assert_eq!(state, InFlightRequestState::ActiveWebSocket);
743+
}
744+
}
745+
622746
// fn wrapping_lt(a: u16, b: u16) -> bool {
623747
// b.wrapping_sub(a) < u16::MAX / 2
624748
// }

0 commit comments

Comments
 (0)