diff --git a/scripts/unasync b/scripts/unasync index 8ae754e..0210b4e 100755 --- a/scripts/unasync +++ b/scripts/unasync @@ -11,6 +11,7 @@ unasync.unasync_files( "src/ahttpx/_quickstart.py", "src/ahttpx/_response.py", "src/ahttpx/_request.py", + "src/ahttpx/_server.py", "src/ahttpx/_streams.py", "src/ahttpx/_urlencode.py", "src/ahttpx/_urlparse.py", diff --git a/src/ahttpx/__init__.py b/src/ahttpx/__init__.py index 532cf2c..f5bfe85 100644 --- a/src/ahttpx/__init__.py +++ b/src/ahttpx/__init__.py @@ -7,7 +7,7 @@ from ._response import * # Response from ._request import * # Request from ._streams import * # ByteStream, IterByteStream, FileStream, Stream -from ._server import * # serve_http, serve_tcp +from ._server import * # serve_http from ._urlencode import * # quote, unquote, urldecode, urlencode from ._urls import * # QueryParams, URL @@ -38,7 +38,6 @@ "Response", "Request", "serve_http", - "serve_tcp", "Stream", "Text", "timeout", diff --git a/src/ahttpx/_client.py b/src/ahttpx/_client.py index a7c1615..5f61c4f 100644 --- a/src/ahttpx/_client.py +++ b/src/ahttpx/_client.py @@ -10,7 +10,7 @@ from ._streams import Stream from ._urls import URL -__all__ = ["Client", "Content"] +__all__ = ["Client"] class Client: diff --git a/src/ahttpx/_network.py b/src/ahttpx/_network.py index 691c6a8..10eb864 100644 --- a/src/ahttpx/_network.py +++ b/src/ahttpx/_network.py @@ -1,7 +1,7 @@ import asyncio import ssl import types - +import typing __all__ = ["NetworkBackend", "NetworkStream", "timeout"] @@ -59,6 +59,26 @@ async def __aexit__( await self.close() +class NetworkServer: + def __init__(self, host: str, port: int, server: asyncio.Server): + self.host = host + self.port = port + self._server = server + + # Context managed usage... + async def __aenter__(self) -> "NetworkServer": + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ): + self._server.close() + await self._server.wait_closed() + + class NetworkBackend: async def connect(self, host: str, port: int) -> NetworkStream: """ @@ -67,7 +87,16 @@ async def connect(self, host: str, port: int) -> NetworkStream: reader, writer = await asyncio.open_connection(host, port) return NetworkStream(reader, writer, address=f"{host}:{port}") + async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer: + async def callback(reader, writer): + stream = NetworkStream(reader, writer) + await handler(stream) + + server = await asyncio.start_server(callback, host, port) + return NetworkServer(host, port, server) + Semaphore = asyncio.Semaphore Lock = asyncio.Lock timeout = asyncio.timeout +sleep = asyncio.sleep diff --git a/src/ahttpx/_server.py b/src/ahttpx/_server.py index a612667..edafaef 100644 --- a/src/ahttpx/_server.py +++ b/src/ahttpx/_server.py @@ -1,14 +1,153 @@ -# TODO... +import contextlib +import logging +import time + +import h11 + +from ._content import Text +from ._request import Request +from ._response import Response +from ._network import NetworkBackend, sleep +from ._streams import IterByteStream __all__ = [ "serve_http", - "serve_tcp", ] +logger = logging.getLogger("httpx.server") + -async def serve_http(): +class ConnectionClosed(Exception): pass -async def serve_tcp(): - pass +class HTTPConnection: + def __init__(self, stream, endpoint): + self._stream = stream + self._endpoint = endpoint + self._state = h11.Connection(our_role=h11.SERVER) + self._keepalive_duration = 5.0 + self._idle_expiry = time.monotonic() + self._keepalive_duration + + # API entry points... + async def handle_requests(self): + try: + method, url, headers = await self._recv_head() + stream = IterByteStream(self._recv_body()) + # TODO: Handle endpoint exceptions + try: + request = Request(method, url, headers=headers, content=stream) + response = await self._endpoint(request) + except Exception as exc: + logger.error("Internal Server Error", exc_info=True) + content = Text("Internal Server Error") + response = Response(code=500, content=content) + await self._send_head(response) + await self._send_body(response) + else: + try: + await self._send_head(response) + await self._send_body(response) + except Exception as exc: + logger.error("Internal Server Error", exc_info=True) + finally: + status_line = f"{request.method} {request.url.target} [{response.code} {response.reason_phrase}]" + logger.info(status_line) + except ConnectionClosed: + pass + finally: + await self._cycle_complete() + + async def close(self): + if self._state.our_state in (h11.DONE, h11.IDLE, h11.MUST_CLOSE): + event = h11.ConnectionClosed() + self._state.send(event) + + await self._stream.close() + + # Receive the request... + async def _recv_head(self) -> tuple[str, str, list[tuple[str, str]]]: + while True: + event = await self._recv_event() + if isinstance(event, h11.Request): + method = event.method.decode('ascii') + target = event.target.decode('ascii') + headers = [ + (k.decode('latin-1'), v.decode('latin-1')) + for k, v in event.headers.raw_items() + ] + return (method, target, headers) + elif isinstance(event, h11.ConnectionClosed): + raise ConnectionClosed() + + async def _recv_body(self): + while True: + event = await self._recv_event() + if isinstance(event, h11.Data): + yield bytes(event.data) + elif isinstance(event, (h11.EndOfMessage, h11.PAUSED)): + break + + async def _recv_event(self) -> h11.Event | type[h11.PAUSED]: + while True: + event = self._state.next_event() + + if event is h11.NEED_DATA: + data = await self._stream.read() + self._state.receive_data(data) + else: + return event # type: ignore[return-value] + + # Return the response... + async def _send_head(self, response: Response): + event = h11.Response( + status_code=response.code, + headers=list(response.headers.items()) + ) + await self._send_event(event) + + async def _send_body(self, response: Response): + async for data in response.stream: + await self._send_event(h11.Data(data=data)) + await self._send_event(h11.EndOfMessage()) + + async def _send_event(self, event: h11.Event) -> None: + data = self._state.send(event) + if data is not None: + await self._stream.write(data) + + # Start it all over again... + async def _cycle_complete(self): + if self._state.our_state is h11.DONE and self._state.their_state is h11.DONE: + await self._state.start_next_cycle() + self._idle_expiry = time.monotonic() + self._keepalive_duration + else: + await self.close() + + +class HTTPServer: + def __init__(self, host, port): + self.url = f"http://{host}:{port}/" + + async def wait(self): + while(True): + await sleep(1) + + +@contextlib.asynccontextmanager +async def serve_http(endpoint): + async def handler(stream): + connection = HTTPConnection(stream, endpoint) + await connection.handle_requests() + + logging.basicConfig( + format="%(levelname)s [%(asctime)s] %(name)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.DEBUG + ) + + backend = NetworkBackend() + async with await backend.serve("127.0.0.1", 8080, handler) as server: + server = HTTPServer(server.host, server.port) + logger.info(f"Serving on {server.url}") + yield server diff --git a/src/httpx/__init__.py b/src/httpx/__init__.py index 532cf2c..f5bfe85 100644 --- a/src/httpx/__init__.py +++ b/src/httpx/__init__.py @@ -7,7 +7,7 @@ from ._response import * # Response from ._request import * # Request from ._streams import * # ByteStream, IterByteStream, FileStream, Stream -from ._server import * # serve_http, serve_tcp +from ._server import * # serve_http from ._urlencode import * # quote, unquote, urldecode, urlencode from ._urls import * # QueryParams, URL @@ -38,7 +38,6 @@ "Response", "Request", "serve_http", - "serve_tcp", "Stream", "Text", "timeout", diff --git a/src/httpx/_client.py b/src/httpx/_client.py index d808ff1..7dcfd8d 100644 --- a/src/httpx/_client.py +++ b/src/httpx/_client.py @@ -10,7 +10,7 @@ from ._streams import Stream from ._urls import URL -__all__ = ["Client", "Content"] +__all__ = ["Client"] class Client: diff --git a/src/httpx/_network.py b/src/httpx/_network.py index 2e242cd..59abb44 100644 --- a/src/httpx/_network.py +++ b/src/httpx/_network.py @@ -1,5 +1,7 @@ +import concurrent.futures import contextlib import contextvars +import select import socket import ssl import threading @@ -44,14 +46,22 @@ def get_current_timeout() -> float | None: class NetworkStream: - def __init__(self, sock: socket.socket) -> None: + def __init__(self, sock: socket.socket, address: tuple[str, int]) -> None: peername = sock.getpeername() self._socket = sock - self._address = f"{peername[0]}:{peername[1]}" + self._address = address self._is_tls = False self._is_closed = False + @property + def host(self) -> str: + return self._address[0] + + @property + def port(self) -> int: + return self._address[1] + def read(self, max_bytes: int = 64 * 1024) -> bytes: timeout = get_current_timeout() self._socket.settimeout(timeout) @@ -79,7 +89,7 @@ def __repr__(self): description = "" description += " TLS" if self._is_tls else "" description += " CLOSED" if self._is_closed else "" - return f"" + return f"" def __del__(self): if not self._is_closed: @@ -98,6 +108,90 @@ def __exit__( self.close() +class NetworkListener: + def __init__(self, sock: socket.socket, address: tuple[str, int]) -> None: + self._server_socket = sock + self._address = address + self._is_closed = False + + @property + def host(self): + return self._address[0] + + @property + def port(self): + return self._address[1] + + def accept(self) -> NetworkStream | None: + """ + Blocks until an incoming connection is accepted, and returns the NetworkStream. + Stops blocking and returns `None` once the listener is closed. + """ + while not self._is_closed: + r, _, _ = select.select([self._server_socket], [], [], 3) + if r: + sock, address = self._server_socket.accept() + return NetworkStream(sock, address) + return None + + def close(self): + self._is_closed = True + self._server_socket.close() + + def __del__(self): + if not self._is_closed: + import warnings + warnings.warn("NetworkListener was garbage collected without being closed.") + + def __enter__(self) -> "NetworkListener": + return self + + def __exit__( + self, + exc_type: type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: types.TracebackType | None = None, + ): + self.close() + + +class NetworkServer: + def __init__(self, listener: NetworkListener, handler: typing.Callable[[NetworkStream], None]) -> None: + self.listener = listener + self.handler = handler + self._max_workers = 5 + self._executor = None + self._thread = None + self._streams = list[NetworkStream] + + @property + def host(self): + return self.listener.host + + @property + def port(self): + return self.listener.port + + def __enter__(self): + self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) + self._executor.submit(self._serve) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.listener.close() + self._executor.shutdown(wait=True) + + def _serve(self): + while stream := self.listener.accept(): + self._executor.submit(self._handler, stream) + + def _handler(self, stream): + try: + self.handler(stream) + finally: + stream.close() + + class NetworkBackend: def connect(self, host: str, port: int) -> NetworkStream: """ @@ -106,7 +200,23 @@ def connect(self, host: str, port: int) -> NetworkStream: address = (host, port) timeout = get_current_timeout() sock = socket.create_connection(address, timeout=timeout) - return NetworkStream(sock) + return NetworkStream(sock, address) + + def listen(self, host: str, port: int) -> NetworkListener: + """ + List on the given address, returning a NetworkListener instance. + """ + address = (host, port) + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(address) + sock.listen(5) + sock.setblocking(False) + return NetworkListener(sock, address) + + def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer: + listener = self.listen(host, port) + return NetworkServer(listener, handler) def __repr__(self): return "" @@ -114,3 +224,4 @@ def __repr__(self): Semaphore = threading.Semaphore Lock = threading.Lock +sleep = time.sleep diff --git a/src/httpx/_server.py b/src/httpx/_server.py index 0d4266e..543829e 100644 --- a/src/httpx/_server.py +++ b/src/httpx/_server.py @@ -1,19 +1,17 @@ -import concurrent.futures import contextlib import logging -import select -import socket -import threading import time import h11 -import httpx +from ._content import Text +from ._request import Request +from ._response import Response +from ._network import NetworkBackend, sleep from ._streams import IterByteStream __all__ = [ "serve_http", - "serve_tcp", ] logger = logging.getLogger("httpx.server") @@ -38,12 +36,12 @@ def handle_requests(self): stream = IterByteStream(self._recv_body()) # TODO: Handle endpoint exceptions try: - request = httpx.Request(method, url, headers=headers, content=stream) + request = Request(method, url, headers=headers, content=stream) response = self._endpoint(request) except Exception as exc: logger.error("Internal Server Error", exc_info=True) - content = httpx.Text("Internal Server Error") - response = httpx.Response(code=500, content=content) + content = Text("Internal Server Error") + response = Response(code=500, content=content) self._send_head(response) self._send_body(response) else: @@ -101,14 +99,14 @@ def _recv_event(self) -> h11.Event | type[h11.PAUSED]: return event # type: ignore[return-value] # Return the response... - def _send_head(self, response: httpx.Response): + def _send_head(self, response: Response): event = h11.Response( status_code=response.code, headers=list(response.headers.items()) ) self._send_event(event) - def _send_body(self, response: httpx.Response): + def _send_body(self, response: Response): for data in response.stream: self._send_event(h11.Data(data=data)) self._send_event(h11.EndOfMessage()) @@ -133,59 +131,7 @@ def __init__(self, host, port): def wait(self): while(True): - time.sleep(1) - - -class TCPServer: - def __init__(self, handler, host: str = "127.0.0.1", port: int = 8080): - self.handler = handler - self.host = host - self.port = port - self._max_workers = 5 - self._server_socket = None - self._client_sockets: list[socket.socket] = [] - self._executor = None - self._thread = None - self._shutdown = threading.Event() - - def __enter__(self): - self._server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self._server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - self._server_socket.bind((self.host, self.port)) - self._server_socket.listen(5) - self._server_socket.setblocking(False) - - self._executor = concurrent.futures.ThreadPoolExecutor(max_workers=self._max_workers) - self._thread = threading.Thread(target=self._serve_loop, daemon=True) - self._thread.start() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self._shutdown.set() - self._thread.join() - self._server_socket.close() - for client_socket in list(self._client_sockets): - client_socket.close() - self._executor.shutdown(wait=True) - - def _serve_loop(self): - while not self._shutdown.is_set(): - readable, _, _ = select.select([self._server_socket], [], [], 0.1) - if readable: - try: - client_socket, _ = self._server_socket.accept() - self._executor.submit(self._handler, client_socket) - except socket.error as e: - pass - - def _handler(self, socket): - self._client_sockets.append(socket) - try: - stream = httpx.NetworkStream(socket) - self.handler(stream) - finally: - self._client_sockets.remove(socket) - stream.close() + sleep(1) @contextlib.contextmanager @@ -200,13 +146,8 @@ def handler(stream): level=logging.DEBUG ) - with TCPServer(handler) as server: + backend = NetworkBackend() + with backend.serve("127.0.0.1", 8080, handler) as server: server = HTTPServer(server.host, server.port) logger.info(f"Serving on {server.url}") yield server - - -@contextlib.contextmanager -def serve_tcp(handler): - with TCPServer(handler) as server: - yield server diff --git a/tests/test_network.py b/tests/test_network.py index 09fad74..e8236c5 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -9,7 +9,8 @@ def echo(stream): @pytest.fixture def server(): - with httpx.serve_tcp(echo) as server: + net = httpx.NetworkBackend() + with net.serve("127.0.0.1", 8080, echo) as server: yield server @@ -22,7 +23,7 @@ def test_network_backend_connect(server): net = httpx.NetworkBackend() stream = net.connect(server.host, server.port) try: - assert repr(stream) == f"" + assert repr(stream) == f"" stream.write(b"Hello, world.") content = stream.read() assert content == b"Hello, world." @@ -36,7 +37,7 @@ def test_network_backend_context_managed(server): stream.write(b"Hello, world.") content = stream.read() assert content == b"Hello, world." - assert repr(stream) == f"" + assert repr(stream) == f"" def test_network_backend_timeout(server):