|
26 | 26 | - submit moves to the server |
27 | 27 | - retrieve the game state |
28 | 28 | - passively observe a specific player |
29 | | -- restart a game without starting a new session |
| 29 | +- restart a game within the current session |
| 30 | +- enable TLS |
30 | 31 | """ |
31 | 32 |
|
32 | 33 | import json |
| 34 | +import os |
33 | 35 | import socket |
| 36 | +import ssl |
34 | 37 | import traceback |
35 | 38 |
|
36 | | -class GameServerError(Exception): |
37 | | - pass |
38 | | - |
39 | | -class IllegalMove(Exception): |
40 | | - pass |
| 39 | +class GameServerError(Exception): pass |
| 40 | +class IllegalMove(Exception): pass |
41 | 41 |
|
42 | 42 | class GameServerAPI: |
43 | 43 | """ |
@@ -103,6 +103,10 @@ def __init__(self, server, port, game, session='auto', players=None, name=''): |
103 | 103 | # tcp connections: |
104 | 104 | self._buffer_size = 4096 # bytes, corresponds to server-side buffer size value |
105 | 105 | self._request_size_max = int(1e6) # bytes, updated after joining a game |
| 106 | + |
| 107 | + # tls: |
| 108 | + self._tls_enabled = False |
| 109 | + self._tls_cert = None |
106 | 110 |
|
107 | 111 | def join(self): |
108 | 112 | """ |
@@ -300,6 +304,21 @@ def restart(self): |
300 | 304 |
|
301 | 305 | if err: raise GameServerError(err) |
302 | 306 |
|
| 307 | + def enable_tls(self, cert=''): |
| 308 | + """ |
| 309 | + Calling this function enables TLS encryption. By providing a |
| 310 | + certificate, authentication of the server is performed. |
| 311 | + |
| 312 | + The server must have TLS enabled. |
| 313 | +
|
| 314 | + Parameters: |
| 315 | + cert (str): certificate (optional) |
| 316 | + """ |
| 317 | + assert type(cert) == str, self._error('cert') |
| 318 | + |
| 319 | + self._tls_enabled = True |
| 320 | + self._tls_cert = self._abs_path(cert) |
| 321 | + |
303 | 322 | def _send(self, data): |
304 | 323 | """ |
305 | 324 | Send data to the server and receive its response. |
@@ -328,47 +347,94 @@ def _send(self, data): |
328 | 347 |
|
329 | 348 | # create a socket: |
330 | 349 | with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sd: |
331 | | - try: |
332 | | - # connect to server: |
333 | | - sd.settimeout(5) |
334 | | - sd.connect((self._server, self._port)) |
335 | | - sd.settimeout(None) # let server handle timeouts |
336 | | - except: |
337 | | - return self._api_error(f'unable to connect to {self._server}:{self._port}') |
338 | | - |
339 | | - try: |
340 | | - # send data to server: |
341 | | - sd.sendall(request) |
342 | | - |
343 | | - # receive data from server: |
344 | | - response = bytearray() |
345 | | - |
346 | | - while True: |
347 | | - data = sd.recv(self._buffer_size) |
348 | | - if not data: break |
349 | | - response += data |
350 | | - |
351 | | - if not response: raise self._NoResponse |
352 | | - response = json.loads(response.decode()) |
353 | | - |
354 | | - # return data: |
355 | | - if response['status'] != 'ok': # server responded with an error |
356 | | - return None, response['message'], response['status'] |
357 | | - |
358 | | - return response['data'], None, None |
359 | | - |
360 | | - except socket.timeout: |
361 | | - return self._api_error('connection timed out') |
362 | | - except self._NoResponse: |
363 | | - return self._api_error('empty or no response received from server') |
364 | | - except (ConnectionResetError, BrokenPipeError): |
365 | | - return self._api_error('connection closed by server') |
366 | | - except UnicodeDecodeError: |
367 | | - return self._api_error('could not decode binary data received from server') |
368 | | - except json.decoder.JSONDecodeError: |
369 | | - return self._api_error('corrupt json received from server') |
370 | | - except: |
371 | | - return self._api_error('unexpected exception:\n' + traceback.format_exc()) |
| 350 | + with self._secure_socket(sd) as sd: |
| 351 | + try: |
| 352 | + # connect to server: |
| 353 | + sd.settimeout(5) |
| 354 | + sd.connect((self._server, self._port)) |
| 355 | + sd.settimeout(None) # let server handle timeouts |
| 356 | + except IndexError: |
| 357 | + return self._api_error(f'unable to connect to {self._server}:{self._port}') |
| 358 | + |
| 359 | + try: |
| 360 | + # send data to server: |
| 361 | + sd.sendall(request) |
| 362 | + |
| 363 | + # receive data from server: |
| 364 | + response = bytearray() |
| 365 | + |
| 366 | + while True: |
| 367 | + data = sd.recv(self._buffer_size) |
| 368 | + if not data: break |
| 369 | + response += data |
| 370 | + |
| 371 | + if not response: raise self._NoResponse |
| 372 | + response = json.loads(response.decode()) |
| 373 | + |
| 374 | + # return data: |
| 375 | + if response['status'] != 'ok': # server responded with an error |
| 376 | + return None, response['message'], response['status'] |
| 377 | + |
| 378 | + return response['data'], None, None |
| 379 | + |
| 380 | + except socket.timeout: |
| 381 | + return self._api_error('connection timed out') |
| 382 | + except self._NoResponse: |
| 383 | + return self._api_error('empty or no response received from server') |
| 384 | + except (ConnectionResetError, BrokenPipeError): |
| 385 | + return self._api_error('connection closed by server') |
| 386 | + except UnicodeDecodeError: |
| 387 | + return self._api_error('could not decode binary data received from server') |
| 388 | + except json.decoder.JSONDecodeError: |
| 389 | + return self._api_error('corrupt json received from server') |
| 390 | + except: |
| 391 | + return self._api_error('unexpected exception:\n' + traceback.format_exc()) |
| 392 | + |
| 393 | + def _secure_socket(self, socket): |
| 394 | + """ |
| 395 | + This function wraps the socket and returns an SSL socket. TLS must be |
| 396 | + enabled by calling API function enable_tls. Otherwise, the passed socket |
| 397 | + is returned unmodified. If a certificate was passed to function |
| 398 | + enable_tls, authentication of the server is enabled. Without a |
| 399 | + certificate, TLS is used for encryption only. |
| 400 | +
|
| 401 | + Parameters: |
| 402 | + socket (socket): a regular socket |
| 403 | + |
| 404 | + Returns: |
| 405 | + socket or SSLSocket: an SSL socket, if TLS is enabled, a regular socket otherwise |
| 406 | + |
| 407 | + Raises: |
| 408 | + # TODO |
| 409 | + """ |
| 410 | + if self._tls_enabled: |
| 411 | + context = ssl.create_default_context() |
| 412 | + |
| 413 | + if self._tls_cert: |
| 414 | + context.load_verify_locations(self._tls_cert) |
| 415 | + else: |
| 416 | + context.check_hostname = False |
| 417 | + context.verify_mode = ssl.CERT_NONE |
| 418 | + |
| 419 | + return context.wrap_socket(socket, server_hostname=self._server) |
| 420 | + |
| 421 | + return socket |
| 422 | + |
| 423 | + def _abs_path(self, file_path): |
| 424 | + """ |
| 425 | + Always returns the absolute path to the file, regardless of where the |
| 426 | + file is located or from where the program was called. |
| 427 | +
|
| 428 | + Parameters: |
| 429 | + file_path (str): path to file, relative or absolute |
| 430 | +
|
| 431 | + Returns: |
| 432 | + str: absolute path to file |
| 433 | + """ |
| 434 | + if not file_path or os.path.isabs(file_path): |
| 435 | + return file_path |
| 436 | + |
| 437 | + return os.path.join(os.path.abspath(os.path.dirname(__file__)), file_path) |
372 | 438 |
|
373 | 439 | @staticmethod |
374 | 440 | def _api_error(message): |
|
0 commit comments