Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/ahttpx/_parsers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
import io
import time
import typing

from ._streams import Stream
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(self, stream: Stream, mode: str) -> None:
self.stream = stream
self.parser = ReadAheadParser(stream)
self.mode = {'CLIENT': Mode.CLIENT, 'SERVER': Mode.SERVER}[mode]
self.keepalive_duration = 5.0

# Track state...
if self.mode == Mode.CLIENT:
Expand All @@ -107,6 +109,7 @@ def __init__(self, stream: Stream, mode: str) -> None:
# Track connection keep alive...
self.send_keep_alive = True
self.recv_keep_alive = True
self.keepalive_until: float | None = None

# Special states...
self.processing_1xx = False
Expand All @@ -119,6 +122,9 @@ async def send_method_line(self, method: bytes, target: bytes, protocol: bytes)

Sending state will switch to SEND_HEADERS state.
"""
# Scrub connection keepalive
self.keepalive_until = None

if self.send_state != State.SEND_METHOD_LINE:
msg = f"Called 'send_method_line' in invalid state {self.send_state}"
raise ProtocolError(msg)
Expand Down Expand Up @@ -244,6 +250,9 @@ async def recv_method_line(self) -> tuple[bytes, bytes, bytes]:

Receive state will switch to RECV_HEADERS.
"""
# Scrub connection keepalive
self.keepalive_until = None

if self.recv_state != State.RECV_METHOD_LINE:
msg = f"Called 'recv_method_line' in invalid state {self.recv_state}"
raise ProtocolError(msg)
Expand Down Expand Up @@ -409,6 +418,7 @@ async def complete(self):
self.send_keep_alive = True
self.recv_keep_alive = True
self.processing_1xx = False
self.keepalive_until = time.monotonic() + self.keepalive_duration

async def close(self):
if self.send_state != State.CLOSED:
Expand All @@ -425,6 +435,9 @@ def is_idle(self) -> bool:
def is_closed(self) -> bool:
return self.send_state == State.CLOSED

def keepalive_expired(self) -> bool:
return (self.keepalive_until is not None) and (time.monotonic() > self.keepalive_until)

