From fe3f2ccc2b3d39ba823e9e8efa97f16d856874b8 Mon Sep 17 00:00:00 2001 From: Anqi <16240361+Nicole00@users.noreply.github.com> Date: Thu, 29 Jan 2026 11:04:59 +0800 Subject: [PATCH 1/5] refactor nebula client and nebula pool & add async interface & add tests --- docker-compose.yml | 218 ++++ docs/1_started.md | 24 +- docs/4_error_handling.md | 2 +- docs/5_vector_and_special_types.md | 4 +- docs/6_scan_all.md | 24 +- example.py | 374 ++++--- example/NebulaPoolExample.py | 11 +- example_async_usage.py | 159 +++ pdm.lock | 10 +- pyproject.toml | 1 + src/nebulagraph_python/__init__.py | 8 +- src/nebulagraph_python/client/__init__.py | 15 +- src/nebulagraph_python/client/_connection.py | 682 +++++------- .../client/_connection_pool.py | 180 ---- src/nebulagraph_python/client/_session.py | 176 --- .../client/_session_pool.py | 289 ----- src/nebulagraph_python/client/auth_result.py | 34 + src/nebulagraph_python/client/client.py | 302 ------ .../client/client_pool_factory.py | 166 +++ src/nebulagraph_python/client/constants.py | 42 +- .../client/nebula_client.py | 422 ++++++++ src/nebulagraph_python/client/nebula_pool.py | 276 +++++ src/nebulagraph_python/client/pool.py | 306 ------ .../client/round_robin_load_balancer.py | 110 ++ test_async_client.py | 75 ++ test_nebula_connection.py | 73 ++ test_refactored.py | 96 ++ tests/INTEGRATION_TEST_README.md | 103 -- tests/test_connection.py | 763 +++++++++++++ tests/test_connection_pool.py | 782 -------------- tests/test_integration.py | 701 ++++++++++++ tests/test_nebula_client.py | 608 +++++++++++ .../test_nebula_client_decode_integration.py | 4 +- tests/test_nebula_pool.py | 889 ++++++++++++++++ tests/test_nebula_pool_integration.py | 625 +++++++++++ tests/test_session_pool.py | 998 ------------------ 36 files changed, 5798 insertions(+), 3754 deletions(-) create mode 100644 docker-compose.yml create mode 100644 example_async_usage.py delete mode 100644 src/nebulagraph_python/client/_connection_pool.py delete mode 100644 src/nebulagraph_python/client/_session.py delete mode 100644 src/nebulagraph_python/client/_session_pool.py create mode 100644 src/nebulagraph_python/client/auth_result.py delete mode 100644 src/nebulagraph_python/client/client.py create mode 100644 src/nebulagraph_python/client/client_pool_factory.py create mode 100644 src/nebulagraph_python/client/nebula_client.py create mode 100644 src/nebulagraph_python/client/nebula_pool.py delete mode 100644 src/nebulagraph_python/client/pool.py create mode 100644 src/nebulagraph_python/client/round_robin_load_balancer.py create mode 100644 test_async_client.py create mode 100644 test_nebula_connection.py create mode 100644 test_refactored.py delete mode 100644 tests/INTEGRATION_TEST_README.md create mode 100644 tests/test_connection.py delete mode 100644 tests/test_connection_pool.py create mode 100644 tests/test_integration.py create mode 100644 tests/test_nebula_client.py create mode 100644 tests/test_nebula_pool.py create mode 100644 tests/test_nebula_pool_integration.py delete mode 100644 tests/test_session_pool.py diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 00000000..ae672280 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,218 @@ +version: '3.8' + +services: + nebula-metad0: + image: vesoft/nebula-graph:v3.8.0 + environment: + USER: root + TZ: UTC + command: + - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 + - --local_ip=nebula-metad0 + - --ws_ip=nebula-metad0 + - --port=9559 + - --data_path=/data/meta + - --log_dir=/logs + - --v=0 + - --minloglevel=0 + healthcheck: + test: ["CMD", "curl", "-f", "http://nebula-metad0:19559/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9559:9559" + - "19559:19559" + volumes: + - ./data/meta0:/data/meta + - ./logs/meta0:/logs + networks: + - nebula-net + restart: on-failure + cap_add: + - SYS_PTRACE + + nebula-metad1: + image: vesoft/nebula-graph:v3.8.0 + environment: + USER: root + TZ: UTC + command: + - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 + - --local_ip=nebula-metad1 + - --ws_ip=nebula-metad1 + - --port=9559 + - --data_path=/data/meta + - --log_dir=/logs + - --v=0 + - --minloglevel=0 + healthcheck: + test: ["CMD", "curl", "-f", "http://nebula-metad1:19559/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9560:9559" + - "19560:19559" + volumes: + - ./data/meta1:/data/meta + - ./logs/meta1:/logs + networks: + - nebula-net + restart: on-failure + cap_add: + - SYS_PTRACE + + nebula-metad2: + image: vesoft/nebula-graph:v3.8.0 + environment: + USER: root + TZ: UTC + command: + - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 + - --local_ip=nebula-metad2 + - --ws_ip=nebula-metad2 + - --port=9559 + - --data_path=/data/meta + - --log_dir=/logs + - --v=0 + - --minloglevel=0 + healthcheck: + test: ["CMD", "curl", "-f", "http://nebula-metad2:19559/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9561:9559" + - "19561:19559" + volumes: + - ./data/meta2:/data/meta + - ./logs/meta2:/logs + networks: + - nebula-net + restart: on-failure + cap_add: + - SYS_PTRACE + + nebula-storaged0: + image: vesoft/nebula-graph:v3.8.0 + environment: + USER: root + TZ: UTC + command: + - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 + - --local_ip=nebula-storaged0 + - --ws_ip=nebula-storaged0 + - --port=9779 + - --data_path=/data/storage + - --log_dir=/logs + - --v=0 + - --minloglevel=0 + depends_on: + - nebula-metad0 + - nebula-metad1 + - nebula-metad2 + healthcheck: + test: ["CMD", "curl", "-f", "http://nebula-storaged0:19779/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9779:9779" + - "19779:19779" + volumes: + - ./data/storage0:/data/storage + - ./logs/storage0:/logs + networks: + - nebula-net + restart: on-failure + cap_add: + - SYS_PTRACE + + nebula-storaged1: + image: vesoft/nebula-graph:v3.8.0 + environment: + USER: root + TZ: UTC + command: + - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 + - --local_ip=nebula-storaged1 + - --ws_ip=nebula-storaged1 + - --port=9779 + - --data_path=/data/storage + - --log_dir=/logs + - --v=0 + - --minloglevel=0 + depends_on: + - nebula-metad0 + - nebula-metad1 + - nebula-metad2 + healthcheck: + test: ["CMD", "curl", "-f", "http://nebula-storaged1:19779/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9780:9779" + - "19780:19779" + volumes: + - ./data/storage1:/data/storage + - ./logs/storage1:/logs + networks: + - nebula-net + restart: on-failure + cap_add: + - SYS_PTRACE + + nebula-graphd: + image: vesoft/nebula-graph:v3.8.0 + environment: + USER: root + TZ: UTC + command: + - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 + - --local_ip=nebula-graphd + - --ws_ip=nebula-graphd + - --port=9669 + - --log_dir=/logs + - --v=0 + - --minloglevel=0 + depends_on: + - nebula-metad0 + - nebula-metad1 + - nebula-metad2 + - nebula-storaged0 + - nebula-storaged1 + healthcheck: + test: ["CMD", "curl", "-f", "http://nebula-graphd:19669/status"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 20s + ports: + - "9669:9669" + - "19669:19669" + volumes: + - ./logs/graph:/logs + networks: + - nebula-net + restart: on-failure + cap_add: + - SYS_PTRACE + + nebula-console: + image: vesoft/nebula-console:v3.8.0 + entrypoint: ["sleep", "infinity"] + depends_on: + - nebula-graphd + networks: + - nebula-net + +networks: + nebula-net: + driver: bridge \ No newline at end of file diff --git a/docs/1_started.md b/docs/1_started.md index d7eb7074..73e4b331 100644 --- a/docs/1_started.md +++ b/docs/1_started.md @@ -21,16 +21,24 @@ pip install -e . ```python import asyncio -from nebulagraph_python.client import NebulaAsyncClient +from nebulagraph_python.client import AsyncNebulaClient async def main() -> None: - async with await NebulaAsyncClient.connect( - hosts=["127.0.0.1:9669"], - username="root", - password="NebulaGraph01", - ) as client: - result = await client.execute("RETURN 1 AS a, 2 AS b") - result.print() + # Create async client + # Note: AsyncNebulaClient requires manual initialization + client = AsyncNebulaClient( + addresses="127.0.0.1:9669", + user_name="root", + password="nebula", + connect_timeout_ms=3000, + request_timeout_ms=30000 + ) + + # Initialize the connection + await client._init_client() + + result = await client.execute("RETURN 1 AS a, 2 AS b") + result.print() asyncio.run(main()) ``` diff --git a/docs/4_error_handling.md b/docs/4_error_handling.md index d364039b..0beb459e 100644 --- a/docs/4_error_handling.md +++ b/docs/4_error_handling.md @@ -17,7 +17,7 @@ The relevant definitions live in the `nebulagraph_python.error` module. ```python import asyncio -from nebulagraph_python.client.client import NebulaAsyncClient +from nebulagraph_python.client.nebula_client import NebulaAsyncClient from nebulagraph_python.error import NebulaGraphRemoteError, ErrorCode async def main(): diff --git a/docs/5_vector_and_special_types.md b/docs/5_vector_and_special_types.md index e95b2b48..763572c8 100644 --- a/docs/5_vector_and_special_types.md +++ b/docs/5_vector_and_special_types.md @@ -27,7 +27,7 @@ API: Examples: ```python -from nebulagraph_python.client.client import NebulaClient +from nebulagraph_python.client.nebula_client import NebulaClient from nebulagraph_python.py_data_types import NVector # Connect (adjust hosts/credentials to your environment) @@ -70,7 +70,7 @@ API: Examples: ```python -from nebulagraph_python.client.client import NebulaClient +from nebulagraph_python.client.nebula_client import NebulaClient from nebulagraph_python.py_data_types import NDuration cli = NebulaClient(hosts=["127.0.0.1:9669"], username="root", password="Nebula.123") diff --git a/docs/6_scan_all.md b/docs/6_scan_all.md index df088312..4c53b93e 100644 --- a/docs/6_scan_all.md +++ b/docs/6_scan_all.md @@ -9,22 +9,24 @@ This guide shows how to iterate through every node and edge type in a graph and ### Example: scan everything in a space ```python -from nebulagraph_python.client import NebulaClient, SessionConfig -from nebulagraph_python.orm import get_graph_type +from nebulagraph_python.client import NebulaClient from nebulagraph_python.tools import scan_edges, scan_nodes +from nebulagraph_python.tools.get_graph_type import get_graph_type +from nebulagraph_python.tools.session_conf import get_graph_type_name + # Create client client = NebulaClient( - hosts=["127.0.0.1:9669"], - username="root", - password="NebulaGraph01", - session_config=SessionConfig( - graph="movie", - ), + addresses="127.0.0.1:9669", + user_name="root", + password="nebula", + connect_timeout_ms=3000, + request_timeout_ms=30000, ) # Discover schema metadata for the target graph -graph_type = get_graph_type(client, graph_name="movie") +graph_type_name = get_graph_type_name(client, graph_name="movie") +graph_type = get_graph_type(client, graph_type_name=graph_type_name) # Scan all nodes (by type) for node_type_name, node_type in graph_type.nodes.items(): @@ -50,9 +52,13 @@ for edge_type_name, edge_type in graph_type.edges.items(): ) ): print("edge", edge_type_name, count, src, edge, dst) + +# Close the client +client.close() ``` ### Notes - **`properties_list`**: provide the properties you want to fetch. Using `list(node_type.properties.keys())` pulls all properties for that type. - **`batch_size`**: controls page size when scanning nodes; tune it based on data volume and memory. - The edge scan yields tuples `(src, edge, dst)` describing the full relation. +- Remember to close the client when done to release resources. diff --git a/example.py b/example.py index ee1e8165..158aae7c 100644 --- a/example.py +++ b/example.py @@ -1,196 +1,234 @@ -from nebulagraph_python import NebulaAsyncClient, SessionConfig, SessionPoolConfig +from nebulagraph_python import NebulaClient, NebulaPool, NebulaPoolConfig -async def async_client_example(): - # Create client - client = await NebulaAsyncClient.connect( - hosts=["127.0.0.1:9669", "127.0.0.1:9670"], - username="root", - password="NebulaGraph01", - session_config=SessionConfig( - graph="movie", - timezone="Asia/Shanghai", - values={"a": "1", "b": "[1, 2, 3]"}, - ), - ) +def sync_client_example(): + """Example using synchronous NebulaClient""" - (await client.execute_py("RETURN $a, $b")).print() - (await client.execute_py("SHOW CURRENT_SESSION")).print() - (await client.execute_py("DESC GRAPH TYPE movie_type")).print() - - query = """ - USE movie - MATCH p=(a:Movie{name:"Unpromised Land"})-[e:WithGenre]->(b:Genre) - RETURN p as path, e as edge_WithGenre, b as genre_node, a.name as movie_name, 3.14 as float_val, true as bool_val - LIMIT 2 - """ - # Execute query - result = await client.execute(query) - - # Print results - result.print() - - # Convert to pandas DataFrame - # df = result.as_pandas_df() - # df.to_csv("query_result.csv", index=False) - - # Get one row - result = await client.execute(query) - row = result.one() - - # Get column names - print(row.column_names) - - # Get one value - cell = row.col_values[0].cast() - print(type(cell), cell) - - # Cast to primitive - cell_primitive = row.col_values[0].cast_primitive() - print(type(cell_primitive), cell_primitive) - - ###### - # special value type example - ###### - - query = """ - RETURN local_datetime("2016-09-20T01:01:01", "%Y-%m-%dT%H:%M:%S") AS localdatetime, - local_time("05:06:07.089", "%H:%M:%S") AS localtime, - zoned_time("05:06:07.089 +08:00", "%H:%M:%S %Ez") AS zonetime, - zoned_datetime("2016-09-20T01:01:01 +0800", "%Y-%m-%dT%H:%M:%S %z") AS zoneddatetime, - date("Tue, 2016-09-20", "%a, %Y-%m-%d") AS d, - RECORD {a: 1, b: true, c: "str literal"} AS record1, - LIST [1, 2, 3, 4, 5] AS l, - "str literal" AS str_literal - """ - - (await client.execute_py(query)).print() - - ###### - # execute_py example - ###### - - query = """ - RETURN {{v1}} as v1, {{v2}} as v2, {{v3}} as v3 - """ - args = {"v1": 1, "v2": "alice", "v3": [True, False, True]} - - res = await client.execute_py(query, args) - # get the first row in primitive type - row = res.one().as_primitive() - res.print() - # assert the row is the same as the args, in python primitive type - assert row == args - # get the result in column-oriented primitive type - print(res.as_primitive_by_column()) - # get the result in row-oriented primitive type - print(list(res.as_primitive_by_row())) - - ###### - # embedding vector example - ###### - - # FOR DDL, DML refer to ann.feature - - # Query KNN - await client.execute_py( - """ - CREATE GRAPH TYPE IF NOT EXISTS ann_test_type { - NODE N1 (:N1&N2{ - idx INT64 PRIMARY KEY, - vec1 VECTOR<3, FLOAT> - }) - } - """ + # Create client using direct initialization + # Note: NebulaClient automatically establishes connection in __init__ + # and inherits from NebulaBaseExecutor, so it can use execute_py() directly + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="root", + password="nebula", + connect_timeout_ms=3000, + request_timeout_ms=30000 ) - await client.execute_py( - """ - CREATE GRAPH IF NOT EXISTS ann_test ann_test_type - """ - ) + try: + print(f"Connected to NebulaGraph at {client.get_host()}") + print(f"Session ID: {client.get_session_id()}") + print(f"Server version: {client.get_version()}") - await client.execute_py( - """ - USE ann_test INSERT OR REPLACE (@N1 {idx: 1, vec1: vector<3, float>([1, 2, 3])}) - """ - ) + # Execute simple query + result = client.execute("SHOW HOSTS") + print("SHOW HOSTS result:") + result.print() - await client.execute_py( + # Execute query with custom timeout + result = client.execute_with_timeout("SHOW SPACES", 5000) + print("SHOW SPACES result:") + result.print() + + # Test ping + is_alive = client.ping() + print(f"Server is alive: {is_alive}") + + # execute_py example (NebulaClient now supports this directly) + query = """ + RETURN {{v1}} as v1, {{v2}} as v2, {{v3}} as v3 """ - USE ann_test INSERT OR REPLACE (@N1 {idx: 2, vec1: vector<3, float>([4, 5, 6])}) - """ - ) + args = {"v1": 1, "v2": "alice", "v3": [True, False, True]} + + res = client.execute_py(query, args) + print("execute_py result:") + res.print() - query = """ - USE ann_test - MATCH (v:N1|N2) - ORDER BY vector_distance(vector<3, float>([1, 2, 3]), v.vec1) LIMIT 3 - RETURN v, v.vec1 as vec1 - """ + # Get the first row in primitive type + row = res.one().as_primitive() + assert row == args + print(f"Row as primitive: {row}") - (await client.execute_py(query)).print() + # Get result in column-oriented primitive type + print(f"Column-oriented: {res.as_primitive_by_column()}") - await client.close() # Explicitly close the client to release all resources + # Get result in row-oriented primitive type + print(f"Row-oriented: {list(res.as_primitive_by_row())}") + except Exception as e: + print(f"Error: {e}") + finally: + client.close() + print("Client closed\n") -async def async_session_pool_example(): - """In this example we will create a client with session pool to execute queries async concurrently""" - from asyncio import gather - client = await NebulaAsyncClient.connect( - hosts=["127.0.0.1:9669", "127.0.0.1:9670"], +def async_client_example(): + """Example using AsyncNebulaClient""" + import asyncio + from nebulagraph_python.client import AsyncNebulaClient + + async def run(): + # Create async client using direct initialization + # Note: AsyncNebulaClient does NOT automatically connect in __init__ + # You must call _init_client() manually before using the client + client = AsyncNebulaClient( + addresses="127.0.0.1:9669", + user_name="root", + password="nebula", + connect_timeout_ms=3000, + request_timeout_ms=30000 + ) + + try: + # Initialize connection (required for AsyncNebulaClient) + await client._init_client() + print(f"Connected to NebulaGraph at {client.get_host()}") + print(f"Session ID: {client.get_session_id()}") + print(f"Server version: {client.get_version()}") + + # Execute simple query + result = await client.execute("SHOW HOSTS") + print("SHOW HOSTS result:") + result.print() + + # Execute query with custom timeout + result = await client.execute_with_timeout("SHOW SPACES", 5000) + print("SHOW SPACES result:") + result.print() + + # Test ping + is_alive = await client.ping() + print(f"Server is alive: {is_alive}") + + # Execute multiple queries concurrently + queries = [ + "SHOW HOSTS", + "SHOW SPACES", + "SHOW TAGS", + ] + + results = await asyncio.gather(*[client.execute(q) for q in queries]) + for i, (query, result) in enumerate(zip(queries, results), 1): + print(f"Query {i} ({query}): {result.row_size()} rows") + + except Exception as e: + print(f"Error: {e}") + finally: + await client.close() + print("Async client closed\n") + + asyncio.run(run()) + + +def pool_example(): + """Example using NebulaPool for connection pooling""" + + # Create pool configuration + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", username="root", - password="NebulaGraph01", - session_pool_config=SessionPoolConfig(), # Add the session pool config to use session pool + password="nebula", + max_client_size=10, + min_client_size=2, + max_wait_ms=5000, + graph="movie", # Optional: set default graph ) - tasks = [client.execute_py("RETURN {{idx}}", {"idx": x}) for x in range(8)] - results = await gather(*tasks) - for result in results: - print(result.as_primitive_by_column()) - await client.close() # Explicitly close to release all resources + pool = NebulaPool(config) - # Using context manager to automatically close the client - async with await NebulaAsyncClient.connect( - hosts=["127.0.0.1:9669"], - username="root", - password="NebulaGraph01", - ) as client: - (await client.execute_py("RETURN 1")).print() + try: + print("Pool created successfully") + print(f"Active sessions: {pool.get_active_sessions()}") + print(f"Idle sessions: {pool.get_idle_sessions()}") + + # Get a client from the pool + client = pool.get_client() + print(f"Got client, session ID: {client.get_session_id()}") + print(f"Active sessions: {pool.get_active_sessions()}") -def sync_session_pool_example(): - """In this example we will create a client with session pool to execute queries multi-threaded concurrently""" + # Execute queries + result = client.execute("SHOW HOSTS") + print("Query result:") + result.print() + + # Use execute_py (NebulaClient now supports this directly) + res = client.execute_py("RETURN 1 AS num") + print("execute_py result:") + res.print() + + # Return the client to the pool + pool.return_client(client) + print(f"Returned client, idle sessions: {pool.get_idle_sessions()}") + + # Get another client + client2 = pool.get_client() + print(f"Got another client, session ID: {client2.get_session_id()}") + + result = client2.execute("SHOW SPACES") + result.print() + + pool.return_client(client2) + + except Exception as e: + print(f"Error: {e}") + finally: + pool.close() + print("Pool closed\n") + + +def multi_threaded_pool_example(): + """Example using NebulaPool with multiple threads""" from concurrent.futures import ThreadPoolExecutor + from nebulagraph_python import NebulaPool, NebulaPoolConfig + + def query_task(pool, idx): + """Task to execute a query""" + client = pool.get_client() + try: + result = client.execute(f"RETURN {idx} AS num") + return result.as_primitive_by_column() + finally: + pool.return_client(client) + + # Create pool + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="root", + password="nebula", + max_client_size=10, + min_client_size=2, + ) - from nebulagraph_python import NebulaClient, SessionPoolConfig - - # Using context manager to automatically close the client - with ( - NebulaClient( - hosts=["127.0.0.1:9669"], - username="root", - password="NebulaGraph01", - session_pool_config=SessionPoolConfig(), # Add the session pool config to use session pool - ) as client, - ThreadPoolExecutor(max_workers=8) as executor, - ): - futures = [ - executor.submit(client.execute_py, "RETURN {{idx}}", {"idx": x}) - for x in range(8) - ] - for future in futures: - print(future.result().as_primitive_by_column()) + with NebulaPool(config) as pool: + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(query_task, pool, i) for i in range(10)] + for future in futures: + print(future.result()) if __name__ == "__main__": - import asyncio import logging - logging.basicConfig(level=logging.DEBUG) - logging.getLogger("nebulagraph_python").setLevel(logging.DEBUG) + logging.basicConfig(level=logging.INFO) + logging.getLogger("nebulagraph_python").setLevel(logging.INFO) + + print("=" * 60) + print("Synchronous Client Example") + print("=" * 60) + sync_client_example() + + print("=" * 60) + print("Asynchronous Client Example") + print("=" * 60) + async_client_example() + + print("=" * 60) + print("Connection Pool Example") + print("=" * 60) + pool_example() - asyncio.run(async_client_example()) - asyncio.run(async_session_pool_example()) - sync_session_pool_example() + print("=" * 60) + print("Multi-threaded Pool Example") + print("=" * 60) + multi_threaded_pool_example() diff --git a/example/NebulaPoolExample.py b/example/NebulaPoolExample.py index 575eaf77..0eb5523a 100755 --- a/example/NebulaPoolExample.py +++ b/example/NebulaPoolExample.py @@ -2,17 +2,8 @@ # -*- coding: utf-8 -*- from typing import Optional, Dict -from nebulagraph_python.client.pool import NebulaPool +from nebulagraph_python.client.nebula_pool import NebulaPool, NebulaPoolConfig, SessionConfig from nebulagraph_python.data import HostAddress -from dataclasses import dataclass, field - -@dataclass -class SessionConfig: - schema: Optional[str] = None - graph: Optional[str] = None - timezone: Optional[str] = None - values: Dict[str, str] = field(default_factory=dict) - configs: Dict[str, str] = field(default_factory=dict) graph_name = "test_graph" diff --git a/example_async_usage.py b/example_async_usage.py new file mode 100644 index 00000000..9d4bbece --- /dev/null +++ b/example_async_usage.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +""" +Example demonstrating the usage of AsyncNebulaClient + +This example shows how to use the async interface for NebulaGraph operations. +""" + +import asyncio +from nebulagraph_python.client import AsyncNebulaClient + + +async def main(): + """Main async function demonstrating AsyncNebulaClient usage""" + + # Method 1: Direct initialization + print("=== Method 1: Direct Initialization ===") + client = AsyncNebulaClient( + addresses="127.0.0.1:9669", + user_name="root", + password="nebula", + connect_timeout_ms=3000, + request_timeout_ms=30000, + ) + + try: + # Initialize the connection (this is async) + await client._init_client() + print(f"✓ Connected to: {client.get_host()}") + print(f"✓ Session ID: {client.get_session_id()}") + print(f"✓ Server version: {client.get_version()}") + + # Execute a query + result = await client.execute("SHOW HOSTS") + print(f"✓ Query executed successfully, rows: {result.row_size()}") + + # Execute with custom timeout + result = await client.execute_with_timeout("SHOW SPACES", 5000) + print(f"✓ Spaces query executed, rows: {result.row_size()}") + + # Ping the server + is_alive = await client.ping() + print(f"✓ Server is alive: {is_alive}") + + except Exception as e: + print(f"✗ Error: {e}") + finally: + await client.close() + print("✓ Client closed\n") + + # Method 2: Direct initialization with TLS disabled + print("=== Method 2: Direct Initialization with TLS ===") + client = AsyncNebulaClient( + addresses="127.0.0.1:9669", + user_name="root", + password="nebula", + connect_timeout_ms=3000, + request_timeout_ms=30000, + enable_tls=False + ) + + try: + await client._init_client() + print(f"✓ Connected to: {client.get_host()}") + + # Execute multiple queries concurrently + queries = [ + "SHOW HOSTS", + "SHOW SPACES", + "SHOW TAGS", + ] + + results = await asyncio.gather(*[client.execute(q) for q in queries]) + for i, (query, result) in enumerate(zip(queries, results), 1): + print(f"✓ Query {i} ({query}): {result.row_size()} rows") + + except Exception as e: + print(f"✗ Error: {e}") + finally: + await client.close() + print("✓ Client closed\n") + print("✓ Client closed\n") + + # Method 3: Async context manager pattern (you can create your own) + print("=== Method 3: Async Context Manager Pattern ===") + + async with AsyncNebulaClient( + addresses="127.0.0.1:9669", + user_name="root", + password="nebula" + ) as client: + await client._init_client() + print(f"✓ Connected with context manager: {client.get_host()}") + + # Simple query + result = await client.execute("SHOW HOSTS") + print(f"✓ Result: {result.row_size()} rows") + + +# Add async context manager support to AsyncNebulaClient +async def async_context_manager_example(): + """Example of creating an async context manager""" + + async def get_async_client(addresses, user_name, password): + """Factory function for creating and initializing an async client""" + client = AsyncNebulaClient( + addresses=addresses, + user_name=user_name, + password=password + ) + await client._init_client() + return client + + # You can also use this pattern: + async def with_async_client(addresses, user_name, password, callback): + """Execute a callback with an async client""" + client = AsyncNebulaClient( + addresses=addresses, + user_name=user_name, + password=password + ) + try: + await client._init_client() + return await callback(client) + finally: + await client.close() + + # Usage example + async def query_callback(client): + """Callback to execute queries""" + result = await client.execute("SHOW HOSTS") + return result.row_size() + + try: + row_count = await with_async_client( + "127.0.0.1:9669", + "root", + "nebula", + query_callback + ) + print(f"✓ Callback pattern: {row_count} rows") + except Exception as e: + print(f"✗ Callback pattern error (expected if server not running): {e}") + + +if __name__ == "__main__": + print("AsyncNebulaClient Usage Examples") + print("=" * 60) + print() + + try: + # Run the main example + asyncio.run(main()) + + # Run the context manager example + asyncio.run(async_context_manager_example()) + + except Exception as e: + print(f"\nNote: Make sure NebulaGraph server is running at 127.0.0.1:9669") + print(f"Error details: {e}") \ No newline at end of file diff --git a/pdm.lock b/pdm.lock index 5b4a0971..cbda8db5 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "all", "console", "full"] strategy = [] lock_version = "4.4" -content_hash = "sha256:8b25d18ccf0074597c9efe1596774a54ae221562ec27393cff85396a7252b648" +content_hash = "sha256:bb5c7603bb95487b552e303cd016f871b518cc1cc1d7c23839c26ec0876cf286" [[package]] name = "annotated-types" @@ -166,6 +166,14 @@ files = [ {file = "pluggy-1.6.0.tar.gz", hash = "sha256:7dcc130b76258d33b90f61b658791dede3486c3e6bfb003ee5c9bfb396dd22f3"}, ] +[[package]] +name = "pool" +version = "0.1.2dev" +summary = "general purposed connection pool for gevent, split from sqlalchemy" +files = [ + {file = "pool-0.1.2dev.tar.gz", hash = "sha256:99a3aefa842c5bb87c92a3bbd40cacab212f104268a8fc6e706a2df92002f785"}, +] + [[package]] name = "prompt-toolkit" version = "3.0.51" diff --git a/pyproject.toml b/pyproject.toml index 3f7a8431..f51f521d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,7 @@ dependencies = [ "anyio>=4.9.0", "minijinja>=2.12.0", "pytest>=9.0.2", + "pool>=0.1.2.dev0", ] requires-python = ">=3.10" license = {text = "Apache-2.0"} diff --git a/src/nebulagraph_python/__init__.py b/src/nebulagraph_python/__init__.py index fb2f7e9f..2fdf8fe0 100644 --- a/src/nebulagraph_python/__init__.py +++ b/src/nebulagraph_python/__init__.py @@ -14,27 +14,23 @@ from .client import ( ConnectionConfig, - NebulaAsyncClient, NebulaBaseAsyncExecutor, NebulaBaseExecutor, NebulaClient, NebulaPool, - SessionConfig, - SessionPoolConfig, + NebulaPoolConfig, unwrap_value, ) from .result_set import Record, ResultSet __all__ = [ "ConnectionConfig", - "NebulaAsyncClient", "NebulaBaseAsyncExecutor", "NebulaBaseExecutor", "NebulaClient", "NebulaPool", + "NebulaPoolConfig", "Record", "ResultSet", - "SessionConfig", - "SessionPoolConfig", "unwrap_value", ] diff --git a/src/nebulagraph_python/client/__init__.py b/src/nebulagraph_python/client/__init__.py index 5553f42a..966de877 100644 --- a/src/nebulagraph_python/client/__init__.py +++ b/src/nebulagraph_python/client/__init__.py @@ -13,26 +13,27 @@ # limitations under the License. from nebulagraph_python.client._connection import ( + AsyncConnection, ConnectionConfig, + _parse_hosts, ) -from nebulagraph_python.client._session import SessionConfig -from nebulagraph_python.client._session_pool import SessionPoolConfig from nebulagraph_python.client.base_executor import ( NebulaBaseAsyncExecutor, NebulaBaseExecutor, unwrap_value, ) -from nebulagraph_python.client.client import NebulaAsyncClient, NebulaClient -from nebulagraph_python.client.pool import NebulaPool +from nebulagraph_python.client.nebula_client import AsyncNebulaClient, NebulaClient +from nebulagraph_python.client.nebula_pool import NebulaPool, NebulaPoolConfig __all__ = [ + "AsyncConnection", + "AsyncNebulaClient", "ConnectionConfig", - "NebulaAsyncClient", "NebulaBaseAsyncExecutor", "NebulaBaseExecutor", "NebulaClient", "NebulaPool", - "SessionConfig", - "SessionPoolConfig", + "NebulaPoolConfig", "unwrap_value", + "_parse_hosts", ] diff --git a/src/nebulagraph_python/client/_connection.py b/src/nebulagraph_python/client/_connection.py index ced1580f..f1525070 100644 --- a/src/nebulagraph_python/client/_connection.py +++ b/src/nebulagraph_python/client/_connection.py @@ -12,57 +12,45 @@ # See the License for the specific language governing permissions and # limitations under the License. +"""Connection classes matching Java implementation""" + +import asyncio import json +import logging +from abc import ABC, abstractmethod from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Literal, Optional, Union -import anyio import grpc import grpc.aio +from nebulagraph_python.proto import ( + common_pb2, + graph_pb2, + graph_pb2_grpc, +) -from nebulagraph_python.client import constants -from nebulagraph_python.client.logger import logger +from nebulagraph_python.client.auth_result import AuthResult +from nebulagraph_python.client.constants import DEFAULT_CONNECT_TIMEOUT_MS, DEFAULT_REQUEST_TIMEOUT_MS from nebulagraph_python.data import HostAddress, SSLParam from nebulagraph_python.error import ( AuthenticatingError, - ConnectingError, ErrorCode, ExecutingError, - InternalError, - NebulaGraphRemoteError, ) -from nebulagraph_python.proto import ( - common_pb2, - graph_pb2, - graph_pb2_grpc, -) -from nebulagraph_python.result_set import ResultSet if TYPE_CHECKING: - from nebulagraph_python.client._session import SessionConfig + from nebulagraph_python.client.nebula_client import NebulaClient - -def _parse_hosts(hosts: Union[str, List[str], List[HostAddress]]) -> List[HostAddress]: - """Convert various host formats to list of HostAddress objects""" - if isinstance(hosts, str): - hosts = hosts.split(",") - - addresses = [] - for host in hosts: - if isinstance(host, HostAddress): - addresses.append(host) - else: - addr, port = host.split(":") - addresses.append(HostAddress(addr, int(port))) - return addresses +logger = logging.getLogger(__name__) @dataclass class ConnectionConfig: + """Configuration for connections (backward compatibility)""" hosts: List[HostAddress] = field(default_factory=list) ssl_param: Optional[SSLParam] = None - connect_timeout: Optional[float] = constants.DEFAULT_CONNECT_TIMEOUT - request_timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT + connect_timeout: Optional[float] = 3.0 + request_timeout: Optional[float] = 60.0 ping_before_execute: bool = False @classmethod @@ -70,8 +58,8 @@ def from_defaults( cls, hosts: Union[str, List[str], List[HostAddress]], ssl_param: Union[SSLParam, Literal[True], None] = None, - connect_timeout: Optional[float] = constants.DEFAULT_CONNECT_TIMEOUT, - request_timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, + connect_timeout: Optional[float] = 3.0, + request_timeout: Optional[float] = 60.0, ): if ssl_param is True: ssl_param = SSLParam() @@ -87,432 +75,338 @@ def __post_init__(self): raise ValueError("hosts cannot be empty") -@dataclass -class Connection: - """Represents a connection to a NebulaGraph server. It is built upon grpc and is thread-safe. - - Required to explicitly call `close()` to release all resources. - """ +def _parse_hosts(hosts: Union[str, List[str], List[HostAddress]]) -> List[HostAddress]: + """Convert various host formats to list of HostAddress objects (backward compatibility)""" + if isinstance(hosts, str): + hosts = hosts.split(",") - # Config - config: ConnectionConfig - # Track which host was successfully connected for session routing - connected: HostAddress | None = field(default=None, init=False) + addresses = [] + for host in hosts: + if isinstance(host, HostAddress): + addresses.append(host) + else: + addr, port = host.split(":") + addresses.append(HostAddress(addr, int(port))) + return addresses - # Owned Resources - _stub: Optional[graph_pb2_grpc.GraphServiceStub] = field(default=None, init=False) - _channel: Optional[grpc.Channel] = field(default=None, init=False) - def __post_init__(self): - self.connect() - - def connect(self): - """Establish connection to NebulaGraph""" - last_error: Optional[Exception] = None +class Connection(ABC): + """Abstract base class for connections, matching Java Connection""" + + server_addr: Optional[HostAddress] = None + + def get_server_address(self) -> Optional[HostAddress]: + """Get the server address""" + return self.server_addr + + @abstractmethod + def open(self, address: HostAddress, builder: "NebulaClient.Builder") -> None: + """Open connection to the server""" + pass + + @abstractmethod + def close(self) -> None: + """Close the connection""" + pass + + @abstractmethod + def ping(self, session_id: int, timeout_ms: int) -> bool: + """Ping the server""" + pass + + +class GrpcConnection(Connection): + """gRPC connection implementation, matching Java GrpcConnection""" + + def __init__(self): + self.channel: Optional[grpc.Channel] = None + self.stub: Optional[graph_pb2_grpc.GraphServiceStub] = None + self.connect_timeout: int = 0 + self.request_timeout: int = 0 + + def open(self, address: HostAddress, client: "NebulaClient") -> None: + """Open gRPC connection to the server""" + self.server_addr = address + self.connect_timeout = client.connect_timeout_mills + self.request_timeout = client.request_timeout_mills + + formatted_host = address.host + if ":" in formatted_host and not formatted_host.startswith("["): + formatted_host = f"[{formatted_host}]" + + channel_options = [ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ] + + if client.enable_tls: + ssl_param = client.ssl_param + if ssl_param is None: + ssl_param = SSLParam() + + self.channel = grpc.secure_channel( + f"{formatted_host}:{address.port}", + credentials=grpc.ssl_channel_credentials( + root_certificates=ssl_param.ca_crt, + private_key=ssl_param.private_key, + certificate_chain=ssl_param.cert, + ), + options=channel_options, + ) + else: + self.channel = grpc.insecure_channel( + f"{formatted_host}:{address.port}", + options=channel_options, + ) - # Try each address until one succeeds - for host_addr in self.config.hosts: + # Wait for channel to be ready + if self.connect_timeout > 0: try: - channel_options = [ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ("grpc.enable_deadline_checking", 1), - ] - - if self.config.ssl_param: - self._channel = grpc.secure_channel( - f"{host_addr.host}:{host_addr.port}", - credentials=grpc.ssl_channel_credentials( - root_certificates=self.config.ssl_param.ca_crt, - private_key=self.config.ssl_param.private_key, - certificate_chain=self.config.ssl_param.cert, - ), - options=channel_options, - ) - else: - self._channel = grpc.insecure_channel( - f"{host_addr.host}:{host_addr.port}", - options=channel_options, - ) - - # Wait for channel to be ready with timeout - if self.config.connect_timeout is not None: - try: - grpc.channel_ready_future(self._channel).result( - timeout=self.config.connect_timeout - ) - except grpc.FutureTimeoutError as e: - raise ConnectingError( - f"Connection timeout after {self.config.connect_timeout} seconds to {host_addr.host}:{host_addr.port}" - ) from e - else: - grpc.channel_ready_future( - self._channel - ).result() # Wait indefinitely if no timeout - - self._stub = graph_pb2_grpc.GraphServiceStub(self._channel) - logger.info( - f"Successfully connected to {host_addr.host}:{host_addr.port}." + grpc.channel_ready_future(self.channel).result( + timeout=self.connect_timeout / 1000.0 ) - # Remember which host we actually connected to - self.connected = host_addr - return - except Exception as e: - logger.warning( - f"Failed to connect to {(host_addr.host, host_addr.port) if host_addr else 'No Available Addr'}: {e}", + except grpc.FutureTimeoutError: + raise ExecutingError( + f"Connection timeout after {self.connect_timeout}ms to {address}" ) - last_error = e - self.close() - else: - return - - # If we get here, all connection attempts failed - raise ConnectingError( - f"Failed to connect to any of the provided hosts. Last error: {last_error}", - ) - - def close(self): - """Close the connection. No Exception will be raised but an error will be logged.""" - try: - if self._channel: - self._channel.close() - self._channel = None - self._stub = None - self.connected = None - except Exception: - logger.exception("Failed to close connection") - - def reconnect(self): - self.close() - self.connect() - - def ping(self) -> bool: - """Ping the connection to check if it's healthy. - - Returns: - True if the connection is healthy, False otherwise. - """ - if not self._stub: - return False - try: - request = graph_pb2.ExecuteRequest( - session_id=-1, - stmt="RETURN 1".encode("utf-8"), - ) - _response = self._stub.Execute(request, timeout=self.config.connect_timeout) - return True - except Exception: - return False - - def execute( - self, - session_id: int, - statement: str, - *, - timeout: Optional[float] = None, - do_ping: bool = False, - ) -> ResultSet: - # Retry connection if ping fails for only one time - if (self.config.ping_before_execute or do_ping) and not self.ping(): - self.close() - self.connect() - if not self._stub: - raise InternalError("Connection not established") - logger.debug(f"Executing in Hosts: {self.config.hosts}") - logger.debug(f"Executing statement: {statement}") - - try: - request = graph_pb2.ExecuteRequest( - session_id=session_id, - stmt=statement.encode("utf-8"), - ) - logger.debug(f"Request: {request}") - # Use request_timeout as default if timeout is not specified - effective_timeout = ( - timeout if timeout is not None else self.config.request_timeout - ) - response = self._stub.Execute(request, timeout=effective_timeout) - logger.debug(f"Response: {response}") - except grpc.RpcError as e: - logger.error(f"RPC error during execute: {e.code()} {e.details()}") - raise ExecutingError(f"RPC error: {e.details()}") from e - except Exception as e: - logger.error(f"Unexpected error during execute: {e}") - raise ExecutingError("Unexpected error during execute") from e - - return ResultSet(response) + self.stub = graph_pb2_grpc.GraphServiceStub(self.channel) + + def close(self) -> None: + """Close the gRPC connection""" + if self.channel is not None: + self.channel.close() + self.channel = None + self.stub = None + + def ping(self, session_id: int, timeout_ms: int) -> bool: + """Ping the server""" + response = self.execute(session_id, "RETURN 1", timeout_ms) + return ( + response.status.code == b"00000" + if hasattr(response, "status") + else True + ) def authenticate( - self, - username: str, - password: Optional[str] = None, - *, - auth_options: Optional[Dict[str, Any]] = None, - session_config: Optional["SessionConfig"] = None, - ) -> int: - """Authenticate with NebulaGraph and return session ID. May raise Exception when authentication failed.""" - from nebulagraph_python.client._session import SessionConfig, init_session - - if not self._stub: - raise InternalError("Connection not established") - - _auth_options = auth_options or {} - _session_config = session_config or SessionConfig() + self, user: str, auth_options: Dict[str, object] + ) -> AuthResult: + """Authenticate with the server""" + if self.stub is None: + raise ExecutingError("Connection not established") client_info = common_pb2.ClientInfo( lang=common_pb2.ClientInfo.PYTHON, protocol_version=b"5.0.0", ) - auth_info_dict = ( - {"password": password, **_auth_options} if password else _auth_options - ) - auth_info_bytes = json.dumps(auth_info_dict).encode("utf-8") + user_bytes = user.encode("utf-8") if user else b"" + auth_info_bytes = json.dumps(auth_options).encode("utf-8") request = graph_pb2.AuthRequest( - username=username.encode("utf-8"), + username=user_bytes, auth_info=auth_info_bytes, client_info=client_info, ) try: - response = self._stub.Authenticate( - request, timeout=self.config.request_timeout + response = self.stub.Authenticate( + request, timeout=self.connect_timeout / 1000.0 ) except grpc.RpcError as e: - logger.error(f"RPC error during authenticate: {e.code()} {e.details()}") - raise AuthenticatingError( - f"RPC error during authentication: {e.details()}" - ) from e - except Exception as e: - logger.error(f"Unexpected error during authenticate: {e}") - raise AuthenticatingError("Unexpected error during authentication") from e + self.close() + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + raise AuthenticatingError( + f"authenticate to {self.server_addr} timeout after {self.connect_timeout}ms" + ) + raise AuthenticatingError(f"RPC error: {e.details()}") if response.status.code != b"00000": - raise NebulaGraphRemoteError( - code=ErrorCode(response.status.code.decode("utf-8")), - message=response.status.message.decode("utf-8"), + self.close() + raise AuthenticatingError( + response.status.message.decode("utf-8") ) - # Initialize session and return session ID - init_session(self, int(response.session_id), _session_config) - return int(response.session_id) + return AuthResult( + session_id=int(response.session_id), + version=response.version.decode("utf-8"), + ) + def execute( + self, session_id: int, stmt: str, timeout: int + ) -> graph_pb2.ExecuteResponse: + """Execute a statement""" + if stmt is None: + raise ValueError("statement is null") -@dataclass -class AsyncConnection: - """Represents a connection to a NebulaGraph server. It is built upon grpc.aio and is async/coroutine-level safe but not thread-safe. + if self.stub is None: + raise ExecutingError("Connection not established") + + request = graph_pb2.ExecuteRequest( + session_id=session_id, stmt=stmt.encode("utf-8") + ) - Required to explicitly call `close()` to release all resources. - """ + try: + return self.stub.Execute(request, timeout=timeout / 1000.0) + except grpc.RpcError as e: + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + raise ExecutingError( + f"request to {self.server_addr} timeout after {timeout}ms" + ) + raise ExecutingError(f"RPC error: {e.details()}") - config: ConnectionConfig - connected: HostAddress | None = None - _stub: Optional[graph_pb2_grpc.GraphServiceStub] = field(default=None, init=False) - _channel: Optional[grpc.aio.Channel] = field( - default=None, init=False - ) # Use grpc.aio.Channel + def execute_default_timeout( + self, session_id: int, stmt: str + ) -> graph_pb2.ExecuteResponse: + """Execute a statement with default timeout""" + return self.execute(session_id, stmt, self.request_timeout) - # Note: __post_init__ cannot be async. - # An async factory method (e.g., AsyncConnection.create(...)) or an explicit await self.connect() - # after __init__ would be needed. For now, connect will be called separately. - async def connect(self): - last_error: Optional[Exception] = None +class AsyncConnection: + """Async gRPC connection for backward compatibility""" + + def __init__(self, config): + self.config = config + self.server_addr: Optional[HostAddress] = None + self._stub: Optional[graph_pb2_grpc.GraphServiceStub] = None + self._channel: Optional[grpc.aio.Channel] = None + self.connect_timeout: float = config.connect_timeout + self.request_timeout: float = config.request_timeout + + async def connect(self, address: HostAddress) -> None: + """Open async gRPC connection to the server""" + self.server_addr = address + + formatted_host = address.host + if ":" in formatted_host and not formatted_host.startswith("["): + formatted_host = f"[{formatted_host}]" + + channel_options = [ + ("grpc.max_send_message_length", -1), + ("grpc.max_receive_message_length", -1), + ] + + if self.config.ssl_param: + ssl_param = self.config.ssl_param + self._channel = grpc.aio.secure_channel( + f"{formatted_host}:{address.port}", + credentials=grpc.ssl_channel_credentials( + root_certificates=ssl_param.ca_crt, + private_key=ssl_param.private_key, + certificate_chain=ssl_param.cert, + ), + options=channel_options, + ) + else: + self._channel = grpc.aio.insecure_channel( + f"{formatted_host}:{address.port}", + options=channel_options, + ) - for host_addr in self.config.hosts: + # Wait for channel to be ready + if self.connect_timeout > 0: try: - channel_options = [ - ("grpc.max_send_message_length", -1), - ("grpc.max_receive_message_length", -1), - ("grpc.enable_deadline_checking", 1), # Deadline checking is good - ] - - if self.config.ssl_param: - self._channel = grpc.aio.secure_channel( - f"{host_addr.host}:{host_addr.port}", - credentials=grpc.ssl_channel_credentials( - root_certificates=self.config.ssl_param.ca_crt, - private_key=self.config.ssl_param.private_key, - certificate_chain=self.config.ssl_param.cert, - ), - options=channel_options, - ) - else: - self._channel = grpc.aio.insecure_channel( - f"{host_addr.host}:{host_addr.port}", - options=channel_options, - ) - - if self.config.connect_timeout is not None: - try: - with anyio.fail_after( - self.config.connect_timeout, - ): - await self._channel.channel_ready() - except TimeoutError as e: - raise ConnectingError( - f"Connection timeout after {self.config.connect_timeout} seconds to {host_addr.host}:{host_addr.port}" - ) from e - else: - await ( - self._channel.channel_ready() - ) # Wait indefinitely if no timeout - - self._stub = graph_pb2_grpc.GraphServiceStub(self._channel) - logger.info( - f"Successfully connected to {host_addr.host}:{host_addr.port} asynchronously." + await asyncio.wait_for( + self._channel.channel_ready(), + timeout=self.connect_timeout ) - self.connected = host_addr - return - except Exception as e: - logger.warning( - f"Failed to connect asynchronously to {(host_addr.host, host_addr.port) if host_addr else 'No Available Addr'}: {e}", + except asyncio.TimeoutError: + await self.close() + raise ExecutingError( + f"Connection timeout after {self.connect_timeout}s to {address}" ) - last_error = e - if self._channel: # Ensure channel is closed on partial failure before trying next host - await self._channel.close() - self._channel = None - self._stub = None # Also clear stub - - # If we get here, all connection attempts failed - raise ConnectingError( - f"Failed to connect asynchronously to any of the provided hosts. Last error: {last_error}", - ) - async def close(self): - try: - if self._channel: - await self._channel.close() - self._channel = None - self._stub = None - self.connected = None - except BaseException: - logger.exception("Failed to close async connection") - - async def reconnect(self): - await self.close() - await self.connect() - - async def ping(self) -> bool: - """Ping the connection to check if it's healthy. - - Returns: - True if the connection is healthy, False otherwise. - """ - if not self._stub: - return False - try: - request = graph_pb2.ExecuteRequest( - session_id=-1, - stmt="RETURN 1".encode("utf-8"), - ) - _response = await self._stub.Execute( - request, timeout=self.config.connect_timeout - ) - return True - except Exception: - return False + self._stub = graph_pb2_grpc.GraphServiceStub(self._channel) + + async def close(self) -> None: + """Close the async gRPC connection""" + if self._channel is not None: + await self._channel.close() + self._channel = None + self._stub = None + + async def ping(self, session_id: int, timeout_ms: int) -> bool: + """Ping the server""" + response = await self.execute(session_id, "RETURN 1", timeout_ms) + return ( + response.status.code == b"00000" + if hasattr(response, "status") + else True + ) async def execute( - self, - session_id: int, - statement: str, - *, - timeout: Optional[float] = None, - do_ping: bool = False, - ) -> ResultSet: - # Retry connection if ping fails for only one time - if (self.config.ping_before_execute or do_ping) and not await self.ping(): - await self.close() - await self.connect() - if not self._stub: - raise InternalError("Async connection not established or stub is missing.") + self, session_id: int, stmt: str, timeout: int + ) -> graph_pb2.ExecuteResponse: + """Execute a statement""" + if stmt is None: + raise ValueError("statement is null") + + if self._stub is None: + raise ExecutingError("Connection not established") - logger.debug(f"Executing in Hosts: {self.config.hosts}") - logger.debug(f"Async executing statement: {statement}") + request = graph_pb2.ExecuteRequest( + session_id=session_id, stmt=stmt.encode("utf-8") + ) try: - request = graph_pb2.ExecuteRequest( - session_id=session_id, - stmt=statement.encode("utf-8"), + return await self._stub.Execute( + request, + timeout=timeout / 1000.0 ) - logger.debug(f"Async request: {request}") - effective_timeout = ( - timeout if timeout is not None else self.config.request_timeout - ) - # The stub call itself is now awaitable - response = await self._stub.Execute(request, timeout=effective_timeout) # type: ignore - logger.debug(f"Async response: {response}") - except grpc.aio.AioRpcError as e: # Catch async gRPC errors - # TODO: Map to specific Nebula errors like ExecutingError, AuthenticatingError - logger.error(f"Async RPC error during execute: {e.code()} {e.details()}") - raise ExecutingError(f"RPC error: {e.details()}") from e - except Exception as e: - logger.error(f"Unexpected error during async execute: {e}") - raise ExecutingError("Unexpected error during async execute") from e - - return ResultSet(response) # ResultSet creation should be the same - - async def authenticate( - self, - username: str, - password: Optional[str] = None, - *, - auth_options: Optional[Dict[str, Any]] = None, - session_config: Optional["SessionConfig"] = None, # Re-use SessionConfig - ) -> int: - from nebulagraph_python.client._session import SessionConfig, ainit_session + except grpc.aio.AioRpcError as e: + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + raise ExecutingError( + f"request to {self.server_addr} timeout after {timeout}ms" + ) + raise ExecutingError(f"RPC error: {e.details()}") - if not self._stub: - raise InternalError("Async connection not established or stub is missing.") + async def execute_default_timeout( + self, session_id: int, stmt: str + ) -> graph_pb2.ExecuteResponse: + """Execute a statement with default timeout""" + return await self.execute(session_id, stmt, int(self.request_timeout * 1000)) - _auth_options = auth_options or {} - _session_config = session_config or SessionConfig() + async def authenticate( + self, user: str, auth_options: Dict[str, object] + ) -> AuthResult: + """Authenticate with the server""" + if self._stub is None: + raise ExecutingError("Connection not established") client_info = common_pb2.ClientInfo( lang=common_pb2.ClientInfo.PYTHON, - protocol_version=b"5.0.0", # Ensure this is up-to-date or configurable + protocol_version=b"5.0.0", ) - auth_info_dict = ( - {"password": password, **_auth_options} if password else _auth_options - ) - auth_info_bytes = json.dumps(auth_info_dict).encode("utf-8") + user_bytes = user.encode("utf-8") if user else b"" + auth_info_bytes = json.dumps(auth_options).encode("utf-8") request = graph_pb2.AuthRequest( - username=username.encode("utf-8"), + username=user_bytes, auth_info=auth_info_bytes, client_info=client_info, ) try: - # Use request_timeout as default if timeout is not specified for authenticate response = await self._stub.Authenticate( - request, timeout=self.config.request_timeout + request, + timeout=self.connect_timeout ) except grpc.aio.AioRpcError as e: - logger.error( - f"Async RPC error during authenticate: {e.code()} {e.details()}" - ) - raise AuthenticatingError( - f"RPC error during authentication: {e.details()}" - ) from e - except Exception as e: # Catch other potential errors - logger.error(f"Unexpected error during async authenticate: {e}") - raise AuthenticatingError( - "Unexpected error during async authentication" - ) from e + await self.close() + if e.code() == grpc.StatusCode.DEADLINE_EXCEEDED: + raise AuthenticatingError( + f"authenticate to {self.server_addr} timeout after {self.connect_timeout}s" + ) + raise AuthenticatingError(f"RPC error: {e.details()}") if response.status.code != b"00000": - raise NebulaGraphRemoteError( - code=ErrorCode(response.status.code.decode("utf-8")), - message=response.status.message.decode("utf-8"), + await self.close() + raise AuthenticatingError( + response.status.message.decode("utf-8") ) - # Create and return an AsyncSession instance - # The AsyncSession class will need to be defined. - # For now, let's create it and initialize its async parts (like setting session params). - await ainit_session(self, int(response.session_id), _session_config) - return int(response.session_id) + return AuthResult( + session_id=int(response.session_id), + version=response.version.decode("utf-8"), + ) \ No newline at end of file diff --git a/src/nebulagraph_python/client/_connection_pool.py b/src/nebulagraph_python/client/_connection_pool.py deleted file mode 100644 index 290f4f52..00000000 --- a/src/nebulagraph_python/client/_connection_pool.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright 2025 vesoft-inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import copy -import logging -import threading -from dataclasses import dataclass, field - -from anyio import Lock - -from nebulagraph_python.client._connection import ( - AsyncConnection, - Connection, - ConnectionConfig, -) -from nebulagraph_python.data import HostAddress -from nebulagraph_python.error import PoolError - -logger = logging.getLogger(__name__) - - -@dataclass -class AsyncConnectionPool: - """Manages a pool of AsyncConnections with one connection per address. - - Uses round-robin strategy for getting connections. - If ping is enabled in ConnectionConfig and fails, the connection will try to reconnect once. - If still fails, the next address will be tried, until all addresses are tried and failed. - - This pool is async/coroutine-level safe but not thread-safe. - """ - - conn_conf: ConnectionConfig - - _connections: dict[HostAddress, AsyncConnection] = field( - default_factory=dict, init=False - ) - _current_index: int = field(default=0, init=False) - _lock: Lock = field(default_factory=Lock, init=False) - - @property - def addresses(self) -> list[HostAddress]: - return self.conn_conf.hosts - - @property - def current_address(self) -> HostAddress: - return self.addresses[self._current_index] - - async def next_address(self) -> HostAddress: - async with self._lock: - self._current_index = (self._current_index + 1) % len(self.addresses) - return self.addresses[self._current_index] - - def __post_init__(self): - """Create a new connection for the specified host address, without connecting.""" - # Create a config with only this host - for host_addr in self.addresses: - copied_conf = copy.copy(self.conn_conf) - copied_conf.hosts = [host_addr] - copied_conf.ping_before_execute = ( - False # Because ping will be done when borrowing a connection - ) - conn = AsyncConnection(copied_conf) - self._connections[host_addr] = conn - - async def connect(self): - for conn in self._connections.values(): - await conn.connect() - - async def get_connection(self, host_addr: HostAddress) -> AsyncConnection | None: - conn = self._connections[host_addr] - if self.conn_conf.ping_before_execute and not await conn.ping(): - try: - await conn.reconnect() - except Exception: - logger.exception("Error reconnecting to server %s", host_addr) - return None - return self._connections[host_addr] - - async def next_connection(self) -> tuple[HostAddress, AsyncConnection]: - for _ in range(len(self.addresses)): - host_addr = await self.next_address() - conn = await self.get_connection(host_addr) - if conn is not None: - return host_addr, conn - else: - continue - raise PoolError("No connection available in the pool") - - async def close(self): - """Close all connections in the pool.""" - for conn in self._connections.values(): - await conn.close() - self._connections.clear() - - -@dataclass -class ConnectionPool: - """Manages a pool of Connections with one connection per address. - - Uses round-robin strategy for getting connections. - If ping is enabled in ConnectionConfig and fails, the connection will try to reconnect once. - If still fails, the next address will be tried, until all addresses are tried and failed. - - This pool is thread-safe. - """ - - conn_conf: ConnectionConfig - - _connections: dict[HostAddress, Connection] = field( - default_factory=dict, init=False - ) - _current_index: int = field(default=0, init=False) - _lock: threading.Lock = field(default_factory=threading.Lock, init=False) - - @property - def addresses(self) -> list[HostAddress]: - return self.conn_conf.hosts - - @property - def current_address(self) -> HostAddress: - return self.addresses[self._current_index] - - def next_address(self) -> HostAddress: - with self._lock: - self._current_index = (self._current_index + 1) % len(self.addresses) - return self.addresses[self._current_index] - - def __post_init__(self): - """Create a new connection for the specified host address, without connecting.""" - # Create a config with only this host - for host_addr in self.addresses: - copied_conf = copy.copy(self.conn_conf) - copied_conf.hosts = [host_addr] - copied_conf.ping_before_execute = ( - False # Because ping will be done when borrowing a connection - ) - conn = Connection(copied_conf) - self._connections[host_addr] = conn - - def connect(self): - for conn in self._connections.values(): - conn.connect() - - def get_connection(self, host_addr: HostAddress) -> Connection | None: - conn = self._connections[host_addr] - if self.conn_conf.ping_before_execute and not conn.ping(): - try: - conn.reconnect() - except Exception: - logger.exception("Error reconnecting to server %s", host_addr) - return None - return self._connections[host_addr] - - def next_connection(self) -> tuple[HostAddress, Connection]: - for _ in range(len(self.addresses)): - host_addr = self.next_address() - conn = self.get_connection(host_addr) - if conn is not None: - return host_addr, conn - else: - continue - raise PoolError("No connection available in the pool") - - def close(self): - """Close all connections in the pool.""" - for conn in self._connections.values(): - conn.close() - self._connections.clear() diff --git a/src/nebulagraph_python/client/_session.py b/src/nebulagraph_python/client/_session.py deleted file mode 100644 index 83bd322b..00000000 --- a/src/nebulagraph_python/client/_session.py +++ /dev/null @@ -1,176 +0,0 @@ -# Copyright 2025 vesoft-inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Dict, Optional - -if TYPE_CHECKING: - from nebulagraph_python.client._connection import AsyncConnection, Connection - -import uuid - -from nebulagraph_python._error_code import ErrorCode -from nebulagraph_python.client.logger import logger -from nebulagraph_python.error import ExecutingError - - -@dataclass(kw_only=True, frozen=True) -class SessionConfig: - schema: Optional[str] = None - graph: Optional[str] = None - timezone: Optional[str] = None - values: Dict[str, str] = field(default_factory=dict) - configs: Dict[str, str] = field(default_factory=dict) - - -@dataclass(kw_only=True) -class SessionBase: - username: str - password: str | None - session_config: SessionConfig | None - auth_options: Dict[str, str] | None - - _session: int = -1 - _hash: int = field(default_factory=lambda: uuid.uuid4().int) - - -@dataclass -class Session(SessionBase): - _conn: "Connection" - - def execute( - self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False - ): - res = self._conn.execute( - self._session, statement, timeout=timeout, do_ping=do_ping - ) - # Retry for only one time - if res.status_code == ErrorCode.SESSION_NOT_FOUND.value: - self._session = self._conn.authenticate( - self.username, - self.password, - session_config=self.session_config, - auth_options=self.auth_options, - ) - res = self._conn.execute( - self._session, statement, timeout=timeout, do_ping=do_ping - ) - res.raise_on_error() - return res - - def _close(self): - """Close session""" - try: - self._conn.execute(self._session, "SESSION CLOSE") - except Exception: - logger.exception("Failed to close session") - - def __hash__(self): - return self._hash - - def __eq__(self, other): - return self._hash == other._hash - - -@dataclass -class AsyncSession(SessionBase): - _conn: "AsyncConnection" - - async def execute( - self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False - ): - res = await self._conn.execute( - self._session, statement, timeout=timeout, do_ping=do_ping - ) - # Retry for only one time - if res.status_code == ErrorCode.SESSION_NOT_FOUND.value: - self._session = await self._conn.authenticate( - self.username, - self.password, - session_config=self.session_config, - auth_options=self.auth_options, - ) - res = await self._conn.execute( - self._session, statement, timeout=timeout, do_ping=do_ping - ) - res.raise_on_error() - return res - - async def _close(self): - try: - await self._conn.execute(self._session, "SESSION CLOSE") - except Exception: - logger.exception("Failed to close async session") - - def __hash__(self): - return self._hash - - def __eq__(self, other): - return self._hash == other._hash - - -async def ainit_session(conn: "AsyncConnection", sid: int, config: SessionConfig): - # All execute calls here must be awaited and then checked - try: - if config.schema is not None: - result = await conn.execute(sid, f"SESSION SET SCHEMA `{config.schema}`") - result.raise_on_error() - if config.graph is not None: - result = await conn.execute(sid, f"SESSION SET GRAPH `{config.graph}`") - result.raise_on_error() - if config.timezone is not None: - result = await conn.execute( - sid, f"SESSION SET TIME ZONE `{config.timezone}`" - ) - result.raise_on_error() - if config.values: - result = await conn.execute( - sid, - f"SESSION SET VALUE {','.join(f'${k_}={v_}' for k_, v_ in config.values.items())}", - ) - result.raise_on_error() - if config.configs: - for k, v in config.configs.items(): - result = await conn.execute(sid, f"SESSION SET {k}={v}") - result.raise_on_error() - except ExecutingError as e: - logger.error(f"Error during async session post-init: {e}") - raise - - -def init_session(conn: "Connection", sid: int, config: SessionConfig): - """Initialize session with configuration settings""" - try: - if config.schema is not None: - result = conn.execute(sid, f"SESSION SET SCHEMA `{config.schema}`") - result.raise_on_error() - if config.graph is not None: - result = conn.execute(sid, f"SESSION SET GRAPH `{config.graph}`") - result.raise_on_error() - if config.timezone is not None: - result = conn.execute(sid, f"SESSION SET TIME ZONE `{config.timezone}`") - result.raise_on_error() - if config.values: - result = conn.execute( - sid, - f"SESSION SET VALUE {','.join(f'${k_}={v_}' for k_, v_ in config.values.items())}", - ) - result.raise_on_error() - if config.configs: - for k, v in config.configs.items(): - result = conn.execute(sid, f"SESSION SET {k}={v}") - result.raise_on_error() - except ExecutingError as e: - logger.error(f"Error during session post-init: {e}") - raise diff --git a/src/nebulagraph_python/client/_session_pool.py b/src/nebulagraph_python/client/_session_pool.py deleted file mode 100644 index 2a085a33..00000000 --- a/src/nebulagraph_python/client/_session_pool.py +++ /dev/null @@ -1,289 +0,0 @@ -# Copyright 2025 vesoft-inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import threading -from contextlib import asynccontextmanager, contextmanager -from dataclasses import dataclass, field -from typing import Any, Dict, Optional, Set - -from anyio import Lock, Semaphore, fail_after - -from nebulagraph_python.client._connection import AsyncConnection, Connection -from nebulagraph_python.client._session import ( - AsyncSession, - Session, - SessionConfig, -) -from nebulagraph_python.client.constants import ( - DEFAULT_SESSION_POOL_SIZE, - DEFAULT_SESSION_POOL_WAIT_TIMEOUT, -) -from nebulagraph_python.error import PoolError - -logger = logging.getLogger(__name__) - - -@dataclass -class SessionPoolConfig: - """Configuration for the SessionPool. - Args: - size: The number of sessions to be managed by the SessionPool. - wait_timeout: The maximum time to wait for a session to be available. If None, wait indefinitely. - """ - - size: int = field(default=DEFAULT_SESSION_POOL_SIZE) - wait_timeout: float | None = field(default=DEFAULT_SESSION_POOL_WAIT_TIMEOUT) - - def __post_init__(self): - if self.size <= 0: - raise ValueError( - f"SessionPoolConfig.size must be greater than 0, but got {self.size}" - ) - if self.wait_timeout is not None and self.wait_timeout <= 0: - self.wait_timeout = None - - -class AsyncSessionPool: - """Manage a pool of sessions. It is built upon anyio Lock and is async/coroutine-level safe but not thread-safe.""" - - free_sessions_queue: Set[AsyncSession] - busy_sessions_queue: Set[AsyncSession] - queue_lock: Lock - queue_count: Semaphore - config: SessionPoolConfig - - @classmethod - async def connect( - cls, - conn: AsyncConnection, - username: str, - password: Optional[str] = None, - auth_options: Optional[Dict[str, Any]] = None, - session_config: Optional[SessionConfig] = None, - pool_config: Optional[SessionPoolConfig] = None, - ): - pool_config = pool_config or SessionPoolConfig() - sessions: Set[AsyncSession] = set() - try: - for _ in range(pool_config.size): - sessions.add( - AsyncSession( - conn, - username=username, - password=password, - session_config=session_config, - auth_options=auth_options, - ) - ) - return cls(sessions, pool_config) - except Exception: - # Clean up any sessions that were successfully created - for session in sessions: - await session._close() - raise - - def __init__( - self, - sessions: Set[AsyncSession], - config: SessionPoolConfig, - ): - """Initialize the SessionPool - - Args: - sessions: The sessions to be managed by the SessionPool. - config: Configuration for the SessionPool. - """ - if len(sessions) != config.size: - raise ValueError( - f"The number of sessions ({len(sessions)}) does not match the size of the pool ({config.size})" - ) - self.free_sessions_queue = sessions - self.busy_sessions_queue = set() - self.queue_lock = Lock() - self.queue_count = Semaphore(len(sessions)) - self.config = config - - @asynccontextmanager - async def borrow(self): - got_session: Optional[AsyncSession] = None - - # Event-based loop (wait for free session to be available) - while True: - if self.config.wait_timeout is not None: - try: - with fail_after(self.config.wait_timeout): - await self.queue_count.acquire() - except TimeoutError: - break - else: - await self.queue_count.acquire() - async with self.queue_lock: - if not self.free_sessions_queue: - logger.error( - "No free sessions available after acquired semaphore, which indicates a bug in the AsyncSessionPool" - ) - # Release semaphore and retry if no sessions available - self.queue_count.release() - continue - session = self.free_sessions_queue.pop() - self.busy_sessions_queue.add(session) - got_session = session - break - - if got_session is None: - raise PoolError( - f"No session available in the SessionPool after waiting {self.config.wait_timeout} seconds" - ) - - try: - yield got_session - finally: - # Ensure session is returned to pool even if exception occurs - async with self.queue_lock: - if got_session in self.busy_sessions_queue: - self.free_sessions_queue.add(got_session) - self.busy_sessions_queue.remove(got_session) - self.queue_count.release() - - async def _close(self): - # Acquire all semaphore permits to prevent new borrows - for _ in range(self.config.size): - await self.queue_count.acquire() - async with self.queue_lock: - # Close all free sessions - for session in self.free_sessions_queue: - await session._close() - # Close all busy sessions (if any remain) - for session in self.busy_sessions_queue: - logger.error( - "Busy sessions remain after acquire all semaphore permits, which indicates a bug in the AsyncSessionPool" - ) - await session._close() - - -class SessionPool: - """Manage a pool of sessions. It is built upon threading.Lock and is thread-safe.""" - - free_sessions_queue: Set[Session] - busy_sessions_queue: Set[Session] - queue_lock: threading.Lock - queue_count: threading.Semaphore - config: SessionPoolConfig - - @classmethod - def connect( - cls, - conn: Connection, - username: str, - password: Optional[str] = None, - auth_options: Optional[Dict[str, Any]] = None, - session_config: Optional[SessionConfig] = None, - pool_config: Optional[SessionPoolConfig] = None, - ): - pool_config = pool_config or SessionPoolConfig() - sessions: Set[Session] = set() - try: - for _ in range(pool_config.size): - sessions.add( - Session( - conn, - username=username, - password=password, - session_config=session_config, - auth_options=auth_options, - ) - ) - return cls(sessions, pool_config) - except Exception: - # Clean up any sessions that were successfully created - for session in sessions: - session._close() - raise - - def __init__( - self, - sessions: Set[Session], - config: SessionPoolConfig, - ): - """Initialize the SessionPool - - Args: - sessions: The sessions to be managed by the SessionPool. - config: Configuration for the SessionPool. - """ - if len(sessions) != config.size: - raise ValueError( - f"The number of sessions ({len(sessions)}) does not match the size of the pool ({config.size})" - ) - self.free_sessions_queue = sessions - self.busy_sessions_queue = set() - self.queue_lock = threading.Lock() - self.queue_count = threading.Semaphore(len(sessions)) - self.config = config - - @contextmanager - def borrow(self): - got_session: Optional[Session] = None - - # Event-based loop (wait for free session to be available) - while True: - if self.config.wait_timeout is not None: - acquired = self.queue_count.acquire(timeout=self.config.wait_timeout) - if not acquired: - break - else: - self.queue_count.acquire() - with self.queue_lock: - if not self.free_sessions_queue: - logger.error( - "No free sessions available after acquired semaphore, which indicates a bug in the SessionPool" - ) - # Release semaphore and retry if no sessions available - self.queue_count.release() - continue - session = self.free_sessions_queue.pop() - self.busy_sessions_queue.add(session) - got_session = session - break - - if got_session is None: - raise PoolError( - f"No session available in the SessionPool after waiting {self.config.wait_timeout} seconds" - ) - - try: - yield got_session - finally: - # Ensure session is returned to pool even if exception occurs - with self.queue_lock: - if got_session in self.busy_sessions_queue: - self.free_sessions_queue.add(got_session) - self.busy_sessions_queue.remove(got_session) - self.queue_count.release() - - def _close(self): - # Acquire all semaphore permits to prevent new borrows - for _ in range(self.config.size): - self.queue_count.acquire() - with self.queue_lock: - # Close all free sessions - for session in self.free_sessions_queue: - session._close() - # Close all busy sessions (if any remain) - for session in self.busy_sessions_queue: - logger.error( - "Busy sessions remain after acquire all semaphore permits, which indicates a bug in the SessionPool" - ) - session._close() diff --git a/src/nebulagraph_python/client/auth_result.py b/src/nebulagraph_python/client/auth_result.py new file mode 100644 index 00000000..05929b8c --- /dev/null +++ b/src/nebulagraph_python/client/auth_result.py @@ -0,0 +1,34 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""AuthResult class matching Java implementation""" + + +from dataclasses import dataclass + + +@dataclass(frozen=True) +class AuthResult: + """Result of authentication, matching Java AuthResult class""" + + session_id: int + version: str + + def get_session_id(self) -> int: + """Get the session ID""" + return self.session_id + + def get_version(self) -> str: + """Get the server version""" + return self.version \ No newline at end of file diff --git a/src/nebulagraph_python/client/client.py b/src/nebulagraph_python/client/client.py deleted file mode 100644 index 4b6abc06..00000000 --- a/src/nebulagraph_python/client/client.py +++ /dev/null @@ -1,302 +0,0 @@ -# Copyright 2025 vesoft-inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -from collections.abc import AsyncGenerator, Generator -from contextlib import asynccontextmanager, contextmanager -from typing import Any, Dict, List, Literal, Optional, Union - -from nebulagraph_python.client._connection import ( - AsyncConnection, - Connection, - ConnectionConfig, -) -from nebulagraph_python.client._connection_pool import ( - AsyncConnectionPool, - ConnectionPool, -) -from nebulagraph_python.client._session import ( - AsyncSession, - Session, - SessionConfig, -) -from nebulagraph_python.client._session_pool import ( - AsyncSessionPool, - SessionPool, - SessionPoolConfig, -) -from nebulagraph_python.client.base_executor import ( - NebulaBaseAsyncExecutor, - NebulaBaseExecutor, -) -from nebulagraph_python.data import HostAddress, SSLParam -from nebulagraph_python.error import PoolError -from nebulagraph_python.result_set import ResultSet - -logger = logging.getLogger(__name__) - - -class NebulaAsyncClient(NebulaBaseAsyncExecutor): - """The async client for connecting to NebulaGraph. It is async/coroutine-level safe but not thread-safe, - which means you can not share the client instance across threads, - but you can call `await client.execute()` concurrently in async coroutines. - - Required to explicitly call `close()` to release all resources. - """ - - # Owned Resources - _conn: AsyncConnection | AsyncConnectionPool - _sessions: dict[HostAddress, AsyncSession | AsyncSessionPool] - - def __init__(*args, **kwargs): - raise RuntimeError( - "Using `await NebulaAsyncClient.connect()` to create a client instance." - ) - - @classmethod - async def connect( - cls, - hosts: Union[str, List[str], List[HostAddress]], - username: str, - password: Optional[str] = None, - *, - ssl_param: Union[SSLParam, Literal[True], None] = None, - auth_options: Optional[Dict[str, Any]] = None, - conn_config: Optional[ConnectionConfig] = None, - session_config: Optional[SessionConfig] = None, - session_pool_config: Optional[SessionPoolConfig] = None, - ): - """Connect to NebulaGraph and initialize the client - - Args: - ---- - hosts: Single host string ("hostname:port"), list of host strings, - or list of HostAddress objects - username: Username for authentication - password: Password for authentication - ssl_param: SSL configuration - auth_options: dict of authentication options - conn_config: Connection configuration. If provided, it overrides hosts and ssl_param. - session_config: Session configuration. - """ - self = super().__new__(cls) - conn_conf = conn_config or ConnectionConfig.from_defaults(hosts, ssl_param) - hosts = conn_conf.hosts - self._sessions = {} - if len(hosts) == 1: - self._conn = AsyncConnection(conn_conf) - await self._conn.connect() - else: - self._conn = AsyncConnectionPool(conn_conf) - await self._conn.connect() - try: - for host_addr in hosts: - conn = ( - await self._conn.get_connection(host_addr) - if isinstance(self._conn, AsyncConnectionPool) - else self._conn - ) - if conn is None: - raise PoolError( - f"Failed to get connection to {host_addr} when initializing NebulaAsyncClient" - ) - if session_pool_config: - self._sessions[host_addr] = await AsyncSessionPool.connect( - conn=conn, - username=username, - password=password, - auth_options=auth_options or {}, - session_config=session_config or SessionConfig(), - pool_config=session_pool_config, - ) - else: - self._sessions[host_addr] = AsyncSession( - _conn=conn, - username=username, - password=password, - session_config=session_config or SessionConfig(), - auth_options=auth_options or {}, - ) - except Exception as e: - await self._conn.close() - raise e - return self - - async def execute( - self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False - ) -> ResultSet: - async with self.borrow() as session: - return await session.execute(statement, timeout=timeout, do_ping=do_ping) - - @asynccontextmanager - async def borrow(self) -> AsyncGenerator[AsyncSession, None]: - if isinstance(self._conn, AsyncConnectionPool): - addr, conn = await self._conn.next_connection() - else: - conn = self._conn - addr = conn.connected - if addr is None: - raise ValueError("Connection not connected") - - _session = self._sessions[addr] - - if isinstance(_session, AsyncSessionPool): - async with _session.borrow() as session: - yield session - else: - yield _session - - async def close(self): - """Close the client connection and session. No Exception will be raised but an error will be logged.""" - for session in self._sessions.values(): - await session._close() - await self._conn.close() - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - await self.close() - - -class NebulaClient(NebulaBaseExecutor): - """The client for connecting to NebulaGraph. It is thread-safe, - which means you can share a client instance across threads and call `execute` concurrently. - - Required to explicitly call `close()` to release all resources. - """ - - # Owned Resources - _conn: Connection | ConnectionPool - _sessions: dict[HostAddress, Session | SessionPool] - - def __init__( - self, - hosts: Union[str, List[str], List[HostAddress]], - username: str, - password: Optional[str] = None, - *, - ssl_param: Union[SSLParam, Literal[True], None] = None, - auth_options: Optional[Dict[str, Any]] = None, - conn_config: Optional[ConnectionConfig] = None, - session_config: Optional[SessionConfig] = None, - session_pool_config: Optional[SessionPoolConfig] = None, - ): - """Initialize NebulaGraph client - - Args: - ---- - hosts: Single host string ("hostname:port"), list of host strings, - or list of HostAddress objects - username: Username for authentication - password: Password for authentication - ssl_param: SSL configuration - auth_options: dict of authentication options - conn_config: Connection configuration. If provided, it overrides hosts and ssl_param. - session_config: Session configuration. - session_pool_config: Session pool configuration. If provided, a session pool will be created. - """ - conn_conf = conn_config or ConnectionConfig.from_defaults(hosts, ssl_param) - hosts = conn_conf.hosts - self._sessions = {} - if len(hosts) == 1: - self._conn = Connection(conn_conf) - self._conn.connect() - else: - self._conn = ConnectionPool(conn_conf) - self._conn.connect() - try: - for host_addr in hosts: - conn = ( - self._conn.get_connection(host_addr) - if isinstance(self._conn, ConnectionPool) - else self._conn - ) - if conn is None: - raise PoolError( - f"Failed to get connection to {host_addr} when initializing NebulaClient" - ) - if session_pool_config: - self._sessions[host_addr] = SessionPool.connect( - conn=conn, - username=username, - password=password, - auth_options=auth_options or {}, - session_config=session_config or SessionConfig(), - pool_config=session_pool_config, - ) - else: - self._sessions[host_addr] = Session( - _conn=conn, - username=username, - password=password, - session_config=session_config or SessionConfig(), - auth_options=auth_options or {}, - ) - except Exception as e: - self._conn.close() - raise e - - def execute( - self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False - ) -> ResultSet: - """Execute a statement using a borrowed session, raising on errors.""" - with self.borrow() as session: - return session.execute(statement, timeout=timeout, do_ping=do_ping) - - @contextmanager - def borrow(self) -> Generator[Session, None, None]: - """Yield a session bound to the selected connection.""" - if isinstance(self._conn, ConnectionPool): - addr, conn = self._conn.next_connection() - else: - conn = self._conn - addr = conn.connected - if addr is None: - raise ValueError("Connection not connected") - - # Route to the correct session (pool or single session) - _session = self._sessions[addr] - - if isinstance(_session, SessionPool): - with _session.borrow() as session: - yield session - else: - yield _session - - def ping(self, timeout: Optional[float] = None) -> bool: - try: - res = ( - (self.execute(statement="RETURN 1", timeout=timeout)) - .one() - .as_primitive() - ) - if not res == {"1": 1}: - raise ValueError(f"Unexpected result from ping: {res}") - return True - except Exception: - logger.exception("Failed to ping NebulaGraph") - return False - - def close(self): - """Close the client connection and session. No Exception will be raised but an error will be logged.""" - for session in self._sessions.values(): - session._close() - self._conn.close() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, traceback): - self.close() diff --git a/src/nebulagraph_python/client/client_pool_factory.py b/src/nebulagraph_python/client/client_pool_factory.py new file mode 100644 index 00000000..b6fc7ebd --- /dev/null +++ b/src/nebulagraph_python/client/client_pool_factory.py @@ -0,0 +1,166 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ClientPoolFactory matching Java implementation""" + +import logging +import time +from typing import TYPE_CHECKING + +from nebulagraph_python.data import SSLParam +from nebulagraph_python.error import AuthenticatingError, ExecutingError + +if TYPE_CHECKING: + from nebulagraph_python.client.nebula_client import NebulaClient + from nebulagraph_python.client.nebula_pool import NebulaPool + from nebulagraph_python.client.round_robin_load_balancer import RoundRobinLoadBalancer + +logger = logging.getLogger(__name__) + + +class ClientPoolFactory: + """Factory for creating NebulaClient instances for the pool""" + + def __init__( + self, + load_balancer: "RoundRobinLoadBalancer", + builder: "NebulaPool.Builder", + ): + """Initialize the factory""" + self.load_balancer = load_balancer + self.builder = builder + + def create(self) -> "NebulaClient": + """Create a new NebulaClient instance""" + try_create = 0 + io_exception = None + auth_exception = None + + while try_create < self.load_balancer.address_size(): + try: + return self._create_client() + except ExecutingError as e: + io_exception = e + except AuthenticatingError as e: + auth_exception = e + try_create += 1 + + if auth_exception is not None: + raise auth_exception + if io_exception is not None: + raise io_exception + + raise ExecutingError( + "No servers host is available, please check your servers is up and network between client and server is connected." + ) + + def _create_client(self) -> "NebulaClient": + """Create and configure a NebulaClient""" + from nebulagraph_python.client.nebula_client import NebulaClient + + address = self.load_balancer.get_address() + + ssl_param = None + if self.builder.enable_tls: + ssl_param = SSLParam( + ca_crt=self.builder.tls_ca.encode() if self.builder.tls_ca else None, + private_key=self.builder.tls_key.encode() if self.builder.tls_key else None, + cert=self.builder.tls_cert.encode() if self.builder.tls_cert else None, + ) + + client = NebulaClient( + f"{address.host}:{address.port}", + self.builder.user_name, + auth_options=self.builder.auth_options, + connect_timeout_ms=self.builder.connect_timeout_mills, + request_timeout_ms=self.builder.request_timeout_mills, + scan_parallel=self.builder.scan_parallel, + enable_tls=self.builder.enable_tls, + ssl_param=ssl_param, + ) + + # Set home schema, home graph and time zone for session + try: + if self.builder.schema and self.builder.schema.strip(): + stmt = f'SESSION SET SCHEMA `{self.builder.schema}`' + result_set = client.execute(stmt) + if not result_set.is_succeeded: + raise RuntimeError( + f"{stmt} failed for {result_set.status_message}" + ) + + if self.builder.graph and self.builder.graph.strip(): + stmt = f'SESSION SET GRAPH "{self.builder.graph}"' + result_set = client.execute(stmt) + if not result_set.is_succeeded: + raise RuntimeError( + f"{stmt} failed for {result_set.status_message}" + ) + + if self.builder.timezone and self.builder.timezone.strip(): + stmt = f'SESSION SET timezone="{self.builder.timezone}"' + result_set = client.execute(stmt) + if not result_set.is_succeeded: + raise RuntimeError( + f"{stmt} failed for {result_set.status_message}" + ) + + for key, value in self.builder.session_configs.items(): + stmt = f"SESSION SET {key}={value}" + result_set = client.execute(stmt) + if not result_set.is_succeeded: + raise RuntimeError( + f"{stmt} failed for {result_set.status_message}" + ) + + if self.builder.parameters: + parameters_set_statement = "SESSION SET VALUE " + for param_key, param_value in self.builder.parameters.items(): + parameters_set_statement += ( + f"${param_key}={param_value}," + ) + parameters_set_statement = parameters_set_statement[:-1] + if parameters_set_statement: + result = client.execute(parameters_set_statement) + if not result.is_succeeded: + raise RuntimeError( + f"{parameters_set_statement} failed for {result.status_message}" + ) + + for pre_stmt in self.builder.pre_statements: + res = client.execute(pre_stmt) + if not res.is_succeeded: + raise RuntimeError( + f"{pre_stmt} failed for {res.status_message}" + ) + + except ExecutingError as e: + client.close() + raise e + + return client + + def destroy(self, client: "NebulaClient") -> None: + """Destroy a NebulaClient instance""" + try: + client.close() + except Exception as e: + logger.warn(f"session release failed: {e}") + + def validate(self, client: "NebulaClient") -> bool: + """Validate if a NebulaClient is still valid""" + is_alive = ( + time.time() * 1000 - client.get_create_time() + ) < self.builder.max_life_time_ms + return client.ping(self.builder.server_ping_timeout_mills) and is_alive diff --git a/src/nebulagraph_python/client/constants.py b/src/nebulagraph_python/client/constants.py index d75ca3c2..1f3b45d1 100644 --- a/src/nebulagraph_python/client/constants.py +++ b/src/nebulagraph_python/client/constants.py @@ -12,15 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -DEFAULT_CONNECT_TIMEOUT = 3 # 3 seconds -DEFAULT_REQUEST_TIMEOUT = 60 # 1 minute +"""Constants for NebulaGraph client, matching Java implementation""" -DEFAULT_MAX_CLIENT_SIZE = 10 -DEFAULT_MIN_CLIENT_SIZE = 1 -DEFAULT_TEST_ON_BORROW = False -DEFAULT_STRICTLY_SERVER_HEALTHY = False -DEFAULT_MAX_WAIT = 60 # 1 minute +# New constants matching Java implementation +DEFAULT_MAX_CLIENT_SIZE: int = 10 +DEFAULT_MIN_CLIENT_SIZE: int = 1 +DEFAULT_CONNECT_TIMEOUT_MS: int = 3 * 1000 # 3 seconds +DEFAULT_REQUEST_TIMEOUT_MS: int = 60 * 1000 # 1 minute +DEFAULT_MAX_TIMEOUT_MS: int = 2**31 - 1 # about 25 days +DEFAULT_MAX_PING_TIMEOUT_MS: int = 10 * 60 * 1000 +DEFAULT_PING_TIMEOUT_MS: int = 1000 +DEFAULT_HEALTH_CHECK_TIME_MS: int = 5 * 60 * 1000 +DEFAULT_TEST_ON_BORROW: bool = True +DEFAULT_BLOCK_WHEN_EXHAUSTED: bool = False +DEFAULT_MAX_WAIT_MS: int = 2**63 - 1 // 1000 +DEFAULT_IDLE_EVICT_SCHEDULE_MS: int = -1 +DEFAULT_MIN_EVICTABLE_IDLE_TIME_MS: int = 30 * 60 * 1000 +DEFAULT_STRICT_SERVER_HEALTHY: bool = False +DEFAULT_MAX_LIFE_TIME_MS: int = 2**63 - 1 +DEFAULT_BATCH_SIZE: int = 1000 +DEFAULT_SCAN_PARALLEL: int = 10 +DEFAULT_ENABLE_TLS: bool = False +DEFAULT_DISABLE_VERIFY_SERVER_CERT: bool = False +DEFAULT_TLS_PEER_NAME_VERIFY: bool = True -DEFAULT_SESSION_POOL_SIZE = 4 -DEFAULT_SESSION_POOL_WAIT_TIMEOUT = 60 # 1 minute -DEFAULT_SESSION_POOL_WAIT_RETRY_INTERVAL = 0.1 +# Backward compatibility constants (old API) +DEFAULT_SESSION_POOL_SIZE: int = 10 +DEFAULT_SESSION_POOL_WAIT_TIMEOUT: float = 0.0 +DEFAULT_MAX_CLIENT_SIZE_OLD: int = 10 +DEFAULT_MIN_CLIENT_SIZE_OLD: int = 1 +DEFAULT_TEST_ON_BORROW_OLD: bool = True +DEFAULT_STRICTLY_SERVER_HEALTHY: bool = False +DEFAULT_MAX_WAIT: float = 5.0 +DEFAULT_CONNECT_TIMEOUT: float = 3.0 +DEFAULT_REQUEST_TIMEOUT: float = 60.0 \ No newline at end of file diff --git a/src/nebulagraph_python/client/nebula_client.py b/src/nebulagraph_python/client/nebula_client.py new file mode 100644 index 00000000..28cb17f6 --- /dev/null +++ b/src/nebulagraph_python/client/nebula_client.py @@ -0,0 +1,422 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NebulaClient implementation matching Java NebulaClient""" + +import asyncio +import logging +import random +import threading +import time +from typing import Dict, List, Optional, TYPE_CHECKING + +import grpc + +from nebulagraph_python._error_code import ErrorCode +from nebulagraph_python.client._connection import GrpcConnection, AsyncConnection, ConnectionConfig +from nebulagraph_python.client.auth_result import AuthResult +from nebulagraph_python.client.base_executor import NebulaBaseExecutor, NebulaBaseAsyncExecutor +from nebulagraph_python.client.constants import ( + DEFAULT_CONNECT_TIMEOUT_MS, + DEFAULT_ENABLE_TLS, + DEFAULT_MAX_TIMEOUT_MS, + DEFAULT_PING_TIMEOUT_MS, + DEFAULT_REQUEST_TIMEOUT_MS, + DEFAULT_SCAN_PARALLEL, +) +from nebulagraph_python.data import HostAddress, SSLParam +from nebulagraph_python.error import ( + AuthenticatingError, + ExecutingError, + NebulaGraphRemoteError, +) +from nebulagraph_python.result_set import ResultSet + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +class NebulaClient(NebulaBaseExecutor): + """Client to connect to NebulaGraph, matching Java NebulaClient""" + + def __init__( + self, + addresses: str, + user_name: str = None, + password: Optional[str] = None, + *, + connect_timeout_ms: int = DEFAULT_CONNECT_TIMEOUT_MS, + request_timeout_ms: int = DEFAULT_REQUEST_TIMEOUT_MS, + server_ping_timeout_ms: int = DEFAULT_PING_TIMEOUT_MS, + scan_parallel: int = DEFAULT_SCAN_PARALLEL, + enable_tls: bool = DEFAULT_ENABLE_TLS, + ssl_param: Optional[SSLParam] = None, + auth_options: Optional[Dict[str, object]] = None, + ): + """Initialize NebulaClient with configuration parameters + + Args: + addresses: NebulaGraph server addresses (e.g., "127.0.0.1:9669,127.0.0.2:9669") + user_name: Username for authentication + password: Password for authentication + connect_timeout_ms: Connection timeout in milliseconds + request_timeout_ms: Request timeout in milliseconds + server_ping_timeout_ms: Server ping timeout in milliseconds + scan_parallel: Scan parallel degree + enable_tls: Enable TLS connection + ssl_param: SSL parameters + auth_options: Additional authentication options + """ + self.servers: List[HostAddress] = self._validate_address(addresses) + self.user_name: str = user_name + self.password: Optional[str] = password + self.auth_options: Dict[str, object] = auth_options or {} + + if password: + self.auth_options["password"] = password + + self.connect_timeout_mills: int = connect_timeout_ms + self.request_timeout_mills: int = request_timeout_ms + self.server_ping_timeout_mills: int = server_ping_timeout_ms + self.scan_parallel: int = scan_parallel + self.enable_tls: bool = enable_tls + self.ssl_param: Optional[SSLParam] = ssl_param + + self.connection: Optional[GrpcConnection] = None + self.session_id: int = -1 + self.version: str = "" + self.create_time: int = 0 + self.is_closed: bool = False + self._lock = threading.Lock() + + self._init_client() + + def execute( + self, + statement: str, + *, + timeout: Optional[float] = None, + do_ping: bool = False, + ) -> ResultSet: + """Execute a query with optional timeout""" + if timeout is None: + timeout = self.request_timeout_mills + return self.execute_with_timeout(statement, int(timeout)) + + def execute_with_timeout(self, gql: str, request_timeout: int) -> ResultSet: + """Execute a query with custom timeout""" + with self._lock: + self._check_closed() + response = self.connection.execute_default_timeout( + self.session_id, gql + ) + return ResultSet(response) + + def get_session_id(self) -> int: + """Get the session ID""" + return self.session_id + + def get_version(self) -> str: + """Get the server version""" + return self.version + + def get_create_time(self) -> int: + """Get the creation time""" + return self.create_time + + def get_host(self) -> str: + """Get the connected host address""" + if self.connection: + return str(self.connection.get_server_address()) + return "" + + def get_connect_timeout_mills(self) -> int: + """Get the connection timeout""" + return self.connect_timeout_mills + + def get_request_timeout_mills(self) -> int: + """Get the request timeout""" + return self.request_timeout_mills + + def get_scan_parallel(self) -> int: + """Get the scan parallel""" + return self.scan_parallel + + def ping(self, timeout_ms: int = DEFAULT_PING_TIMEOUT_MS) -> bool: + """Ping the server""" + with self._lock: + self._check_closed() + try: + return self.connection.ping(self.session_id, timeout_ms) + except ExecutingError as e: + logger.error(f"ping error for host {self.get_host()}: {e}") + return False + + def close(self) -> None: + """Close the client""" + with self._lock: + if not self.is_closed: + self.is_closed = True + if self.connection is not None: + try: + self.connection.execute( + self.session_id, "SESSION CLOSE", 1000 + ) + self.connection.close() + except Exception as e: + logger.warn(f"signout failed: {e}") + self.connection = None + + def is_closed_client(self) -> bool: + """Check if the client is closed""" + return self.is_closed + + def _check_closed(self) -> None: + """Check if the client is closed and raise exception if so""" + if self.is_closed: + raise RuntimeError("The NebulaClient already closed.") + + def _init_client(self) -> None: + """Initialize the client connection""" + auth_result: Optional[AuthResult] = None + self.connection = GrpcConnection() + + try_connect_times = len(self.servers) + random.shuffle(self.servers) + + while try_connect_times > 0: + try_connect_times -= 1 + try: + self.connection.open(self.servers[try_connect_times], self) + auth_result = self.connection.authenticate( + self.user_name, self.auth_options + ) + self.session_id = auth_result.get_session_id() + self.version = auth_result.get_version() + self.create_time = int(time.time() * 1000) + break + except AuthenticatingError as e: + logger.error(f"create NebulaClient failed: {e}") + raise + except Exception as e: + if try_connect_times == 0: + logger.error(f"create NebulaClient failed: {e}") + raise + + @staticmethod + def _validate_address(addresses: str) -> List[HostAddress]: + """Validate and parse addresses""" + result = [] + if isinstance(addresses, str): + for addr in addresses.split(","): + addr = addr.strip() + if ":" in addr: + host, port = addr.rsplit(":", 1) + result.append(HostAddress(host, int(port))) + else: + raise ValueError(f"Invalid address format: {addr}") + return result + + +class AsyncNebulaClient(NebulaBaseAsyncExecutor): + """Async client to connect to NebulaGraph, matching Java NebulaClient with async support""" + + def __init__( + self, + addresses: str, + user_name: str = None, + password: Optional[str] = None, + *, + connect_timeout_ms: int = DEFAULT_CONNECT_TIMEOUT_MS, + request_timeout_ms: int = DEFAULT_REQUEST_TIMEOUT_MS, + server_ping_timeout_ms: int = DEFAULT_PING_TIMEOUT_MS, + scan_parallel: int = DEFAULT_SCAN_PARALLEL, + enable_tls: bool = DEFAULT_ENABLE_TLS, + ssl_param: Optional[SSLParam] = None, + auth_options: Optional[Dict[str, object]] = None, + ): + """Initialize AsyncNebulaClient with configuration parameters + + Args: + addresses: NebulaGraph server addresses (e.g., "127.0.0.1:9669,127.0.0.2:9669") + user_name: Username for authentication + password: Password for authentication + connect_timeout_ms: Connection timeout in milliseconds + request_timeout_ms: Request timeout in milliseconds + server_ping_timeout_ms: Server ping timeout in milliseconds + scan_parallel: Scan parallel degree + enable_tls: Enable TLS connection + ssl_param: SSL parameters + auth_options: Additional authentication options + """ + self.servers: List[HostAddress] = self._validate_address(addresses) + self.user_name: str = user_name + self.password: Optional[str] = password + self.auth_options: Dict[str, object] = auth_options or {} + + if password: + self.auth_options["password"] = password + + self.connect_timeout_mills: int = connect_timeout_ms + self.request_timeout_mills: int = request_timeout_ms + self.server_ping_timeout_mills: int = server_ping_timeout_ms + self.scan_parallel: int = scan_parallel + self.enable_tls: bool = enable_tls + self.ssl_param: Optional[SSLParam] = ssl_param + + self.connection: Optional[AsyncConnection] = None + self.session_id: int = -1 + self.version: str = "" + self.create_time: int = 0 + self.is_closed: bool = False + self._lock = asyncio.Lock() + + async def execute( + self, + statement: str, + *, + timeout: Optional[float] = None, + do_ping: bool = False, + ) -> ResultSet: + """Execute a query with optional timeout""" + if timeout is None: + timeout = self.request_timeout_mills + return await self.execute_with_timeout(statement, int(timeout)) + + async def execute_with_timeout(self, gql: str, request_timeout: int) -> ResultSet: + """Execute a query with custom timeout""" + async with self._lock: + self._check_closed() + response = await self.connection.execute_default_timeout( + self.session_id, gql + ) + return ResultSet(response) + + def get_session_id(self) -> int: + """Get the session ID""" + return self.session_id + + def get_version(self) -> str: + """Get the server version""" + return self.version + + def get_create_time(self) -> int: + """Get the creation time""" + return self.create_time + + def get_host(self) -> str: + """Get the connected host address""" + if self.connection: + return str(self.connection.server_addr) + return "" + + def get_connect_timeout_mills(self) -> int: + """Get the connection timeout""" + return self.connect_timeout_mills + + def get_request_timeout_mills(self) -> int: + """Get the request timeout""" + return self.request_timeout_mills + + def get_scan_parallel(self) -> int: + """Get the scan parallel""" + return self.scan_parallel + + async def ping(self, timeout_ms: int = DEFAULT_PING_TIMEOUT_MS) -> bool: + """Ping the server""" + async with self._lock: + self._check_closed() + try: + return await self.connection.ping(self.session_id, timeout_ms) + except ExecutingError as e: + logger.error(f"ping error for host {self.get_host()}: {e}") + return False + + async def close(self) -> None: + """Close the client""" + async with self._lock: + if not self.is_closed: + self.is_closed = True + if self.connection is not None: + try: + await self.connection.execute( + self.session_id, "SESSION CLOSE", 1000 + ) + await self.connection.close() + except Exception as e: + logger.warn(f"signout failed: {e}") + self.connection = None + + def is_closed_client(self) -> bool: + """Check if the client is closed""" + return self.is_closed + + def _check_closed(self) -> None: + """Check if the client is closed and raise exception if so""" + if self.is_closed: + raise RuntimeError("The AsyncNebulaClient already closed.") + + async def _init_client(self) -> None: + """Initialize the client connection""" + auth_result: Optional[AuthResult] = None + + # Create connection config + config = ConnectionConfig.from_defaults( + hosts=self.servers, + ssl_param=self.enable_tls or self.ssl_param, + connect_timeout=self.connect_timeout_mills / 1000.0, + request_timeout=self.request_timeout_mills / 1000.0, + ) + if self.ssl_param: + config.ssl_param = self.ssl_param + + self.connection = AsyncConnection(config) + + try_connect_times = len(self.servers) + random.shuffle(self.servers) + + while try_connect_times > 0: + try_connect_times -= 1 + try: + await self.connection.connect(self.servers[try_connect_times]) + auth_result = await self.connection.authenticate( + self.user_name, self.auth_options + ) + self.session_id = auth_result.get_session_id() + self.version = auth_result.get_version() + self.create_time = int(time.time() * 1000) + break + except AuthenticatingError as e: + logger.error(f"create AsyncNebulaClient failed: {e}") + raise + except Exception as e: + if try_connect_times == 0: + logger.error(f"create AsyncNebulaClient failed: {e}") + raise + + @staticmethod + def _validate_address(addresses: str) -> List[HostAddress]: + """Validate and parse addresses""" + result = [] + if isinstance(addresses, str): + for addr in addresses.split(","): + addr = addr.strip() + if ":" in addr: + host, port = addr.rsplit(":", 1) + result.append(HostAddress(host, int(port))) + else: + raise ValueError(f"Invalid address format: {addr}") + return result \ No newline at end of file diff --git a/src/nebulagraph_python/client/nebula_pool.py b/src/nebulagraph_python/client/nebula_pool.py new file mode 100644 index 00000000..91ddba05 --- /dev/null +++ b/src/nebulagraph_python/client/nebula_pool.py @@ -0,0 +1,276 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""NebulaPool implementation using Python configuration style""" + +import logging +import time +import threading +from dataclasses import dataclass, field +from typing import Dict, List, Optional, TYPE_CHECKING + +from nebulagraph_python.client.client_pool_factory import ClientPoolFactory +from nebulagraph_python.client.constants import ( + DEFAULT_BLOCK_WHEN_EXHAUSTED, + DEFAULT_CONNECT_TIMEOUT_MS, + DEFAULT_ENABLE_TLS, + DEFAULT_HEALTH_CHECK_TIME_MS, + DEFAULT_IDLE_EVICT_SCHEDULE_MS, + DEFAULT_MAX_CLIENT_SIZE, + DEFAULT_MAX_LIFE_TIME_MS, + DEFAULT_MAX_WAIT_MS, + DEFAULT_MIN_CLIENT_SIZE, + DEFAULT_MIN_EVICTABLE_IDLE_TIME_MS, + DEFAULT_PING_TIMEOUT_MS, + DEFAULT_REQUEST_TIMEOUT_MS, + DEFAULT_SCAN_PARALLEL, + DEFAULT_STRICT_SERVER_HEALTHY, + DEFAULT_TEST_ON_BORROW, +) +from nebulagraph_python.data import HostAddress, SSLParam +from nebulagraph_python.client.nebula_client import NebulaClient +from nebulagraph_python.client.round_robin_load_balancer import RoundRobinLoadBalancer + +if TYPE_CHECKING: + pass + +logger = logging.getLogger(__name__) + + +@dataclass +class NebulaPoolConfig: + """Configuration for NebulaPool using Python dataclass""" + + # Connection settings + addresses: str + username: str + password: Optional[str] = None + + # Pool settings + max_client_size: int = DEFAULT_MAX_CLIENT_SIZE + min_client_size: int = DEFAULT_MIN_CLIENT_SIZE + max_wait_ms: int = DEFAULT_MAX_WAIT_MS + block_when_exhausted: bool = DEFAULT_BLOCK_WHEN_EXHAUSTED + + # Timeout settings + connect_timeout_ms: int = DEFAULT_CONNECT_TIMEOUT_MS + request_timeout_ms: int = DEFAULT_REQUEST_TIMEOUT_MS + server_ping_timeout_ms: int = DEFAULT_PING_TIMEOUT_MS + + # Health check settings + health_check_time_ms: int = DEFAULT_HEALTH_CHECK_TIME_MS + test_on_borrow: bool = DEFAULT_TEST_ON_BORROW + + # Eviction settings + idle_evict_schedule_ms: int = DEFAULT_IDLE_EVICT_SCHEDULE_MS + min_evictable_idle_time_ms: int = DEFAULT_MIN_EVICTABLE_IDLE_TIME_MS + + # Server health settings + strictly_server_healthy: bool = DEFAULT_STRICT_SERVER_HEALTHY + + # Life time settings + max_life_time_ms: int = DEFAULT_MAX_LIFE_TIME_MS + + # Session settings + graph: Optional[str] = None + schema: Optional[str] = None + timezone: Optional[str] = None + session_configs: Dict[str, str] = field(default_factory=dict) + parameters: Dict[str, str] = field(default_factory=dict) + pre_statements: List[str] = field(default_factory=list) + + # Other settings + scan_parallel: int = DEFAULT_SCAN_PARALLEL + enable_tls: bool = DEFAULT_ENABLE_TLS + ssl_param: Optional[SSLParam] = None + auth_options: Dict[str, object] = field(default_factory=dict) + + def __post_init__(self): + """Initialize auth_options with password if provided""" + if self.password: + self.auth_options["password"] = self.password + + +class NebulaPool: + """NebulaGraph connection pool using Python configuration style""" + + def __init__(self, config: NebulaPoolConfig): + """Initialize the NebulaPool with configuration""" + self.config = config + self._load_balancer: Optional[RoundRobinLoadBalancer] = None + self._factory: Optional[ClientPoolFactory] = None + self._pool: List[NebulaClient] = [] + self._in_use: Dict[NebulaClient, bool] = {} + self._lock = threading.Lock() + self._closed = False + + self._init_pool() + + def _init_pool(self) -> None: + """Initialize the connection pool""" + # Parse addresses + addresses = self._parse_addresses(self.config.addresses) + + # Create load balancer config + class LoadBalancerConfig: + def __init__(self, pool_config: NebulaPoolConfig, addrs: List[HostAddress]): + self.address = addrs + self.strictly_server_healthy = pool_config.strictly_server_healthy + self.user_name = pool_config.username + self.auth_options = pool_config.auth_options + self.connect_timeout_mills = pool_config.connect_timeout_ms + self.request_timeout_mills = pool_config.request_timeout_ms + self.server_ping_timeout_mills = pool_config.server_ping_timeout_ms + self.scan_parallel = pool_config.scan_parallel + self.enable_tls = pool_config.enable_tls + self.disable_verify_server_cert = False + self.tls_ca = ( + pool_config.ssl_param.ca_crt.decode() + if pool_config.ssl_param and pool_config.ssl_param.ca_crt + else None + ) + self.tls_cert = ( + pool_config.ssl_param.cert.decode() + if pool_config.ssl_param and pool_config.ssl_param.cert + else None + ) + self.tls_key = ( + pool_config.ssl_param.private_key.decode() + if pool_config.ssl_param and pool_config.ssl_param.private_key + else None + ) + # Session settings + self.graph = pool_config.graph + self.schema = pool_config.schema + self.timezone = pool_config.timezone + self.session_configs = pool_config.session_configs + self.parameters = pool_config.parameters + self.pre_statements = pool_config.pre_statements + self.max_life_time_ms = pool_config.max_life_time_ms + + lb_config = LoadBalancerConfig(self.config, addresses) + self._load_balancer = RoundRobinLoadBalancer(lb_config) + self._factory = ClientPoolFactory(self._load_balancer, lb_config) + + # Check server health + self._load_balancer.check_servers() + + # Initialize minimum number of clients + for _ in range(self.config.min_client_size): + try: + client = self._factory.create() + self._pool.append(client) + self._in_use[client] = False + except Exception as e: + logger.warning(f"Failed to create initial client: {e}") + + @staticmethod + def _parse_addresses(addresses: str) -> List[HostAddress]: + """Parse address string to HostAddress list""" + result = [] + for addr in addresses.split(","): + addr = addr.strip() + if ":" in addr: + host, port = addr.rsplit(":", 1) + result.append(HostAddress(host, int(port))) + else: + raise ValueError(f"Invalid address format: {addr}") + return result + + def get_client(self) -> NebulaClient: + """Get a client from the pool""" + if self._closed: + raise RuntimeError("Pool is closed") + + start_time = time.time() + while time.time() - start_time < self.config.max_wait_ms / 1000.0: + with self._lock: + # Try to find an available client + for client in self._pool: + if not self._in_use.get(client, False): + if self.config.test_on_borrow: + if not self._factory.validate(client): + self._pool.remove(client) + del self._in_use[client] + self._factory.destroy(client) + continue + self._in_use[client] = True + return client + + # Try to create a new client if under max limit + if len(self._pool) < self.config.max_client_size: + try: + client = self._factory.create() + self._pool.append(client) + self._in_use[client] = True + return client + except Exception as e: + logger.error(f"Failed to create new client: {e}") + + # If block_when_exhausted is False, raise exception + if not self.config.block_when_exhausted: + raise RuntimeError("No available clients in pool") + + # Wait a bit before retrying + time.sleep(0.01) + + raise RuntimeError(f"Timeout waiting for client after {self.config.max_wait_ms}ms") + + def return_client(self, client: NebulaClient) -> None: + """Return a client to the pool""" + if self._closed: + return + + with self._lock: + if client in self._in_use: + # Check if client should be invalidated + if ( + client.is_closed_client() + or (time.time() * 1000 - client.get_create_time()) + >= self.config.max_life_time_ms + ): + self._pool.remove(client) + del self._in_use[client] + self._factory.destroy(client) + else: + self._in_use[client] = False + + def close(self) -> None: + """Close the pool and all clients""" + with self._lock: + if not self._closed: + self._closed = True + for client in self._pool: + try: + self._factory.destroy(client) + except Exception as e: + logger.warning(f"Failed to close client: {e}") + self._pool.clear() + self._in_use.clear() + + def get_active_sessions(self) -> int: + """Get the number of active sessions""" + with self._lock: + return sum(1 for in_use in self._in_use.values() if in_use) + + def get_idle_sessions(self) -> int: + """Get the number of idle sessions""" + with self._lock: + return sum(1 for in_use in self._in_use.values() if not in_use) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() \ No newline at end of file diff --git a/src/nebulagraph_python/client/pool.py b/src/nebulagraph_python/client/pool.py deleted file mode 100644 index d0a8669c..00000000 --- a/src/nebulagraph_python/client/pool.py +++ /dev/null @@ -1,306 +0,0 @@ -# Copyright 2025 vesoft-inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from contextlib import contextmanager -from copy import copy -from dataclasses import dataclass -from itertools import cycle -from threading import Lock -from typing import Any, Dict, Iterator, List, Literal, Optional, Union - -from nebulagraph_python.client import constants -from nebulagraph_python.client._connection import ( - ConnectionConfig, - _parse_hosts, -) -from nebulagraph_python.client._session import SessionConfig -from nebulagraph_python.client.base_executor import NebulaBaseExecutor -from nebulagraph_python.client.client import NebulaClient -from nebulagraph_python.data import HostAddress, SSLParam -from nebulagraph_python.error import InternalError, PoolError - -logger = logging.getLogger(__name__) - - -@dataclass -class NebulaPoolConfig: - """Configuration for the NebulaGraph connection pool""" - - max_client_size: int = constants.DEFAULT_MAX_CLIENT_SIZE - min_client_size: int = constants.DEFAULT_MIN_CLIENT_SIZE - test_on_borrow: bool = constants.DEFAULT_TEST_ON_BORROW - strictly_server_healthy: bool = constants.DEFAULT_STRICTLY_SERVER_HEALTHY - max_wait: float = constants.DEFAULT_MAX_WAIT - - -class NebulaPool(NebulaBaseExecutor): - """A connection pool that manages multiple NebulaGraph clients with round-robin load balancing. - Safe for thread-level concurrency, not async/coroutine-level. - - Required to explicitly call `close()` to release all resources. - """ - - # Config - hosts: List[HostAddress] - username: str - password: str - ssl_param: Union[SSLParam, Literal[True], None] - auth_options: Optional[Dict[str, Any]] - pool_config: NebulaPoolConfig - session_config: Optional[SessionConfig] - conn_config: Optional[ConnectionConfig] - - # Owned Resources - _clients: List[NebulaClient] - - # State - _lock: Lock - _client_cycle: Iterator[NebulaClient] - _hosts_cycle: Iterator[HostAddress] - _in_use: Dict[NebulaClient, bool] # Track if client is in use - - def __init__( - self, - hosts: Union[str, List[str], List[HostAddress]], - username: str, - password: str, - *, - ssl_param: Union[SSLParam, Literal[True], None] = None, - auth_options: Optional[Dict[str, Any]] = None, - pool_config: Optional[NebulaPoolConfig] = None, - session_config: Optional[SessionConfig] = None, - conn_config: Optional[ConnectionConfig] = None, - ): - """Initialize NebulaGraph connection pool - - Args: - ---- - hosts: Single host string ("hostname:port"), list of host strings, - or list of HostAddress objects - username: Username for authentication - password: Password for authentication - ssl_param: SSL configuration - auth_options: dict of authentication options - pool_config: Pool configuration - session_config: Session configuration - connection_config: Connection configuration. If provided, - it will override the hosts and ssl_param - """ - self.hosts = ( - _parse_hosts(hosts) - if (conn_config is None or not conn_config.hosts) - else conn_config.hosts - ) - self.username = username - self.password = password - self.ssl_param = ssl_param - self.auth_options = auth_options - self.pool_config = pool_config or NebulaPoolConfig() - self.session_config = session_config - self.conn_config = conn_config - - self._clients = [] - self._lock = Lock() - self._in_use = {} # Initialize tracking dict - self._hosts_cycle = cycle(self.hosts) - - # Initialize the client pool - self.fulfill_pool() - - def fulfill_pool(self, locked: bool = False): - """May raise exception with partial success""" - to_fill_num = max(self.pool_config.max_client_size - len(self._clients), 0) - - def _inner_default() -> None: - for _ in range(to_fill_num): - client = NebulaClient( - hosts=self.hosts, - username=self.username, - password=self.password, - ssl_param=self.ssl_param, - auth_options=self.auth_options, - conn_config=self.conn_config, - session_config=self.session_config, - ) - self._clients.append(client) - self._in_use[client] = False - # Initialize the round-robin cycle - self._client_cycle = cycle(self._clients) - - def _inner_for_strictly_server_healthy() -> None: - # When new pool is created and strictly_server_healthy is True, - # we need to connect to all hosts - for _ in range(len(self.hosts)): - # Round-robin host address selection - host = next(self._hosts_cycle) - conn_config = None - if self.conn_config: - conn_config = copy(self.conn_config) - conn_config.hosts = [host] - - client = NebulaClient( - hosts=[host], - username=self.username, - password=self.password, - ssl_param=self.ssl_param, - auth_options=self.auth_options, - conn_config=self.conn_config, - session_config=self.session_config, - ) - if len(self._clients) < self.pool_config.max_client_size: - self._clients.append(client) - self._in_use[client] = False - else: - client.close() - # Initialize the round-robin cycle - self._client_cycle = cycle(self._clients) - - def _inner(): - try: - if not self.pool_config.strictly_server_healthy: - _inner_default() - else: - _inner_for_strictly_server_healthy() - except Exception as e: - self._client_cycle = cycle(self._clients) - raise e - - if not locked: - with self._lock: - _inner() - else: - _inner() - - def kick_from_pool(self, client: NebulaClient, locked: bool = False) -> None: - """Kick a client from the pool and close its connection. - - Args: - ---- - client: The client to kick from the pool - - Raises: - ------ - InternalError: If the client is not from this pool - """ - if client not in self._clients: - raise InternalError("Client does not belong to this pool") - - def _inner(): - self._clients.remove(client) - self._in_use.pop(client) - # Close the client connection - client.close() - # Recreate the cycle with remaining clients - if len(self._clients) < self.pool_config.min_client_size: - try: - self.fulfill_pool(locked=True) - except Exception: - logger.exception("Failed or partial success when fulfilling pool") - else: - self._client_cycle = cycle(self._clients) - - if not locked: - with self._lock: - _inner() - else: - _inner() - - def get_client(self) -> NebulaClient: - """Get the next available client using round-robin selection. - - Returns: - ------- - NebulaClient: The next available client from the pool - - Raises: - ------ - InternalError: When kicking a client from the pool fails - PoolError: If all clients are in use after max_wait seconds - """ - - def _inner(): - # Try one full cycle through the clients - for _ in range(len(self._clients)): - client = next(self._client_cycle) - if self._in_use[client]: - continue - if self.pool_config.test_on_borrow and not client.ping(): - self.kick_from_pool(client, locked=True) - continue - - self._in_use[client] = True - return client - return None - - with self._lock: - start_time = time.time() - while time.time() - start_time < self.pool_config.max_wait: - client = _inner() - if client: - return client - raise PoolError("All clients are in use") - - def return_client(self, client: NebulaClient) -> None: - """Return a client back to the pool. - - Args: - ---- - client: The client to return to the pool - - Raises: - ------ - InternalError: If the client is not from this pool - """ - if client not in self._clients: - raise InternalError("Client does not belong to this pool") - - with self._lock: - self._in_use[client] = False - - @contextmanager - def borrow(self): - """Borrow a client from the pool using a context manager. - - Returns: - ------- - ContextManager[NebulaClient]: A context manager that yields a client - - Raises: - ------ - PoolError: If all clients are in use after max_wait seconds - InternalError: If kicking a client from the pool fails - - Example: - ------- - with pool.borrow() as client: - result = client.execute("SHOW HOSTS") - """ - client = self.get_client() - try: - yield client - finally: - self.return_client(client) - - def execute( - self, statement: str, *, timeout: Optional[float] = None, do_ping: bool = False - ): - with self.borrow() as client: - return client.execute(statement, timeout=timeout, do_ping=do_ping) - - def close(self): - """Close all clients in the pool. No Exception will be raised but errors will be logged.""" - for client in self._clients: - client.close() diff --git a/src/nebulagraph_python/client/round_robin_load_balancer.py b/src/nebulagraph_python/client/round_robin_load_balancer.py new file mode 100644 index 00000000..9637a0ba --- /dev/null +++ b/src/nebulagraph_python/client/round_robin_load_balancer.py @@ -0,0 +1,110 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RoundRobinLoadBalancer matching Java implementation""" + +import logging +from typing import TYPE_CHECKING, Dict, List + +from nebulagraph_python.data import HostAddress, SSLParam +from nebulagraph_python.error import AuthenticatingError, ExecutingError + +if TYPE_CHECKING: + from nebulagraph_python.client.nebula_pool import NebulaPool + +logger = logging.getLogger(__name__) + + +class RoundRobinLoadBalancer: + """Round-robin load balancer for NebulaGraph servers""" + + def __init__(self, builder: "NebulaPool.Builder"): + """Initialize the load balancer""" + self.addresses: List[HostAddress] = list(builder.address) + self.strictly_server_healthy: bool = builder.strictly_server_healthy + self.user_name: str = builder.user_name + self.auth_options: Dict[str, object] = builder.auth_options + self.connection_timeout: int = builder.connect_timeout_mills + self.enable_tls: bool = builder.enable_tls + self.disable_verify_server_cert: bool = builder.disable_verify_server_cert + self.tls_ca: Optional[str] = builder.tls_ca + self.tls_cert: Optional[str] = builder.tls_cert + self.tls_key: Optional[str] = builder.tls_key + + from nebulagraph_python.client.nebula_client import NebulaClient + + self._nebula_client_class = NebulaClient + self._pos: int = 0 + + def address_size(self) -> int: + """Get the number of addresses""" + return len(self.addresses) + + def get_address(self) -> HostAddress: + """Get the next address using round-robin""" + if self._pos >= 2**31 - 1: + self._pos = 0 + new_pos = self._pos % len(self.addresses) + self._pos += 1 + return self.addresses[new_pos] + + def ping(self, addr: HostAddress) -> bool: + """Ping a server address""" + from nebulagraph_python.client.nebula_client import NebulaClient + + ssl_param = None + if self.enable_tls: + ssl_param = SSLParam( + ca_crt=self.tls_ca.encode() if self.tls_ca else None, + private_key=self.tls_key.encode() if self.tls_key else None, + cert=self.tls_cert.encode() if self.tls_cert else None, + ) + + client = NebulaClient( + f"{addr.host}:{addr.port}", + self.user_name, + auth_options=self.auth_options, + connect_timeout_ms=self.connection_timeout, + enable_tls=self.enable_tls, + ssl_param=ssl_param, + ) + client.close() + return True + + def check_servers(self) -> None: + """Check if servers are healthy""" + last_auth_e: AuthenticatingError = None + last_io_e: ExecutingError = None + good_address: int = 0 + + for host_address in self.addresses: + try: + self.ping(host_address) + good_address += 1 + except AuthenticatingError as e: + last_auth_e = e + except ExecutingError as e: + last_io_e = e + + if self.strictly_server_healthy: + if good_address == self.address_size(): + return + else: + if good_address >= 1: + return + + if last_auth_e is not None: + raise last_auth_e + if last_io_e is not None: + raise last_io_e \ No newline at end of file diff --git a/test_async_client.py b/test_async_client.py new file mode 100644 index 00000000..65028a26 --- /dev/null +++ b/test_async_client.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 +"""Test script for AsyncNebulaClient functionality""" + +import asyncio +from nebulagraph_python.client import AsyncNebulaClient + + +async def test_async_client(): + """Test AsyncNebulaClient basic operations""" + # Create async client using direct initialization + client = AsyncNebulaClient( + addresses="127.0.0.1:9669", + user_name="root", + password="nebula", + connect_timeout_ms=3000, + request_timeout_ms=30000, + ) + + try: + # Initialize connection + await client._init_client() + print(f"Connected to NebulaGraph at {client.get_host()}") + print(f"Session ID: {client.get_session_id()}") + print(f"Server version: {client.get_version()}") + + # Test ping + ping_result = await client.ping() + print(f"Ping result: {ping_result}") + + # Test execute + result = await client.execute("SHOW HOSTS") + print("Query result:") + print(result) + + # Test execute_with_timeout + result2 = await client.execute_with_timeout("SHOW SPACES", 5000) + print("Spaces result:") + print(result2) + + except Exception as e: + print(f"Error: {e}") + finally: + # Close client + await client.close() + print("Client closed") + + +async def test_async_with_tls(): + """Test AsyncNebulaClient with TLS disabled""" + client = AsyncNebulaClient( + addresses="127.0.0.1:9669", + user_name="root", + password="nebula", + connect_timeout_ms=3000, + request_timeout_ms=30000, + enable_tls=False + ) + + try: + await client._init_client() + print(f"TLS test - Connected to {client.get_host()}") + print(f"Session ID: {client.get_session_id()}") + except Exception as e: + print(f"TLS test error: {e}") + finally: + await client.close() + + +if __name__ == "__main__": + print("Testing AsyncNebulaClient...") + print("=" * 50) + asyncio.run(test_async_client()) + print("\n" + "=" * 50) + print("Testing AsyncNebulaClient with TLS...") + asyncio.run(test_async_with_tls()) \ No newline at end of file diff --git a/test_nebula_connection.py b/test_nebula_connection.py new file mode 100644 index 00000000..419094b3 --- /dev/null +++ b/test_nebula_connection.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python3 +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple script to test NebulaGraph connection""" + +import os +import sys + +# Add src to path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) + +from nebulagraph_python import NebulaPool, NebulaPoolConfig + +# Configuration +NEBULA_HOSTS = os.getenv("NEBULA_HOSTS", "192.168.8.6:3820") +NEBULA_USER = os.getenv("NEBULA_USER", "root") +NEBULA_PASSWORD = os.getenv("NEBULA_PASSWORD", "NebulaGraph01") + +def test_connection(): + """Test basic connection to NebulaGraph""" + print(f"Testing connection to NebulaGraph at {NEBULA_HOSTS}...") + print(f"User: {NEBULA_USER}") + + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=1, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + + # Test simple query + result = client.execute("RETURN 1 AS num") + if result.is_succeeded: + print("✓ Connection successful!") + print(f"✓ Query executed successfully: {result}") + else: + print(f"✗ Query failed: {result.error_msg()}") + + pool.return_client(client) + pool.close() + return True + + except Exception as e: + print(f"✗ Connection failed: {e}") + print("\nPlease ensure:") + print("1. NebulaGraph server is running") + print("2. Server address is correct (default: 127.0.0.1:9669)") + print("3. Username and password are correct") + print("\nYou can set custom credentials using environment variables:") + print(" export NEBULA_HOSTS='127.0.0.1:9669'") + print(" export NEBULA_USER='root'") + print(" export NEBULA_PASSWORD='nebula'") + return False + +if __name__ == "__main__": + success = test_connection() + sys.exit(0 if success else 1) diff --git a/test_refactored.py b/test_refactored.py new file mode 100644 index 00000000..7e254a8b --- /dev/null +++ b/test_refactored.py @@ -0,0 +1,96 @@ +#!/usr/bin/env python3 +""" +测试重构后的NebulaClient和NebulaPool +""" + +from nebulagraph_python.client import NebulaClient, NebulaPool, NebulaPoolConfig + +print("=" * 60) +print("测试重构后的NebulaClient") +print("=" * 60) + +try: + # 测试NebulaClient + client = NebulaClient( + "192.168.8.6:3820", + "root", + "NebulaGraph01", + connect_timeout_ms=3000, + request_timeout_ms=60000, + ) + + print(f"✓ 成功创建NebulaClient") + print(f" Session ID: {client.get_session_id()}") + print(f" Version: {client.get_version()}") + print(f" Host: {client.get_host()}") + + # 执行查询 + result = client.execute("SHOW GRAPHS") + print(f"✓ 成功执行查询: SHOW GRAPHS") + print(f" 结果: {result}") + + # 测试ping + ping_result = client.ping() + print(f"✓ Ping结果: {ping_result}") + + client.close() + print("✓ 客户端已关闭") + +except Exception as e: + print(f"✗ NebulaClient测试失败: {e}") + import traceback + traceback.print_exc() + +print("\n" + "=" * 60) +print("测试重构后的NebulaPool") +print("=" * 60) + +try: + # 测试NebulaPool + config = NebulaPoolConfig( + addresses="192.168.8.6:3820", + username="root", + password="NebulaGraph01", + max_client_size=3, + min_client_size=1, + ) + + pool = NebulaPool(config) + print(f"✓ 成功创建NebulaPool") + + # 通过pool获取client执行查询 + client = pool.get_client() + result = client.execute("SHOW GRAPHS") + print(f"✓ 成功执行查询: SHOW GRAPHS") + print(f" 结果: {result}") + pool.return_client(client) + + # 测试获取和返回客户端 + client1 = pool.get_client() + print(f"✓ 成功获取客户端1") + print(f" 活跃会话数: {pool.get_active_sessions()}") + print(f" 空闲会话数: {pool.get_idle_sessions()}") + + client2 = pool.get_client() + print(f"✓ 成功获取客户端2") + print(f" 活跃会话数: {pool.get_active_sessions()}") + + pool.return_client(client1) + print(f"✓ 成功返回客户端1") + print(f" 活跃会话数: {pool.get_active_sessions()}") + print(f" 空闲会话数: {pool.get_idle_sessions()}") + + pool.return_client(client2) + print(f"✓ 成功返回客户端2") + + pool.close() + print("✓ 连接池已关闭") + +except Exception as e: + print(f"✗ NebulaPool测试失败: {e}") + import traceback + traceback.print_exc() + +print("\n" + "=" * 60) +print("测试完成") +print("=" * 60) \ No newline at end of file diff --git a/tests/INTEGRATION_TEST_README.md b/tests/INTEGRATION_TEST_README.md deleted file mode 100644 index 9d1b8d85..00000000 --- a/tests/INTEGRATION_TEST_README.md +++ /dev/null @@ -1,103 +0,0 @@ -# Integration Tests for NebulaGraph Python Client - -This directory contains integration tests that require an actual NebulaGraph server connection. - -## Prerequisites - -1. **Start NebulaGraph Server** - - Make sure NebulaGraph is running and accessible. Default address: `127.0.0.1:9669` - -2. **Set Environment Variables** (Optional) - - If your NebulaGraph server uses different credentials, set these environment variables: - - ```bash - export NEBULA_HOSTS="127.0.0.1:9669" - export NEBULA_USER="root" - export NEBULA_PASSWORD="nebula" - ``` - - Default values: - - `NEBULA_HOSTS`: `127.0.0.1:9669` - - `NEBULA_USER`: `root` - - `NEBULA_PASSWORD`: `nebula` - -## Running Integration Tests - -```bash -# Run all integration tests -pdm run python tests/test_nebula_client_decode_integration.py -v -``` - -## Test Coverage - -The integration tests cover: - -### Geography Types -- Point decoding -- LineString decoding -- Polygon decoding - -### Basic Types -- Integer -- String (including Unicode/Chinese) -- Boolean -- Date -- Duration (month-based and time-based) -- List -- Record (named tuple) -- Vector (embedding) -- Set -- Map - -### Graph Types -- Node decoding -- Edge decoding -- Path decoding - -## Test Data - -The tests create a temporary graph named `decode` with: -- Node type `player` with various property types -- Node type `person` -- Edge type `friend` connecting person nodes - -Test data is automatically cleaned up after test completion. - -## Troubleshooting - -### Connection Refused -``` -Failed to set up test environment: Connection refused -``` -Make sure NebulaGraph server is running: -```bash -# Check if NebulaGraph is listening -netstat -an | grep 9669 -``` - -### Authentication Failed -``` -Failed to set up test environment: Authentication failed -``` -Verify your credentials are correct: -```bash -export NEBULA_USER="your_username" -export NEBULA_PASSWORD="your_password" -``` - -### Graph Already Exists -The tests automatically drop existing `decode` and `decode_type` graphs before creating new ones. If you encounter issues, manually clean up: -```bash -# Using Nebula Console -DROP GRAPH IF EXISTS decode; -DROP GRAPH TYPE IF EXISTS decode_type; -``` - -## Notes - -- These tests require a running NebulaGraph server -- Tests create and drop temporary graphs automatically -- Tests are independent and can be run in any order -- All tests are skipped if the initial connection fails diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 00000000..6ede33d0 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,763 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from unittest.mock import AsyncMock, Mock, patch, MagicMock + +import grpc +import pytest + +from nebulagraph_python.client._connection import ( + AsyncConnection, + Connection, + ConnectionConfig, + _parse_hosts, +) +from nebulagraph_python.data import HostAddress, SSLParam +from nebulagraph_python.error import ( + AuthenticatingError, + ConnectingError, + ErrorCode, + ExecutingError, + InternalError, +) + + +class TestParseHosts: + """Test cases for _parse_hosts function""" + + def test_parse_single_string_host(self): + """Test parsing a single host string""" + hosts = _parse_hosts("127.0.0.1:9669") + assert len(hosts) == 1 + assert hosts[0].host == "127.0.0.1" + assert hosts[0].port == 9669 + + def test_parse_multiple_string_hosts(self): + """Test parsing multiple host strings""" + hosts = _parse_hosts("127.0.0.1:9669,127.0.0.2:9669") + assert len(hosts) == 2 + assert hosts[0].host == "127.0.0.1" + assert hosts[0].port == 9669 + assert hosts[1].host == "127.0.0.2" + assert hosts[1].port == 9669 + + def test_parse_host_address_objects(self): + """Test parsing HostAddress objects""" + hosts = _parse_hosts([HostAddress("127.0.0.1", 9669), HostAddress("127.0.0.2", 9670)]) + assert len(hosts) == 2 + assert hosts[0].host == "127.0.0.1" + assert hosts[0].port == 9669 + assert hosts[1].host == "127.0.0.2" + assert hosts[1].port == 9670 + + def test_parse_mixed_hosts(self): + """Test parsing mixed host formats""" + hosts = _parse_hosts(["127.0.0.1:9669", HostAddress("127.0.0.2", 9670)]) + assert len(hosts) == 2 + assert hosts[0].host == "127.0.0.1" + assert hosts[0].port == 9669 + assert hosts[1].host == "127.0.0.2" + assert hosts[1].port == 9670 + + +class TestConnectionConfig: + """Test cases for ConnectionConfig""" + + def test_from_defaults_basic(self): + """Test creating ConnectionConfig from defaults""" + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + assert len(config.hosts) == 1 + assert config.hosts[0].host == "127.0.0.1" + assert config.hosts[0].port == 9669 + assert config.ssl_param is None + assert config.connect_timeout is not None + assert config.request_timeout is not None + + def test_from_defaults_with_ssl_true(self): + """Test creating ConnectionConfig with SSL enabled""" + config = ConnectionConfig.from_defaults("127.0.0.1:9669", ssl_param=True) + assert config.ssl_param is not None + assert isinstance(config.ssl_param, SSLParam) + + def test_from_defaults_with_ssl_param(self): + """Test creating ConnectionConfig with custom SSLParam""" + ssl = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") + config = ConnectionConfig.from_defaults("127.0.0.1:9669", ssl_param=ssl) + assert config.ssl_param == ssl + + def test_from_defaults_multiple_hosts(self): + """Test creating ConnectionConfig with multiple hosts""" + config = ConnectionConfig.from_defaults("127.0.0.1:9669,127.0.0.2:9669") + assert len(config.hosts) == 2 + + def test_from_defaults_with_timeouts(self): + """Test creating ConnectionConfig with custom timeouts""" + config = ConnectionConfig.from_defaults( + "127.0.0.1:9669", connect_timeout=10.0, request_timeout=30.0 + ) + assert config.connect_timeout == 10.0 + assert config.request_timeout == 30.0 + + def test_connection_config_empty_hosts_raises_error(self): + """Test that empty hosts raises ValueError""" + with pytest.raises(ValueError, match="hosts cannot be empty"): + ConnectionConfig(hosts=[]) + + +class TestConnection: + """Test cases for synchronous Connection""" + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_connect_success(self, mock_stub_class, mock_channel_class, mock_future): + """Test successful connection""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_stub_class.return_value = mock_stub + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + assert conn._stub is not None + assert conn._channel is not None + assert conn.connected is not None + assert conn.connected.host == "127.0.0.1" + assert conn.connected.port == 9669 + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_connect_with_timeout(self, mock_stub_class, mock_channel_class, mock_future): + """Test connection with timeout""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_stub_class.return_value = mock_stub + + config = ConnectionConfig.from_defaults("127.0.0.1:9669", connect_timeout=5.0) + conn = Connection(config) + + assert conn._stub is not None + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_connect_failover_to_second_host(self, mock_stub_class, mock_channel_class, mock_future): + """Test connection failover to second host""" + call_count = [0] + + def create_channel(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise Exception("First host failed") + mock_channel = MagicMock() + return mock_channel + + mock_channel_class.side_effect = create_channel + mock_stub = MagicMock() + mock_stub_class.return_value = mock_stub + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + + config = ConnectionConfig.from_defaults("127.0.0.1:9669,127.0.0.2:9669") + conn = Connection(config) + + assert conn._stub is not None + assert conn.connected.host == "127.0.0.2" + + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + def test_connect_all_hosts_fail(self, mock_channel_class): + """Test connection failure when all hosts fail""" + mock_channel_class.side_effect = Exception("Connection failed") + + config = ConnectionConfig.from_defaults("127.0.0.1:9669,127.0.0.2:9669") + with pytest.raises(ConnectingError, match="Failed to connect to any"): + Connection(config) + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_close(self, mock_stub_class, mock_channel_class, mock_future): + """Test closing connection""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_stub_class.return_value = mock_stub + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + conn.close() + + assert conn._channel is None + assert conn._stub is None + assert conn.connected is None + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_ping_success(self, mock_stub_class, mock_channel_class, mock_future): + """Test successful ping""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_stub_class.return_value = mock_stub + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + assert conn.ping() is True + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_ping_failure(self, mock_stub_class, mock_channel_class, mock_future): + """Test ping failure""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_stub.Execute.side_effect = Exception("Ping failed") + mock_stub_class.return_value = mock_stub + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + assert conn.ping() is False + + def test_ping_no_stub(self): + """Test ping when stub is not initialized""" + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + conn._stub = None + + assert conn.ping() is False + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_execute_success(self, mock_stub_class, mock_channel_class, mock_future): + """Test successful execute""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_response = MagicMock() + mock_stub_class.return_value = mock_stub + mock_stub.Execute.return_value = mock_response + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + # Mock authenticate to set session_id + with patch.object(conn, "authenticate", return_value=1): + result = conn.execute(1, "RETURN 1") + + assert result is not None + mock_stub.Execute.assert_called_once() + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_execute_with_timeout(self, mock_stub_class, mock_channel_class, mock_future): + """Test execute with custom timeout""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_response = MagicMock() + mock_stub_class.return_value = mock_stub + mock_stub.Execute.return_value = mock_response + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + with patch.object(conn, "authenticate", return_value=1): + conn.execute(1, "RETURN 1", timeout=5.0) + + # Verify timeout was passed + call_args = mock_stub.Execute.call_args + assert call_args[1]["timeout"] == 5.0 + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_execute_rpc_error(self, mock_stub_class, mock_channel_class, mock_future): + """Test execute with RPC error""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_stub_class.return_value = mock_stub + mock_stub.Execute.side_effect = grpc.RpcError("RPC error") + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + with patch.object(conn, "authenticate", return_value=1): + with pytest.raises(ExecutingError, match="RPC error"): + conn.execute(1, "RETURN 1") + + def test_execute_no_stub(self): + """Test execute when stub is not initialized""" + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + conn._stub = None + + with pytest.raises(InternalError, match="Connection not established"): + conn.execute(1, "RETURN 1") + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_authenticate_success(self, mock_stub_class, mock_channel_class, mock_future): + """Test successful authentication""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_response = MagicMock() + mock_response.status.code = b"00000" + mock_response.session_id = 12345 + mock_stub_class.return_value = mock_stub + mock_stub.Authenticate.return_value = mock_response + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + with patch("nebulagraph_python.client.client.init_session"): + session_id = conn.authenticate("user", "pass") + + assert session_id == 12345 + mock_stub.Authenticate.assert_called_once() + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_authenticate_with_auth_options(self, mock_stub_class, mock_channel_class, mock_future): + """Test authentication with auth options""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_response = MagicMock() + mock_response.status.code = b"00000" + mock_response.session_id = 12345 + mock_stub_class.return_value = mock_stub + mock_stub.Authenticate.return_value = mock_response + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + with patch("nebulagraph_python.client.client.init_session"): + session_id = conn.authenticate( + "user", "pass", auth_options={"option1": "value1"} + ) + + assert session_id == 12345 + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_authenticate_failure(self, mock_stub_class, mock_channel_class, mock_future): + """Test authentication failure""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_response = MagicMock() + mock_response.status.code = b"E_AUTH_FAILURE" + mock_response.status.message = b"Authentication failed" + mock_stub_class.return_value = mock_stub + mock_stub.Authenticate.return_value = mock_response + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + with pytest.raises(Exception): + conn.authenticate("user", "wrong_pass") + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_authenticate_rpc_error(self, mock_stub_class, mock_channel_class, mock_future): + """Test authentication with RPC error""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_stub_class.return_value = mock_stub + mock_stub.Authenticate.side_effect = grpc.RpcError("RPC error") + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + + with pytest.raises(AuthenticatingError, match="RPC error"): + conn.authenticate("user", "pass") + + def test_authenticate_no_stub(self): + """Test authenticate when stub is not initialized""" + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = Connection(config) + conn._stub = None + + with pytest.raises(InternalError, match="Connection not established"): + conn.authenticate("user", "pass") + + @patch("nebulagraph_python.client._connection.grpc.channel_ready_future") + @patch("nebulagraph_python.client._connection.grpc.secure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + def test_connect_with_ssl(self, mock_stub_class, mock_channel_class, mock_future): + """Test connection with SSL""" + mock_future_result = MagicMock() + mock_future_result.result.return_value = None + mock_future.return_value = mock_future_result + mock_channel = MagicMock() + mock_channel_class.return_value = mock_channel + mock_stub = MagicMock() + mock_stub_class.return_value = mock_stub + + ssl = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") + config = ConnectionConfig.from_defaults("127.0.0.1:9669", ssl_param=ssl) + conn = Connection(config) + + assert conn._stub is not None + assert conn._channel is not None + + +class TestAsyncConnection: + """Test cases for asynchronous AsyncConnection""" + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_connect_success(self, mock_stub_class, mock_channel_class): + """Test successful async connection""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_stub_class.return_value = mock_stub + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + + assert conn._stub is not None + assert conn._channel is not None + assert conn.connected is not None + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_connect_with_timeout(self, mock_stub_class, mock_channel_class): + """Test async connection with timeout""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_stub_class.return_value = mock_stub + + config = ConnectionConfig.from_defaults("127.0.0.1:9669", connect_timeout=5.0) + conn = AsyncConnection(config) + await conn.connect() + + assert conn._stub is not None + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_connect_failover_to_second_host(self, mock_stub_class, mock_channel_class): + """Test async connection failover to second host""" + call_count = [0] + + async def create_channel(*args, **kwargs): + call_count[0] += 1 + if call_count[0] == 1: + raise Exception("First host failed") + mock_channel = AsyncMock() + return mock_channel + + mock_channel_class.side_effect = create_channel + mock_stub = AsyncMock() + mock_stub_class.return_value = mock_stub + + config = ConnectionConfig.from_defaults("127.0.0.1:9669,127.0.0.2:9669") + conn = AsyncConnection(config) + await conn.connect() + + assert conn._stub is not None + assert conn.connected.host == "127.0.0.2" + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + async def test_connect_all_hosts_fail(self, mock_channel_class): + """Test async connection failure when all hosts fail""" + mock_channel_class.side_effect = Exception("Connection failed") + + config = ConnectionConfig.from_defaults("127.0.0.1:9669,127.0.0.2:9669") + conn = AsyncConnection(config) + + with pytest.raises(ConnectingError, match="Failed to connect asynchronously"): + await conn.connect() + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_close(self, mock_stub_class, mock_channel_class): + """Test closing async connection""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_stub_class.return_value = mock_stub + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + await conn.close() + + assert conn._channel is None + assert conn._stub is None + assert conn.connected is None + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_ping_success(self, mock_stub_class, mock_channel_class): + """Test successful async ping""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_stub_class.return_value = mock_stub + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + + assert await conn.ping() is True + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_ping_failure(self, mock_stub_class, mock_channel_class): + """Test async ping failure""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_stub.Execute.side_effect = Exception("Ping failed") + mock_stub_class.return_value = mock_stub + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + + assert await conn.ping() is False + + @pytest.mark.asyncio + async def test_ping_no_stub(self): + """Test async ping when stub is not initialized""" + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + conn._stub = None + + assert await conn.ping() is False + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_execute_success(self, mock_stub_class, mock_channel_class): + """Test successful async execute""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_response = MagicMock() + mock_stub_class.return_value = mock_stub + mock_stub.Execute.return_value = mock_response + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + + with patch.object(conn, "authenticate", return_value=1): + result = await conn.execute(1, "RETURN 1") + + assert result is not None + mock_stub.Execute.assert_called_once() + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_execute_with_timeout(self, mock_stub_class, mock_channel_class): + """Test async execute with custom timeout""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_response = MagicMock() + mock_stub_class.return_value = mock_stub + mock_stub.Execute.return_value = mock_response + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + + with patch.object(conn, "authenticate", return_value=1): + await conn.execute(1, "RETURN 1", timeout=5.0) + + call_args = mock_stub.Execute.call_args + assert call_args[1]["timeout"] == 5.0 + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_execute_rpc_error(self, mock_stub_class, mock_channel_class): + """Test async execute with RPC error""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_stub_class.return_value = mock_stub + mock_stub.Execute.side_effect = grpc.aio.AioRpcError("RPC error") + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + + with patch.object(conn, "authenticate", return_value=1): + with pytest.raises(ExecutingError, match="RPC error"): + await conn.execute(1, "RETURN 1") + + @pytest.mark.asyncio + async def test_execute_no_stub(self): + """Test async execute when stub is not initialized""" + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + conn._stub = None + + with pytest.raises(InternalError, match="Async connection not established"): + await conn.execute(1, "RETURN 1") + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_authenticate_success(self, mock_stub_class, mock_channel_class): + """Test successful async authentication""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_response = MagicMock() + mock_response.status.code = b"00000" + mock_response.session_id = 12345 + mock_stub_class.return_value = mock_stub + mock_stub.Authenticate.return_value = mock_response + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + + with patch("nebulagraph_python.client._connection.ainit_session"): + session_id = await conn.authenticate("user", "pass") + + assert session_id == 12345 + mock_stub.Authenticate.assert_called_once() + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_authenticate_failure(self, mock_stub_class, mock_channel_class): + """Test async authentication failure""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_response = MagicMock() + mock_response.status.code = b"E_AUTH_FAILURE" + mock_response.status.message = b"Authentication failed" + mock_stub_class.return_value = mock_stub + mock_stub.Authenticate.return_value = mock_response + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + + with pytest.raises(Exception): + await conn.authenticate("user", "wrong_pass") + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.insecure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_authenticate_rpc_error(self, mock_stub_class, mock_channel_class): + """Test async authentication with RPC error""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_stub_class.return_value = mock_stub + mock_stub.Authenticate.side_effect = grpc.aio.AioRpcError("RPC error") + + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + await conn.connect() + + with pytest.raises(AuthenticatingError, match="RPC error"): + await conn.authenticate("user", "pass") + + @pytest.mark.asyncio + async def test_authenticate_no_stub(self): + """Test async authenticate when stub is not initialized""" + config = ConnectionConfig.from_defaults("127.0.0.1:9669") + conn = AsyncConnection(config) + conn._stub = None + + with pytest.raises(InternalError, match="Async connection not established"): + await conn.authenticate("user", "pass") + + @pytest.mark.asyncio + @patch("nebulagraph_python.client._connection.grpc.aio.secure_channel") + @patch("nebulagraph_python.client._connection.graph_pb2_grpc.GraphServiceStub") + async def test_connect_with_ssl(self, mock_stub_class, mock_channel_class): + """Test async connection with SSL""" + mock_channel = AsyncMock() + mock_channel_class.return_value = mock_channel + mock_stub = AsyncMock() + mock_stub_class.return_value = mock_stub + + ssl = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") + config = ConnectionConfig.from_defaults("127.0.0.1:9669", ssl_param=ssl) + conn = AsyncConnection(config) + await conn.connect() + + assert conn._stub is not None + assert conn._channel is not None \ No newline at end of file diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py deleted file mode 100644 index 883f5682..00000000 --- a/tests/test_connection_pool.py +++ /dev/null @@ -1,782 +0,0 @@ -import asyncio -import threading -import time -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -from nebulagraph_python.client._connection import ( - AsyncConnection, - Connection, - ConnectionConfig, -) -from nebulagraph_python.client._connection_pool import ( - AsyncConnectionPool, - ConnectionPool, -) -from nebulagraph_python.data import HostAddress -from nebulagraph_python.error import ConnectingError, PoolError - - -class TestConnectionPool: - """Test cases for ConnectionPool (synchronous)""" - - def test_init_single_host(self): - """Test initialization with a single host""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host]) - pool = ConnectionPool(config) - - assert len(pool.addresses) == 1 - assert pool.addresses[0] == host - assert pool.current_address == host - assert len(pool._connections) == 1 - assert host in pool._connections - - def test_init_multiple_hosts(self): - """Test initialization with multiple hosts""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - HostAddress("localhost", 9671), - ] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - assert len(pool.addresses) == 3 - assert pool.addresses == hosts - assert pool.current_address == hosts[0] - assert len(pool._connections) == 3 - for host in hosts: - assert host in pool._connections - - def test_init_with_ping_enabled(self): - """Test initialization with ping enabled""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = ConnectionPool(config) - - # Each connection should have ping disabled in its config - for conn in pool._connections.values(): - assert not conn.config.ping_before_execute - - @patch('nebulagraph_python.client._connection_pool.Connection') - def test_init_creates_connections_with_single_host_config(self, mock_connection): - """Test that each connection is created with a config containing only its host""" - hosts = [ - HostAddress("host1", 9669), - HostAddress("host2", 9669), - ] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - # Should create one connection per host - assert mock_connection.call_count == 2 - - # Each connection should be created with a config containing only one host - for call_args in mock_connection.call_args_list: - conn_config = call_args[0][0] - assert len(conn_config.hosts) == 1 - assert not conn_config.ping_before_execute - - def test_next_address_round_robin(self): - """Test that next_address implements round-robin""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - HostAddress("localhost", 9671), - ] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - # Initial address should be hosts[0] - assert pool.current_address == hosts[0] - - # Should cycle through hosts - assert pool.next_address() == hosts[1] - assert pool.next_address() == hosts[2] - assert pool.next_address() == hosts[0] # Back to first - assert pool.next_address() == hosts[1] - - def test_next_address_thread_safety(self): - """Test that next_address is thread-safe""" - hosts = [HostAddress(f"localhost", 9669 + i) for i in range(3)] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - results = [] - errors = [] - - def get_addresses(thread_id): - try: - for _ in range(100): - addr = pool.next_address() - results.append((thread_id, addr)) - except Exception as e: - errors.append((thread_id, e)) - - # Start multiple threads - threads = [] - for i in range(5): - thread = threading.Thread(target=get_addresses, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - assert len(errors) == 0, f"Unexpected errors: {errors}" - assert len(results) == 500 # 5 threads * 100 calls each - - # Check that all addresses are valid - for _, addr in results: - assert addr in hosts - - @patch('nebulagraph_python.client._connection_pool.Connection') - def test_connect_success(self, mock_connection): - """Test successful connection to all hosts""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - # Mock the connections - mock_conn1 = Mock() - mock_conn2 = Mock() - pool._connections[hosts[0]] = mock_conn1 - pool._connections[hosts[1]] = mock_conn2 - - pool.connect() - - # Should call connect on all connections - mock_conn1.connect.assert_called_once() - mock_conn2.connect.assert_called_once() - - def test_get_connection_without_ping(self): - """Test getting connection when ping is disabled""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=False) - pool = ConnectionPool(config) - - # Mock the connection - mock_conn = Mock() - pool._connections[host] = mock_conn - - result = pool.get_connection(host) - - assert result == mock_conn - mock_conn.ping.assert_not_called() - - def test_get_connection_with_ping_success(self): - """Test getting connection when ping succeeds""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = ConnectionPool(config) - - # Mock the connection - mock_conn = Mock() - mock_conn.ping.return_value = True - pool._connections[host] = mock_conn - - result = pool.get_connection(host) - - assert result == mock_conn - mock_conn.ping.assert_called_once() - mock_conn.reconnect.assert_not_called() - - def test_get_connection_with_ping_fail_reconnect_success(self): - """Test getting connection when ping fails but reconnect succeeds""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = ConnectionPool(config) - - # Mock the connection - mock_conn = Mock() - mock_conn.ping.return_value = False - pool._connections[host] = mock_conn - - result = pool.get_connection(host) - - assert result == mock_conn - mock_conn.ping.assert_called_once() - mock_conn.reconnect.assert_called_once() - - @patch('nebulagraph_python.client._connection_pool.logger') - def test_get_connection_with_ping_fail_reconnect_fail(self, mock_logger): - """Test getting connection when both ping and reconnect fail""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = ConnectionPool(config) - - # Mock the connection - mock_conn = Mock() - mock_conn.ping.return_value = False - mock_conn.reconnect.side_effect = ConnectingError("Connection failed") - pool._connections[host] = mock_conn - - result = pool.get_connection(host) - - assert result is None - mock_conn.ping.assert_called_once() - mock_conn.reconnect.assert_called_once() - mock_logger.exception.assert_called_once() - - def test_next_connection_success(self): - """Test getting next available connection""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - # Mock connections - mock_conn1 = Mock() - mock_conn2 = Mock() - pool._connections[hosts[0]] = mock_conn1 - pool._connections[hosts[1]] = mock_conn2 - - # Mock get_connection to return the connection for the current address - with patch.object(pool, 'get_connection', side_effect=[mock_conn1]): - addr, conn = pool.next_connection() - - assert addr == hosts[1] # Should advance to next address - assert conn == mock_conn1 - - def test_next_connection_with_failures(self): - """Test getting next connection when some hosts are unavailable""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - HostAddress("localhost", 9671), - ] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - # Mock get_connection to fail for first two hosts, succeed for third - def mock_get_connection(host_addr): - if host_addr in [hosts[0], hosts[1]]: - return None - return Mock() - - with patch.object(pool, 'get_connection', side_effect=mock_get_connection): - addr, conn = pool.next_connection() - - assert addr == hosts[2] - assert conn is not None - - def test_next_connection_all_fail(self): - """Test getting next connection when all hosts are unavailable""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - # Mock get_connection to always return None - with patch.object(pool, 'get_connection', return_value=None): - with pytest.raises(PoolError, match="No connection available in the pool"): - pool.next_connection() - - def test_close_all_connections(self): - """Test closing all connections in the pool""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - # Mock connections - mock_conn1 = Mock() - mock_conn2 = Mock() - pool._connections[hosts[0]] = mock_conn1 - pool._connections[hosts[1]] = mock_conn2 - - pool.close() - - # Should close all connections and clear the dictionary - mock_conn1.close.assert_called_once() - mock_conn2.close.assert_called_once() - assert len(pool._connections) == 0 - - def test_concurrent_next_connection(self): - """Test concurrent access to next_connection""" - hosts = [HostAddress(f"localhost", 9669 + i) for i in range(3)] - config = ConnectionConfig(hosts=hosts) - pool = ConnectionPool(config) - - # Mock all connections to be available - for host in hosts: - pool._connections[host] = Mock() - - results = [] - errors = [] - - def get_next_connection(thread_id): - try: - for _ in range(10): - with patch.object(pool, 'get_connection', return_value=Mock()): - addr, conn = pool.next_connection() - results.append((thread_id, addr)) - except Exception as e: - errors.append((thread_id, e)) - - # Start multiple threads - threads = [] - for i in range(3): - thread = threading.Thread(target=get_next_connection, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - assert len(errors) == 0, f"Unexpected errors: {errors}" - assert len(results) == 30 # 3 threads * 10 calls each - - -class TestAsyncConnectionPool: - """Test cases for AsyncConnectionPool (asynchronous)""" - - @pytest.mark.asyncio - async def test_init_single_host(self): - """Test initialization with a single host""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host]) - pool = AsyncConnectionPool(config) - - assert len(pool.addresses) == 1 - assert pool.addresses[0] == host - assert pool.current_address == host - assert len(pool._connections) == 1 - assert host in pool._connections - - @pytest.mark.asyncio - async def test_init_multiple_hosts(self): - """Test initialization with multiple hosts""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - HostAddress("localhost", 9671), - ] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - assert len(pool.addresses) == 3 - assert pool.addresses == hosts - assert pool.current_address == hosts[0] - assert len(pool._connections) == 3 - for host in hosts: - assert host in pool._connections - - @pytest.mark.asyncio - async def test_init_with_ping_enabled(self): - """Test initialization with ping enabled""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = AsyncConnectionPool(config) - - # Each connection should have ping disabled in its config - for conn in pool._connections.values(): - assert not conn.config.ping_before_execute - - @pytest.mark.asyncio - @patch('nebulagraph_python.client._connection_pool.AsyncConnection') - async def test_init_creates_connections_with_single_host_config(self, mock_connection): - """Test that each connection is created with a config containing only its host""" - hosts = [ - HostAddress("host1", 9669), - HostAddress("host2", 9669), - ] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - # Should create one connection per host - assert mock_connection.call_count == 2 - - # Each connection should be created with a config containing only one host - for call_args in mock_connection.call_args_list: - conn_config = call_args[0][0] - assert len(conn_config.hosts) == 1 - assert not conn_config.ping_before_execute - - @pytest.mark.asyncio - async def test_next_address_round_robin(self): - """Test that next_address implements round-robin""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - HostAddress("localhost", 9671), - ] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - # Initial address should be hosts[0] - assert pool.current_address == hosts[0] - - # Should cycle through hosts - assert await pool.next_address() == hosts[1] - assert await pool.next_address() == hosts[2] - assert await pool.next_address() == hosts[0] # Back to first - assert await pool.next_address() == hosts[1] - - @pytest.mark.asyncio - async def test_next_address_async_safety(self): - """Test that next_address is async/coroutine-level safe""" - hosts = [HostAddress(f"localhost", 9669 + i) for i in range(10)] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - results = [] - - async def get_addresses(task_id): - for _ in range(100): - addr = await pool.next_address() - results.append((task_id, addr)) - - # Start multiple tasks - tasks = [] - for i in range(5): - task = asyncio.create_task(get_addresses(i)) - tasks.append(task) - - # Wait for all tasks to complete - await asyncio.gather(*tasks) - - assert len(results) == 500 # 5 tasks * 100 calls each - - # Check that all addresses are valid - for _, addr in results: - assert addr in hosts - - @pytest.mark.asyncio - @patch('nebulagraph_python.client._connection_pool.AsyncConnection') - async def test_connect_success(self, mock_connection): - """Test successful connection to all hosts""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - # Mock the connections - mock_conn1 = AsyncMock() - mock_conn2 = AsyncMock() - pool._connections[hosts[0]] = mock_conn1 - pool._connections[hosts[1]] = mock_conn2 - - await pool.connect() - - # Should call connect on all connections - mock_conn1.connect.assert_called_once() - mock_conn2.connect.assert_called_once() - - @pytest.mark.asyncio - async def test_get_connection_without_ping(self): - """Test getting connection when ping is disabled""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=False) - pool = AsyncConnectionPool(config) - - # Mock the connection - mock_conn = AsyncMock() - pool._connections[host] = mock_conn - - result = await pool.get_connection(host) - - assert result == mock_conn - mock_conn.ping.assert_not_called() - - @pytest.mark.asyncio - async def test_get_connection_with_ping_success(self): - """Test getting connection when ping succeeds""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = AsyncConnectionPool(config) - - # Mock the connection - mock_conn = AsyncMock() - mock_conn.ping.return_value = True - pool._connections[host] = mock_conn - - result = await pool.get_connection(host) - - assert result == mock_conn - mock_conn.ping.assert_called_once() - mock_conn.reconnect.assert_not_called() - - @pytest.mark.asyncio - async def test_get_connection_with_ping_fail_reconnect_success(self): - """Test getting connection when ping fails but reconnect succeeds""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = AsyncConnectionPool(config) - - # Mock the connection - mock_conn = AsyncMock() - mock_conn.ping.return_value = False - pool._connections[host] = mock_conn - - result = await pool.get_connection(host) - - assert result == mock_conn - mock_conn.ping.assert_called_once() - mock_conn.reconnect.assert_called_once() - - @pytest.mark.asyncio - @patch('nebulagraph_python.client._connection_pool.logger') - async def test_get_connection_with_ping_fail_reconnect_fail(self, mock_logger): - """Test getting connection when both ping and reconnect fail""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = AsyncConnectionPool(config) - - # Mock the connection - mock_conn = AsyncMock() - mock_conn.ping.return_value = False - mock_conn.reconnect.side_effect = ConnectingError("Connection failed") - pool._connections[host] = mock_conn - - result = await pool.get_connection(host) - - assert result is None - mock_conn.ping.assert_called_once() - mock_conn.reconnect.assert_called_once() - mock_logger.exception.assert_called_once() - - @pytest.mark.asyncio - async def test_next_connection_success(self): - """Test getting next available connection""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - # Mock connections - mock_conn1 = AsyncMock() - mock_conn2 = AsyncMock() - pool._connections[hosts[0]] = mock_conn1 - pool._connections[hosts[1]] = mock_conn2 - - # Mock get_connection to return the connection for the current address - async def mock_get_connection(host_addr): - if host_addr == hosts[1]: - return mock_conn1 - return None - - with patch.object(pool, 'get_connection', side_effect=mock_get_connection): - addr, conn = await pool.next_connection() - - assert addr == hosts[1] # Should advance to next address - assert conn == mock_conn1 - - @pytest.mark.asyncio - async def test_next_connection_with_failures(self): - """Test getting next connection when some hosts are unavailable""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - HostAddress("localhost", 9671), - ] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - # Mock get_connection to fail for first two hosts, succeed for third - async def mock_get_connection(host_addr): - if host_addr in [hosts[0], hosts[1]]: - return None - return AsyncMock() - - with patch.object(pool, 'get_connection', side_effect=mock_get_connection): - addr, conn = await pool.next_connection() - - assert addr == hosts[2] - assert conn is not None - - @pytest.mark.asyncio - async def test_next_connection_all_fail(self): - """Test getting next connection when all hosts are unavailable""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - # Mock get_connection to always return None - async def mock_get_connection(host_addr): - return None - - with patch.object(pool, 'get_connection', side_effect=mock_get_connection): - with pytest.raises(PoolError, match="No connection available in the pool"): - await pool.next_connection() - - @pytest.mark.asyncio - async def test_close_all_connections(self): - """Test closing all connections in the pool""" - hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - # Mock connections - mock_conn1 = AsyncMock() - mock_conn2 = AsyncMock() - pool._connections[hosts[0]] = mock_conn1 - pool._connections[hosts[1]] = mock_conn2 - - await pool.close() - - # Should close all connections and clear the dictionary - mock_conn1.close.assert_called_once() - mock_conn2.close.assert_called_once() - assert len(pool._connections) == 0 - - @pytest.mark.asyncio - async def test_concurrent_next_connection(self): - """Test concurrent access to next_connection""" - hosts = [HostAddress(f"localhost", 9669 + i) for i in range(5)] - config = ConnectionConfig(hosts=hosts) - pool = AsyncConnectionPool(config) - - # Mock all connections to be available - for host in hosts: - pool._connections[host] = AsyncMock() - - results = [] - - async def get_next_connection(task_id): - for _ in range(10): - async def mock_get_connection(host_addr): - return AsyncMock() - - with patch.object(pool, 'get_connection', side_effect=mock_get_connection): - addr, conn = await pool.next_connection() - results.append((task_id, addr)) - - # Start multiple tasks - tasks = [] - for i in range(3): - task = asyncio.create_task(get_next_connection(i)) - tasks.append(task) - - # Wait for all tasks to complete - await asyncio.gather(*tasks) - - assert len(results) == 30 # 3 tasks * 10 calls each - - -class TestConnectionPoolEdgeCases: - """Test edge cases and error conditions for both pool types""" - - def test_pool_empty_hosts_list(self): - """Test sync pool with empty hosts list""" - with pytest.raises(ValueError): - config = ConnectionConfig(hosts=[]) - - - - def test_sync_pool_single_host_round_robin(self): - """Test sync pool with single host always returns same address""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host]) - pool = ConnectionPool(config) - - for _ in range(10): - assert pool.next_address() == host - - @pytest.mark.asyncio - async def test_async_pool_single_host_round_robin(self): - """Test async pool with single host always returns same address""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host]) - pool = AsyncConnectionPool(config) - - for _ in range(10): - assert await pool.next_address() == host - - @patch('nebulagraph_python.client._connection_pool.logger') - def test_sync_get_connection_exception_during_reconnect(self, mock_logger): - """Test sync pool handling of unexpected exceptions during reconnect""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = ConnectionPool(config) - - # Mock the connection - mock_conn = Mock() - mock_conn.ping.return_value = False - mock_conn.reconnect.side_effect = RuntimeError("Unexpected error") - pool._connections[host] = mock_conn - - # Should catch the exception and return None - result = pool.get_connection(host) - assert result is None - mock_logger.exception.assert_called_once() - - @pytest.mark.asyncio - @patch('nebulagraph_python.client._connection_pool.logger') - async def test_async_get_connection_exception_during_reconnect(self, mock_logger): - """Test async pool handling of unexpected exceptions during reconnect""" - host = HostAddress("localhost", 9669) - config = ConnectionConfig(hosts=[host], ping_before_execute=True) - pool = AsyncConnectionPool(config) - - # Mock the connection - mock_conn = AsyncMock() - mock_conn.ping.return_value = False - mock_conn.reconnect.side_effect = RuntimeError("Unexpected error") - pool._connections[host] = mock_conn - - # Should catch the exception and return None - result = await pool.get_connection(host) - assert result is None - mock_logger.exception.assert_called_once() - - def test_sync_pool_config_modification_isolation(self): - """Test that pool config modifications don't affect original config""" - original_hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=original_hosts, ping_before_execute=True) - pool = ConnectionPool(config) - - # Original config should remain unchanged - assert config.ping_before_execute is True - assert len(config.hosts) == 2 - - # Each connection should have modified config - for conn in pool._connections.values(): - assert not conn.config.ping_before_execute - assert len(conn.config.hosts) == 1 - - @pytest.mark.asyncio - async def test_async_pool_config_modification_isolation(self): - """Test that pool config modifications don't affect original config""" - original_hosts = [ - HostAddress("localhost", 9669), - HostAddress("localhost", 9670), - ] - config = ConnectionConfig(hosts=original_hosts, ping_before_execute=True) - pool = AsyncConnectionPool(config) - - # Original config should remain unchanged - assert config.ping_before_execute is True - assert len(config.hosts) == 2 - - # Each connection should have modified config - for conn in pool._connections.values(): - assert not conn.config.ping_before_execute - assert len(conn.config.hosts) == 1 - diff --git a/tests/test_integration.py b/tests/test_integration.py new file mode 100644 index 00000000..0502055f --- /dev/null +++ b/tests/test_integration.py @@ -0,0 +1,701 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import os +import pytest + +from nebulagraph_python import ( + NebulaClient, + NebulaAsyncClient, + NebulaPool, + NebulaPoolConfig, + SessionConfig, + SessionPoolConfig, +) + +# 从环境变量获取测试配置,如果没有则使用默认值 +NEBULA_HOST = os.getenv("NEBULA_HOST", "192.168.8.6") +NEBULA_PORT = os.getenv("NEBULA_PORT", "3820") +NEBULA_USER = os.getenv("NEBULA_USER", "root") +NEBULA_PASSWORD = os.getenv("NEBULA_PASSWORD", "NebulaGraph01") + +NEBULA_ADDRESS = f"{NEBULA_HOST}:{NEBULA_PORT}" + + +class TestConnectionIntegration: + """实际连接测试 - 测试Connection功能""" + + def test_connection_basic(self): + """测试基本连接""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + assert client is not None + client.close() + + def test_connection_ping(self): + """测试连接ping功能""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + assert client.ping() is True + client.close() + + def test_connection_execute_simple_query(self): + """测试执行简单查询""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + result = client.execute("RETURN 1") + assert result is not None + assert result.is_succeeded + client.close() + + def test_connection_execute_show_hosts(self): + """测试执行SHOW HOSTS命令""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + # NebulaGraph 5.0 使用不同的语法 + result = client.execute("SHOW HOSTS GRAPH") + assert result is not None + assert result.is_succeeded + client.close() + + def test_connection_context_manager(self): + """测试上下文管理器""" + with NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) as client: + assert client is not None + result = client.execute("RETURN 1") + assert result.is_succeeded + + def test_connection_with_session_config(self): + """测试带会话配置的连接""" + session_config = SessionConfig() + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + session_config=session_config, + ) + assert client is not None + result = client.execute("RETURN 1") + assert result.is_succeeded + client.close() + + +class TestAsyncConnectionIntegration: + """异步连接测试 - 测试AsyncConnection功能""" + + @pytest.mark.asyncio + async def test_async_connection_basic(self): + """测试基本异步连接""" + client = await NebulaAsyncClient.connect( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + assert client is not None + await client.close() + + @pytest.mark.asyncio + async def test_async_connection_execute(self): + """测试异步执行查询""" + client = await NebulaAsyncClient.connect( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + result = await client.execute("RETURN 1") + assert result is not None + assert result.is_succeeded + await client.close() + + @pytest.mark.asyncio + async def test_async_connection_show_hosts(self): + """测试异步执行SHOW HOSTS""" + client = await NebulaAsyncClient.connect( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + result = await client.execute("SHOW HOSTS") + assert result is not None + assert result.is_succeeded + await client.close() + + @pytest.mark.asyncio + async def test_async_connection_context_manager(self): + """测试异步上下文管理器""" + async with await NebulaAsyncClient.connect( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) as client: + assert client is not None + result = await client.execute("RETURN 1") + assert result.is_succeeded + + +class TestSessionPoolIntegration: + """会话池集成测试""" + + def test_session_pool_basic(self): + """测试基本会话池""" + pool_config = SessionPoolConfig(size=3) + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + session_pool_config=pool_config, + ) + assert client is not None + + # 执行多个查询 + for i in range(5): + result = client.execute("RETURN 1") + assert result.is_succeeded + + client.close() + + def test_session_pool_concurrent(self): + """测试会话池并发访问""" + import threading + import time + + pool_config = SessionPoolConfig(size=3) + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + session_pool_config=pool_config, + ) + + results = [] + errors = [] + + def execute_query(thread_id): + try: + result = client.execute("RETURN 1") + results.append(thread_id) + time.sleep(0.1) + except Exception as e: + errors.append((thread_id, e)) + + threads = [] + for i in range(5): + thread = threading.Thread(target=execute_query, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(errors) == 0 + assert len(results) == 5 + client.close() + + def test_session_pool_borrow_session(self): + """测试借用会话""" + pool_config = SessionPoolConfig(size=2) + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + session_pool_config=pool_config, + ) + + with client.borrow() as session: + result = session.execute("RETURN 1") + assert result.is_succeeded + + client.close() + + +class TestAsyncSessionPoolIntegration: + """异步会话池集成测试""" + + @pytest.mark.asyncio + async def test_async_session_pool_basic(self): + """测试基本异步会话池""" + pool_config = SessionPoolConfig(size=3) + client = await NebulaAsyncClient.connect( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + session_pool_config=pool_config, + ) + assert client is not None + + # 执行多个查询 + for i in range(5): + result = await client.execute("RETURN 1") + assert result.is_succeeded + + await client.close() + + @pytest.mark.asyncio + async def test_async_session_pool_concurrent(self): + """测试异步会话池并发访问""" + pool_config = SessionPoolConfig(size=3) + client = await NebulaAsyncClient.connect( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + session_pool_config=pool_config, + ) + + async def execute_query(task_id): + result = await client.execute("RETURN 1") + assert result.is_succeeded + + tasks = [execute_query(i) for i in range(5)] + await asyncio.gather(*tasks) + + await client.close() + + @pytest.mark.asyncio + async def test_async_session_pool_borrow_session(self): + """测试异步借用会话""" + pool_config = SessionPoolConfig(size=2) + client = await NebulaAsyncClient.connect( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + session_pool_config=pool_config, + ) + + async with client.borrow() as session: + result = await session.execute("RETURN 1") + assert result.is_succeeded + + await client.close() + + +class TestNebulaPoolIntegration: + """连接池集成测试""" + + def test_nebula_pool_basic(self): + """测试基本连接池""" + pool_config = NebulaPoolConfig( + max_client_size=3, min_client_size=1, max_wait=10.0 + ) + pool = NebulaPool( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + pool_config=pool_config, + ) + assert pool is not None + + result = pool.execute("RETURN 1") + assert result.is_succeeded + + pool.close() + + def test_nebula_pool_borrow_client(self): + """测试借用客户端""" + pool_config = NebulaPoolConfig( + max_client_size=2, min_client_size=1, max_wait=10.0 + ) + pool = NebulaPool( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + pool_config=pool_config, + ) + + with pool.borrow() as client: + result = client.execute("RETURN 1") + assert result.is_succeeded + + pool.close() + + def test_nebula_pool_concurrent(self): + """测试连接池并发访问""" + import threading + import time + + pool_config = NebulaPoolConfig( + max_client_size=5, min_client_size=2, max_wait=10.0 + ) + pool = NebulaPool( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + pool_config=pool_config, + ) + + results = [] + errors = [] + + def execute_query(thread_id): + try: + result = pool.execute("RETURN 1") + results.append(thread_id) + time.sleep(0.1) + except Exception as e: + errors.append((thread_id, e)) + + threads = [] + for i in range(5): + thread = threading.Thread(target=execute_query, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(errors) == 0 + assert len(results) == 5 + pool.close() + + def test_nebula_pool_round_robin(self): + """测试轮询负载均衡""" + pool_config = NebulaPoolConfig( + max_client_size=2, min_client_size=1, max_wait=10.0 + ) + pool = NebulaPool( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + pool_config=pool_config, + ) + + # 执行多个查询,应该轮询使用不同的客户端 + for i in range(4): + result = pool.execute("RETURN 1") + assert result.is_succeeded + + pool.close() + + def test_nebula_pool_context_manager(self): + """测试连接池上下文管理器""" + pool_config = NebulaPoolConfig( + max_client_size=2, min_client_size=1, max_wait=10.0 + ) + pool = NebulaPool( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + pool_config=pool_config, + ) + # NebulaPool不支持上下文管理器,手动关闭 + result = pool.execute("RETURN 1") + assert result.is_succeeded + pool.close() + + def test_nebula_pool_get_client_and_return(self): + """测试获取和返回客户端""" + pool_config = NebulaPoolConfig( + max_client_size=2, min_client_size=1, max_wait=10.0 + ) + pool = NebulaPool( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + pool_config=pool_config, + ) + + client = pool.get_client() + assert client is not None + + result = client.execute("RETURN 1") + assert result.is_succeeded + + pool.return_client(client) + pool.close() + + +class TestGraphOperations: + """图操作测试""" + + def test_create_space(self): + """测试创建图空间""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + # 删除可能存在的图空间 + client.execute("DROP SPACE IF EXISTS test_space") + + # 创建图空间 + result = client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") + assert result.is_succeeded + + # 使用图空间 + result = client.execute("USE test_space") + assert result.is_succeeded + + client.close() + + def test_create_tag(self): + """测试创建标签""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + # 确保图空间存在 + client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") + client.execute("USE test_space") + + # 创建标签 + result = client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") + assert result.is_succeeded + + client.close() + + def test_create_edge(self): + """测试创建边类型""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + # 确保图空间存在 + client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") + client.execute("USE test_space") + + # 创建边类型 + result = client.execute("CREATE EDGE IF NOT EXISTS follow(degree int)") + assert result.is_succeeded + + client.close() + + def test_insert_vertex(self): + """测试插入顶点""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + # 准备图空间 + client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") + client.execute("USE test_space") + client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") + + # 插入顶点 + result = client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18), "2":("Jerry", 20)') + assert result.is_succeeded + + client.close() + + def test_insert_edge(self): + """测试插入边""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + # 准备图空间 + client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") + client.execute("USE test_space") + client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") + client.execute("CREATE EDGE IF NOT EXISTS follow(degree int)") + client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18), "2":("Jerry", 20)') + + # 插入边 + result = client.execute('INSERT EDGE follow(degree) VALUES "1"->"2":(90)') + assert result.is_succeeded + + client.close() + + def test_query_vertex(self): + """测试查询顶点""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + # 准备数据 + client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") + client.execute("USE test_space") + client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") + client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18)') + + # 查询顶点 + result = client.execute('FETCH PROP ON person "1" YIELD vertex as v') + assert result.is_succeeded + + client.close() + + def test_query_edge(self): + """测试查询边""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + # 准备数据 + client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") + client.execute("USE test_space") + client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") + client.execute("CREATE EDGE IF NOT EXISTS follow(degree int)") + client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18), "2":("Jerry", 20)') + client.execute('INSERT EDGE follow(degree) VALUES "1"->"2":(90)') + + # 查询边 + result = client.execute('FETCH PROP ON follow "1"->"2" YIELD edge as e') + assert result.is_succeeded + + client.close() + + def test_complex_query(self): + """测试复杂查询""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + # 准备数据 + client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") + client.execute("USE test_space") + client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") + client.execute("CREATE EDGE IF NOT EXISTS follow(degree int)") + client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18), "2":("Jerry", 20), "3":("Alice", 22)') + client.execute('INSERT EDGE follow(degree) VALUES "1"->"2":(90), "2"->"3":(80)') + + # 复杂查询:查找Tom关注的人 + result = client.execute('GO FROM "1" OVER follow YIELD $$.person.name AS name, $$.person.age AS age') + assert result.is_succeeded + + client.close() + + +class TestErrorHandling: + """错误处理测试""" + + def test_invalid_query(self): + """测试无效查询""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + result = client.execute("INVALID QUERY") + assert not result.is_succeeded + + client.close() + + def test_wrong_credentials(self): + """测试错误凭据""" + with pytest.raises(Exception): + client = NebulaClient( + NEBULA_ADDRESS, + "wrong_user", + "wrong_password", + ) + client.close() + + def test_connection_timeout(self): + """测试连接超时""" + from nebulagraph_python.client._connection import ConnectionConfig + + conn_config = ConnectionConfig.from_defaults( + NEBULA_ADDRESS, connect_timeout=1.0 + ) + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + conn_config=conn_config, + ) + # 应该能连接成功 + assert client.ping() + client.close() + + +class TestPerformance: + """性能测试""" + + def test_batch_insert(self): + """测试批量插入""" + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + # 准备图空间 + client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") + client.execute("USE test_space") + client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") + + # 批量插入 + vertices = [] + for i in range(100): + vertices.append(f'"{i}":("Person{i}", {20 + i % 30})') + + query = f'INSERT VERTEX person(name, age) VALUES {", ".join(vertices)}' + result = client.execute(query) + assert result.is_succeeded + + client.close() + + def test_concurrent_queries(self): + """测试并发查询""" + import threading + + client = NebulaClient( + NEBULA_ADDRESS, + NEBULA_USER, + NEBULA_PASSWORD, + ) + + results = [] + errors = [] + + def execute_query(thread_id): + try: + result = client.execute("RETURN 1") + results.append(thread_id) + except Exception as e: + errors.append((thread_id, e)) + + threads = [] + for i in range(10): + thread = threading.Thread(target=execute_query, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(errors) == 0 + assert len(results) == 10 + client.close() \ No newline at end of file diff --git a/tests/test_nebula_client.py b/tests/test_nebula_client.py new file mode 100644 index 00000000..7847d0ae --- /dev/null +++ b/tests/test_nebula_client.py @@ -0,0 +1,608 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Complete tests for NebulaClient""" + +import time +from unittest.mock import Mock, MagicMock, patch, call + +import pytest + +from nebulagraph_python.client.nebula_client import NebulaClient +from nebulagraph_python.data import HostAddress, SSLParam +from nebulagraph_python.error import AuthenticatingError, ExecutingError +from nebulagraph_python.client.constants import ( + DEFAULT_CONNECT_TIMEOUT_MS, + DEFAULT_REQUEST_TIMEOUT_MS, + DEFAULT_PING_TIMEOUT_MS, + DEFAULT_SCAN_PARALLEL, + DEFAULT_ENABLE_TLS, + DEFAULT_MAX_TIMEOUT_MS, +) +from nebulagraph_python.proto import graph_pb2 + + +class TestNebulaClient: + """Test cases for NebulaClient""" + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_creation_with_defaults(self, mock_connection_class): + """Test creating a NebulaClient with default values""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass" + ) + + assert client.servers == [HostAddress("127.0.0.1", 9669)] + assert client.user_name == "test_user" + assert client.password == "test_pass" + assert client.connect_timeout_mills == DEFAULT_CONNECT_TIMEOUT_MS + assert client.request_timeout_mills == DEFAULT_REQUEST_TIMEOUT_MS + assert client.server_ping_timeout_mills == DEFAULT_PING_TIMEOUT_MS + assert client.scan_parallel == DEFAULT_SCAN_PARALLEL + assert client.enable_tls == DEFAULT_ENABLE_TLS + assert client.session_id == 12345 + assert client.version == "v5.0.0" + assert client.is_closed is False + mock_connection.open.assert_called_once() + mock_connection.authenticate.assert_called_once() + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_creation_with_custom_timeouts(self, mock_connection_class): + """Test creating a NebulaClient with custom timeouts""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + connect_timeout_ms=5000, + request_timeout_ms=120000, + server_ping_timeout_ms=2000, + ) + + assert client.connect_timeout_mills == 5000 + assert client.request_timeout_mills == 120000 + assert client.server_ping_timeout_mills == 2000 + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_creation_with_scan_parallel(self, mock_connection_class): + """Test creating a NebulaClient with custom scan_parallel""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + scan_parallel=20, + ) + + assert client.scan_parallel == 20 + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_creation_with_tls(self, mock_connection_class): + """Test creating a NebulaClient with TLS enabled""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + ssl_param = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + enable_tls=True, + ssl_param=ssl_param, + ) + + assert client.enable_tls is True + assert client.ssl_param == ssl_param + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_creation_with_auth_options(self, mock_connection_class): + """Test creating a NebulaClient with auth options""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + auth_options = {"password": "test_pass", "custom_option": "value"} + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + auth_options=auth_options, + ) + + assert client.auth_options == auth_options + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_creation_multiple_addresses(self, mock_connection_class): + """Test creating a NebulaClient with multiple addresses""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669,127.0.0.2:9669,127.0.0.3:9669", + user_name="test_user", + password="test_pass", + ) + + assert len(client.servers) == 3 + assert HostAddress("127.0.0.1", 9669) in client.servers + assert HostAddress("127.0.0.2", 9669) in client.servers + assert HostAddress("127.0.0.3", 9669) in client.servers + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_execute(self, mock_connection_class): + """Test executing a query""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + # Create a proper ExecuteResponse object + mock_execute_response = graph_pb2.ExecuteResponse() + mock_execute_response.status.code = b"00000" + mock_execute_response.status.message = b"Success" + # Initialize summary field by creating a Summary object + from nebulagraph_python.proto.graph_pb2 import Summary + mock_execute_response.summary.CopyFrom(Summary()) + mock_connection.execute_default_timeout.return_value = mock_execute_response + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + result = client.execute("RETURN 1") + + assert result is not None + mock_connection.execute_default_timeout.assert_called_once_with(12345, "RETURN 1") + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_execute_with_timeout(self, mock_connection_class): + """Test executing a query with custom timeout""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + # Create a proper ExecuteResponse object + mock_execute_response = graph_pb2.ExecuteResponse() + mock_execute_response.status.code = b"00000" + mock_execute_response.status.message = b"Success" + # Initialize summary field by creating a Summary object + from nebulagraph_python.proto.graph_pb2 import Summary + mock_execute_response.summary.CopyFrom(Summary()) + mock_connection.execute_default_timeout.return_value = mock_execute_response + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + result = client.execute_with_timeout("RETURN 1", 5000) + + assert result is not None + mock_connection.execute_default_timeout.assert_called_once_with(12345, "RETURN 1") + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_get_session_id(self, mock_connection_class): + """Test getting session ID""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + assert client.get_session_id() == 12345 + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_get_version(self, mock_connection_class): + """Test getting version""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + assert client.get_version() == "v5.0.0" + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_get_create_time(self, mock_connection_class): + """Test getting create time""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + create_time = client.get_create_time() + assert create_time > 0 + assert create_time <= int(time.time() * 1000) + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_get_host(self, mock_connection_class): + """Test getting host address""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + mock_connection.get_server_address.return_value = HostAddress("127.0.0.1", 9669) + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + assert client.get_host() == "127.0.0.1:9669" + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_get_connect_timeout_mills(self, mock_connection_class): + """Test getting connect timeout""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + connect_timeout_ms=5000, + ) + + assert client.get_connect_timeout_mills() == 5000 + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_get_request_timeout_mills(self, mock_connection_class): + """Test getting request timeout""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + request_timeout_ms=120000, + ) + + assert client.get_request_timeout_mills() == 120000 + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_get_scan_parallel(self, mock_connection_class): + """Test getting scan parallel""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + scan_parallel=20, + ) + + assert client.get_scan_parallel() == 20 + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_ping_success(self, mock_connection_class): + """Test successful ping""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + mock_connection.ping.return_value = True + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + assert client.ping() is True + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_ping_failure(self, mock_connection_class): + """Test ping failure""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + mock_connection.ping.side_effect = ExecutingError("Ping failed") + mock_connection.get_server_address.return_value = HostAddress("127.0.0.1", 9669) + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + assert client.ping() is False + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_close(self, mock_connection_class): + """Test closing the client""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + client.close() + + assert client.is_closed is True + mock_connection.execute.assert_called_once_with(12345, "SESSION CLOSE", 1000) + mock_connection.close.assert_called_once() + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_close_exception_handling(self, mock_connection_class): + """Test closing client handles exceptions gracefully""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + mock_connection.execute.side_effect = Exception("Close failed") + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + # Should not raise exception + client.close() + + assert client.is_closed is True + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_is_closed_client(self, mock_connection_class): + """Test checking if client is closed""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + assert client.is_closed_client() is False + + client.close() + + assert client.is_closed_client() is True + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_execute_after_close_raises_error(self, mock_connection_class): + """Test executing after close raises error""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + client.close() + + with pytest.raises(RuntimeError, match="The NebulaClient already closed"): + client.execute("RETURN 1") + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_ping_after_close_raises_error(self, mock_connection_class): + """Test ping after close raises error""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + client.close() + + with pytest.raises(RuntimeError, match="The NebulaClient already closed"): + client.ping() + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_validate_address_valid(self, mock_connection_class): + """Test validating valid address""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + addresses = NebulaClient._validate_address("127.0.0.1:9669") + + assert addresses == [HostAddress("127.0.0.1", 9669)] + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_validate_multiple_addresses(self, mock_connection_class): + """Test validating multiple addresses""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + addresses = NebulaClient._validate_address("127.0.0.1:9669,127.0.0.2:9669,127.0.0.3:9669") + + assert len(addresses) == 3 + assert addresses[0] == HostAddress("127.0.0.1", 9669) + assert addresses[1] == HostAddress("127.0.0.2", 9669) + assert addresses[2] == HostAddress("127.0.0.3", 9669) + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_validate_address_invalid(self, mock_connection_class): + """Test validating invalid address raises error""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + mock_connection.authenticate.return_value = mock_auth_result + + with pytest.raises(ValueError, match="Invalid address format"): + NebulaClient._validate_address("127.0.0.1") + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_init_retries_on_failure(self, mock_connection_class): + """Test client initialization retries on connection failure""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_auth_result = MagicMock() + mock_auth_result.get_session_id.return_value = 12345 + mock_auth_result.get_version.return_value = "v5.0.0" + + # First two attempts fail, third succeeds + mock_connection.open.side_effect = [ + Exception("Connection failed"), + Exception("Connection failed"), + None, + ] + mock_connection.authenticate.return_value = mock_auth_result + + client = NebulaClient( + addresses="127.0.0.1:9669,127.0.0.2:9669,127.0.0.3:9669", + user_name="test_user", + password="test_pass", + ) + + assert client.session_id == 12345 + assert mock_connection.open.call_count == 3 + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_init_auth_failure_raises_error(self, mock_connection_class): + """Test client initialization raises error on auth failure""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_connection.authenticate.side_effect = AuthenticatingError("Auth failed") + + with pytest.raises(AuthenticatingError, match="Auth failed"): + NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) + + @patch("nebulagraph_python.client.nebula_client.GrpcConnection") + def test_client_init_all_servers_fail_raises_error(self, mock_connection_class): + """Test client initialization raises error when all servers fail""" + mock_connection = MagicMock() + mock_connection_class.return_value = mock_connection + mock_connection.open.side_effect = Exception("Connection failed") + + with pytest.raises(Exception, match="Connection failed"): + NebulaClient( + addresses="127.0.0.1:9669", + user_name="test_user", + password="test_pass", + ) \ No newline at end of file diff --git a/tests/test_nebula_client_decode_integration.py b/tests/test_nebula_client_decode_integration.py index 7d369a12..9b91ee13 100644 --- a/tests/test_nebula_client_decode_integration.py +++ b/tests/test_nebula_client_decode_integration.py @@ -57,8 +57,8 @@ def setUpClass(cls): cls.client = None try: cls.client = NebulaClient( - hosts=cls.hosts, - username=cls.user, + addresses=cls.hosts, + user_name=cls.user, password=cls.password, ) diff --git a/tests/test_nebula_pool.py b/tests/test_nebula_pool.py new file mode 100644 index 00000000..03b562d1 --- /dev/null +++ b/tests/test_nebula_pool.py @@ -0,0 +1,889 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Complete tests for NebulaPool and NebulaPoolConfig""" + +import time +import threading +from unittest.mock import Mock, MagicMock, patch, AsyncMock + +import pytest + +from nebulagraph_python.client.nebula_pool import NebulaPool, NebulaPoolConfig +from nebulagraph_python.client.nebula_client import NebulaClient +from nebulagraph_python.data import HostAddress, SSLParam +from nebulagraph_python.error import AuthenticatingError, ExecutingError +from nebulagraph_python.client.constants import ( + DEFAULT_MAX_CLIENT_SIZE, + DEFAULT_MIN_CLIENT_SIZE, + DEFAULT_CONNECT_TIMEOUT_MS, + DEFAULT_REQUEST_TIMEOUT_MS, + DEFAULT_PING_TIMEOUT_MS, + DEFAULT_HEALTH_CHECK_TIME_MS, + DEFAULT_TEST_ON_BORROW, + DEFAULT_BLOCK_WHEN_EXHAUSTED, + DEFAULT_MAX_WAIT_MS, + DEFAULT_IDLE_EVICT_SCHEDULE_MS, + DEFAULT_MIN_EVICTABLE_IDLE_TIME_MS, + DEFAULT_STRICT_SERVER_HEALTHY, + DEFAULT_MAX_LIFE_TIME_MS, + DEFAULT_SCAN_PARALLEL, + DEFAULT_ENABLE_TLS, +) + + +class TestNebulaPoolConfig: + """Test cases for NebulaPoolConfig""" + + def test_config_defaults(self): + """Test NebulaPoolConfig with default values""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass" + ) + assert config.addresses == "127.0.0.1:9669" + assert config.username == "test_user" + assert config.password == "test_pass" + assert config.max_client_size == DEFAULT_MAX_CLIENT_SIZE + assert config.min_client_size == DEFAULT_MIN_CLIENT_SIZE + assert config.max_wait_ms == DEFAULT_MAX_WAIT_MS + assert config.block_when_exhausted == DEFAULT_BLOCK_WHEN_EXHAUSTED + assert config.connect_timeout_ms == DEFAULT_CONNECT_TIMEOUT_MS + assert config.request_timeout_ms == DEFAULT_REQUEST_TIMEOUT_MS + assert config.server_ping_timeout_ms == DEFAULT_PING_TIMEOUT_MS + assert config.health_check_time_ms == DEFAULT_HEALTH_CHECK_TIME_MS + assert config.test_on_borrow == DEFAULT_TEST_ON_BORROW + assert config.idle_evict_schedule_ms == DEFAULT_IDLE_EVICT_SCHEDULE_MS + assert config.min_evictable_idle_time_ms == DEFAULT_MIN_EVICTABLE_IDLE_TIME_MS + assert config.strictly_server_healthy == DEFAULT_STRICT_SERVER_HEALTHY + assert config.max_life_time_ms == DEFAULT_MAX_LIFE_TIME_MS + assert config.scan_parallel == DEFAULT_SCAN_PARALLEL + assert config.enable_tls == DEFAULT_ENABLE_TLS + assert config.graph is None + assert config.schema is None + assert config.timezone is None + assert config.session_configs == {} + assert config.parameters == {} + assert config.pre_statements == [] + assert config.ssl_param is None + + def test_config_custom_pool_settings(self): + """Test NebulaPoolConfig with custom pool settings""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + max_client_size=20, + min_client_size=5, + max_wait_ms=5000, + block_when_exhausted=True, + ) + assert config.max_client_size == 20 + assert config.min_client_size == 5 + assert config.max_wait_ms == 5000 + assert config.block_when_exhausted is True + + def test_config_custom_timeout_settings(self): + """Test NebulaPoolConfig with custom timeout settings""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + connect_timeout_ms=5000, + request_timeout_ms=120000, + server_ping_timeout_ms=2000, + ) + assert config.connect_timeout_ms == 5000 + assert config.request_timeout_ms == 120000 + assert config.server_ping_timeout_ms == 2000 + + def test_config_custom_health_check_settings(self): + """Test NebulaPoolConfig with custom health check settings""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + health_check_time_ms=300000, + test_on_borrow=False, + ) + assert config.health_check_time_ms == 300000 + assert config.test_on_borrow is False + + def test_config_custom_eviction_settings(self): + """Test NebulaPoolConfig with custom eviction settings""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + idle_evict_schedule_ms=60000, + min_evictable_idle_time_ms=900000, + ) + assert config.idle_evict_schedule_ms == 60000 + assert config.min_evictable_idle_time_ms == 900000 + + def test_config_custom_server_settings(self): + """Test NebulaPoolConfig with custom server settings""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + strictly_server_healthy=True, + max_life_time_ms=3600000, + ) + assert config.strictly_server_healthy is True + assert config.max_life_time_ms == 3600000 + + def test_config_custom_session_settings(self): + """Test NebulaPoolConfig with custom session settings""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + graph="test_graph", + schema="test_schema", + timezone="UTC", + session_configs={"key": "value"}, + parameters={"param1": "value1"}, + pre_statements=["USE test_graph"], + ) + assert config.graph == "test_graph" + assert config.schema == "test_schema" + assert config.timezone == "UTC" + assert config.session_configs == {"key": "value"} + assert config.parameters == {"param1": "value1"} + assert config.pre_statements == ["USE test_graph"] + + def test_config_custom_other_settings(self): + """Test NebulaPoolConfig with custom other settings""" + ssl_param = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + scan_parallel=20, + enable_tls=True, + ssl_param=ssl_param, + ) + assert config.scan_parallel == 20 + assert config.enable_tls is True + assert config.ssl_param == ssl_param + + def test_config_auth_options_post_init(self): + """Test that auth_options is populated with password in __post_init__""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass" + ) + assert config.auth_options == {"password": "test_pass"} + + def test_config_auth_options_without_password(self): + """Test auth_options without password""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password=None + ) + assert config.auth_options == {} + + def test_config_multiple_addresses(self): + """Test NebulaPoolConfig with multiple addresses""" + config = NebulaPoolConfig( + addresses="127.0.0.1:9669,127.0.0.2:9669,127.0.0.3:9669", + username="test_user", + password="test_pass" + ) + assert config.addresses == "127.0.0.1:9669,127.0.0.2:9669,127.0.0.3:9669" + + def test_config_all_parameters(self): + """Test NebulaPoolConfig with all parameters""" + ssl_param = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") + config = NebulaPoolConfig( + addresses="127.0.0.1:9669,127.0.0.2:9669", + username="test_user", + password="test_pass", + max_client_size=20, + min_client_size=5, + max_wait_ms=5000, + block_when_exhausted=True, + connect_timeout_ms=5000, + request_timeout_ms=120000, + server_ping_timeout_ms=2000, + health_check_time_ms=300000, + test_on_borrow=False, + idle_evict_schedule_ms=60000, + min_evictable_idle_time_ms=900000, + strictly_server_healthy=True, + max_life_time_ms=3600000, + graph="test_graph", + schema="test_schema", + timezone="UTC", + session_configs={"key": "value"}, + parameters={"param1": "value1"}, + pre_statements=["USE test_graph"], + scan_parallel=20, + enable_tls=True, + ssl_param=ssl_param, + ) + assert config.addresses == "127.0.0.1:9669,127.0.0.2:9669" + assert config.username == "test_user" + assert config.password == "test_pass" + assert config.max_client_size == 20 + assert config.min_client_size == 5 + assert config.max_wait_ms == 5000 + assert config.block_when_exhausted is True + assert config.connect_timeout_ms == 5000 + assert config.request_timeout_ms == 120000 + assert config.server_ping_timeout_ms == 2000 + assert config.health_check_time_ms == 300000 + assert config.test_on_borrow is False + assert config.idle_evict_schedule_ms == 60000 + assert config.min_evictable_idle_time_ms == 900000 + assert config.strictly_server_healthy is True + assert config.max_life_time_ms == 3600000 + assert config.graph == "test_graph" + assert config.schema == "test_schema" + assert config.timezone == "UTC" + assert config.session_configs == {"key": "value"} + assert config.parameters == {"param1": "value1"} + assert config.pre_statements == ["USE test_graph"] + assert config.scan_parallel == 20 + assert config.enable_tls is True + assert config.ssl_param == ssl_param + + +class TestNebulaPool: + """Test cases for NebulaPool""" + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_creation_with_defaults(self, mock_factory_class, mock_lb_class): + """Test creating a NebulaPool with default configuration""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=2 + ) + + pool = NebulaPool(config) + + assert pool.config == config + assert len(pool._pool) == 2 + assert pool._closed is False + mock_lb.check_servers.assert_called_once() + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_creation_with_custom_config(self, mock_factory_class, mock_lb_class): + """Test creating a NebulaPool with custom configuration""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669,127.0.0.2:9669", + username="test_user", + password="test_pass", + max_client_size=10, + min_client_size=3, + test_on_borrow=False, + ) + + pool = NebulaPool(config) + + assert len(pool._pool) == 3 + assert mock_factory.create.call_count == 3 + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_creation_with_ssl(self, mock_factory_class, mock_lb_class): + """Test creating a NebulaPool with SSL enabled""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + ssl_param = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + enable_tls=True, + ssl_param=ssl_param, + min_client_size=1 + ) + + pool = NebulaPool(config) + + assert pool.config.enable_tls is True + assert pool.config.ssl_param == ssl_param + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_get_client_success(self, mock_factory_class, mock_lb_class): + """Test successfully getting a client from the pool""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + mock_factory.validate.return_value = True + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1, + test_on_borrow=True + ) + + pool = NebulaPool(config) + + client = pool.get_client() + + assert client is not None + assert client in pool._pool + assert pool._in_use[client] is True + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_get_client_creates_new(self, mock_factory_class, mock_lb_class): + """Test getting a client creates a new one if under max limit""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client1 = MagicMock(spec=NebulaClient) + mock_client1.is_closed_client.return_value = False + mock_client1.get_create_time.return_value = int(time.time() * 1000) + mock_client2 = MagicMock(spec=NebulaClient) + mock_client2.is_closed_client.return_value = False + mock_client2.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.side_effect = [mock_client1, mock_client2] + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1, + max_client_size=2 + ) + + pool = NebulaPool(config) + + # Get all clients + client1 = pool.get_client() + client2 = pool.get_client() + + assert client1 != client2 + assert len(pool._pool) == 2 + assert pool._in_use[client1] is True + assert pool._in_use[client2] is True + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_get_client_timeout(self, mock_factory_class, mock_lb_class): + """Test getting a client times out when all are in use""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1, + max_client_size=1, + max_wait_ms=100, + block_when_exhausted=True # Enable blocking to wait for timeout + ) + + pool = NebulaPool(config) + + # Get the only client + client1 = pool.get_client() + + # Try to get another - should timeout + with pytest.raises(RuntimeError, match="Timeout waiting for client"): + pool.get_client() + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_get_client_block_when_exhausted_false(self, mock_factory_class, mock_lb_class): + """Test getting a client raises when block_when_exhausted is False""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1, + max_client_size=1, + block_when_exhausted=False + ) + + pool = NebulaPool(config) + + # Get the only client + client1 = pool.get_client() + + # Try to get another - should raise immediately + with pytest.raises(RuntimeError, match="No available clients in pool"): + pool.get_client() + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_get_client_test_on_borrow_invalidates(self, mock_factory_class, mock_lb_class): + """Test getting a client with test_on_borrow invalidates invalid clients""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client1 = MagicMock(spec=NebulaClient) + mock_client1.is_closed_client.return_value = False + mock_client1.get_create_time.return_value = int(time.time() * 1000) + mock_client2 = MagicMock(spec=NebulaClient) + mock_client2.is_closed_client.return_value = False + mock_client2.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.side_effect = [mock_client1, mock_client2] + mock_factory.validate.side_effect = [False, True] + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1, + test_on_borrow=True + ) + + pool = NebulaPool(config) + + # First client should be invalidated, second should be returned + client = pool.get_client() + + assert client == mock_client2 + assert mock_client1 not in pool._pool + assert mock_client2 in pool._pool + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_return_client_success(self, mock_factory_class, mock_lb_class): + """Test successfully returning a client to the pool""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1 + ) + + pool = NebulaPool(config) + + client = pool.get_client() + assert pool._in_use[client] is True + + pool.return_client(client) + assert pool._in_use[client] is False + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_return_client_closed(self, mock_factory_class, mock_lb_class): + """Test returning a closed client removes it from pool""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = True + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1 + ) + + pool = NebulaPool(config) + + client = pool.get_client() + pool.return_client(client) + + assert client not in pool._pool + assert client not in pool._in_use + mock_factory.destroy.assert_called_once_with(client) + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_return_client_expired(self, mock_factory_class, mock_lb_class): + """Test returning an expired client removes it from pool""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + # Set create time to 2 hours ago + mock_client.get_create_time.return_value = int(time.time() * 1000) - 7200000 + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1, + max_life_time_ms=3600000 # 1 hour + ) + + pool = NebulaPool(config) + + client = pool.get_client() + pool.return_client(client) + + assert client not in pool._pool + assert client not in pool._in_use + mock_factory.destroy.assert_called_once_with(client) + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_close(self, mock_factory_class, mock_lb_class): + """Test closing the pool""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client1 = MagicMock(spec=NebulaClient) + mock_client2 = MagicMock(spec=NebulaClient) + mock_factory.create.side_effect = [mock_client1, mock_client2] + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=2 + ) + + pool = NebulaPool(config) + pool.close() + + assert pool._closed is True + assert len(pool._pool) == 0 + assert len(pool._in_use) == 0 + mock_factory.destroy.assert_any_call(mock_client1) + mock_factory.destroy.assert_any_call(mock_client2) + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_get_client_after_close(self, mock_factory_class, mock_lb_class): + """Test getting a client after pool is closed raises error""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1 + ) + + pool = NebulaPool(config) + pool.close() + + with pytest.raises(RuntimeError, match="Pool is closed"): + pool.get_client() + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_get_active_sessions(self, mock_factory_class, mock_lb_class): + """Test getting active session count""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client1 = MagicMock(spec=NebulaClient) + mock_client2 = MagicMock(spec=NebulaClient) + mock_client3 = MagicMock(spec=NebulaClient) + mock_factory.create.side_effect = [mock_client1, mock_client2, mock_client3] + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=3 + ) + + pool = NebulaPool(config) + + assert pool.get_active_sessions() == 0 + + client1 = pool.get_client() + assert pool.get_active_sessions() == 1 + + client2 = pool.get_client() + assert pool.get_active_sessions() == 2 + + pool.return_client(client1) + assert pool.get_active_sessions() == 1 + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_get_idle_sessions(self, mock_factory_class, mock_lb_class): + """Test getting idle session count""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client1 = MagicMock(spec=NebulaClient) + mock_client2 = MagicMock(spec=NebulaClient) + mock_client3 = MagicMock(spec=NebulaClient) + mock_client1.is_closed_client.return_value = False + mock_client2.is_closed_client.return_value = False + mock_client3.is_closed_client.return_value = False + mock_client1.get_create_time.return_value = int(time.time() * 1000) + mock_client2.get_create_time.return_value = int(time.time() * 1000) + mock_client3.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.side_effect = [mock_client1, mock_client2, mock_client3] + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=3 + ) + + pool = NebulaPool(config) + + # Initially all 3 clients are idle + assert pool.get_idle_sessions() == 3 + + # Get client1 - now 2 idle + client1 = pool.get_client() + assert pool.get_idle_sessions() == 2 + + # Get client2 - now 1 idle + client2 = pool.get_client() + assert pool.get_idle_sessions() == 1 + + # Return client1 - now 2 idle + pool.return_client(client1) + assert pool.get_idle_sessions() == 2 + + # Return client2 - now 3 idle + pool.return_client(client2) + assert pool.get_idle_sessions() == 3 + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_context_manager(self, mock_factory_class, mock_lb_class): + """Test pool as context manager""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=1 + ) + + with NebulaPool(config) as pool: + assert pool._closed is False + + assert pool._closed is True + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_concurrent_access(self, mock_factory_class, mock_lb_class): + """Test concurrent access to the pool""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_clients = [MagicMock(spec=NebulaClient) for _ in range(3)] + for client in mock_clients: + client.is_closed_client.return_value = False + client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.side_effect = mock_clients + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + max_client_size=3, + min_client_size=3 + ) + + pool = NebulaPool(config) + + results = [] + errors = [] + + def use_pool(thread_id): + try: + client = pool.get_client() + results.append(thread_id) + time.sleep(0.1) + pool.return_client(client) + except Exception as e: + errors.append((thread_id, e)) + + threads = [] + for i in range(3): + thread = threading.Thread(target=use_pool, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert len(errors) == 0 + assert len(results) == 3 + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_parse_addresses(self, mock_factory_class, mock_lb_class): + """Test address parsing""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669,127.0.0.2:9669,127.0.0.3:9669", + username="test_user", + password="test_pass", + min_client_size=1 + ) + + pool = NebulaPool(config) + + addresses = NebulaPool._parse_addresses(config.addresses) + assert len(addresses) == 3 + assert addresses[0] == HostAddress("127.0.0.1", 9669) + assert addresses[1] == HostAddress("127.0.0.2", 9669) + assert addresses[2] == HostAddress("127.0.0.3", 9669) + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_parse_addresses_invalid(self, mock_factory_class, mock_lb_class): + """Test parsing invalid addresses raises error""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client = MagicMock(spec=NebulaClient) + mock_client.is_closed_client.return_value = False + mock_client.get_create_time.return_value = int(time.time() * 1000) + mock_factory.create.return_value = mock_client + + config = NebulaPoolConfig( + addresses="127.0.0.1", # Missing port + username="test_user", + password="test_pass", + min_client_size=1 + ) + + with pytest.raises(ValueError, match="Invalid address format"): + NebulaPool._parse_addresses(config.addresses) + + @patch("nebulagraph_python.client.nebula_pool.RoundRobinLoadBalancer") + @patch("nebulagraph_python.client.nebula_pool.ClientPoolFactory") + def test_pool_init_failure_handles_gracefully(self, mock_factory_class, mock_lb_class): + """Test pool initialization handles creation failures gracefully""" + mock_lb = MagicMock() + mock_lb_class.return_value = mock_lb + mock_factory = MagicMock() + mock_factory_class.return_value = mock_factory + mock_client1 = MagicMock(spec=NebulaClient) + mock_client2 = MagicMock(spec=NebulaClient) + mock_client1.is_closed_client.return_value = False + mock_client1.get_create_time.return_value = int(time.time() * 1000) + mock_client2.is_closed_client.return_value = False + mock_client2.get_create_time.return_value = int(time.time() * 1000) + + # First call succeeds, second fails, third succeeds + mock_factory.create.side_effect = [mock_client1, Exception("Create failed"), mock_client2] + + config = NebulaPoolConfig( + addresses="127.0.0.1:9669", + username="test_user", + password="test_pass", + min_client_size=3 + ) + + # Pool should still be created, but with fewer clients + pool = NebulaPool(config) + + # Should have 2 clients (one failed) + assert len(pool._pool) == 2 \ No newline at end of file diff --git a/tests/test_nebula_pool_integration.py b/tests/test_nebula_pool_integration.py new file mode 100644 index 00000000..bbe4387c --- /dev/null +++ b/tests/test_nebula_pool_integration.py @@ -0,0 +1,625 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Integration tests for NebulaPool with real NebulaGraph connection""" + +import os +import time +import threading +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pytest + +from nebulagraph_python import NebulaPool, NebulaPoolConfig +from nebulagraph_python.error import AuthenticatingError, ExecutingError + +# 从环境变量获取测试配置,如果没有则使用默认值 +NEBULA_HOSTS = os.getenv("NEBULA_HOSTS", "192.168.8.6:3820") +NEBULA_USER = os.getenv("NEBULA_USER", "root") +NEBULA_PASSWORD = os.getenv("NEBULA_PASSWORD", "NebulaGraph01") + + +@pytest.mark.integration +class TestNebulaPoolIntegration: + + def test_nebula_pool_basic(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + result = client.execute("RETURN 1") + assert result.is_succeeded + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_builder(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + connect_timeout_ms=1111, + request_timeout_ms=2222, + scan_parallel=15, + health_check_time_ms=3333, + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + assert client.get_connect_timeout_mills() == 1111 + assert client.get_request_timeout_mills() == 2222 + assert client.get_scan_parallel() == 15 + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_null_user(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=None, + password=None, + connect_timeout_ms=1111, + request_timeout_ms=2222, + scan_parallel=15, + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + pool.return_client(client) + pytest.fail("Should have raised AuthenticatingError") + except AuthenticatingError: + # Expected + pass + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_wrong_password(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password="wrong_password", + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + pytest.fail("Should have raised AuthenticatingError") + except AuthenticatingError as e: + # Expected + assert "invalid username or password" in str(e) or "Auth failed" in str(e) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_wrong_server(self): + print("<==== test_nebula_pool_wrong_server ====>") + pool = None + try: + config = NebulaPoolConfig( + addresses="127.0.0.1:1000", + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + pytest.fail("Should have raised ExecutingError") + except ExecutingError as e: + # Expected - can be timeout or connection refused + assert "Connection refused" in str(e) or "UNAVAILABLE" in str(e) or "timeout" in str(e).lower() + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_session_set_graph(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + # graph="test_pool_space", # Skip graph setting for simplicity + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + # Just verify connection works + result = client.execute("RETURN 1") + assert result.is_succeeded + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_session_set_timezone(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + timezone="Asia/Shanghai", + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + # Just verify connection works + result = client.execute("show current_session") + assert result.is_succeeded + for record in result: + timezone_value = record["timezone"] + assert timezone_value.cast() == "Asia/Shanghai" + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_session_set_format(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + session_configs={ + "date_format": "\"%Y/%m/%d\"", + "local_datetime_format": "\"%Y-%m-%d %H:%M:%S\"", + }, + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + # Just verify connection works + result = client.execute("show session configs") + assert result.is_succeeded + for record in result: + name_value = record["name"] + if name_value.cast() == "date_format": + assert record["value"].cast() == "%Y/%m/%d" + if name_value.cast() == "local_datetime_format": + assert record["value"].cast() == "%Y-%m-%d %H:%M:%S" + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_pre_statements(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + pre_statements=["RETURN 1", "session set timezone=\"Asia/Shanghai\""], + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + # Just verify connection works after pre-statements + result = client.execute("show current_session") + assert result.is_succeeded + for record in result: + timezone_value = record["timezone"] + assert timezone_value.cast() == "Asia/Shanghai" + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_wrong_pre_statement(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + pre_statements=["wrong statement"], + max_client_size=10, + min_client_size=1, + ) + # Pool initialization will log warning but continue + # This is expected behavior - pool is resilient + pool = NebulaPool(config) + # Pool should still work for other operations + # Even if initial client creation failed + try: + client = pool.get_client() + result = client.execute("RETURN 1") + # This might succeed if pool recovered or if min_client_size was 0 + pool.return_client(client) + except Exception as e: + # Expected if no clients could be created + pass + except Exception as e: + # If pool creation itself fails, that's acceptable + assert "wrong statement" in str(e).lower() or "syntax error" in str(e).lower() + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_max_life_time(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_life_time_ms=5000, # 5秒 + max_client_size=1, + min_client_size=1, + ) + pool = NebulaPool(config) + client1 = pool.get_client() + session_id1 = client1.get_session_id() + time.sleep(6) # wait to beyond the max life + pool.return_client(client1) + client2 = pool.get_client() + session_id2 = client2.get_session_id() + assert session_id1 != session_id2 + pool.return_client(client2) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_multiple_clients(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=5, + min_client_size=2, + ) + pool = NebulaPool(config) + + clients = [] + for i in range(3): + client = pool.get_client() + result = client.execute("RETURN 1") + assert result.is_succeeded + clients.append(client) + + for client in clients: + pool.return_client(client) + + # get client again + client = pool.get_client() + result = client.execute("RETURN 1") + assert result.is_succeeded + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_concurrent_access(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + + failed_count = [0] + lock = threading.Lock() + + def execute_query(thread_id): + try: + client = pool.get_client() + result = client.execute("RETURN 1") + if not result.is_succeeded: + with lock: + failed_count[0] += 1 + pool.return_client(client) + except Exception as e: + with lock: + failed_count[0] += 1 + + # create 10 thread to execute parallel + threads = [] + for i in range(10): + thread = threading.Thread(target=execute_query, args=(i,)) + threads.append(thread) + thread.start() + + for thread in threads: + thread.join() + + assert failed_count[0] == 0, f"{failed_count[0]} threads failed" + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_get_active_idle_sessions(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=5, + min_client_size=3, + ) + pool = NebulaPool(config) + + # at begin:all client is idle + assert pool.get_idle_sessions() == 3 + assert pool.get_active_sessions() == 0 + + # get one client + client1 = pool.get_client() + assert pool.get_idle_sessions() == 2 + assert pool.get_active_sessions() == 1 + + # get another client + client2 = pool.get_client() + assert pool.get_idle_sessions() == 1 + assert pool.get_active_sessions() == 2 + + # return one client + pool.return_client(client1) + assert pool.get_idle_sessions() == 2 + assert pool.get_active_sessions() == 1 + + # return another client + pool.return_client(client2) + assert pool.get_idle_sessions() == 3 + assert pool.get_active_sessions() == 0 + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_context_manager(self): + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=10, + min_client_size=1, + ) + with NebulaPool(config) as pool: + client = pool.get_client() + result = client.execute("RETURN 1") + assert result.is_succeeded + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + + def test_nebula_pool_multiple_addresses(self): + pool = None + try: + addresses = f"{NEBULA_HOSTS},{NEBULA_HOSTS}" + config = NebulaPoolConfig( + addresses=addresses, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + result = client.execute("RETURN 1") + assert result.is_succeeded + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_test_on_borrow(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + test_on_borrow=True, + max_client_size=2, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + result = client.execute("RETURN 1") + assert result.is_succeeded + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_timeout_when_exhausted(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=1, + min_client_size=1, + max_wait_ms=100, + block_when_exhausted=True, + ) + pool = NebulaPool(config) + + client1 = pool.get_client() + + with pytest.raises(RuntimeError, match="Timeout waiting for client"): + pool.get_client() + + pool.return_client(client1) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_no_block_when_exhausted(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=1, + min_client_size=1, + block_when_exhausted=False, + ) + pool = NebulaPool(config) + + client1 = pool.get_client() + + with pytest.raises(RuntimeError, match="No available clients in pool"): + pool.get_client() + + pool.return_client(client1) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_client_ping(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + assert client.ping() is True + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_client_execute_complex_query(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=10, + min_client_size=1, + ) + pool = NebulaPool(config) + client = pool.get_client() + + queries = [ + "RETURN 1 AS num", + "RETURN 'hello' AS str", + "RETURN 1 + 2 AS result", + "RETURN [1, 2, 3]", + "RETURN {key: 'value'}", + "RETURN 1.5 AS float_num", + ] + + for query in queries: + result = client.execute(query) + assert result.is_succeeded, f"Query failed: {query} - {result.status_message}" + + pool.return_client(client) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() + + def test_nebula_pool_reuse_client(self): + pool = None + try: + config = NebulaPoolConfig( + addresses=NEBULA_HOSTS, + username=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=1, + min_client_size=1, + ) + pool = NebulaPool(config) + + # get one client + client1 = pool.get_client() + session_id1 = client1.get_session_id() + pool.return_client(client1) + + # get one client again + client2 = pool.get_client() + session_id2 = client2.get_session_id() + assert session_id1 == session_id2, "Should reuse the same client" + pool.return_client(client2) + except Exception as e: + pytest.fail(f"Test failed: {e}") + finally: + if pool is not None: + pool.close() diff --git a/tests/test_session_pool.py b/tests/test_session_pool.py deleted file mode 100644 index 05b40872..00000000 --- a/tests/test_session_pool.py +++ /dev/null @@ -1,998 +0,0 @@ -import asyncio -import threading -import time -from unittest.mock import AsyncMock, Mock, patch - -import pytest - -from nebulagraph_python.client._session_pool import ( - AsyncSessionPool, - SessionPool, - SessionPoolConfig, -) -from nebulagraph_python.client._session import ( - Session, - AsyncSession, - SessionConfig, -) -from nebulagraph_python.error import PoolError -from copy import copy - - -class TestSessionPool: - """Test cases for SessionPool (synchronous)""" - - def test_init_basic(self): - """Test basic initialization""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=3) - pool = SessionPool(copy(sessions), config) - - assert pool.free_sessions_queue == sessions - assert pool.busy_sessions_queue == set() - assert len(pool.free_sessions_queue) == 3 - assert pool.queue_count._value == 3 # Semaphore initial value - - def test_init_with_config(self): - """Test initialization with custom config""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=2, wait_timeout=10.0) - pool = SessionPool(copy(sessions), config) - - assert pool.config.size == 2 - assert pool.config.wait_timeout == 10.0 - - def test_init_with_all_config_params(self): - """Test initialization with all configuration parameters""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - config = SessionPoolConfig( - size=3, - wait_timeout=5.0, - ) - pool = SessionPool(copy(sessions), config) - - assert pool.config.size == 3 - assert pool.config.wait_timeout == 5.0 - assert len(pool.free_sessions_queue) == 3 - - def test_borrow_single_session(self): - """Test borrowing a single session""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - pool = SessionPool(copy(sessions), SessionPoolConfig(size=3)) - - with pool.borrow() as session: - assert session in sessions - assert session in pool.busy_sessions_queue - assert session not in pool.free_sessions_queue - assert len(pool.busy_sessions_queue) == 1 - assert len(pool.free_sessions_queue) == 2 - - # After context exit, session should be returned - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 3 - assert session in pool.free_sessions_queue - - def test_borrow_all_sessions(self): - """Test borrowing all available sessions""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - pool = SessionPool(copy(sessions), SessionPoolConfig(size=2)) - - with pool.borrow() as session1: - with pool.borrow() as session2: - assert {session1, session2} == sessions - assert len(pool.busy_sessions_queue) == 2 - assert len(pool.free_sessions_queue) == 0 - - def test_borrow_timeout_exceeded(self): - """Test borrowing when timeout is exceeded""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=1, wait_timeout=0.2) - pool = SessionPool(copy(sessions), config) - - with pool.borrow(): # Acquire the only session - # Try to borrow another session - should timeout - with pytest.raises(PoolError, match="No session available in the SessionPool after waiting 0.2 seconds"): - with pool.borrow(): - pass - - def test_borrow_infinite_wait_with_release(self): - """Test borrowing with infinite wait that succeeds when session becomes available""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=1, wait_timeout=None) - pool = SessionPool(copy(sessions), config) - - def release_session(): - time.sleep(0.1) # Wait a bit - # This will be triggered by the context manager exit - - def acquire_and_release(): - with pool.borrow(): - threading.Thread(target=release_session).start() - time.sleep(0.2) # Hold session briefly - - # Start a thread that will acquire and then release the session - thread1 = threading.Thread(target=acquire_and_release) - thread1.start() - - time.sleep(0.05) # Ensure first thread acquires the session - - # This should succeed once the first thread releases the session - start_time = time.time() - with pool.borrow() as session: - assert session in sessions - elapsed = time.time() - start_time - assert elapsed >= 0.15 # Should have waited for release - - thread1.join() - - def test_concurrent_borrowing(self): - """Test concurrent borrowing from multiple threads""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) - for i in range(5) - } - config = SessionPoolConfig(size=5) - pool = SessionPool(copy(sessions), config) - results = [] - errors = [] - - def borrow_session(thread_id): - try: - with pool.borrow() as session: - results.append((thread_id, session)) - time.sleep(0.1) # Simulate work - except Exception as e: - errors.append((thread_id, e)) - - # Start multiple threads - threads = [] - for i in range(5): - thread = threading.Thread(target=borrow_session, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - assert len(errors) == 0, f"Unexpected errors: {errors}" - assert len(results) == 5 - - # All sessions should be returned - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 5 - - def test_semaphore_consistency(self): - """Test that semaphore behavior stays consistent with actual session availability""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=2) - pool = SessionPool(copy(sessions), config) - - # Test we can acquire sessions sequentially - with pool.borrow(): - # One session borrowed - test we can still acquire one more - acquired_second = pool.queue_count.acquire(blocking=False) - if acquired_second: - pool.queue_count.release() - assert acquired_second, "Should be able to acquire second session" - - with pool.borrow(): - # Both sessions borrowed - test we cannot acquire more - cannot_acquire = not pool.queue_count.acquire(blocking=False) - assert cannot_acquire, "Should not be able to acquire third session" - - # All sessions returned - test we can acquire again - acquired_after_return = pool.queue_count.acquire(blocking=False) - if acquired_after_return: - pool.queue_count.release() - assert acquired_after_return, "Should be able to acquire session after return" - - def test_close_all_free_sessions(self): - """Test closing pool with all sessions free""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=3) - pool = SessionPool(copy(sessions), config) - - # Mock the close_session method for all sessions - for session in sessions: - session._close = Mock() - - pool._close() - - # Should close all sessions - for session in sessions: - session._close.assert_called_once() - - @patch('nebulagraph_python.client._session_pool.logger') - def test_close_with_busy_sessions(self, mock_logger): - """Test closing pool with some busy sessions""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=3) - pool = SessionPool(copy(sessions), config) - - # Mock the close_session method for all sessions - for session in sessions: - session._close = Mock() - - # Manually move a session to busy state - busy_session = list(sessions)[1] # Get the second session - pool.free_sessions_queue.remove(busy_session) - pool.busy_sessions_queue.add(busy_session) - - pool._close() - - # Should close all sessions - for session in sessions: - session._close.assert_called_once() - # Should log error about busy sessions - mock_logger.error.assert_called_once() - assert "Busy sessions remain" in mock_logger.error.call_args[0][0] - - def test_connect_success(self): - """Test successful connection via classmethod""" - mock_conn = Mock() - - config = SessionPoolConfig(size=4) - - # Test the connect method - pool = SessionPool.connect( - conn=mock_conn, - username="test_user", - password="test_pass", - pool_config=config - ) - - # Verify the pool was created correctly - assert len(pool.free_sessions_queue) == 4 - assert len(pool.busy_sessions_queue) == 0 - - def test_connect_partial_failure(self): - """Test connection with partial failure during setup""" - # Create a mock connection that fails after creating some sessions - mock_conn = Mock() - - config = SessionPoolConfig(size=3) - - # Mock Session constructor to fail on third call - original_session = Session - call_count = 0 - - def mock_session_init(*args, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 3: - raise Exception("Auth failed") - return original_session(*args, **kwargs) - - with patch('nebulagraph_python.client._session_pool.Session', side_effect=mock_session_init): - with pytest.raises(Exception, match="Auth failed"): - SessionPool.connect( - conn=mock_conn, - username="test_user", - password="test_pass", - pool_config=config - ) - - def test_connect_authentication_failure_first_attempt(self): - """Test connection failure on first authentication attempt""" - mock_conn = Mock() - - config = SessionPoolConfig(size=2) - - # Mock Session constructor to fail on first call - with patch('nebulagraph_python.client._session_pool.Session', side_effect=Exception("Auth failed on first attempt")): - with pytest.raises(Exception, match="Auth failed on first attempt"): - SessionPool.connect( - conn=mock_conn, - username="test_user", - password="test_pass", - pool_config=config - ) - - def test_multiple_borrow_release_cycles(self): - """Test multiple borrow-release cycles work correctly""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=2) - pool = SessionPool(copy(sessions), config) - - # First cycle - with pool.borrow() as session1: - assert session1 in pool.busy_sessions_queue - assert len(pool.free_sessions_queue) == 1 - - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 2 - - # Second cycle - with pool.borrow() as session2: - with pool.borrow() as session3: - assert {session2, session3} == sessions - assert len(pool.busy_sessions_queue) == 2 - assert len(pool.free_sessions_queue) == 0 - - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 2 - - -class TestAsyncSessionPool: - """Test cases for AsyncSessionPool (asynchronous)""" - - @pytest.mark.asyncio - async def test_init_basic(self): - """Test basic initialization""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=3) - pool = AsyncSessionPool(copy(sessions), config) - - assert pool.free_sessions_queue == sessions - assert pool.busy_sessions_queue == set() - assert len(pool.free_sessions_queue) == 3 - # Test that we can borrow a session (semaphore has permits) - async with pool.borrow() as session: - assert session in sessions - - @pytest.mark.asyncio - async def test_init_with_config(self): - """Test initialization with custom config""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=2, wait_timeout=10.0) - pool = AsyncSessionPool(copy(sessions), config) - - assert pool.config.size == 2 - assert pool.config.wait_timeout == 10.0 - - @pytest.mark.asyncio - async def test_init_with_all_config_params(self): - """Test initialization with all configuration parameters""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - config = SessionPoolConfig( - size=3, - wait_timeout=5.0, - ) - pool = AsyncSessionPool(copy(sessions), config) - - assert pool.config.size == 3 - assert pool.config.wait_timeout == 5.0 - assert len(pool.free_sessions_queue) == 3 - - @pytest.mark.asyncio - async def test_borrow_single_session(self): - """Test borrowing a single session""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=3) - pool = AsyncSessionPool(copy(sessions), config) - - async with pool.borrow() as session: - assert session in sessions - assert session in pool.busy_sessions_queue - assert session not in pool.free_sessions_queue - assert len(pool.busy_sessions_queue) == 1 - assert len(pool.free_sessions_queue) == 2 - - # After context exit, session should be returned - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 3 - assert session in pool.free_sessions_queue - - @pytest.mark.asyncio - async def test_borrow_all_sessions(self): - """Test borrowing all available sessions""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=2) - pool = AsyncSessionPool(copy(sessions), config) - - async with pool.borrow() as session1: - async with pool.borrow() as session2: - assert {session1, session2} == sessions - assert len(pool.busy_sessions_queue) == 2 - assert len(pool.free_sessions_queue) == 0 - - @pytest.mark.asyncio - async def test_borrow_timeout_exceeded(self): - """Test borrowing when timeout is exceeded""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=1, wait_timeout=0.2) - pool = AsyncSessionPool(copy(sessions), config) - - async with pool.borrow(): # Acquire the only session - # Try to borrow another session - should timeout - with pytest.raises(PoolError, match="No session available in the SessionPool after waiting 0.2 seconds"): - async with pool.borrow(): - pass - - @pytest.mark.asyncio - async def test_borrow_infinite_wait_with_release(self): - """Test borrowing with infinite wait that succeeds when session becomes available""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=1, wait_timeout=None) - pool = AsyncSessionPool(copy(sessions), config) - - async def acquire_and_release(): - async with pool.borrow(): - await asyncio.sleep(0.2) # Hold session briefly - - # Start a task that will acquire and then release the session - task1 = asyncio.create_task(acquire_and_release()) - - await asyncio.sleep(0.05) # Ensure first task acquires the session - - # This should succeed once the first task releases the session - start_time = time.time() - async with pool.borrow() as session: - assert session in sessions - elapsed = time.time() - start_time - assert elapsed >= 0.15 # Should have waited for release - - await task1 - - @pytest.mark.asyncio - async def test_concurrent_borrowing(self): - """Test concurrent borrowing from multiple coroutines""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) - for i in range(5) - } - config = SessionPoolConfig(size=5) - pool = AsyncSessionPool(copy(sessions), config) - results = [] - errors = [] - - async def borrow_session(task_id): - try: - async with pool.borrow() as session: - results.append((task_id, session)) - await asyncio.sleep(0.1) # Simulate async work - except Exception as e: - errors.append((task_id, e)) - - # Start multiple concurrent tasks - tasks = [borrow_session(i) for i in range(5)] - await asyncio.gather(*tasks) - - assert len(errors) == 0, f"Unexpected errors: {errors}" - assert len(results) == 5 - - # All sessions should be returned - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 5 - - @pytest.mark.asyncio - async def test_semaphore_consistency(self): - """Test that semaphore behavior stays consistent with actual session availability""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=2) - pool = AsyncSessionPool(copy(sessions), config) - - # Test we can acquire sessions sequentially - async with pool.borrow(): - # One session borrowed - test we can still acquire one more - acquired_second = False - try: - pool.queue_count.acquire_nowait() - acquired_second = True - pool.queue_count.release() - except: - pass - assert acquired_second, "Should be able to acquire second session" - - async with pool.borrow(): - # Both sessions borrowed - test we cannot acquire more - cannot_acquire = False - try: - pool.queue_count.acquire_nowait() - except: - cannot_acquire = True - assert cannot_acquire, "Should not be able to acquire third session" - - # All sessions returned - test we can acquire again - acquired_after_return = False - try: - pool.queue_count.acquire_nowait() - acquired_after_return = True - pool.queue_count.release() - except: - pass - assert acquired_after_return, "Should be able to acquire session after return" - - @pytest.mark.asyncio - async def test_close_all_free_sessions(self): - """Test closing pool with all sessions free""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=3) - pool = AsyncSessionPool(copy(sessions), config) - - # Mock the close_session method for all sessions - for session in sessions: - session._close = AsyncMock() - - await pool._close() - - # Should close all sessions - for session in sessions: - session._close.assert_called_once() - - @pytest.mark.asyncio - @patch('nebulagraph_python.client._session_pool.logger') - async def test_close_with_busy_sessions(self, mock_logger): - """Test closing pool with some busy sessions""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user3", password="pass3", session_config=None, auth_options=None), - } - pool = AsyncSessionPool(copy(sessions), config=SessionPoolConfig(size=3)) - - # Mock the close_session method for all sessions - for session in sessions: - session._close = AsyncMock() - - # Manually move a session to busy state - busy_session = list(sessions)[1] # Get the second session - pool.free_sessions_queue.remove(busy_session) - pool.busy_sessions_queue.add(busy_session) - - await pool._close() - - # Should close all sessions - for session in sessions: - session._close.assert_called_once() - # Should log error about busy sessions - mock_logger.error.assert_called_once() - assert "Busy sessions remain" in mock_logger.error.call_args[0][0] - - @pytest.mark.asyncio - async def test_connect_success(self): - """Test successful connection via classmethod""" - mock_conn = AsyncMock() - - config = SessionPoolConfig(size=4) - - # Test the connect method - pool = await AsyncSessionPool.connect( - conn=mock_conn, - username="test_user", - password="test_pass", - pool_config=config - ) - - # Verify the pool was created correctly - assert len(pool.free_sessions_queue) == 4 - assert len(pool.busy_sessions_queue) == 0 - - @pytest.mark.asyncio - async def test_connect_partial_failure(self): - """Test connection with partial failure during setup""" - mock_conn = AsyncMock() - - config = SessionPoolConfig(size=3) - - # Mock AsyncSession constructor to fail on third call - original_session = AsyncSession - call_count = 0 - - def mock_session_init(*args, **kwargs): - nonlocal call_count - call_count += 1 - if call_count == 3: - raise Exception("Auth failed") - return original_session(*args, **kwargs) - - with patch('nebulagraph_python.client._session_pool.AsyncSession', side_effect=mock_session_init): - with pytest.raises(Exception, match="Auth failed"): - await AsyncSessionPool.connect( - conn=mock_conn, - username="test_user", - password="test_pass", - pool_config=config - ) - - @pytest.mark.asyncio - async def test_connect_authentication_failure_first_attempt(self): - """Test connection failure on first authentication attempt""" - mock_conn = AsyncMock() - - config = SessionPoolConfig(size=2) - - # Mock AsyncSession constructor to fail on first call - with patch('nebulagraph_python.client._session_pool.AsyncSession', side_effect=Exception("Auth failed on first attempt")): - with pytest.raises(Exception, match="Auth failed on first attempt"): - await AsyncSessionPool.connect( - conn=mock_conn, - username="test_user", - password="test_pass", - pool_config=config - ) - - @pytest.mark.asyncio - async def test_multiple_borrow_release_cycles(self): - """Test multiple borrow-release cycles work correctly""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=2) - pool = AsyncSessionPool(copy(sessions), config) - - # First cycle - async with pool.borrow() as session1: - assert session1 in pool.busy_sessions_queue - assert len(pool.free_sessions_queue) == 1 - - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 2 - - # Second cycle - async with pool.borrow() as session2: - async with pool.borrow() as session3: - assert {session2, session3} == sessions - assert len(pool.busy_sessions_queue) == 2 - assert len(pool.free_sessions_queue) == 0 - - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 2 - - -class TestSessionPoolEdgeCases: - """Test edge cases for both sync and async session pools""" - - def test_sync_pool_exception_in_context(self): - """Test that sessions are properly returned even when exceptions occur in sync pool""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=1) - pool = SessionPool(copy(sessions), config) - - with pytest.raises(ValueError): - with pool.borrow() as session: - assert session in pool.busy_sessions_queue - raise ValueError("Test exception") - - # Session should be returned to pool even after exception - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 1 - - @pytest.mark.asyncio - async def test_async_pool_exception_in_context(self): - """Test that sessions are properly returned even when exceptions occur in async pool""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=1) - pool = AsyncSessionPool(copy(sessions), config) - - with pytest.raises(ValueError): - async with pool.borrow() as session: - assert session in pool.busy_sessions_queue - raise ValueError("Test exception") - - # Session should be returned to pool even after exception - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 1 - - def test_sync_multiple_exceptions_in_context(self): - """Test multiple exceptions in sync pool context managers""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - Session(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=2) - pool = SessionPool(copy(sessions), config) - - with pytest.raises(ValueError): - with pool.borrow(): - with pool.borrow(): - assert len(pool.busy_sessions_queue) == 2 - raise ValueError("Test exception") - - # All sessions should be returned - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 2 - - @pytest.mark.asyncio - async def test_async_multiple_exceptions_in_context(self): - """Test multiple exceptions in async pool context managers""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - AsyncSession(_conn=mock_conn, username="user2", password="pass2", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=2) - pool = AsyncSessionPool(copy(sessions), config) - - with pytest.raises(ValueError): - async with pool.borrow(): - async with pool.borrow(): - assert len(pool.busy_sessions_queue) == 2 - raise ValueError("Test exception") - - # All sessions should be returned - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 2 - - def test_sync_empty_pool(self): - """Test sync pool with zero sessions""" - with pytest.raises(ValueError, match="SessionPoolConfig.size must be greater than 0, but got 0"): - config = SessionPoolConfig(size=0) - - @pytest.mark.asyncio - async def test_async_empty_pool(self): - """Test async pool with zero sessions""" - with pytest.raises(ValueError, match="SessionPoolConfig.size must be greater than 0, but got 0"): - config = SessionPoolConfig(size=0) - - def test_sync_zero_timeout(self): - """Test sync pool with zero timeout converts to None""" - config = SessionPoolConfig(size=1, wait_timeout=0.0) - # Should convert zero timeout to None - assert config.wait_timeout is None - - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - pool = SessionPool(copy(sessions), config) - - with pool.borrow(): - # Should timeout immediately since timeout is None (infinite wait) but session is busy - pass # This test is mainly about the config behavior - - @pytest.mark.asyncio - async def test_async_zero_timeout(self): - """Test async pool with zero timeout converts to None""" - config = SessionPoolConfig(size=1, wait_timeout=0.0) - # Should convert zero timeout to None - assert config.wait_timeout is None - - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - pool = AsyncSessionPool(copy(sessions), config) - - async with pool.borrow(): - # Should timeout immediately since timeout is None (infinite wait) but session is busy - pass # This test is mainly about the config behavior - - def test_sync_negative_timeout(self): - """Test sync pool with negative timeout converts to None""" - config = SessionPoolConfig(size=1, wait_timeout=-1.0) - # Should convert negative timeout to None - assert config.wait_timeout is None - - -class TestSessionPoolConfig: - """Test SessionPoolConfig validation and behavior""" - - def test_config_default_values(self): - """Test default configuration values""" - config = SessionPoolConfig(size=5) - - assert config.size == 5 - assert config.wait_timeout == 60 # Default is 60 seconds, not None - - def test_config_custom_values(self): - """Test custom configuration values""" - config = SessionPoolConfig( - size=10, - wait_timeout=30.0, - ) - - assert config.size == 10 - assert config.wait_timeout == 30.0 - - def test_config_zero_size(self): - """Test configuration with zero pool size""" - # SessionPoolConfig raises ValueError for size <= 0 - with pytest.raises(ValueError, match="SessionPoolConfig.size must be greater than 0"): - config = SessionPoolConfig(size=0) - - def test_config_large_size(self): - """Test configuration with large pool size""" - config = SessionPoolConfig(size=1000) - assert config.size == 1000 - - def test_sync_pool_with_custom_retry_interval(self): - """Test sync pool behavior with custom retry interval""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=1, wait_timeout=0.3) - pool = SessionPool(copy(sessions), config) - - with pool.borrow(): - start_time = time.time() - with pytest.raises(PoolError): - with pool.borrow(): - pass - elapsed = time.time() - start_time - # Should have attempted multiple retries within the timeout - assert 0.25 <= elapsed <= 0.35 - - @pytest.mark.asyncio - async def test_async_pool_with_custom_retry_interval(self): - """Test async pool behavior with custom retry interval""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username="user1", password="pass1", session_config=None, auth_options=None), - } - config = SessionPoolConfig(size=1, wait_timeout=0.3) - pool = AsyncSessionPool(copy(sessions), config) - - async with pool.borrow(): - start_time = time.time() - with pytest.raises(PoolError): - async with pool.borrow(): - pass - elapsed = time.time() - start_time - # Should have attempted multiple retries within the timeout - assert 0.25 <= elapsed <= 0.35 - - -class TestSessionPoolStressTests: - """Stress tests for session pools""" - - def test_sync_high_concurrency_stress(self): - """Test sync pool under high concurrency stress""" - mock_conn = Mock() - sessions = { - Session(_conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) - for i in range(10) - } - config = SessionPoolConfig(size=10) - pool = SessionPool(copy(sessions), config) - results = [] - errors = [] - - def stress_worker(worker_id): - try: - for i in range(5): # Each worker does 5 operations - with pool.borrow() as session: - results.append((worker_id, i, session)) - time.sleep(0.01) # Very short work simulation - except Exception as e: - errors.append((worker_id, e)) - - # Start 20 threads (more than pool size) - threads = [] - for i in range(20): - thread = threading.Thread(target=stress_worker, args=(i,)) - threads.append(thread) - thread.start() - - # Wait for all threads to complete - for thread in threads: - thread.join() - - assert len(errors) == 0, f"Unexpected errors: {errors}" - assert len(results) == 100 # 20 workers * 5 operations each - - # All sessions should be returned - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 10 - - @pytest.mark.asyncio - async def test_async_high_concurrency_stress(self): - """Test async pool under high concurrency stress""" - mock_conn = AsyncMock() - sessions = { - AsyncSession(_conn=mock_conn, username=f"user{i}", password=f"pass{i}", session_config=None, auth_options=None) - for i in range(10) - } - config = SessionPoolConfig(size=10) - pool = AsyncSessionPool(copy(sessions), config) - results = [] - errors = [] - - async def stress_worker(worker_id): - try: - for i in range(5): # Each worker does 5 operations - async with pool.borrow() as session: - results.append((worker_id, i, session)) - await asyncio.sleep(0.01) # Very short async work simulation - except Exception as e: - errors.append((worker_id, e)) - - # Start 20 tasks (more than pool size) - tasks = [stress_worker(i) for i in range(20)] - await asyncio.gather(*tasks) - - assert len(errors) == 0, f"Unexpected errors: {errors}" - assert len(results) == 100 # 20 workers * 5 operations each - - # All sessions should be returned - assert len(pool.busy_sessions_queue) == 0 - assert len(pool.free_sessions_queue) == 10 \ No newline at end of file From a070f7c51d4d895d22a8dcc3c282e783b8e439c7 Mon Sep 17 00:00:00 2001 From: Anqi <16240361+Nicole00@users.noreply.github.com> Date: Thu, 29 Jan 2026 11:10:47 +0800 Subject: [PATCH 2/5] remove docker-compose.yml --- docker-compose.yml | 218 --------------------------------------------- 1 file changed, 218 deletions(-) delete mode 100644 docker-compose.yml diff --git a/docker-compose.yml b/docker-compose.yml deleted file mode 100644 index ae672280..00000000 --- a/docker-compose.yml +++ /dev/null @@ -1,218 +0,0 @@ -version: '3.8' - -services: - nebula-metad0: - image: vesoft/nebula-graph:v3.8.0 - environment: - USER: root - TZ: UTC - command: - - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 - - --local_ip=nebula-metad0 - - --ws_ip=nebula-metad0 - - --port=9559 - - --data_path=/data/meta - - --log_dir=/logs - - --v=0 - - --minloglevel=0 - healthcheck: - test: ["CMD", "curl", "-f", "http://nebula-metad0:19559/status"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 20s - ports: - - "9559:9559" - - "19559:19559" - volumes: - - ./data/meta0:/data/meta - - ./logs/meta0:/logs - networks: - - nebula-net - restart: on-failure - cap_add: - - SYS_PTRACE - - nebula-metad1: - image: vesoft/nebula-graph:v3.8.0 - environment: - USER: root - TZ: UTC - command: - - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 - - --local_ip=nebula-metad1 - - --ws_ip=nebula-metad1 - - --port=9559 - - --data_path=/data/meta - - --log_dir=/logs - - --v=0 - - --minloglevel=0 - healthcheck: - test: ["CMD", "curl", "-f", "http://nebula-metad1:19559/status"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 20s - ports: - - "9560:9559" - - "19560:19559" - volumes: - - ./data/meta1:/data/meta - - ./logs/meta1:/logs - networks: - - nebula-net - restart: on-failure - cap_add: - - SYS_PTRACE - - nebula-metad2: - image: vesoft/nebula-graph:v3.8.0 - environment: - USER: root - TZ: UTC - command: - - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 - - --local_ip=nebula-metad2 - - --ws_ip=nebula-metad2 - - --port=9559 - - --data_path=/data/meta - - --log_dir=/logs - - --v=0 - - --minloglevel=0 - healthcheck: - test: ["CMD", "curl", "-f", "http://nebula-metad2:19559/status"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 20s - ports: - - "9561:9559" - - "19561:19559" - volumes: - - ./data/meta2:/data/meta - - ./logs/meta2:/logs - networks: - - nebula-net - restart: on-failure - cap_add: - - SYS_PTRACE - - nebula-storaged0: - image: vesoft/nebula-graph:v3.8.0 - environment: - USER: root - TZ: UTC - command: - - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 - - --local_ip=nebula-storaged0 - - --ws_ip=nebula-storaged0 - - --port=9779 - - --data_path=/data/storage - - --log_dir=/logs - - --v=0 - - --minloglevel=0 - depends_on: - - nebula-metad0 - - nebula-metad1 - - nebula-metad2 - healthcheck: - test: ["CMD", "curl", "-f", "http://nebula-storaged0:19779/status"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 20s - ports: - - "9779:9779" - - "19779:19779" - volumes: - - ./data/storage0:/data/storage - - ./logs/storage0:/logs - networks: - - nebula-net - restart: on-failure - cap_add: - - SYS_PTRACE - - nebula-storaged1: - image: vesoft/nebula-graph:v3.8.0 - environment: - USER: root - TZ: UTC - command: - - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 - - --local_ip=nebula-storaged1 - - --ws_ip=nebula-storaged1 - - --port=9779 - - --data_path=/data/storage - - --log_dir=/logs - - --v=0 - - --minloglevel=0 - depends_on: - - nebula-metad0 - - nebula-metad1 - - nebula-metad2 - healthcheck: - test: ["CMD", "curl", "-f", "http://nebula-storaged1:19779/status"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 20s - ports: - - "9780:9779" - - "19780:19779" - volumes: - - ./data/storage1:/data/storage - - ./logs/storage1:/logs - networks: - - nebula-net - restart: on-failure - cap_add: - - SYS_PTRACE - - nebula-graphd: - image: vesoft/nebula-graph:v3.8.0 - environment: - USER: root - TZ: UTC - command: - - --meta_server_addrs=nebula-metad0:9559,nebula-metad1:9559,nebula-metad2:9559 - - --local_ip=nebula-graphd - - --ws_ip=nebula-graphd - - --port=9669 - - --log_dir=/logs - - --v=0 - - --minloglevel=0 - depends_on: - - nebula-metad0 - - nebula-metad1 - - nebula-metad2 - - nebula-storaged0 - - nebula-storaged1 - healthcheck: - test: ["CMD", "curl", "-f", "http://nebula-graphd:19669/status"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 20s - ports: - - "9669:9669" - - "19669:19669" - volumes: - - ./logs/graph:/logs - networks: - - nebula-net - restart: on-failure - cap_add: - - SYS_PTRACE - - nebula-console: - image: vesoft/nebula-console:v3.8.0 - entrypoint: ["sleep", "infinity"] - depends_on: - - nebula-graphd - networks: - - nebula-net - -networks: - nebula-net: - driver: bridge \ No newline at end of file From 63fe1435c1d14b34d2401bb5e589be048b9e1e0d Mon Sep 17 00:00:00 2001 From: Anqi <16240361+Nicole00@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:18:36 +0800 Subject: [PATCH 3/5] update parameter --- docs/1_started.md | 6 +- docs/2_concurrency.md | 6 +- docs/3_templating_query.md | 2 +- docs/4_error_handling.md | 2 +- docs/5_vector_and_special_types.md | 4 +- docs/7_orm.md | 2 +- docs/7_orm_example.py | 2 +- example.py | 4 +- example/NebulaPoolExample.py | 60 ++- src/nebulagraph_python/client/nebula_pool.py | 6 +- src/ng_console/__init__.py | 2 +- tests/test_integration.py | 416 +++++-------------- tests/test_nebula_pool.py | 70 ++-- 13 files changed, 191 insertions(+), 391 deletions(-) diff --git a/docs/1_started.md b/docs/1_started.md index 73e4b331..a2e73245 100644 --- a/docs/1_started.md +++ b/docs/1_started.md @@ -124,7 +124,7 @@ from nebulagraph_python import NebulaClient with NebulaClient( hosts=["127.0.0.1:9669"], - username="root", + user_name="root", password="NebulaGraph01", ) as client: result = client.execute("RETURN 1 AS a, 2 AS b") @@ -142,7 +142,7 @@ from nebulagraph_python import NebulaClient client = NebulaClient( hosts=["127.0.0.1:9669"], - username="root", + user_name="root", password="NebulaGraph01", ) try: @@ -161,7 +161,7 @@ from nebulagraph_python.client import NebulaAsyncClient async def main() -> None: client = await NebulaAsyncClient.connect( hosts=["127.0.0.1:9669"], - username="root", + user_name="root", password="NebulaGraph01", ) try: diff --git a/docs/2_concurrency.md b/docs/2_concurrency.md index eb063b36..cc2fa289 100644 --- a/docs/2_concurrency.md +++ b/docs/2_concurrency.md @@ -32,7 +32,7 @@ async def concurrent_example(): # Create client with session pool for concurrency async with await NebulaAsyncClient.connect( hosts=["127.0.0.1:9669", "127.0.0.1:9670"], # Multiple hosts for HA - username="root", + user_name="root", password="NebulaGraph01", session_pool_config=SessionPoolConfig( size=3, # Pool of 3 sessions per host @@ -70,7 +70,7 @@ By default, statements run on a random session from the pool. When you need to r async def contextual_example(): async with await NebulaAsyncClient.connect( hosts=["127.0.0.1:9669"], - username="root", + user_name="root", password="NebulaGraph01", session_pool_config=SessionPoolConfig(), ) as client: @@ -148,7 +148,7 @@ from nebulagraph_python import NebulaClient, SessionPoolConfig with NebulaClient( hosts=["127.0.0.1:9669"], - username="root", + user_name="root", password="NebulaGraph01", session_pool_config=SessionPoolConfig(), # enables multiple sessions per host ) as client, ThreadPoolExecutor(max_workers=8) as executor: diff --git a/docs/3_templating_query.md b/docs/3_templating_query.md index e57865e8..d4a18522 100644 --- a/docs/3_templating_query.md +++ b/docs/3_templating_query.md @@ -17,7 +17,7 @@ from nebulagraph_python.client import NebulaAsyncClient async def main() -> None: async with await NebulaAsyncClient.connect( hosts=["127.0.0.1:9669"], - username="root", + user_name="root", password="NebulaGraph01", ) as client: query = """ diff --git a/docs/4_error_handling.md b/docs/4_error_handling.md index 0beb459e..f19d157c 100644 --- a/docs/4_error_handling.md +++ b/docs/4_error_handling.md @@ -22,7 +22,7 @@ from nebulagraph_python.error import NebulaGraphRemoteError, ErrorCode async def main(): async with await NebulaAsyncClient.connect( - hosts="localhost:9669", username="root", password="NebulaGraph01" + hosts="localhost:9669", user_name="root", password="NebulaGraph01" ) as client: try: rs = await client.execute_py("USE not_exist_graph RETURN 1") diff --git a/docs/5_vector_and_special_types.md b/docs/5_vector_and_special_types.md index 763572c8..a6f50a80 100644 --- a/docs/5_vector_and_special_types.md +++ b/docs/5_vector_and_special_types.md @@ -31,7 +31,7 @@ from nebulagraph_python.client.nebula_client import NebulaClient from nebulagraph_python.py_data_types import NVector # Connect (adjust hosts/credentials to your environment) -cli = NebulaClient(hosts=["127.0.0.1:9669"], username="root", password="NebulaGraph01") +cli = NebulaClient(hosts=["127.0.0.1:9669"], user_name="root", password="NebulaGraph01") # RETURN a vector and read it from the result res = cli.execute_py("RETURN vector<3, float>([1, 2, 3]) AS vec") @@ -73,7 +73,7 @@ Examples: from nebulagraph_python.client.nebula_client import NebulaClient from nebulagraph_python.py_data_types import NDuration -cli = NebulaClient(hosts=["127.0.0.1:9669"], username="root", password="Nebula.123") +cli = NebulaClient(hosts=["127.0.0.1:9669"], user_name="root", password="Nebula.123") # RETURN a duration literal from the server and read it # Adjust the literal to your NebulaGraph version if needed diff --git a/docs/7_orm.md b/docs/7_orm.md index d4bd3f94..41db352e 100644 --- a/docs/7_orm.md +++ b/docs/7_orm.md @@ -31,7 +31,7 @@ from nebulagraph_python.py_data_types import NVector client = NebulaClient( hosts=["127.0.0.1:9669"], - username="root", + user_name="root", password="NebulaGraph01", session_config=SessionConfig(graph="movie"), ) diff --git a/docs/7_orm_example.py b/docs/7_orm_example.py index 1c40ddec..8827f6c4 100644 --- a/docs/7_orm_example.py +++ b/docs/7_orm_example.py @@ -24,7 +24,7 @@ # Create client client = NebulaClient( hosts=["127.0.0.1:9669"], - username="root", + user_name="root", password="NebulaGraph01", session_config=SessionConfig( graph="movie", diff --git a/example.py b/example.py index 158aae7c..ab1efce5 100644 --- a/example.py +++ b/example.py @@ -126,7 +126,7 @@ def pool_example(): # Create pool configuration config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="root", + user_name="root", password="nebula", max_client_size=10, min_client_size=2, @@ -194,7 +194,7 @@ def query_task(pool, idx): # Create pool config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="root", + user_name="root", password="nebula", max_client_size=10, min_client_size=2, diff --git a/example/NebulaPoolExample.py b/example/NebulaPoolExample.py index 0eb5523a..82a85db9 100755 --- a/example/NebulaPoolExample.py +++ b/example/NebulaPoolExample.py @@ -1,57 +1,55 @@ #!/usr/bin/env python3 # -*- coding: utf-8 -*- -from typing import Optional, Dict -from nebulagraph_python.client.nebula_pool import NebulaPool, NebulaPoolConfig, SessionConfig -from nebulagraph_python.data import HostAddress +from nebulagraph_python import NebulaPool, NebulaPoolConfig +from nebulagraph_python.client import NebulaBaseExecutor + +class NebulaPoolExecutor(NebulaBaseExecutor): + """Wrapper to make NebulaPool compatible with NebulaBaseExecutor""" + + def __init__(self, pool): + self.pool = pool + self.client = None + + def execute(self, statement: str, *, timeout=None, do_ping=False): + if self.client is None: + self.client = self.pool.get_client() + return self.client.execute_with_timeout(statement, timeout or 30000) + graph_name = "test_graph" def main(): - # config the connect information - hosts = ["127.0.0.1:9669"] - username = "root" + # configure the connection information + addresses = "127.0.0.1:9669" + user_name = "root" password = "NebulaGraph01" # create NebulaPool - pool = NebulaPool( - hosts=hosts, - username=username, + config = NebulaPoolConfig( + addresses=addresses, + user_name=user_name, password=password, - session_config=SessionConfig(graph=graph_name) + graph=graph_name ) + pool = NebulaPool(config) try: print("use execute_py to execute `SHOW GRAPHS` ...") - result = pool.execute_py("SHOW GRAPHS") + executor = NebulaPoolExecutor(pool) + result = executor.execute("SHOW GRAPHS") - # 打印结果 + # print results print("\n query result:") print("-" * 50) - result.print(style="table") - - print("\n\nuse execute to execute `SHOW GRAPHS`:") - print("-" * 50) - result2 = pool.execute("SHOW GRAPHS") - result2.print(style="table") - - # get the query result + result.print() print("-" * 50) - if result.size > 0: - for row in result: - print(f"Row: {row}") - else: - print("Empty") except Exception as e: - print(f"\nerror: {e}") - import traceback - traceback.print_exc() + print(f"Error: {e}") finally: - print("\nclose the pool...") pool.close() - print("closed") if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/src/nebulagraph_python/client/nebula_pool.py b/src/nebulagraph_python/client/nebula_pool.py index 91ddba05..bc4ddc13 100644 --- a/src/nebulagraph_python/client/nebula_pool.py +++ b/src/nebulagraph_python/client/nebula_pool.py @@ -54,7 +54,7 @@ class NebulaPoolConfig: # Connection settings addresses: str - username: str + user_name: str password: Optional[str] = None # Pool settings @@ -127,7 +127,7 @@ class LoadBalancerConfig: def __init__(self, pool_config: NebulaPoolConfig, addrs: List[HostAddress]): self.address = addrs self.strictly_server_healthy = pool_config.strictly_server_healthy - self.user_name = pool_config.username + self.user_name = pool_config.user_name self.auth_options = pool_config.auth_options self.connect_timeout_mills = pool_config.connect_timeout_ms self.request_timeout_mills = pool_config.request_timeout_ms @@ -273,4 +273,4 @@ def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): - self.close() \ No newline at end of file + self.close() diff --git a/src/ng_console/__init__.py b/src/ng_console/__init__.py index 4a697599..066732d8 100644 --- a/src/ng_console/__init__.py +++ b/src/ng_console/__init__.py @@ -26,7 +26,7 @@ def create_client(hosts: str, username: str, password: str) -> NebulaClient: """Create and verify NebulaGraph client connection""" try: - client = NebulaClient(hosts, username, password) + client = NebulaClient(hosts, user_name=username, password=password) # Test connection if not client.ping(): raise RuntimeError("Failed to connect to NebulaGraph") diff --git a/tests/test_integration.py b/tests/test_integration.py index 0502055f..bf716baa 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -18,14 +18,12 @@ from nebulagraph_python import ( NebulaClient, - NebulaAsyncClient, NebulaPool, NebulaPoolConfig, - SessionConfig, - SessionPoolConfig, ) +from nebulagraph_python.client import AsyncNebulaClient -# 从环境变量获取测试配置,如果没有则使用默认值 +# Get test configuration from environment variables, or use default values NEBULA_HOST = os.getenv("NEBULA_HOST", "192.168.8.6") NEBULA_PORT = os.getenv("NEBULA_PORT", "3820") NEBULA_USER = os.getenv("NEBULA_USER", "root") @@ -35,10 +33,10 @@ class TestConnectionIntegration: - """实际连接测试 - 测试Connection功能""" + """Connection integration tests - Test Connection functionality""" def test_connection_basic(self): - """测试基本连接""" + """Test basic connection""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, @@ -48,7 +46,7 @@ def test_connection_basic(self): client.close() def test_connection_ping(self): - """测试连接ping功能""" + """Test connection ping functionality""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, @@ -58,7 +56,7 @@ def test_connection_ping(self): client.close() def test_connection_execute_simple_query(self): - """测试执行简单查询""" + """Test executing simple query""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, @@ -70,20 +68,20 @@ def test_connection_execute_simple_query(self): client.close() def test_connection_execute_show_hosts(self): - """测试执行SHOW HOSTS命令""" + """Test executing SHOW HOSTS command""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # NebulaGraph 5.0 使用不同的语法 + # NebulaGraph 5.0 uses different syntax result = client.execute("SHOW HOSTS GRAPH") assert result is not None assert result.is_succeeded client.close() def test_connection_context_manager(self): - """测试上下文管理器""" + """Test context manager""" with NebulaClient( NEBULA_ADDRESS, NEBULA_USER, @@ -93,43 +91,31 @@ def test_connection_context_manager(self): result = client.execute("RETURN 1") assert result.is_succeeded - def test_connection_with_session_config(self): - """测试带会话配置的连接""" - session_config = SessionConfig() - client = NebulaClient( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - session_config=session_config, - ) - assert client is not None - result = client.execute("RETURN 1") - assert result.is_succeeded - client.close() - class TestAsyncConnectionIntegration: - """异步连接测试 - 测试AsyncConnection功能""" + """Async connection integration tests - Test AsyncConnection functionality""" @pytest.mark.asyncio async def test_async_connection_basic(self): - """测试基本异步连接""" - client = await NebulaAsyncClient.connect( + """Test basic async connection""" + client = AsyncNebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) + await client._init_client() assert client is not None await client.close() @pytest.mark.asyncio async def test_async_connection_execute(self): - """测试异步执行查询""" - client = await NebulaAsyncClient.connect( + """Test async query execution""" + client = AsyncNebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) + await client._init_client() result = await client.execute("RETURN 1") assert result is not None assert result.is_succeeded @@ -137,226 +123,87 @@ async def test_async_connection_execute(self): @pytest.mark.asyncio async def test_async_connection_show_hosts(self): - """测试异步执行SHOW HOSTS""" - client = await NebulaAsyncClient.connect( + """Test async SHOW HOSTS execution""" + client = AsyncNebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) + await client._init_client() result = await client.execute("SHOW HOSTS") assert result is not None assert result.is_succeeded await client.close() - @pytest.mark.asyncio - async def test_async_connection_context_manager(self): - """测试异步上下文管理器""" - async with await NebulaAsyncClient.connect( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - ) as client: - assert client is not None - result = await client.execute("RETURN 1") - assert result.is_succeeded - - -class TestSessionPoolIntegration: - """会话池集成测试""" - - def test_session_pool_basic(self): - """测试基本会话池""" - pool_config = SessionPoolConfig(size=3) - client = NebulaClient( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - session_pool_config=pool_config, - ) - assert client is not None - - # 执行多个查询 - for i in range(5): - result = client.execute("RETURN 1") - assert result.is_succeeded - - client.close() - - def test_session_pool_concurrent(self): - """测试会话池并发访问""" - import threading - import time - - pool_config = SessionPoolConfig(size=3) - client = NebulaClient( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - session_pool_config=pool_config, - ) - - results = [] - errors = [] - - def execute_query(thread_id): - try: - result = client.execute("RETURN 1") - results.append(thread_id) - time.sleep(0.1) - except Exception as e: - errors.append((thread_id, e)) - - threads = [] - for i in range(5): - thread = threading.Thread(target=execute_query, args=(i,)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - assert len(errors) == 0 - assert len(results) == 5 - client.close() - - def test_session_pool_borrow_session(self): - """测试借用会话""" - pool_config = SessionPoolConfig(size=2) - client = NebulaClient( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - session_pool_config=pool_config, - ) - - with client.borrow() as session: - result = session.execute("RETURN 1") - assert result.is_succeeded - - client.close() - - -class TestAsyncSessionPoolIntegration: - """异步会话池集成测试""" - - @pytest.mark.asyncio - async def test_async_session_pool_basic(self): - """测试基本异步会话池""" - pool_config = SessionPoolConfig(size=3) - client = await NebulaAsyncClient.connect( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - session_pool_config=pool_config, - ) - assert client is not None - - # 执行多个查询 - for i in range(5): - result = await client.execute("RETURN 1") - assert result.is_succeeded - - await client.close() - - @pytest.mark.asyncio - async def test_async_session_pool_concurrent(self): - """测试异步会话池并发访问""" - pool_config = SessionPoolConfig(size=3) - client = await NebulaAsyncClient.connect( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - session_pool_config=pool_config, - ) - - async def execute_query(task_id): - result = await client.execute("RETURN 1") - assert result.is_succeeded - - tasks = [execute_query(i) for i in range(5)] - await asyncio.gather(*tasks) - - await client.close() - - @pytest.mark.asyncio - async def test_async_session_pool_borrow_session(self): - """测试异步借用会话""" - pool_config = SessionPoolConfig(size=2) - client = await NebulaAsyncClient.connect( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - session_pool_config=pool_config, - ) - - async with client.borrow() as session: - result = await session.execute("RETURN 1") - assert result.is_succeeded - - await client.close() - class TestNebulaPoolIntegration: - """连接池集成测试""" + """Connection pool integration tests""" def test_nebula_pool_basic(self): - """测试基本连接池""" + """Test basic connection pool""" pool_config = NebulaPoolConfig( - max_client_size=3, min_client_size=1, max_wait=10.0 - ) - pool = NebulaPool( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - pool_config=pool_config, - ) + addresses=NEBULA_ADDRESS, + user_name=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=3, + min_client_size=1, + max_wait_ms=10000, + ) + pool = NebulaPool(pool_config) assert pool is not None - result = pool.execute("RETURN 1") + client = pool.get_client() + result = client.execute("RETURN 1") assert result.is_succeeded + pool.return_client(client) pool.close() - def test_nebula_pool_borrow_client(self): - """测试借用客户端""" + def test_nebula_pool_get_client_and_return(self): + """Test getting and returning client""" pool_config = NebulaPoolConfig( - max_client_size=2, min_client_size=1, max_wait=10.0 - ) - pool = NebulaPool( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - pool_config=pool_config, + addresses=NEBULA_ADDRESS, + user_name=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=2, + min_client_size=1, + max_wait_ms=10000, ) + pool = NebulaPool(pool_config) - with pool.borrow() as client: - result = client.execute("RETURN 1") - assert result.is_succeeded + client = pool.get_client() + assert client is not None + result = client.execute("RETURN 1") + assert result.is_succeeded + + pool.return_client(client) pool.close() def test_nebula_pool_concurrent(self): - """测试连接池并发访问""" + """Test connection pool concurrent access""" import threading import time pool_config = NebulaPoolConfig( - max_client_size=5, min_client_size=2, max_wait=10.0 - ) - pool = NebulaPool( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - pool_config=pool_config, + addresses=NEBULA_ADDRESS, + user_name=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=5, + min_client_size=2, + max_wait_ms=10000, ) + pool = NebulaPool(pool_config) results = [] errors = [] def execute_query(thread_id): try: - result = pool.execute("RETURN 1") + client = pool.get_client() + result = client.execute("RETURN 1") results.append(thread_id) + pool.return_client(client) time.sleep(0.1) except Exception as e: errors.append((thread_id, e)) @@ -374,192 +221,152 @@ def execute_query(thread_id): assert len(results) == 5 pool.close() - def test_nebula_pool_round_robin(self): - """测试轮询负载均衡""" - pool_config = NebulaPoolConfig( - max_client_size=2, min_client_size=1, max_wait=10.0 - ) - pool = NebulaPool( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - pool_config=pool_config, - ) - - # 执行多个查询,应该轮询使用不同的客户端 - for i in range(4): - result = pool.execute("RETURN 1") - assert result.is_succeeded - - pool.close() - def test_nebula_pool_context_manager(self): - """测试连接池上下文管理器""" - pool_config = NebulaPoolConfig( - max_client_size=2, min_client_size=1, max_wait=10.0 - ) - pool = NebulaPool( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - pool_config=pool_config, - ) - # NebulaPool不支持上下文管理器,手动关闭 - result = pool.execute("RETURN 1") - assert result.is_succeeded - pool.close() - - def test_nebula_pool_get_client_and_return(self): - """测试获取和返回客户端""" + """Test connection pool context manager""" pool_config = NebulaPoolConfig( - max_client_size=2, min_client_size=1, max_wait=10.0 - ) - pool = NebulaPool( - NEBULA_ADDRESS, - NEBULA_USER, - NEBULA_PASSWORD, - pool_config=pool_config, - ) - - client = pool.get_client() - assert client is not None - - result = client.execute("RETURN 1") - assert result.is_succeeded - - pool.return_client(client) - pool.close() + addresses=NEBULA_ADDRESS, + user_name=NEBULA_USER, + password=NEBULA_PASSWORD, + max_client_size=2, + min_client_size=1, + max_wait_ms=10000, + ) + with NebulaPool(pool_config) as pool: + client = pool.get_client() + result = client.execute("RETURN 1") + assert result.is_succeeded + pool.return_client(client) class TestGraphOperations: - """图操作测试""" + """Graph operations tests""" def test_create_space(self): - """测试创建图空间""" + """Test creating graph space""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # 删除可能存在的图空间 + # Drop existing graph space if exists client.execute("DROP SPACE IF EXISTS test_space") - # 创建图空间 + # Create graph space result = client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") assert result.is_succeeded - # 使用图空间 + # Use graph space result = client.execute("USE test_space") assert result.is_succeeded client.close() def test_create_tag(self): - """测试创建标签""" + """Test creating tag""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # 确保图空间存在 + # Ensure graph space exists client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") client.execute("USE test_space") - # 创建标签 + # Create tag result = client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") assert result.is_succeeded client.close() def test_create_edge(self): - """测试创建边类型""" + """Test creating edge type""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # 确保图空间存在 + # Ensure graph space exists client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") client.execute("USE test_space") - # 创建边类型 + # Create edge type result = client.execute("CREATE EDGE IF NOT EXISTS follow(degree int)") assert result.is_succeeded client.close() def test_insert_vertex(self): - """测试插入顶点""" + """Test inserting vertex""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # 准备图空间 + # Prepare graph space client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") client.execute("USE test_space") client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") - # 插入顶点 + # Insert vertex result = client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18), "2":("Jerry", 20)') assert result.is_succeeded client.close() def test_insert_edge(self): - """测试插入边""" + """Test inserting edge""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # 准备图空间 + # Prepare graph space client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") client.execute("USE test_space") client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") client.execute("CREATE EDGE IF NOT EXISTS follow(degree int)") client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18), "2":("Jerry", 20)') - # 插入边 + # Insert edge result = client.execute('INSERT EDGE follow(degree) VALUES "1"->"2":(90)') assert result.is_succeeded client.close() def test_query_vertex(self): - """测试查询顶点""" + """Test querying vertex""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # 准备数据 + # Prepare data client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") client.execute("USE test_space") client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18)') - # 查询顶点 + # Query vertex result = client.execute('FETCH PROP ON person "1" YIELD vertex as v') assert result.is_succeeded client.close() def test_query_edge(self): - """测试查询边""" + """Test querying edge""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # 准备数据 + # Prepare data client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") client.execute("USE test_space") client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") @@ -567,21 +374,21 @@ def test_query_edge(self): client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18), "2":("Jerry", 20)') client.execute('INSERT EDGE follow(degree) VALUES "1"->"2":(90)') - # 查询边 + # Query edge result = client.execute('FETCH PROP ON follow "1"->"2" YIELD edge as e') assert result.is_succeeded client.close() def test_complex_query(self): - """测试复杂查询""" + """Test complex query""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # 准备数据 + # Prepare data client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") client.execute("USE test_space") client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") @@ -589,7 +396,7 @@ def test_complex_query(self): client.execute('INSERT VERTEX person(name, age) VALUES "1":("Tom", 18), "2":("Jerry", 20), "3":("Alice", 22)') client.execute('INSERT EDGE follow(degree) VALUES "1"->"2":(90), "2"->"3":(80)') - # 复杂查询:查找Tom关注的人 + # Complex query: find people Tom follows result = client.execute('GO FROM "1" OVER follow YIELD $$.person.name AS name, $$.person.age AS age') assert result.is_succeeded @@ -597,10 +404,10 @@ def test_complex_query(self): class TestErrorHandling: - """错误处理测试""" + """Error handling tests""" def test_invalid_query(self): - """测试无效查询""" + """Test invalid query""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, @@ -613,7 +420,7 @@ def test_invalid_query(self): client.close() def test_wrong_credentials(self): - """测试错误凭据""" + """Test wrong credentials""" with pytest.raises(Exception): client = NebulaClient( NEBULA_ADDRESS, @@ -623,40 +430,35 @@ def test_wrong_credentials(self): client.close() def test_connection_timeout(self): - """测试连接超时""" - from nebulagraph_python.client._connection import ConnectionConfig - - conn_config = ConnectionConfig.from_defaults( - NEBULA_ADDRESS, connect_timeout=1.0 - ) + """Test connection timeout""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, - conn_config=conn_config, + connect_timeout_ms=1000, ) - # 应该能连接成功 + # Should connect successfully assert client.ping() client.close() class TestPerformance: - """性能测试""" + """Performance tests""" def test_batch_insert(self): - """测试批量插入""" + """Test batch insert""" client = NebulaClient( NEBULA_ADDRESS, NEBULA_USER, NEBULA_PASSWORD, ) - # 准备图空间 + # Prepare graph space client.execute("CREATE SPACE IF NOT EXISTS test_space(partition_num=10, replica_factor=1, vid_type=FIXED_STRING(32))") client.execute("USE test_space") client.execute("CREATE TAG IF NOT EXISTS person(name string, age int)") - # 批量插入 + # Batch insert vertices = [] for i in range(100): vertices.append(f'"{i}":("Person{i}", {20 + i % 30})') @@ -668,7 +470,7 @@ def test_batch_insert(self): client.close() def test_concurrent_queries(self): - """测试并发查询""" + """Test concurrent queries""" import threading client = NebulaClient( @@ -698,4 +500,4 @@ def execute_query(thread_id): assert len(errors) == 0 assert len(results) == 10 - client.close() \ No newline at end of file + client.close() diff --git a/tests/test_nebula_pool.py b/tests/test_nebula_pool.py index 03b562d1..deaf10f7 100644 --- a/tests/test_nebula_pool.py +++ b/tests/test_nebula_pool.py @@ -50,11 +50,11 @@ def test_config_defaults(self): """Test NebulaPoolConfig with default values""" config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass" ) assert config.addresses == "127.0.0.1:9669" - assert config.username == "test_user" + assert config.user_name == "test_user" assert config.password == "test_pass" assert config.max_client_size == DEFAULT_MAX_CLIENT_SIZE assert config.min_client_size == DEFAULT_MIN_CLIENT_SIZE @@ -83,7 +83,7 @@ def test_config_custom_pool_settings(self): """Test NebulaPoolConfig with custom pool settings""" config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", max_client_size=20, min_client_size=5, @@ -99,7 +99,7 @@ def test_config_custom_timeout_settings(self): """Test NebulaPoolConfig with custom timeout settings""" config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", connect_timeout_ms=5000, request_timeout_ms=120000, @@ -113,7 +113,7 @@ def test_config_custom_health_check_settings(self): """Test NebulaPoolConfig with custom health check settings""" config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", health_check_time_ms=300000, test_on_borrow=False, @@ -125,7 +125,7 @@ def test_config_custom_eviction_settings(self): """Test NebulaPoolConfig with custom eviction settings""" config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", idle_evict_schedule_ms=60000, min_evictable_idle_time_ms=900000, @@ -137,7 +137,7 @@ def test_config_custom_server_settings(self): """Test NebulaPoolConfig with custom server settings""" config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", strictly_server_healthy=True, max_life_time_ms=3600000, @@ -149,7 +149,7 @@ def test_config_custom_session_settings(self): """Test NebulaPoolConfig with custom session settings""" config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", graph="test_graph", schema="test_schema", @@ -170,7 +170,7 @@ def test_config_custom_other_settings(self): ssl_param = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", scan_parallel=20, enable_tls=True, @@ -184,7 +184,7 @@ def test_config_auth_options_post_init(self): """Test that auth_options is populated with password in __post_init__""" config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass" ) assert config.auth_options == {"password": "test_pass"} @@ -193,7 +193,7 @@ def test_config_auth_options_without_password(self): """Test auth_options without password""" config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password=None ) assert config.auth_options == {} @@ -202,7 +202,7 @@ def test_config_multiple_addresses(self): """Test NebulaPoolConfig with multiple addresses""" config = NebulaPoolConfig( addresses="127.0.0.1:9669,127.0.0.2:9669,127.0.0.3:9669", - username="test_user", + user_name="test_user", password="test_pass" ) assert config.addresses == "127.0.0.1:9669,127.0.0.2:9669,127.0.0.3:9669" @@ -212,7 +212,7 @@ def test_config_all_parameters(self): ssl_param = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") config = NebulaPoolConfig( addresses="127.0.0.1:9669,127.0.0.2:9669", - username="test_user", + user_name="test_user", password="test_pass", max_client_size=20, min_client_size=5, @@ -238,7 +238,7 @@ def test_config_all_parameters(self): ssl_param=ssl_param, ) assert config.addresses == "127.0.0.1:9669,127.0.0.2:9669" - assert config.username == "test_user" + assert config.user_name == "test_user" assert config.password == "test_pass" assert config.max_client_size == 20 assert config.min_client_size == 5 @@ -282,7 +282,7 @@ def test_pool_creation_with_defaults(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=2 ) @@ -309,7 +309,7 @@ def test_pool_creation_with_custom_config(self, mock_factory_class, mock_lb_clas config = NebulaPoolConfig( addresses="127.0.0.1:9669,127.0.0.2:9669", - username="test_user", + user_name="test_user", password="test_pass", max_client_size=10, min_client_size=3, @@ -337,7 +337,7 @@ def test_pool_creation_with_ssl(self, mock_factory_class, mock_lb_class): ssl_param = SSLParam(ca_crt=b"ca", private_key=b"key", cert=b"cert") config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", enable_tls=True, ssl_param=ssl_param, @@ -365,7 +365,7 @@ def test_pool_get_client_success(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1, test_on_borrow=True @@ -397,7 +397,7 @@ def test_pool_get_client_creates_new(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1, max_client_size=2 @@ -429,7 +429,7 @@ def test_pool_get_client_timeout(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1, max_client_size=1, @@ -461,7 +461,7 @@ def test_pool_get_client_block_when_exhausted_false(self, mock_factory_class, mo config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1, max_client_size=1, @@ -496,7 +496,7 @@ def test_pool_get_client_test_on_borrow_invalidates(self, mock_factory_class, mo config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1, test_on_borrow=True @@ -526,7 +526,7 @@ def test_pool_return_client_success(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1 ) @@ -554,7 +554,7 @@ def test_pool_return_client_closed(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1 ) @@ -584,7 +584,7 @@ def test_pool_return_client_expired(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1, max_life_time_ms=3600000 # 1 hour @@ -613,7 +613,7 @@ def test_pool_close(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=2 ) @@ -642,7 +642,7 @@ def test_pool_get_client_after_close(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1 ) @@ -668,7 +668,7 @@ def test_pool_get_active_sessions(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=3 ) @@ -707,7 +707,7 @@ def test_pool_get_idle_sessions(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=3 ) @@ -748,7 +748,7 @@ def test_pool_context_manager(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1 ) @@ -774,7 +774,7 @@ def test_pool_concurrent_access(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_name="test_user", password="test_pass", max_client_size=3, min_client_size=3 @@ -821,7 +821,7 @@ def test_pool_parse_addresses(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1:9669,127.0.0.2:9669,127.0.0.3:9669", - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1 ) @@ -849,7 +849,7 @@ def test_pool_parse_addresses_invalid(self, mock_factory_class, mock_lb_class): config = NebulaPoolConfig( addresses="127.0.0.1", # Missing port - username="test_user", + user_name="test_user", password="test_pass", min_client_size=1 ) @@ -877,7 +877,7 @@ def test_pool_init_failure_handles_gracefully(self, mock_factory_class, mock_lb_ config = NebulaPoolConfig( addresses="127.0.0.1:9669", - username="test_user", + user_nameuser_name="test_user", password="test_pass", min_client_size=3 ) @@ -886,4 +886,4 @@ def test_pool_init_failure_handles_gracefully(self, mock_factory_class, mock_lb_ pool = NebulaPool(config) # Should have 2 clients (one failed) - assert len(pool._pool) == 2 \ No newline at end of file + assert len(pool._pool) == 2 From abb32ccc245d9f7cb39ee25749cf9e5136a773da Mon Sep 17 00:00:00 2001 From: Anqi <16240361+Nicole00@users.noreply.github.com> Date: Thu, 29 Jan 2026 15:19:56 +0800 Subject: [PATCH 4/5] remove sleep --- src/nebulagraph_python/client/nebula_pool.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/nebulagraph_python/client/nebula_pool.py b/src/nebulagraph_python/client/nebula_pool.py index bc4ddc13..337bfa48 100644 --- a/src/nebulagraph_python/client/nebula_pool.py +++ b/src/nebulagraph_python/client/nebula_pool.py @@ -222,9 +222,6 @@ def get_client(self) -> NebulaClient: if not self.config.block_when_exhausted: raise RuntimeError("No available clients in pool") - # Wait a bit before retrying - time.sleep(0.01) - raise RuntimeError(f"Timeout waiting for client after {self.config.max_wait_ms}ms") def return_client(self, client: NebulaClient) -> None: From f18cdcbe72fbfe0e2fd4fa7a45707aed2c6f4131 Mon Sep 17 00:00:00 2001 From: Anqi <16240361+Nicole00@users.noreply.github.com> Date: Thu, 29 Jan 2026 17:33:38 +0800 Subject: [PATCH 5/5] update --- src/nebulagraph_python/client/_connection.py | 3 +-- src/nebulagraph_python/client/auth_result.py | 4 +--- src/nebulagraph_python/client/client_pool_factory.py | 1 - src/nebulagraph_python/client/constants.py | 6 ++---- src/nebulagraph_python/client/nebula_client.py | 3 +-- src/nebulagraph_python/client/round_robin_load_balancer.py | 3 +-- 6 files changed, 6 insertions(+), 14 deletions(-) diff --git a/src/nebulagraph_python/client/_connection.py b/src/nebulagraph_python/client/_connection.py index f1525070..42dd3380 100644 --- a/src/nebulagraph_python/client/_connection.py +++ b/src/nebulagraph_python/client/_connection.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Connection classes matching Java implementation""" import asyncio import json @@ -409,4 +408,4 @@ async def authenticate( return AuthResult( session_id=int(response.session_id), version=response.version.decode("utf-8"), - ) \ No newline at end of file + ) diff --git a/src/nebulagraph_python/client/auth_result.py b/src/nebulagraph_python/client/auth_result.py index 05929b8c..605841e0 100644 --- a/src/nebulagraph_python/client/auth_result.py +++ b/src/nebulagraph_python/client/auth_result.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""AuthResult class matching Java implementation""" - from dataclasses import dataclass @@ -31,4 +29,4 @@ def get_session_id(self) -> int: def get_version(self) -> str: """Get the server version""" - return self.version \ No newline at end of file + return self.version diff --git a/src/nebulagraph_python/client/client_pool_factory.py b/src/nebulagraph_python/client/client_pool_factory.py index b6fc7ebd..b18d3b10 100644 --- a/src/nebulagraph_python/client/client_pool_factory.py +++ b/src/nebulagraph_python/client/client_pool_factory.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""ClientPoolFactory matching Java implementation""" import logging import time diff --git a/src/nebulagraph_python/client/constants.py b/src/nebulagraph_python/client/constants.py index 1f3b45d1..28a5c32c 100644 --- a/src/nebulagraph_python/client/constants.py +++ b/src/nebulagraph_python/client/constants.py @@ -12,9 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Constants for NebulaGraph client, matching Java implementation""" -# New constants matching Java implementation DEFAULT_MAX_CLIENT_SIZE: int = 10 DEFAULT_MIN_CLIENT_SIZE: int = 1 DEFAULT_CONNECT_TIMEOUT_MS: int = 3 * 1000 # 3 seconds @@ -36,7 +34,7 @@ DEFAULT_DISABLE_VERIFY_SERVER_CERT: bool = False DEFAULT_TLS_PEER_NAME_VERIFY: bool = True -# Backward compatibility constants (old API) +# old default config DEFAULT_SESSION_POOL_SIZE: int = 10 DEFAULT_SESSION_POOL_WAIT_TIMEOUT: float = 0.0 DEFAULT_MAX_CLIENT_SIZE_OLD: int = 10 @@ -45,4 +43,4 @@ DEFAULT_STRICTLY_SERVER_HEALTHY: bool = False DEFAULT_MAX_WAIT: float = 5.0 DEFAULT_CONNECT_TIMEOUT: float = 3.0 -DEFAULT_REQUEST_TIMEOUT: float = 60.0 \ No newline at end of file +DEFAULT_REQUEST_TIMEOUT: float = 60.0 diff --git a/src/nebulagraph_python/client/nebula_client.py b/src/nebulagraph_python/client/nebula_client.py index 28cb17f6..91e9ba7b 100644 --- a/src/nebulagraph_python/client/nebula_client.py +++ b/src/nebulagraph_python/client/nebula_client.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""NebulaClient implementation matching Java NebulaClient""" import asyncio import logging @@ -419,4 +418,4 @@ def _validate_address(addresses: str) -> List[HostAddress]: result.append(HostAddress(host, int(port))) else: raise ValueError(f"Invalid address format: {addr}") - return result \ No newline at end of file + return result diff --git a/src/nebulagraph_python/client/round_robin_load_balancer.py b/src/nebulagraph_python/client/round_robin_load_balancer.py index 9637a0ba..bb4db5ed 100644 --- a/src/nebulagraph_python/client/round_robin_load_balancer.py +++ b/src/nebulagraph_python/client/round_robin_load_balancer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""RoundRobinLoadBalancer matching Java implementation""" import logging from typing import TYPE_CHECKING, Dict, List @@ -107,4 +106,4 @@ def check_servers(self) -> None: if last_auth_e is not None: raise last_auth_e if last_io_e is not None: - raise last_io_e \ No newline at end of file + raise last_io_e