Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## 0.18.0 [unreleased]

### Features

1. [#196](https://github.com/InfluxCommunity/influxdb3-python/pull/196): Support passing middleware functions to the Flight client.

### Bug Fixes

1. [#194](https://github.com/InfluxCommunity/influxdb3-python/pull/194): Fix `InfluxDBClient3.write_file()` and `InfluxDBClient3.write_dataframe()` fail with batching mode.
Expand Down
33 changes: 33 additions & 0 deletions Examples/query_with_middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from pyarrow import flight

from config import Config
from influxdb_client_3 import InfluxDBClient3, flight_client_options


# This middleware will add an additional attribute `some-attribute` to the header
class ModifyHeaderClientMiddleware(flight.ClientMiddleware):
def sending_headers(self):
return {
"some-attribute": "some-value",
}

def received_headers(self, headers):
pass


class ModifyHeaderClientMiddlewareFactory(flight.ClientMiddlewareFactory):
def start_call(self, info):
return ModifyHeaderClientMiddleware()


config = Config()
middleware = [ModifyHeaderClientMiddlewareFactory()]
client = InfluxDBClient3(
host=config.host,
token=config.token,
database=config.database,
flight_client_options=flight_client_options(middleware=middleware)
)

df = client.query(query="select * from cpu11 limit 10", mode="pandas")
print(len(df))
15 changes: 14 additions & 1 deletion influxdb_client_3/query/query_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class QueryApiOptions(object):
flight_client_options (dict): base set of flight client options passed to internal pyarrow.flight.FlightClient
timeout(float): timeout in seconds to wait for a response
disable_grpc_compression (bool): disable gRPC compression for query responses
middleware (list): list of middleware functions to be applied to Flight calls
"""
_DEFAULT_TIMEOUT = 300.0
tls_root_certs: bytes = None
Expand All @@ -28,13 +29,15 @@ class QueryApiOptions(object):
flight_client_options: dict = None
timeout: float = None
disable_grpc_compression: bool = False
middleware: list = None

def __init__(self, root_certs_path: str,
verify: bool,
proxy: str,
flight_client_options: dict,
timeout: float = _DEFAULT_TIMEOUT,
disable_grpc_compression: bool = False):
disable_grpc_compression: bool = False,
middleware: list = None):
"""
Initialize a set of QueryApiOptions

Expand All @@ -45,6 +48,7 @@ def __init__(self, root_certs_path: str,
to be passed to internal pyarrow.flight.FlightClient.
:param timeout: timeout in seconds to wait for a response.
:param disable_grpc_compression: disable gRPC compression for query responses.
:param middleware: list of middleware functions to be applied to Flight calls.
"""
if root_certs_path:
self.tls_root_certs = self._read_certs(root_certs_path)
Expand All @@ -53,6 +57,7 @@ def __init__(self, root_certs_path: str,
self.flight_client_options = flight_client_options
self.timeout = timeout
self.disable_grpc_compression = disable_grpc_compression
self.middleware = middleware

def _read_certs(self, path: str) -> bytes:
with open(path, "rb") as certs_file:
Expand Down Expand Up @@ -81,6 +86,7 @@ class QueryApiOptionsBuilder(object):
_flight_client_options: dict = None
_timeout: float = None
_disable_grpc_compression: bool = False
_middleware: list = None

def root_certs(self, path: str):
self._root_certs_path = path
Expand All @@ -107,6 +113,10 @@ def disable_grpc_compression(self, disable: bool):
self._disable_grpc_compression = disable
return self

def middleware(self, middleware: list):
self._middleware = middleware
return self

def build(self) -> QueryApiOptions:
"""Build a QueryApiOptions object with previously set values"""
return QueryApiOptions(
Expand All @@ -116,6 +126,7 @@ def build(self) -> QueryApiOptions:
flight_client_options=self._flight_client_options,
timeout=self._timeout,
disable_grpc_compression=self._disable_grpc_compression,
middleware=self._middleware
)


Expand Down Expand Up @@ -181,6 +192,8 @@ def __init__(self,
self._flight_client_options["generic_options"].append(
("grpc.compression_enabled_algorithms_bitset", 1)
)
if options.middleware:
self._flight_client_options["middleware"] = options.middleware
if self._proxy:
self._flight_client_options["generic_options"].append(("grpc.http_proxy", self._proxy))
self._flight_client = FlightClient(connection_string, **self._flight_client_options)
Expand Down
43 changes: 41 additions & 2 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
Ticket
)

from influxdb_client_3 import InfluxDBClient3
from influxdb_client_3 import InfluxDBClient3, flight_client_options
from influxdb_client_3.query.query_api import QueryApiOptionsBuilder, QueryApi
from influxdb_client_3.version import USER_AGENT
from tests.util import asyncio_run
Expand All @@ -25,7 +25,8 @@
HeaderCheckServerMiddlewareFactory,
NoopAuthHandler,
get_req_headers,
set_req_headers
set_req_headers, ModifyHeaderClientMiddlewareFactory,
HeaderCheckServerMiddlewareFactory1
)


Expand Down Expand Up @@ -175,11 +176,13 @@ def test_query_client_with_options(self):
cert_chain = 'mTLS_explicit_chain'
self.create_cert_file(cert_file)
test_flight_client_options = {'private_key': private_key, 'cert_chain': cert_chain}
middleware = [ModifyHeaderClientMiddlewareFactory()]
options = QueryApiOptionsBuilder()\
.proxy(proxy_name) \
.root_certs(cert_file) \
.tls_verify(False) \
.flight_client_options(test_flight_client_options) \
.middleware(middleware) \
.build()

client = QueryApi(connection,
Expand All @@ -195,6 +198,7 @@ def test_query_client_with_options(self):
assert client._flight_client_options['private_key'] == private_key
assert client._flight_client_options['cert_chain'] == cert_chain
assert client._proxy == proxy_name
assert client._flight_client_options['middleware'] == middleware
fc_opts = client._flight_client_options
assert dict(fc_opts['generic_options'])['grpc.secondary_user_agent'].startswith('influxdb3-python/')
assert dict(fc_opts['generic_options'])['grpc.http_proxy'] == proxy_name
Expand Down Expand Up @@ -311,6 +315,41 @@ def test_prepare_query(self):
assert _req_headers['authorization'] == [f"Bearer {token}"]
set_req_headers({})

def test_query_with_middleware_success(self):
with HeaderCheckFlightServer(
auth_handler=NoopAuthHandler(),
middleware={"check": HeaderCheckServerMiddlewareFactory1()}) as server:

middleware = [ModifyHeaderClientMiddlewareFactory()]
client = InfluxDBClient3(
host=f'http://localhost:{server.port}',
org='test_org',
databse='test_db',
token='TEST_TOKEN',
flight_client_options=flight_client_options(middleware=middleware)
)

df = client.query(query='SELECT * FROM test', mode="pandas")
self.assertIsNotNone(df)

def test_query_with_missing_middleware(self):
with HeaderCheckFlightServer(
auth_handler=NoopAuthHandler(),
middleware={"check": HeaderCheckServerMiddlewareFactory1()}) as server:

client = InfluxDBClient3(
host=f'http://localhost:{server.port}',
org='test_org',
databse='test_db',
token='TEST_TOKEN'
)

try:
client.query(query='SELECT * FROM test', mode="pandas")
self.fail("Should have failed due to missing middleware")
except Exception as e:
assert "Invalid header value from middleware" in str(e)

@asyncio_run
async def test_query_async_pandas(self):
with ConstantFlightServer() as server:
Expand Down
29 changes: 28 additions & 1 deletion tests/util/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
from pyarrow import (
array,
Table,
concat_tables, ArrowException
concat_tables, ArrowException, flight
)
from pyarrow._flight import FlightInternalError
from pyarrow.flight import (
FlightServerBase,
RecordBatchStream,
Expand Down Expand Up @@ -159,6 +160,32 @@ def number_batches(table):
yield batch, buf


class ModifyHeaderClientMiddleware(flight.ClientMiddleware):
def sending_headers(self):
return {
"header-from-middleware": "some-value",
}

def received_headers(self, headers):
pass


class ModifyHeaderClientMiddlewareFactory(flight.ClientMiddlewareFactory):
def start_call(self, info):
return ModifyHeaderClientMiddleware()


class HeaderCheckServerMiddlewareFactory1(ServerMiddlewareFactory):
"""Factory to create HeaderCheckServerMiddleware and check header values"""
def start_call(self, info, headers):
values = case_insensitive_header_lookup(headers, "header-from-middleware")
if values is None or values[0] != 'some-value':
raise FlightInternalError("Invalid header value from middleware")
global req_headers
req_headers = headers
return HeaderCheckServerMiddleware('')


class ErrorFlightServer(FlightServerBase):
def do_get(self, context, ticket):
raise ArrowException