def description(self) -> str:
return {
State.SEND_METHOD_LINE: "idle",
Expand All @@ -439,10 +452,9 @@ def __repr__(self) -> str:


class HTTPStream(Stream):
def __init__(self, parser: HTTPParser, callback: typing.Callable | None = None):
def __init__(self, parser: HTTPParser):
self._parser = parser
self._buffer = io.BytesIO()
self._callback = callback

async def read(self, size=-1) -> bytes:
sections = []
Expand Down Expand Up @@ -474,12 +486,8 @@ async def read(self, size=-1) -> bytes:
return output

async def close(self) -> None:
try:
self._buffer.close()
await self._parser.complete()
finally:
if self._callback is not None:
await self._callback()
self._buffer.close()
await self._parser.complete()


class ReadAheadParser:
Expand Down
16 changes: 5 additions & 11 deletions src/ahttpx/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ async def _get_connection(self, request: Request) -> "Connection":
# Attempt to reuse an existing connection.
url = request.url
origin = URL(scheme=url.scheme, host=url.host, port=url.port)
now = time.monotonic()
for conn in self._connections:
if conn.origin() == origin and conn.is_idle() and not conn.is_expired(now):
if conn.origin() == origin and conn.is_idle() and not conn.is_expired():
return conn

# Or else create a new connection.
Expand All @@ -102,7 +101,7 @@ async def _get_connection(self, request: Request) -> "Connection":
async def _cleanup(self) -> None:
now = time.monotonic()
for conn in list(self._connections):
if conn.is_expired(now):
if conn.is_expired():
await conn.close()
if conn.is_closed():
self._connections.remove(conn)
Expand Down Expand Up @@ -142,8 +141,6 @@ class Connection(Transport):
def __init__(self, stream: Stream, origin: URL | str):
self._stream = stream
self._origin = URL(origin) if not isinstance(origin, URL) else origin
self._keepalive_duration = 5.0
self._idle_expiry = time.monotonic() + self._keepalive_duration
self._request_lock = Lock()
self._parser = HTTPParser(stream, mode='CLIENT')

Expand All @@ -154,8 +151,8 @@ def origin(self) -> URL:
def is_idle(self) -> bool:
return self._parser.is_idle()

def is_expired(self, when: float) -> bool:
return self._parser.is_idle() and when > self._idle_expiry
def is_expired(self) -> bool:
return self._parser.is_idle() and self._parser.keepalive_expired()

def is_closed(self) -> bool:
return self._parser.is_closed()
Expand All @@ -170,7 +167,7 @@ async def send(self, request: Request) -> Response:
await self._send_head(request)
await self._send_body(request)
code, headers = await self._recv_head()
stream = HTTPStream(self._parser, callback=self._complete)
stream = HTTPStream(self._parser)
# TODO...
return Response(code, headers=headers, content=stream)
# finally:
Expand Down Expand Up @@ -233,9 +230,6 @@ async def _recv_body(self) -> bytes:
return await self._parser.recv_body()

# Request/response cycle complete...
async def _complete(self) -> None:
self._idle_expiry = time.monotonic() + self._keepalive_duration

async def _close(self) -> None:
await self._parser.close()

Expand Down
10 changes: 2 additions & 8 deletions src/ahttpx/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@ def __init__(self, stream, endpoint):
self._stream = stream
self._endpoint = endpoint
self._parser = HTTPParser(stream, mode='SERVER')
self._keepalive_duration = 5.0
self._idle_expiry = time.monotonic() + self._keepalive_duration

# API entry points...
async def handle_requests(self):
try:
while not await self._parser.recv_close():
while not (self._parser.keepalive_expired() or await self._parser.recv_close()):
method, url, headers = await self._recv_head()
stream = HTTPStream(self._parser, callback=self._complete)
stream = HTTPStream(self._parser)
# TODO: Handle endpoint exceptions
async with Request(method, url, headers=headers, content=stream) as request:
try:
Expand Down Expand Up @@ -82,10 +80,6 @@ async def _send_body(self, response: Response):
await self._parser.send_body(data)
await self._parser.send_body(b'')

# Start it all over again...
async def _complete(self):
self._idle_expiry = time.monotonic() + self._keepalive_duration


class HTTPServer:
def __init__(self, host, port):
Expand Down
24 changes: 16 additions & 8 deletions src/httpx/_parsers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import enum
import io
import time
import typing

from ._streams import Stream
Expand Down Expand Up @@ -89,6 +90,7 @@ def __init__(self, stream: Stream, mode: str) -> None:
self.stream = stream
self.parser = ReadAheadParser(stream)
self.mode = {'CLIENT': Mode.CLIENT, 'SERVER': Mode.SERVER}[mode]
self.keepalive_duration = 5.0

# Track state...
if self.mode == Mode.CLIENT:
Expand All @@ -107,6 +109,7 @@ def __init__(self, stream: Stream, mode: str) -> None:
# Track connection keep alive...
self.send_keep_alive = True
self.recv_keep_alive = True
self.keepalive_until: float | None = None

# Special states...
self.processing_1xx = False
Expand All @@ -119,6 +122,9 @@ def send_method_line(self, method: bytes, target: bytes, protocol: bytes) -> Non

Sending state will switch to SEND_HEADERS state.
"""
# Scrub connection keepalive
self.keepalive_until = None

if self.send_state != State.SEND_METHOD_LINE:
msg = f"Called 'send_method_line' in invalid state {self.send_state}"
raise ProtocolError(msg)
Expand Down Expand Up @@ -244,6 +250,9 @@ def recv_method_line(self) -> tuple[bytes, bytes, bytes]:

Receive state will switch to RECV_HEADERS.
"""
# Scrub connection keepalive
self.keepalive_until = None

if self.recv_state != State.RECV_METHOD_LINE:
msg = f"Called 'recv_method_line' in invalid state {self.recv_state}"
raise ProtocolError(msg)
Expand Down Expand Up @@ -409,6 +418,7 @@ def complete(self):
self.send_keep_alive = True
self.recv_keep_alive = True
self.processing_1xx = False
self.keepalive_until = time.monotonic() + self.keepalive_duration

def close(self):
if self.send_state != State.CLOSED:
Expand All @@ -425,6 +435,9 @@ def is_idle(self) -> bool:
def is_closed(self) -> bool:
return self.send_state == State.CLOSED

def keepalive_expired(self) -> bool:
return (self.keepalive_until is not None) and (time.monotonic() > self.keepalive_until)

def description(self) -> str:
return {
State.SEND_METHOD_LINE: "idle",
Expand All @@ -439,10 +452,9 @@ def __repr__(self) -> str:


class HTTPStream(Stream):
def __init__(self, parser: HTTPParser, callback: typing.Callable | None = None):
def __init__(self, parser: HTTPParser):
self._parser = parser
self._buffer = io.BytesIO()
self._callback = callback

def read(self, size=-1) -> bytes:
sections = []
Expand Down Expand Up @@ -474,12 +486,8 @@ def read(self, size=-1) -> bytes:
return output

def close(self) -> None:
try:
self._buffer.close()
self._parser.complete()
finally:
if self._callback is not None:
self._callback()
self._buffer.close()
self._parser.complete()


class ReadAheadParser:
Expand Down
16 changes: 5 additions & 11 deletions src/httpx/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ def _get_connection(self, request: Request) -> "Connection":
# Attempt to reuse an existing connection.
url = request.url
origin = URL(scheme=url.scheme, host=url.host, port=url.port)
now = time.monotonic()
for conn in self._connections:
if conn.origin() == origin and conn.is_idle() and not conn.is_expired(now):
if conn.origin() == origin and conn.is_idle() and not conn.is_expired():
return conn

# Or else create a new connection.
Expand All @@ -102,7 +101,7 @@ def _get_connection(self, request: Request) -> "Connection":
def _cleanup(self) -> None:
now = time.monotonic()
for conn in list(self._connections):
if conn.is_expired(now):
if conn.is_expired():
conn.close()
if conn.is_closed():
self._connections.remove(conn)
Expand Down Expand Up @@ -142,8 +141,6 @@ class Connection(Transport):
def __init__(self, stream: Stream, origin: URL | str):
self._stream = stream
self._origin = URL(origin) if not isinstance(origin, URL) else origin
self._keepalive_duration = 5.0
self._idle_expiry = time.monotonic() + self._keepalive_duration
self._request_lock = Lock()
self._parser = HTTPParser(stream, mode='CLIENT')

Expand All @@ -154,8 +151,8 @@ def origin(self) -> URL:
def is_idle(self) -> bool:
return self._parser.is_idle()

def is_expired(self, when: float) -> bool:
return self._parser.is_idle() and when > self._idle_expiry
def is_expired(self) -> bool:
return self._parser.is_idle() and self._parser.keepalive_expired()

def is_closed(self) -> bool:
return self._parser.is_closed()
Expand All @@ -170,7 +167,7 @@ def send(self, request: Request) -> Response:
self._send_head(request)
self._send_body(request)
code, headers = self._recv_head()
stream = HTTPStream(self._parser, callback=self._complete)
stream = HTTPStream(self._parser)
# TODO...
return Response(code, headers=headers, content=stream)
# finally:
Expand Down Expand Up @@ -233,9 +230,6 @@ def _recv_body(self) -> bytes:
return self._parser.recv_body()

# Request/response cycle complete...
def _complete(self) -> None:
self._idle_expiry = time.monotonic() + self._keepalive_duration

def _close(self) -> None:
self._parser.close()

Expand Down
10 changes: 2 additions & 8 deletions src/httpx/_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,13 @@ def __init__(self, stream, endpoint):
self._stream = stream
self._endpoint = endpoint
self._parser = HTTPParser(stream, mode='SERVER')
self._keepalive_duration = 5.0
self._idle_expiry = time.monotonic() + self._keepalive_duration

# API entry points...
def handle_requests(self):
try:
while not self._parser.recv_close():
while not (self._parser.keepalive_expired() or self._parser.recv_close()):
method, url, headers = self._recv_head()
stream = HTTPStream(self._parser, callback=self._complete)
stream = HTTPStream(self._parser)
# TODO: Handle endpoint exceptions
with Request(method, url, headers=headers, content=stream) as request:
try:
Expand Down Expand Up @@ -82,10 +80,6 @@ def _send_body(self, response: Response):
self._parser.send_body(data)
self._parser.send_body(b'')

# Start it all over again...
def _complete(self):
self._idle_expiry = time.monotonic() + self._keepalive_duration


class HTTPServer:
def __init__(self, host, port):
Expand Down
Loading