diff --git a/CHANGELOG.md b/CHANGELOG.md index c021310..d8380d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,10 @@ ## 0.18.0 [unreleased] +### Features + +1. [#196](https://github.com/InfluxCommunity/influxdb3-python/pull/196): Support passing middleware functions to the Flight client. + ### Bug Fixes 1. [#194](https://github.com/InfluxCommunity/influxdb3-python/pull/194): Fix `InfluxDBClient3.write_file()` and `InfluxDBClient3.write_dataframe()` fail with batching mode. diff --git a/Examples/query_with_middleware.py b/Examples/query_with_middleware.py new file mode 100644 index 0000000..e9f3d7f --- /dev/null +++ b/Examples/query_with_middleware.py @@ -0,0 +1,33 @@ +from pyarrow import flight + +from config import Config +from influxdb_client_3 import InfluxDBClient3, flight_client_options + + +# This middleware will add an additional attribute `some-attribute` to the header +class ModifyHeaderClientMiddleware(flight.ClientMiddleware): + def sending_headers(self): + return { + "some-attribute": "some-value", + } + + def received_headers(self, headers): + pass + + +class ModifyHeaderClientMiddlewareFactory(flight.ClientMiddlewareFactory): + def start_call(self, info): + return ModifyHeaderClientMiddleware() + + +config = Config() +middleware = [ModifyHeaderClientMiddlewareFactory()] +client = InfluxDBClient3( + host=config.host, + token=config.token, + database=config.database, + flight_client_options=flight_client_options(middleware=middleware) +) + +df = client.query(query="select * from cpu11 limit 10", mode="pandas") +print(len(df)) diff --git a/influxdb_client_3/query/query_api.py b/influxdb_client_3/query/query_api.py index 0ba92de..7631795 100644 --- a/influxdb_client_3/query/query_api.py +++ b/influxdb_client_3/query/query_api.py @@ -20,6 +20,7 @@ class QueryApiOptions(object): flight_client_options (dict): base set of flight client options passed to internal pyarrow.flight.FlightClient timeout(float): timeout in seconds to wait for a response disable_grpc_compression (bool): disable gRPC compression for query responses + middleware (list): list of middleware functions to be applied to Flight calls """ _DEFAULT_TIMEOUT = 300.0 tls_root_certs: bytes = None @@ -28,13 +29,15 @@ class QueryApiOptions(object): flight_client_options: dict = None timeout: float = None disable_grpc_compression: bool = False + middleware: list = None def __init__(self, root_certs_path: str, verify: bool, proxy: str, flight_client_options: dict, timeout: float = _DEFAULT_TIMEOUT, - disable_grpc_compression: bool = False): + disable_grpc_compression: bool = False, + middleware: list = None): """ Initialize a set of QueryApiOptions @@ -45,6 +48,7 @@ def __init__(self, root_certs_path: str, to be passed to internal pyarrow.flight.FlightClient. :param timeout: timeout in seconds to wait for a response. :param disable_grpc_compression: disable gRPC compression for query responses. + :param middleware: list of middleware functions to be applied to Flight calls. """ if root_certs_path: self.tls_root_certs = self._read_certs(root_certs_path) @@ -53,6 +57,7 @@ def __init__(self, root_certs_path: str, self.flight_client_options = flight_client_options self.timeout = timeout self.disable_grpc_compression = disable_grpc_compression + self.middleware = middleware def _read_certs(self, path: str) -> bytes: with open(path, "rb") as certs_file: @@ -81,6 +86,7 @@ class QueryApiOptionsBuilder(object): _flight_client_options: dict = None _timeout: float = None _disable_grpc_compression: bool = False + _middleware: list = None def root_certs(self, path: str): self._root_certs_path = path @@ -107,6 +113,10 @@ def disable_grpc_compression(self, disable: bool): self._disable_grpc_compression = disable return self + def middleware(self, middleware: list): + self._middleware = middleware + return self + def build(self) -> QueryApiOptions: """Build a QueryApiOptions object with previously set values""" return QueryApiOptions( @@ -116,6 +126,7 @@ def build(self) -> QueryApiOptions: flight_client_options=self._flight_client_options, timeout=self._timeout, disable_grpc_compression=self._disable_grpc_compression, + middleware=self._middleware ) @@ -181,6 +192,8 @@ def __init__(self, self._flight_client_options["generic_options"].append( ("grpc.compression_enabled_algorithms_bitset", 1) ) + if options.middleware: + self._flight_client_options["middleware"] = options.middleware if self._proxy: self._flight_client_options["generic_options"].append(("grpc.http_proxy", self._proxy)) self._flight_client = FlightClient(connection_string, **self._flight_client_options) diff --git a/tests/test_query.py b/tests/test_query.py index ade4b2d..f702f17 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -12,7 +12,7 @@ Ticket ) -from influxdb_client_3 import InfluxDBClient3 +from influxdb_client_3 import InfluxDBClient3, flight_client_options from influxdb_client_3.query.query_api import QueryApiOptionsBuilder, QueryApi from influxdb_client_3.version import USER_AGENT from tests.util import asyncio_run @@ -25,7 +25,8 @@ HeaderCheckServerMiddlewareFactory, NoopAuthHandler, get_req_headers, - set_req_headers + set_req_headers, ModifyHeaderClientMiddlewareFactory, + HeaderCheckServerMiddlewareFactory1 ) @@ -175,11 +176,13 @@ def test_query_client_with_options(self): cert_chain = 'mTLS_explicit_chain' self.create_cert_file(cert_file) test_flight_client_options = {'private_key': private_key, 'cert_chain': cert_chain} + middleware = [ModifyHeaderClientMiddlewareFactory()] options = QueryApiOptionsBuilder()\ .proxy(proxy_name) \ .root_certs(cert_file) \ .tls_verify(False) \ .flight_client_options(test_flight_client_options) \ + .middleware(middleware) \ .build() client = QueryApi(connection, @@ -195,6 +198,7 @@ def test_query_client_with_options(self): assert client._flight_client_options['private_key'] == private_key assert client._flight_client_options['cert_chain'] == cert_chain assert client._proxy == proxy_name + assert client._flight_client_options['middleware'] == middleware fc_opts = client._flight_client_options assert dict(fc_opts['generic_options'])['grpc.secondary_user_agent'].startswith('influxdb3-python/') assert dict(fc_opts['generic_options'])['grpc.http_proxy'] == proxy_name @@ -311,6 +315,41 @@ def test_prepare_query(self): assert _req_headers['authorization'] == [f"Bearer {token}"] set_req_headers({}) + def test_query_with_middleware_success(self): + with HeaderCheckFlightServer( + auth_handler=NoopAuthHandler(), + middleware={"check": HeaderCheckServerMiddlewareFactory1()}) as server: + + middleware = [ModifyHeaderClientMiddlewareFactory()] + client = InfluxDBClient3( + host=f'http://localhost:{server.port}', + org='test_org', + databse='test_db', + token='TEST_TOKEN', + flight_client_options=flight_client_options(middleware=middleware) + ) + + df = client.query(query='SELECT * FROM test', mode="pandas") + self.assertIsNotNone(df) + + def test_query_with_missing_middleware(self): + with HeaderCheckFlightServer( + auth_handler=NoopAuthHandler(), + middleware={"check": HeaderCheckServerMiddlewareFactory1()}) as server: + + client = InfluxDBClient3( + host=f'http://localhost:{server.port}', + org='test_org', + databse='test_db', + token='TEST_TOKEN' + ) + + try: + client.query(query='SELECT * FROM test', mode="pandas") + self.fail("Should have failed due to missing middleware") + except Exception as e: + assert "Invalid header value from middleware" in str(e) + @asyncio_run async def test_query_async_pandas(self): with ConstantFlightServer() as server: diff --git a/tests/util/mocks.py b/tests/util/mocks.py index 5d7201e..8716e19 100644 --- a/tests/util/mocks.py +++ b/tests/util/mocks.py @@ -5,8 +5,9 @@ from pyarrow import ( array, Table, - concat_tables, ArrowException + concat_tables, ArrowException, flight ) +from pyarrow._flight import FlightInternalError from pyarrow.flight import ( FlightServerBase, RecordBatchStream, @@ -159,6 +160,32 @@ def number_batches(table): yield batch, buf +class ModifyHeaderClientMiddleware(flight.ClientMiddleware): + def sending_headers(self): + return { + "header-from-middleware": "some-value", + } + + def received_headers(self, headers): + pass + + +class ModifyHeaderClientMiddlewareFactory(flight.ClientMiddlewareFactory): + def start_call(self, info): + return ModifyHeaderClientMiddleware() + + +class HeaderCheckServerMiddlewareFactory1(ServerMiddlewareFactory): + """Factory to create HeaderCheckServerMiddleware and check header values""" + def start_call(self, info, headers): + values = case_insensitive_header_lookup(headers, "header-from-middleware") + if values is None or values[0] != 'some-value': + raise FlightInternalError("Invalid header value from middleware") + global req_headers + req_headers = headers + return HeaderCheckServerMiddleware('') + + class ErrorFlightServer(FlightServerBase): def do_get(self, context, ticket): raise ArrowException