diff --git a/src/ahttpx/_parsers.py b/src/ahttpx/_parsers.py index 6ac2c33..182df44 100644 --- a/src/ahttpx/_parsers.py +++ b/src/ahttpx/_parsers.py @@ -1,5 +1,6 @@ import enum import io +import time import typing from ._streams import Stream @@ -89,6 +90,7 @@ def __init__(self, stream: Stream, mode: str) -> None: self.stream = stream self.parser = ReadAheadParser(stream) self.mode = {'CLIENT': Mode.CLIENT, 'SERVER': Mode.SERVER}[mode] + self.keepalive_duration = 5.0 # Track state... if self.mode == Mode.CLIENT: @@ -107,6 +109,7 @@ def __init__(self, stream: Stream, mode: str) -> None: # Track connection keep alive... self.send_keep_alive = True self.recv_keep_alive = True + self.keepalive_until: float | None = None # Special states... self.processing_1xx = False @@ -119,6 +122,9 @@ async def send_method_line(self, method: bytes, target: bytes, protocol: bytes) Sending state will switch to SEND_HEADERS state. """ + # Scrub connection keepalive + self.keepalive_until = None + if self.send_state != State.SEND_METHOD_LINE: msg = f"Called 'send_method_line' in invalid state {self.send_state}" raise ProtocolError(msg) @@ -244,6 +250,9 @@ async def recv_method_line(self) -> tuple[bytes, bytes, bytes]: Receive state will switch to RECV_HEADERS. """ + # Scrub connection keepalive + self.keepalive_until = None + if self.recv_state != State.RECV_METHOD_LINE: msg = f"Called 'recv_method_line' in invalid state {self.recv_state}" raise ProtocolError(msg) @@ -409,6 +418,7 @@ async def complete(self): self.send_keep_alive = True self.recv_keep_alive = True self.processing_1xx = False + self.keepalive_until = time.monotonic() + self.keepalive_duration async def close(self): if self.send_state != State.CLOSED: @@ -425,6 +435,9 @@ def is_idle(self) -> bool: def is_closed(self) -> bool: return self.send_state == State.CLOSED + def keepalive_expired(self) -> bool: + return (self.keepalive_until is not None) and (time.monotonic() > self.keepalive_until) + def description(self) -> str: return { State.SEND_METHOD_LINE: "idle", @@ -439,10 +452,9 @@ def __repr__(self) -> str: class HTTPStream(Stream): - def __init__(self, parser: HTTPParser, callback: typing.Callable | None = None): + def __init__(self, parser: HTTPParser): self._parser = parser self._buffer = io.BytesIO() - self._callback = callback async def read(self, size=-1) -> bytes: sections = [] @@ -474,12 +486,8 @@ async def read(self, size=-1) -> bytes: return output async def close(self) -> None: - try: - self._buffer.close() - await self._parser.complete() - finally: - if self._callback is not None: - await self._callback() + self._buffer.close() + await self._parser.complete() class ReadAheadParser: diff --git a/src/ahttpx/_pool.py b/src/ahttpx/_pool.py index c4838da..911e76f 100644 --- a/src/ahttpx/_pool.py +++ b/src/ahttpx/_pool.py @@ -84,9 +84,8 @@ async def _get_connection(self, request: Request) -> "Connection": # Attempt to reuse an existing connection. url = request.url origin = URL(scheme=url.scheme, host=url.host, port=url.port) - now = time.monotonic() for conn in self._connections: - if conn.origin() == origin and conn.is_idle() and not conn.is_expired(now): + if conn.origin() == origin and conn.is_idle() and not conn.is_expired(): return conn # Or else create a new connection. @@ -102,7 +101,7 @@ async def _get_connection(self, request: Request) -> "Connection": async def _cleanup(self) -> None: now = time.monotonic() for conn in list(self._connections): - if conn.is_expired(now): + if conn.is_expired(): await conn.close() if conn.is_closed(): self._connections.remove(conn) @@ -142,8 +141,6 @@ class Connection(Transport): def __init__(self, stream: Stream, origin: URL | str): self._stream = stream self._origin = URL(origin) if not isinstance(origin, URL) else origin - self._keepalive_duration = 5.0 - self._idle_expiry = time.monotonic() + self._keepalive_duration self._request_lock = Lock() self._parser = HTTPParser(stream, mode='CLIENT') @@ -154,8 +151,8 @@ def origin(self) -> URL: def is_idle(self) -> bool: return self._parser.is_idle() - def is_expired(self, when: float) -> bool: - return self._parser.is_idle() and when > self._idle_expiry + def is_expired(self) -> bool: + return self._parser.is_idle() and self._parser.keepalive_expired() def is_closed(self) -> bool: return self._parser.is_closed() @@ -170,7 +167,7 @@ async def send(self, request: Request) -> Response: await self._send_head(request) await self._send_body(request) code, headers = await self._recv_head() - stream = HTTPStream(self._parser, callback=self._complete) + stream = HTTPStream(self._parser) # TODO... return Response(code, headers=headers, content=stream) # finally: @@ -233,9 +230,6 @@ async def _recv_body(self) -> bytes: return await self._parser.recv_body() # Request/response cycle complete... - async def _complete(self) -> None: - self._idle_expiry = time.monotonic() + self._keepalive_duration - async def _close(self) -> None: await self._parser.close() diff --git a/src/ahttpx/_server.py b/src/ahttpx/_server.py index 577d001..d30416e 100644 --- a/src/ahttpx/_server.py +++ b/src/ahttpx/_server.py @@ -24,15 +24,13 @@ def __init__(self, stream, endpoint): self._stream = stream self._endpoint = endpoint self._parser = HTTPParser(stream, mode='SERVER') - self._keepalive_duration = 5.0 - self._idle_expiry = time.monotonic() + self._keepalive_duration # API entry points... async def handle_requests(self): try: - while not await self._parser.recv_close(): + while not (self._parser.keepalive_expired() or await self._parser.recv_close()): method, url, headers = await self._recv_head() - stream = HTTPStream(self._parser, callback=self._complete) + stream = HTTPStream(self._parser) # TODO: Handle endpoint exceptions async with Request(method, url, headers=headers, content=stream) as request: try: @@ -82,10 +80,6 @@ async def _send_body(self, response: Response): await self._parser.send_body(data) await self._parser.send_body(b'') - # Start it all over again... - async def _complete(self): - self._idle_expiry = time.monotonic() + self._keepalive_duration - class HTTPServer: def __init__(self, host, port): diff --git a/src/httpx/_parsers.py b/src/httpx/_parsers.py index 415bfef..2c7f1af 100644 --- a/src/httpx/_parsers.py +++ b/src/httpx/_parsers.py @@ -1,5 +1,6 @@ import enum import io +import time import typing from ._streams import Stream @@ -89,6 +90,7 @@ def __init__(self, stream: Stream, mode: str) -> None: self.stream = stream self.parser = ReadAheadParser(stream) self.mode = {'CLIENT': Mode.CLIENT, 'SERVER': Mode.SERVER}[mode] + self.keepalive_duration = 5.0 # Track state... if self.mode == Mode.CLIENT: @@ -107,6 +109,7 @@ def __init__(self, stream: Stream, mode: str) -> None: # Track connection keep alive... self.send_keep_alive = True self.recv_keep_alive = True + self.keepalive_until: float | None = None # Special states... self.processing_1xx = False @@ -119,6 +122,9 @@ def send_method_line(self, method: bytes, target: bytes, protocol: bytes) -> Non Sending state will switch to SEND_HEADERS state. """ + # Scrub connection keepalive + self.keepalive_until = None + if self.send_state != State.SEND_METHOD_LINE: msg = f"Called 'send_method_line' in invalid state {self.send_state}" raise ProtocolError(msg) @@ -244,6 +250,9 @@ def recv_method_line(self) -> tuple[bytes, bytes, bytes]: Receive state will switch to RECV_HEADERS. """ + # Scrub connection keepalive + self.keepalive_until = None + if self.recv_state != State.RECV_METHOD_LINE: msg = f"Called 'recv_method_line' in invalid state {self.recv_state}" raise ProtocolError(msg) @@ -409,6 +418,7 @@ def complete(self): self.send_keep_alive = True self.recv_keep_alive = True self.processing_1xx = False + self.keepalive_until = time.monotonic() + self.keepalive_duration def close(self): if self.send_state != State.CLOSED: @@ -425,6 +435,9 @@ def is_idle(self) -> bool: def is_closed(self) -> bool: return self.send_state == State.CLOSED + def keepalive_expired(self) -> bool: + return (self.keepalive_until is not None) and (time.monotonic() > self.keepalive_until) + def description(self) -> str: return { State.SEND_METHOD_LINE: "idle", @@ -439,10 +452,9 @@ def __repr__(self) -> str: class HTTPStream(Stream): - def __init__(self, parser: HTTPParser, callback: typing.Callable | None = None): + def __init__(self, parser: HTTPParser): self._parser = parser self._buffer = io.BytesIO() - self._callback = callback def read(self, size=-1) -> bytes: sections = [] @@ -474,12 +486,8 @@ def read(self, size=-1) -> bytes: return output def close(self) -> None: - try: - self._buffer.close() - self._parser.complete() - finally: - if self._callback is not None: - self._callback() + self._buffer.close() + self._parser.complete() class ReadAheadParser: diff --git a/src/httpx/_pool.py b/src/httpx/_pool.py index 8959159..796c99b 100644 --- a/src/httpx/_pool.py +++ b/src/httpx/_pool.py @@ -84,9 +84,8 @@ def _get_connection(self, request: Request) -> "Connection": # Attempt to reuse an existing connection. url = request.url origin = URL(scheme=url.scheme, host=url.host, port=url.port) - now = time.monotonic() for conn in self._connections: - if conn.origin() == origin and conn.is_idle() and not conn.is_expired(now): + if conn.origin() == origin and conn.is_idle() and not conn.is_expired(): return conn # Or else create a new connection. @@ -102,7 +101,7 @@ def _get_connection(self, request: Request) -> "Connection": def _cleanup(self) -> None: now = time.monotonic() for conn in list(self._connections): - if conn.is_expired(now): + if conn.is_expired(): conn.close() if conn.is_closed(): self._connections.remove(conn) @@ -142,8 +141,6 @@ class Connection(Transport): def __init__(self, stream: Stream, origin: URL | str): self._stream = stream self._origin = URL(origin) if not isinstance(origin, URL) else origin - self._keepalive_duration = 5.0 - self._idle_expiry = time.monotonic() + self._keepalive_duration self._request_lock = Lock() self._parser = HTTPParser(stream, mode='CLIENT') @@ -154,8 +151,8 @@ def origin(self) -> URL: def is_idle(self) -> bool: return self._parser.is_idle() - def is_expired(self, when: float) -> bool: - return self._parser.is_idle() and when > self._idle_expiry + def is_expired(self) -> bool: + return self._parser.is_idle() and self._parser.keepalive_expired() def is_closed(self) -> bool: return self._parser.is_closed() @@ -170,7 +167,7 @@ def send(self, request: Request) -> Response: self._send_head(request) self._send_body(request) code, headers = self._recv_head() - stream = HTTPStream(self._parser, callback=self._complete) + stream = HTTPStream(self._parser) # TODO... return Response(code, headers=headers, content=stream) # finally: @@ -233,9 +230,6 @@ def _recv_body(self) -> bytes: return self._parser.recv_body() # Request/response cycle complete... - def _complete(self) -> None: - self._idle_expiry = time.monotonic() + self._keepalive_duration - def _close(self) -> None: self._parser.close() diff --git a/src/httpx/_server.py b/src/httpx/_server.py index 4f1ca3a..7ed426b 100644 --- a/src/httpx/_server.py +++ b/src/httpx/_server.py @@ -24,15 +24,13 @@ def __init__(self, stream, endpoint): self._stream = stream self._endpoint = endpoint self._parser = HTTPParser(stream, mode='SERVER') - self._keepalive_duration = 5.0 - self._idle_expiry = time.monotonic() + self._keepalive_duration # API entry points... def handle_requests(self): try: - while not self._parser.recv_close(): + while not (self._parser.keepalive_expired() or self._parser.recv_close()): method, url, headers = self._recv_head() - stream = HTTPStream(self._parser, callback=self._complete) + stream = HTTPStream(self._parser) # TODO: Handle endpoint exceptions with Request(method, url, headers=headers, content=stream) as request: try: @@ -82,10 +80,6 @@ def _send_body(self, response: Response): self._parser.send_body(data) self._parser.send_body(b'') - # Start it all over again... - def _complete(self): - self._idle_expiry = time.monotonic() + self._keepalive_duration - class HTTPServer: def __init__(self, host, port):