Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
]
dependencies = [
"h11==0.*",
"truststore==0.10",
"certifi",
]
dynamic = ["version"]

Expand Down
21 changes: 17 additions & 4 deletions src/ahttpx/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
import types
import typing

import certifi

__all__ = ["NetworkBackend", "NetworkStream", "timeout"]


class NetworkStream:
def __init__(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, address: str = ''
) -> None:
peername = writer.get_extra_info('peername')

self._reader = reader
self._writer = writer
self._address = f"{peername[0]}:{peername[1]}"
self._ = address
self._tls = False
self._closed = False

Expand Down Expand Up @@ -81,12 +81,25 @@ async def __aexit__(


class NetworkBackend:
def __init__(self):
self._ssl_context = ssl.create_default_context(cafile=certifi.where())

async def connect(self, host: str, port: int) -> NetworkStream:
"""
Connect to the given address, returning a Stream instance.
"""
address = f"{host}:{port}"
reader, writer = await asyncio.open_connection(host, port)
return NetworkStream(reader, writer, address=address)

async def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream:
"""
Connect to the given address, returning a Stream instance.
"""
address = f"{host}:{port}"
reader, writer = await asyncio.open_connection(host, port)
return NetworkStream(reader, writer, address=f"{host}:{port}")
await writer.start_tls(self._ssl_context, server_hostname=hostname)
return NetworkStream(reader, writer, address=address)

async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer:
async def callback(reader, writer):
Expand Down
18 changes: 4 additions & 14 deletions src/ahttpx/_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ssl
import time
import typing
import types
Expand Down Expand Up @@ -54,15 +53,11 @@ async def stream(


class ConnectionPool(Transport):
def __init__(self, ssl_context: ssl.SSLContext | None = None, backend: NetworkBackend | None = None):
if ssl_context is None:
import truststore
ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
def __init__(self, backend: NetworkBackend | None = None):
if backend is None:
backend = NetworkBackend()

self._connections: list[Connection] = []
self._ssl_context = ssl_context
self._network_backend = backend
self._limit_concurrency = Semaphore(100)
self._closed = False
Expand Down Expand Up @@ -99,7 +94,6 @@ async def _get_connection(self, request: Request) -> "Connection":
conn = await open_connection(
origin,
hostname=request.headers["Host"],
ssl_context=self._ssl_context,
backend=self._network_backend
)
self._connections.append(conn)
Expand Down Expand Up @@ -302,7 +296,6 @@ async def __aexit__(
async def open_connection(
url: URL | str,
hostname: str = '',
ssl_context: ssl.SSLContext | None = None,
backend: NetworkBackend | None = None,
) -> Connection:

Expand All @@ -316,13 +309,10 @@ async def open_connection(

host = url.host
port = url.port or {"http": 80, "https": 443}[url.scheme]
hostname = hostname or url.host

stream = await backend.connect(host, port)
if url.scheme == "https":
if ssl_context is None:
import truststore
ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
await stream.start_tls(ssl_context, hostname=hostname)
stream = await backend.connect_tls(host, port, hostname)
else:
stream = await backend.connect(host, port)

return Connection(stream, url)
23 changes: 17 additions & 6 deletions src/httpx/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
import types
import typing

import certifi


_timeout_stack: contextvars.ContextVar[list[float]] = contextvars.ContextVar("timeout_context", default=[])

__all__ = ["NetworkBackend", "NetworkStream", "timeout"]
Expand Down Expand Up @@ -47,8 +50,6 @@ def get_current_timeout() -> float | None:

class NetworkStream:
def __init__(self, sock: socket.socket, address: tuple[str, int]) -> None:
peername = sock.getpeername()

self._socket = sock
self._address = address
self._is_tls = False
Expand All @@ -75,10 +76,6 @@ def write(self, buffer: bytes) -> None:
n = self._socket.send(buffer)
buffer = buffer[n:]

def start_tls(self, ctx: ssl.SSLContext, hostname: str | None = None) -> None:
self._socket = ctx.wrap_socket(self._socket, server_hostname=hostname)
self._is_tls = True

def close(self) -> None:
if not self._is_closed:
timeout = get_current_timeout()
Expand Down Expand Up @@ -194,6 +191,9 @@ def _handler(self, stream):


class NetworkBackend:
def __init__(self):
self._ssl_context = ssl.create_default_context(cafile=certifi.where())

def connect(self, host: str, port: int) -> NetworkStream:
"""
Connect to the given address, returning a NetworkStream instance.
Expand All @@ -203,6 +203,17 @@ def connect(self, host: str, port: int) -> NetworkStream:
sock = socket.create_connection(address, timeout=timeout)
return NetworkStream(sock, address)

def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream:
"""
Connect to the given address, returning a NetworkStream instance.
"""
address = (host, port)
hostname = hostname or host
timeout = get_current_timeout()
sock = socket.create_connection(address, timeout=timeout)
sock = self._ssl_context.wrap_socket(sock, server_hostname=hostname)
return NetworkStream(sock, address)

def listen(self, host: str, port: int) -> NetworkListener:
"""
List on the given address, returning a NetworkListener instance.
Expand Down
20 changes: 5 additions & 15 deletions src/httpx/_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import ssl
import time
import typing
import types
Expand Down Expand Up @@ -54,15 +53,11 @@ def stream(


class ConnectionPool(Transport):
def __init__(self, ssl_context: ssl.SSLContext | None = None, backend: NetworkBackend | None = None):
if ssl_context is None:
import truststore
ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
def __init__(self, backend: NetworkBackend | None = None):
if backend is None:
backend = NetworkBackend()

self._connections: list[Connection] = []
self._ssl_context = ssl_context
self._network_backend = backend
self._limit_concurrency = Semaphore(100)
self._closed = False
Expand All @@ -72,7 +67,7 @@ def send(self, request: Request) -> Response:
if self._closed:
raise RuntimeError("ConnectionPool is closed.")

# TODO: concurrency
# TODO: concurrency limiting
self._cleanup()
connection = self._get_connection(request)
response = connection.send(request)
Expand All @@ -99,7 +94,6 @@ def _get_connection(self, request: Request) -> "Connection":
conn = open_connection(
origin,
hostname=request.headers["Host"],
ssl_context=self._ssl_context,
backend=self._network_backend
)
self._connections.append(conn)
Expand Down Expand Up @@ -302,7 +296,6 @@ def __exit__(
def open_connection(
url: URL | str,
hostname: str = '',
ssl_context: ssl.SSLContext | None = None,
backend: NetworkBackend | None = None,
) -> Connection:

Expand All @@ -316,13 +309,10 @@ def open_connection(

host = url.host
port = url.port or {"http": 80, "https": 443}[url.scheme]
hostname = hostname or url.host

stream = backend.connect(host, port)
if url.scheme == "https":
if ssl_context is None:
import truststore
ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
stream.start_tls(ssl_context, hostname=hostname)
stream = backend.connect_tls(host, port, hostname)
else:
stream = backend.connect(host, port)

return Connection(stream, url)