11import logging
22import time
33from collections .abc import Callable
4- from dataclasses import dataclass
54from enum import IntEnum
65from pathlib import Path
76from socket import IPPROTO_TCP , TCP_NODELAY , socket
1817RLBOT_SERVER_PORT = 23234
1918
2019
21- class SocketDataType (IntEnum ):
22- """
23- See https://github.com/RLBot/core/blob/master/RLBotCS/Types/DataType.cs
24- and https://wiki.rlbot.org/framework/sockets-specification/#data-types
25- """
26-
27- NONE = 0
28- GAME_PACKET = 1
29- FIELD_INFO = 2
30- START_COMMAND = 3
31- MATCH_CONFIGURATION = 4
32- PLAYER_INPUT = 5
33- DESIRED_GAME_STATE = 6
34- RENDER_GROUP = 7
35- REMOVE_RENDER_GROUP = 8
36- MATCH_COMMUNICATION = 9
37- BALL_PREDICTION = 10
38- CONNECTION_SETTINGS = 11
39- STOP_COMMAND = 12
40- SET_LOADOUT = 13
41- INIT_COMPLETE = 14
42- CONTROLLABLE_TEAM_INFO = 15
43-
44-
45- @dataclass (repr = False , eq = False , frozen = True , match_args = False , slots = True )
46- class SocketMessage :
47- type : SocketDataType
48- data : bytes
49-
50-
5120class MsgHandlingResult (IntEnum ):
5221 TERMINATED = 0
5322 NO_INCOMING_MSGS = 1
@@ -66,7 +35,6 @@ class SocketRelay:
6635 is_connected = False
6736 _running = False
6837 """Indicates whether a messages are being handled by the `run` loop (potentially in a background thread)"""
69- _ball_pred = flat .BallPrediction ()
7038
7139 on_connect_handlers : list [Callable [[], None ]] = []
7240 packet_handlers : list [Callable [[flat .GamePacket ], None ]] = []
@@ -77,7 +45,7 @@ class SocketRelay:
7745 controllable_team_info_handlers : list [
7846 Callable [[flat .ControllableTeamInfo ], None ]
7947 ] = []
80- raw_handlers : list [Callable [[SocketMessage ], None ]] = []
48+ raw_handlers : list [Callable [[flat . CoreMessage ], None ]] = []
8149
8250 def __init__ (
8351 self ,
@@ -116,68 +84,61 @@ def _read_exact(self, n: int) -> bytes:
11684 pos += cr
11785 return bytes (buff )
11886
119- def read_message (self ) -> SocketMessage :
120- type_int = self ._read_int ()
87+ def read_message (self ) -> bytes :
12188 size = self ._read_int ()
122- data = self ._read_exact (size )
123- return SocketMessage (SocketDataType (type_int ), data )
89+ return self ._read_exact (size )
12490
125- def send_bytes (self , data : bytes , data_type : SocketDataType ):
91+ def send_bytes (self , data : bytes ):
12692 assert self .is_connected , "Connection has not been established"
12793
12894 size = len (data )
12995 if size > MAX_SIZE_2_BYTES :
130- self .logger .error (
131- "Couldn't send %s message because it was too big!" , data_type .name
132- )
96+ self .logger .error ("Couldn't send message because it was too big!" )
13397 return
13498
135- message = self ._int_to_bytes (data_type ) + self . _int_to_bytes ( size ) + data
99+ message = self ._int_to_bytes (size ) + data
136100 self .socket .sendall (message )
137101
138- def send_init_complete (self ):
139- self .send_bytes (bytes (), SocketDataType .INIT_COMPLETE )
140-
141- def send_set_loadout (self , set_loadout : flat .SetLoadout ):
142- self .send_bytes (set_loadout .pack (), SocketDataType .SET_LOADOUT )
143-
144- def send_match_comm (self , match_comm : flat .MatchComm ):
145- self .send_bytes (match_comm .pack (), SocketDataType .MATCH_COMMUNICATION )
146-
147- def send_player_input (self , player_input : flat .PlayerInput ):
148- self .send_bytes (player_input .pack (), SocketDataType .PLAYER_INPUT )
149-
150- def send_game_state (self , game_state : flat .DesiredGameState ):
151- self .send_bytes (game_state .pack (), SocketDataType .DESIRED_GAME_STATE )
152-
153- def send_render_group (self , render_group : flat .RenderGroup ):
154- self .send_bytes (render_group .pack (), SocketDataType .RENDER_GROUP )
102+ def send_msg (
103+ self ,
104+ msg : (
105+ flat .DisconnectSignal
106+ | flat .StartCommand
107+ | flat .MatchConfiguration
108+ | flat .PlayerInput
109+ | flat .DesiredGameState
110+ | flat .RenderGroup
111+ | flat .RemoveRenderGroup
112+ | flat .MatchComm
113+ | flat .ConnectionSettings
114+ | flat .StopCommand
115+ | flat .SetLoadout
116+ | flat .InitComplete
117+ ),
118+ ):
119+ self .send_bytes (flat .InterfacePacket (msg ).pack ())
155120
156121 def remove_render_group (self , group_id : int ):
157- flatbuffer = flat .RemoveRenderGroup (group_id ).pack ()
158- self .send_bytes (flatbuffer , SocketDataType .REMOVE_RENDER_GROUP )
122+ self .send_msg (flat .RemoveRenderGroup (group_id ))
159123
160124 def stop_match (self , shutdown_server : bool = False ):
161- flatbuffer = flat .StopCommand (shutdown_server ).pack ()
162- self .send_bytes (flatbuffer , SocketDataType .STOP_COMMAND )
125+ self .send_msg (flat .StopCommand (shutdown_server ))
163126
164127 def start_match (self , match_config : Path | flat .MatchConfiguration ):
165128 self .logger .info ("Python interface is attempting to start match..." )
166129
167130 match match_config :
168131 case Path () as path :
169132 string_path = str (path .absolute ().resolve ())
170- flatbuffer = flat .StartCommand (string_path ).pack ()
171- flat_type = SocketDataType .START_COMMAND
133+ flatbuffer = flat .StartCommand (string_path )
172134 case flat .MatchConfiguration () as settings :
173- flatbuffer = settings .pack ()
174- flat_type = SocketDataType .MATCH_CONFIGURATION
135+ flatbuffer = settings
175136 case _:
176137 raise ValueError (
177138 "Expected MatchSettings or path to match settings toml file"
178139 )
179140
180- self .send_bytes (flatbuffer , flat_type )
141+ self .send_msg (flatbuffer )
181142
182143 def connect (
183144 self ,
@@ -242,13 +203,14 @@ def connect(
242203 for handler in self .on_connect_handlers :
243204 handler ()
244205
245- flatbuffer = flat .ConnectionSettings (
246- agent_id = self .agent_id ,
247- wants_ball_predictions = wants_ball_predictions ,
248- wants_comms = wants_match_communications ,
249- close_between_matches = close_between_matches ,
250- ).pack ()
251- self .send_bytes (flatbuffer , SocketDataType .CONNECTION_SETTINGS )
206+ self .send_msg (
207+ flat .ConnectionSettings (
208+ agent_id = self .agent_id ,
209+ wants_ball_predictions = wants_ball_predictions ,
210+ wants_comms = wants_match_communications ,
211+ close_between_matches = close_between_matches ,
212+ )
213+ )
252214
253215 def run (self , * , background_thread : bool = False ):
254216 """
@@ -286,16 +248,14 @@ def handle_incoming_messages(self, blocking: bool = False) -> MsgHandlingResult:
286248 return self .handle_incoming_message (incoming_message )
287249 except flat .InvalidFlatbuffer as e :
288250 self .logger .error (
289- "Error while unpacking message of type %s (%s bytes): %s" ,
290- incoming_message .type .name ,
291- len (incoming_message .data ),
251+ "Error while unpacking message (%s bytes): %s" ,
252+ len (incoming_message ),
292253 e ,
293254 )
294255 return MsgHandlingResult .TERMINATED
295256 except Exception as e :
296257 self .logger .error (
297- "Unexpected error while handling message of type %s: %s" ,
298- incoming_message .type .name ,
258+ "Unexpected error while handling message of type: %s" ,
299259 e ,
300260 )
301261 return MsgHandlingResult .TERMINATED
@@ -306,56 +266,43 @@ def handle_incoming_messages(self, blocking: bool = False) -> MsgHandlingResult:
306266 self .logger .error ("SocketRelay disconnected unexpectedly!" )
307267 return MsgHandlingResult .TERMINATED
308268
309- def handle_incoming_message (
310- self , incoming_message : SocketMessage
311- ) -> MsgHandlingResult :
269+ def handle_incoming_message (self , incoming_message : bytes ) -> MsgHandlingResult :
312270 """
313271 Handles a messages by passing it to the relevant handlers.
314272 Returns True if the message was NOT a shutdown request (i.e. NONE).
315273 """
316274
275+ flatbuffer = flat .CorePacket .unpack (incoming_message ).message
276+
317277 for raw_handler in self .raw_handlers :
318- raw_handler (incoming_message )
278+ raw_handler (flatbuffer )
319279
320- match incoming_message . type :
321- case SocketDataType . NONE :
280+ match flatbuffer . item :
281+ case flat . DisconnectSignal () :
322282 return MsgHandlingResult .TERMINATED
323- case SocketDataType .GAME_PACKET :
324- if len (self .packet_handlers ) > 0 :
325- packet = flat .GamePacket .unpack (incoming_message .data )
326- for handler in self .packet_handlers :
327- handler (packet )
328- case SocketDataType .FIELD_INFO :
329- if len (self .field_info_handlers ) > 0 :
330- field_info = flat .FieldInfo .unpack (incoming_message .data )
331- for handler in self .field_info_handlers :
332- handler (field_info )
333- case SocketDataType .MATCH_CONFIGURATION :
334- if len (self .match_config_handlers ) > 0 :
335- match_settings = flat .MatchConfiguration .unpack (
336- incoming_message .data
337- )
338- for handler in self .match_config_handlers :
339- handler (match_settings )
340- case SocketDataType .MATCH_COMMUNICATION :
341- if len (self .match_comm_handlers ) > 0 :
342- match_comm = flat .MatchComm .unpack (incoming_message .data )
343- for handler in self .match_comm_handlers :
344- handler (match_comm )
345- case SocketDataType .BALL_PREDICTION :
346- if len (self .ball_prediction_handlers ) > 0 :
347- self ._ball_pred .unpack_with (incoming_message .data )
348- for handler in self .ball_prediction_handlers :
349- handler (self ._ball_pred )
350- case SocketDataType .CONTROLLABLE_TEAM_INFO :
351- if len (self .controllable_team_info_handlers ) > 0 :
352- player_mappings = flat .ControllableTeamInfo .unpack (
353- incoming_message .data
354- )
355- for handler in self .controllable_team_info_handlers :
356- handler (player_mappings )
283+ case flat .GamePacket () as packet :
284+ for handler in self .packet_handlers :
285+ handler (packet )
286+ case flat .FieldInfo () as field_info :
287+ for handler in self .field_info_handlers :
288+ handler (field_info )
289+ case flat .MatchConfiguration () as match_config :
290+ for handler in self .match_config_handlers :
291+ handler (match_config )
292+ case flat .MatchComm () as match_comm :
293+ for handler in self .match_comm_handlers :
294+ handler (match_comm )
295+ case flat .BallPrediction () as ball_prediction :
296+ for handler in self .ball_prediction_handlers :
297+ handler (ball_prediction )
298+ case flat .ControllableTeamInfo () as controllable_team_info :
299+ for handler in self .controllable_team_info_handlers :
300+ handler (controllable_team_info )
357301 case _:
358- pass
302+ self .logger .warning (
303+ "Received unknown message type: %s" ,
304+ type (flatbuffer .item ).__name__ ,
305+ )
359306
360307 return MsgHandlingResult .MORE_MSGS_QUEUED
361308
@@ -364,7 +311,7 @@ def disconnect(self):
364311 self .logger .warning ("Asked to disconnect but was already disconnected." )
365312 return
366313
367- self .send_bytes ( bytes ([ 1 ]), SocketDataType . NONE )
314+ self .send_msg ( flat . DisconnectSignal () )
368315 timeout = 5.0
369316 while self ._running and timeout > 0 :
370317 time .sleep (0.1 )
0 commit comments