Skip to content

Commit e06fbcb

Browse files
authored
Merge pull request #83 from unwize/feature/websockets
Feature/websockets
2 parents 217beca + 94065f8 commit e06fbcb

3 files changed

Lines changed: 181 additions & 1 deletion

File tree

src/game/game_state_controller.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
are atomic).
88
"""
99

10+
import asyncio
1011
import dataclasses
1112

1213
from loguru import logger
@@ -42,6 +43,7 @@ class GameStateController:
4243
def __init__(self):
4344
self.state_device_stack: list[tuple[sd.StateDevice, StackState]] = []
4445
self.add_state_device(room.room_manager.get_room(cache.get_cache()["player_location"]))
46+
self.input_lock: asyncio.Lock = asyncio.Lock()
4547

4648
# Built-ins
4749

@@ -121,6 +123,23 @@ def deliver_input(self, user_input: any) -> bool:
121123

122124
return False
123125

126+
async def deliver_input_async(self, user_input: any) -> bool:
127+
"""
128+
Deliver the user's input to the top sd.StateDevice. Returns True if the
129+
device accepts the input.
130+
131+
Args:
132+
user_input: Input that the user delivers to the service via the API
133+
134+
Returns: True if the input is accepted, False otherwise.
135+
"""
136+
async with self.input_lock:
137+
if self._get_state_device().validate_input(user_input):
138+
self._get_state_device().input(user_input)
139+
return True
140+
141+
return False
142+
124143
def add_state_device(self, device: sd.StateDevice) -> None:
125144
"""
126145
Appends a sd.StateDevice to the top of the state_device_stack

src/main.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import game
22

3-
from fastapi import FastAPI
3+
from fastapi import FastAPI, WebSocket
44
from loguru import logger
55

66
from timeit import default_timer
@@ -27,6 +27,34 @@ def root_put(user_input: int | str):
2727
return r
2828

2929

30+
@tx_engine.websocket("/")
31+
async def websocket_endpoint(websocket: WebSocket):
32+
"""
33+
An interactive websocket endpoint. Used to communicate in real time with clients instead of in blocking-series.
34+
Functionally equivalent to get/put.
35+
36+
Args:
37+
websocket: The active websocket object used to manage the connection
38+
39+
Returns: None
40+
"""
41+
42+
await websocket.accept()
43+
while True:
44+
data = await websocket.receive_json()
45+
start = default_timer()
46+
r = game.state_device_controller.deliver_input(data)
47+
duration = default_timer() - start
48+
logger.info(f"Completed input submission in {duration}s")
49+
50+
start = default_timer()
51+
r = game.state_device_controller.get_current_frame()
52+
duration = default_timer() - start
53+
logger.info(f"Completed state retrieval in {duration}s")
54+
55+
await websocket.send_text(r.model_dump_json())
56+
57+
3058
@tx_engine.get("/cache")
3159
def cache(cache_path: str):
3260
from game.cache import get_cache

src/viewer/ws_viewer.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import json
2+
import os
3+
import asyncio
4+
from websockets.asyncio.client import connect
5+
6+
from loguru import logger
7+
from rich import print
8+
9+
10+
def formatting_to_tags(tags: list[str], opening_tag: bool = None, closing_tag: bool = None) -> str:
11+
buf = ""
12+
if opening_tag:
13+
for tag in tags:
14+
buf = buf + f"[{tag}]"
15+
16+
elif closing_tag:
17+
for tag in tags:
18+
buf = buf + f"[/{tag}]"
19+
20+
return buf
21+
22+
23+
def format_string(content: str, tags: list[str]) -> str:
24+
return formatting_to_tags(tags, opening_tag=True) + content + formatting_to_tags(tags, closing_tag=True)
25+
26+
27+
def parse_content(content: list) -> str:
28+
buf = ""
29+
for element in content:
30+
if type(element) is str:
31+
buf = buf + element
32+
elif type(element) is dict:
33+
buf = (
34+
buf
35+
+ formatting_to_tags(element["formatting"], opening_tag=True)
36+
+ element["value"]
37+
+ formatting_to_tags(element["formatting"], closing_tag=True)
38+
)
39+
return buf
40+
41+
42+
class WebsocketViewer:
43+
"""
44+
A primitive TXEngine client built for websocket connections.
45+
"""
46+
47+
def __init__(self):
48+
self._ip = input("Enter ip address (default: localhost)")
49+
if self._ip.strip() == "":
50+
self._ip = "localhost"
51+
52+
@classmethod
53+
def clear(cls):
54+
os.system("cls")
55+
56+
def get_text_header(self, tx_engine_response: dict) -> str:
57+
input_type = (
58+
tx_engine_response["input_type"]
59+
if type(tx_engine_response["input_type"]) is str
60+
else tx_engine_response["input_type"][0]
61+
)
62+
input_range = tx_engine_response["input_range"]
63+
64+
formatting = ["italic"]
65+
66+
if input_type == "int":
67+
hdr = f"Enter a number between ({input_range['min']} and {input_range['max']}):"
68+
69+
elif input_type == "none":
70+
hdr = "Press any key:"
71+
72+
elif input_type == "str":
73+
hdr = "Enter a string: "
74+
75+
elif input_type == "affirmative":
76+
hdr = "Enter y, n, yes, or no:"
77+
elif input_type == "any":
78+
hdr = "Press any key..."
79+
else:
80+
logger.error(f"Unexpected input type: {input_type}")
81+
logger.debug(f"Failed frame: {str(tx_engine_response)}")
82+
raise ValueError(f"Unexpected input type: {input_type}")
83+
84+
return format_string(hdr, formatting)
85+
86+
def display(self, tx_engine_response: dict):
87+
"""
88+
Primitively print GET results
89+
"""
90+
self.clear()
91+
92+
def entity_to_str(entity_dict: dict[str, any]) -> str:
93+
entity_name = entity_dict["name"]
94+
primary_resource_name = entity_dict["primary_resource_name"]
95+
primary_resource_value = entity_dict["primary_resource_val"]
96+
primary_resource_max = entity_dict["primary_resource_max"]
97+
return f"{entity_name}\n{primary_resource_name}]: [{primary_resource_value}/{primary_resource_max}]"
98+
99+
if "enemies" in tx_engine_response["components"]:
100+
print("ENEMIES")
101+
for enemy in tx_engine_response["components"]["enemies"]:
102+
print(entity_to_str(enemy))
103+
104+
if "allies" in tx_engine_response["components"]:
105+
print("ALLIES")
106+
for ally in tx_engine_response["components"]["allies"]:
107+
print(entity_to_str(ally))
108+
109+
print(parse_content(tx_engine_response["components"]["content"]))
110+
111+
if "options" in tx_engine_response["components"] and type(tx_engine_response["components"]["options"]) is list:
112+
for idx, opt in enumerate(tx_engine_response["components"]["options"]):
113+
print(f"[{idx}] {parse_content(opt)}")
114+
115+
print(self.get_text_header(tx_engine_response))
116+
117+
async def client(self) -> None:
118+
async with connect(f"ws://{self._ip}:8000") as websocket:
119+
await websocket.send("{}") # Ping to get a baseline response
120+
response = await websocket.recv()
121+
while True:
122+
self.clear()
123+
self.display(json.loads(response))
124+
user_input = input()
125+
if user_input.strip() == "":
126+
user_input = "{}"
127+
await websocket.send(user_input)
128+
response = await websocket.recv()
129+
130+
131+
if __name__ == "__main__":
132+
client = WebsocketViewer()
133+
asyncio.run(client.client())

0 commit comments

Comments
 (0)