diff --git a/pyproject.toml b/pyproject.toml index 291091a..b5aacb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ ] dependencies = [ "h11==0.*", - "truststore==0.10", + "certifi", ] dynamic = ["version"] diff --git a/src/ahttpx/_network.py b/src/ahttpx/_network.py index e3d7f9e..6c9661e 100644 --- a/src/ahttpx/_network.py +++ b/src/ahttpx/_network.py @@ -3,6 +3,8 @@ import types import typing +import certifi + __all__ = ["NetworkBackend", "NetworkStream", "timeout"] @@ -10,11 +12,9 @@ class NetworkStream: def __init__( self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter, address: str = '' ) -> None: - peername = writer.get_extra_info('peername') - self._reader = reader self._writer = writer - self._address = f"{peername[0]}:{peername[1]}" + self._ = address self._tls = False self._closed = False @@ -81,12 +81,25 @@ async def __aexit__( class NetworkBackend: + def __init__(self): + self._ssl_context = ssl.create_default_context(cafile=certifi.where()) + async def connect(self, host: str, port: int) -> NetworkStream: """ Connect to the given address, returning a Stream instance. """ + address = f"{host}:{port}" + reader, writer = await asyncio.open_connection(host, port) + return NetworkStream(reader, writer, address=address) + + async def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream: + """ + Connect to the given address, returning a Stream instance. + """ + address = f"{host}:{port}" reader, writer = await asyncio.open_connection(host, port) - return NetworkStream(reader, writer, address=f"{host}:{port}") + await writer.start_tls(self._ssl_context, server_hostname=hostname) + return NetworkStream(reader, writer, address=address) async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer: async def callback(reader, writer): diff --git a/src/ahttpx/_pool.py b/src/ahttpx/_pool.py index f93460a..fdb09c1 100644 --- a/src/ahttpx/_pool.py +++ b/src/ahttpx/_pool.py @@ -1,4 +1,3 @@ -import ssl import time import typing import types @@ -54,15 +53,11 @@ async def stream( class ConnectionPool(Transport): - def __init__(self, ssl_context: ssl.SSLContext | None = None, backend: NetworkBackend | None = None): - if ssl_context is None: - import truststore - ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + def __init__(self, backend: NetworkBackend | None = None): if backend is None: backend = NetworkBackend() self._connections: list[Connection] = [] - self._ssl_context = ssl_context self._network_backend = backend self._limit_concurrency = Semaphore(100) self._closed = False @@ -99,7 +94,6 @@ async def _get_connection(self, request: Request) -> "Connection": conn = await open_connection( origin, hostname=request.headers["Host"], - ssl_context=self._ssl_context, backend=self._network_backend ) self._connections.append(conn) @@ -302,7 +296,6 @@ async def __aexit__( async def open_connection( url: URL | str, hostname: str = '', - ssl_context: ssl.SSLContext | None = None, backend: NetworkBackend | None = None, ) -> Connection: @@ -316,13 +309,10 @@ async def open_connection( host = url.host port = url.port or {"http": 80, "https": 443}[url.scheme] - hostname = hostname or url.host - stream = await backend.connect(host, port) if url.scheme == "https": - if ssl_context is None: - import truststore - ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - await stream.start_tls(ssl_context, hostname=hostname) + stream = await backend.connect_tls(host, port, hostname) + else: + stream = await backend.connect(host, port) return Connection(stream, url) diff --git a/src/httpx/_network.py b/src/httpx/_network.py index 7fa0992..472eabb 100644 --- a/src/httpx/_network.py +++ b/src/httpx/_network.py @@ -9,6 +9,9 @@ import types import typing +import certifi + + _timeout_stack: contextvars.ContextVar[list[float]] = contextvars.ContextVar("timeout_context", default=[]) __all__ = ["NetworkBackend", "NetworkStream", "timeout"] @@ -47,8 +50,6 @@ def get_current_timeout() -> float | None: class NetworkStream: def __init__(self, sock: socket.socket, address: tuple[str, int]) -> None: - peername = sock.getpeername() - self._socket = sock self._address = address self._is_tls = False @@ -75,10 +76,6 @@ def write(self, buffer: bytes) -> None: n = self._socket.send(buffer) buffer = buffer[n:] - def start_tls(self, ctx: ssl.SSLContext, hostname: str | None = None) -> None: - self._socket = ctx.wrap_socket(self._socket, server_hostname=hostname) - self._is_tls = True - def close(self) -> None: if not self._is_closed: timeout = get_current_timeout() @@ -194,6 +191,9 @@ def _handler(self, stream): class NetworkBackend: + def __init__(self): + self._ssl_context = ssl.create_default_context(cafile=certifi.where()) + def connect(self, host: str, port: int) -> NetworkStream: """ Connect to the given address, returning a NetworkStream instance. @@ -203,6 +203,17 @@ def connect(self, host: str, port: int) -> NetworkStream: sock = socket.create_connection(address, timeout=timeout) return NetworkStream(sock, address) + def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream: + """ + Connect to the given address, returning a NetworkStream instance. + """ + address = (host, port) + hostname = hostname or host + timeout = get_current_timeout() + sock = socket.create_connection(address, timeout=timeout) + sock = self._ssl_context.wrap_socket(sock, server_hostname=hostname) + return NetworkStream(sock, address) + def listen(self, host: str, port: int) -> NetworkListener: """ List on the given address, returning a NetworkListener instance. diff --git a/src/httpx/_pool.py b/src/httpx/_pool.py index 9c31458..43b9b65 100644 --- a/src/httpx/_pool.py +++ b/src/httpx/_pool.py @@ -1,4 +1,3 @@ -import ssl import time import typing import types @@ -54,15 +53,11 @@ def stream( class ConnectionPool(Transport): - def __init__(self, ssl_context: ssl.SSLContext | None = None, backend: NetworkBackend | None = None): - if ssl_context is None: - import truststore - ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + def __init__(self, backend: NetworkBackend | None = None): if backend is None: backend = NetworkBackend() self._connections: list[Connection] = [] - self._ssl_context = ssl_context self._network_backend = backend self._limit_concurrency = Semaphore(100) self._closed = False @@ -72,7 +67,7 @@ def send(self, request: Request) -> Response: if self._closed: raise RuntimeError("ConnectionPool is closed.") - # TODO: concurrency + # TODO: concurrency limiting self._cleanup() connection = self._get_connection(request) response = connection.send(request) @@ -99,7 +94,6 @@ def _get_connection(self, request: Request) -> "Connection": conn = open_connection( origin, hostname=request.headers["Host"], - ssl_context=self._ssl_context, backend=self._network_backend ) self._connections.append(conn) @@ -302,7 +296,6 @@ def __exit__( def open_connection( url: URL | str, hostname: str = '', - ssl_context: ssl.SSLContext | None = None, backend: NetworkBackend | None = None, ) -> Connection: @@ -316,13 +309,10 @@ def open_connection( host = url.host port = url.port or {"http": 80, "https": 443}[url.scheme] - hostname = hostname or url.host - stream = backend.connect(host, port) if url.scheme == "https": - if ssl_context is None: - import truststore - ssl_context = truststore.SSLContext(ssl.PROTOCOL_TLS_CLIENT) - stream.start_tls(ssl_context, hostname=hostname) + stream = backend.connect_tls(host, port, hostname) + else: + stream = backend.connect(host, port) return Connection(stream, url)