@@ -26,10 +26,60 @@ pub struct InFlightRequestHandle {
2626 pub new : bool ,
2727}
2828
29+ #[ derive( Debug , Clone , Copy , PartialEq , Eq ) ]
30+ pub enum InFlightRequestState {
31+ AwaitingHttpResponseStart ,
32+ AwaitingWebSocketOpen ,
33+ ActiveWebSocket ,
34+ Closed ,
35+ }
36+
37+ impl InFlightRequestState {
38+ fn accept_message ( & mut self , message_kind : & protocol:: mk2:: ToServerTunnelMessageKind ) -> bool {
39+ use protocol:: mk2:: ToServerTunnelMessageKind ;
40+
41+ match ( self , message_kind) {
42+ (
43+ state @ InFlightRequestState :: AwaitingHttpResponseStart ,
44+ ToServerTunnelMessageKind :: ToServerResponseStart ( _)
45+ | ToServerTunnelMessageKind :: ToServerResponseAbort ,
46+ ) => {
47+ * state = InFlightRequestState :: Closed ;
48+ true
49+ }
50+ (
51+ state @ InFlightRequestState :: AwaitingWebSocketOpen ,
52+ ToServerTunnelMessageKind :: ToServerWebSocketOpen ( _) ,
53+ ) => {
54+ * state = InFlightRequestState :: ActiveWebSocket ;
55+ true
56+ }
57+ (
58+ state @ InFlightRequestState :: AwaitingWebSocketOpen ,
59+ ToServerTunnelMessageKind :: ToServerWebSocketClose ( _) ,
60+ )
61+ | (
62+ state @ InFlightRequestState :: ActiveWebSocket ,
63+ ToServerTunnelMessageKind :: ToServerWebSocketClose ( _) ,
64+ ) => {
65+ * state = InFlightRequestState :: Closed ;
66+ true
67+ }
68+ (
69+ InFlightRequestState :: ActiveWebSocket ,
70+ ToServerTunnelMessageKind :: ToServerWebSocketMessage ( _)
71+ | ToServerTunnelMessageKind :: ToServerWebSocketMessageAck ( _) ,
72+ ) => true ,
73+ _ => false ,
74+ }
75+ }
76+ }
77+
2978struct InFlightRequest {
3079 /// UPS subject to send messages to for this request.
3180 receiver_subject : String ,
3281 protocol_version : u16 ,
82+ state : InFlightRequestState ,
3383 /// Sender for incoming messages to this request.
3484 msg_tx : mpsc:: Sender < protocol:: mk2:: ToServerTunnelMessageKind > ,
3585 /// Used to check if the request handler has been dropped.
@@ -134,6 +184,7 @@ impl SharedState {
134184 receiver_subject : String ,
135185 protocol_version : u16 ,
136186 request_id : protocol:: mk2:: RequestId ,
187+ state : InFlightRequestState ,
137188 ) -> InFlightRequestHandle {
138189 let ( msg_tx, msg_rx) = mpsc:: channel ( 128 ) ;
139190 let ( drop_tx, drop_rx) = watch:: channel ( None ) ;
@@ -143,6 +194,7 @@ impl SharedState {
143194 entry. insert_entry ( InFlightRequest {
144195 receiver_subject,
145196 protocol_version,
197+ state,
146198 msg_tx,
147199 drop_tx,
148200 opened : false ,
@@ -159,6 +211,7 @@ impl SharedState {
159211 entry. receiver_subject = receiver_subject;
160212 entry. msg_tx = msg_tx;
161213 entry. drop_tx = drop_tx;
214+ entry. state = state;
162215 entry. opened = false ;
163216 entry. last_pong = util:: timestamp:: now ( ) ;
164217
@@ -355,7 +408,7 @@ impl SharedState {
355408 Ok ( protocol:: mk2:: ToGateway :: ToServerTunnelMessage ( msg) ) => {
356409 let message_id = msg. message_id ;
357410
358- let Some ( in_flight) = self
411+ let Some ( mut in_flight) = self
359412 . in_flight_requests
360413 . get_async ( & message_id. request_id )
361414 . await
@@ -369,6 +422,18 @@ impl SharedState {
369422 continue ;
370423 } ;
371424
425+ if !in_flight. state . accept_message ( & msg. message_kind ) {
426+ tracing:: warn!(
427+ gateway_id=%protocol:: util:: id_to_string( & message_id. gateway_id) ,
428+ request_id=%protocol:: util:: id_to_string( & message_id. request_id) ,
429+ message_index=message_id. message_index,
430+ state=?in_flight. state,
431+ message_kind=?msg. message_kind,
432+ "dropping invalid tunnel message for request state"
433+ ) ;
434+ continue ;
435+ }
436+
372437 // Send message to the request handler to emulate the real network action
373438 let inner_size = match & msg. message_kind {
374439 protocol:: mk2:: ToServerTunnelMessageKind :: ToServerWebSocketMessage (
@@ -619,6 +684,65 @@ fn wrapping_gt(a: u16, b: u16) -> bool {
619684 a != b && a. wrapping_sub ( b) < u16:: MAX / 2
620685}
621686
687+ #[ cfg( test) ]
688+ mod tests {
689+ use super :: InFlightRequestState ;
690+ use rivet_runner_protocol as protocol;
691+
692+ #[ test]
693+ fn http_requests_only_accept_http_terminal_messages ( ) {
694+ let mut state = InFlightRequestState :: AwaitingHttpResponseStart ;
695+ assert ! (
696+ state. accept_message( & protocol:: mk2:: ToServerTunnelMessageKind :: ToServerResponseAbort , )
697+ ) ;
698+ assert_eq ! ( state, InFlightRequestState :: Closed ) ;
699+
700+ let mut state = InFlightRequestState :: AwaitingHttpResponseStart ;
701+ assert ! ( !state. accept_message(
702+ & protocol:: mk2:: ToServerTunnelMessageKind :: ToServerWebSocketMessage (
703+ protocol:: mk2:: ToServerWebSocketMessage {
704+ data: Vec :: new( ) ,
705+ binary: false ,
706+ } ,
707+ ) ,
708+ ) ) ;
709+ assert_eq ! ( state, InFlightRequestState :: AwaitingHttpResponseStart ) ;
710+ }
711+
712+ #[ test]
713+ fn websockets_must_open_before_streaming ( ) {
714+ let mut state = InFlightRequestState :: AwaitingWebSocketOpen ;
715+ assert ! ( !state. accept_message(
716+ & protocol:: mk2:: ToServerTunnelMessageKind :: ToServerWebSocketMessage (
717+ protocol:: mk2:: ToServerWebSocketMessage {
718+ data: Vec :: new( ) ,
719+ binary: false ,
720+ } ,
721+ ) ,
722+ ) ) ;
723+ assert_eq ! ( state, InFlightRequestState :: AwaitingWebSocketOpen ) ;
724+
725+ assert ! ( state. accept_message(
726+ & protocol:: mk2:: ToServerTunnelMessageKind :: ToServerWebSocketOpen (
727+ protocol:: mk2:: ToServerWebSocketOpen {
728+ can_hibernate: false ,
729+ } ,
730+ ) ,
731+ ) ) ;
732+ assert_eq ! ( state, InFlightRequestState :: ActiveWebSocket ) ;
733+ }
734+
735+ #[ test]
736+ fn active_websockets_reject_http_messages ( ) {
737+ let mut state = InFlightRequestState :: ActiveWebSocket ;
738+ assert ! (
739+ !state
740+ . accept_message( & protocol:: mk2:: ToServerTunnelMessageKind :: ToServerResponseAbort , )
741+ ) ;
742+ assert_eq ! ( state, InFlightRequestState :: ActiveWebSocket ) ;
743+ }
744+ }
745+
622746// fn wrapping_lt(a: u16, b: u16) -> bool {
623747// b.wrapping_sub(a) < u16::MAX / 2
624748// }
0 commit comments