diff --git a/src/nebulagraph_python/client/__init__.py b/src/nebulagraph_python/client/__init__.py index 966de877..9a6992dc 100644 --- a/src/nebulagraph_python/client/__init__.py +++ b/src/nebulagraph_python/client/__init__.py @@ -15,8 +15,8 @@ from nebulagraph_python.client._connection import ( AsyncConnection, ConnectionConfig, - _parse_hosts, ) +from nebulagraph_python.client.address_utils import parse_address, parse_hosts from nebulagraph_python.client.base_executor import ( NebulaBaseAsyncExecutor, NebulaBaseExecutor, @@ -35,5 +35,6 @@ "NebulaPool", "NebulaPoolConfig", "unwrap_value", - "_parse_hosts", + "parse_address", + "parse_hosts", ] diff --git a/src/nebulagraph_python/client/_connection.py b/src/nebulagraph_python/client/_connection.py index 42dd3380..cb430c86 100644 --- a/src/nebulagraph_python/client/_connection.py +++ b/src/nebulagraph_python/client/_connection.py @@ -27,7 +27,7 @@ graph_pb2, graph_pb2_grpc, ) - +from nebulagraph_python.client.address_utils import parse_hosts from nebulagraph_python.client.auth_result import AuthResult from nebulagraph_python.client.constants import DEFAULT_CONNECT_TIMEOUT_MS, DEFAULT_REQUEST_TIMEOUT_MS from nebulagraph_python.data import HostAddress, SSLParam @@ -63,7 +63,7 @@ def from_defaults( if ssl_param is True: ssl_param = SSLParam() return cls( - hosts=_parse_hosts(hosts), + hosts=parse_hosts(hosts), ssl_param=ssl_param, connect_timeout=connect_timeout, request_timeout=request_timeout, @@ -74,21 +74,6 @@ def __post_init__(self): raise ValueError("hosts cannot be empty") -def _parse_hosts(hosts: Union[str, List[str], List[HostAddress]]) -> List[HostAddress]: - """Convert various host formats to list of HostAddress objects (backward compatibility)""" - if isinstance(hosts, str): - hosts = hosts.split(",") - - addresses = [] - for host in hosts: - if isinstance(host, HostAddress): - addresses.append(host) - else: - addr, port = host.split(":") - addresses.append(HostAddress(addr, int(port))) - return addresses - - class Connection(ABC): """Abstract base class for connections, matching Java Connection""" diff --git a/src/nebulagraph_python/client/address_utils.py b/src/nebulagraph_python/client/address_utils.py new file mode 100644 index 00000000..27d779ba --- /dev/null +++ b/src/nebulagraph_python/client/address_utils.py @@ -0,0 +1,142 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Address parsing utilities for NebulaGraph clients. + +This module provides centralized address parsing functionality supporting +IPv4, IPv6, and hostname formats. It's designed to be used across all +client and pool implementations without creating circular dependencies. +""" + +from typing import List, Union + +from nebulagraph_python.data import HostAddress + + +def parse_address(address: str) -> HostAddress: + """Parse a single address string (IPv4, IPv6, or hostname) into HostAddress. + + Supports: + - IPv4: "127.0.0.1:9669" + - IPv6 with brackets: "[2001:db8::1]:9669" + - IPv6 full: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9669" + - IPv6 compressed: "[fe80::1]:9669" + - IPv6 localhost: "[::1]:9669" + - IPv6 with zone index: "[fe80::1%eth0]:9669" + - Hostname: "localhost:9669" + + Args: + address: Address string to parse + + Returns: + HostAddress object + + Raises: + ValueError: If address format is invalid + + Examples: + >>> parse_address("127.0.0.1:9669") + HostAddress(host='127.0.0.1', port=9669) + + >>> parse_address("[2001:db8::1]:9669") + HostAddress(host='2001:db8::1', port=9669) + + >>> parse_address("[::1]:9669") + HostAddress(host='::1', port=9669) + """ + address = address.strip() + + # IPv6 address with brackets: [2001:db8::1]:9669 + if address.startswith("["): + if "]" not in address: + raise ValueError(f"Invalid IPv6 address format: {address}") + # Find the closing bracket + bracket_end = address.index("]") + host = address[1:bracket_end] # Extract address between brackets + + # Check if port is specified after bracket + if bracket_end + 1 < len(address): + if address[bracket_end + 1] != ":": + raise ValueError(f"Invalid IPv6 address format, expected ':' after ']': {address}") + port_str = address[bracket_end + 2:] + if not port_str: + raise ValueError(f"Port number missing after ':': {address}") + port = int(port_str) + else: + raise ValueError(f"Port number missing for IPv6 address: {address}") + + return HostAddress(host, port) + + # IPv4 or hostname: 127.0.0.1:9669 or localhost:9669 + # Use rsplit to handle case where host might have only one colon (IPv4) + # But we need to be careful - if there are multiple colons (IPv6 without brackets), + # we should take the last colon as port separator + if ":" in address: + parts = address.rsplit(":", 1) + host = parts[0] + port_str = parts[1] + + # Validate port is a number + try: + port = int(port_str) + except ValueError: + raise ValueError(f"Invalid port number: {port_str}") + + return HostAddress(host, port) + + raise ValueError(f"Invalid address format: {address}") + + +def parse_hosts(hosts: Union[str, List[str], List[HostAddress]]) -> List[HostAddress]: + """Convert various host formats to list of HostAddress objects. + + This is a convenience function that accepts multiple input formats: + - A single string with comma-separated addresses + - A list of address strings + - A list of HostAddress objects (returned as-is) + - A mixed list of strings and HostAddress objects + + Args: + hosts: Host specification in one of the supported formats + + Returns: + List of HostAddress objects + + Raises: + ValueError: If any address format is invalid + + Examples: + >>> parse_hosts("127.0.0.1:9669") + [HostAddress(host='127.0.0.1', port=9669)] + + >>> parse_hosts("127.0.0.1:9669,[2001:db8::1]:9669") + [HostAddress(host='127.0.0.1', port=9669), HostAddress(host='2001:db8::1', port=9669)] + + >>> parse_hosts(["127.0.0.1:9669", "[::1]:9669"]) + [HostAddress(host='127.0.0.1', port=9669), HostAddress(host='::1', port=9669)] + """ + if isinstance(hosts, str): + hosts = hosts.split(",") + + addresses = [] + for host in hosts: + if isinstance(host, HostAddress): + addresses.append(host) + else: + addresses.append(parse_address(host)) + return addresses + + +# Export public API +__all__ = ["parse_address", "parse_hosts"] \ No newline at end of file diff --git a/src/nebulagraph_python/client/base_executor.py b/src/nebulagraph_python/client/base_executor.py index 4b0495cd..487572d8 100644 --- a/src/nebulagraph_python/client/base_executor.py +++ b/src/nebulagraph_python/client/base_executor.py @@ -90,6 +90,56 @@ async def execute_py( logger.debug("Executing NebulaGraph statement:\n%s", stmt) return await self.execute(stmt, timeout=timeout, do_ping=do_ping) + async def print_query_result( + self, + query: str, + style: str = "table", + width: Optional[int] = None, + min_width: int = 8, + max_width: Optional[int] = None, + padding: int = 1, + collapse_padding: bool = False, + ) -> None: + """Execute a query and print the results in a formatted way using rich + + Args: + ---- + query: The nGQL query to execute + style: Output style - either "table" (default) or "rows" + width: Fixed width for all columns. If None, width will be auto-calculated + min_width: Minimum width of columns when using table style + max_width: Maximum width of columns. If None, no maximum is enforced + padding: Number of spaces around cell contents in table style + collapse_padding: Reduce padding when cell contents are too wide + + Raises: + ------ + Exception if execution fails + + """ + try: + result = await self.execute(query) + result.print( + style=style, + width=width, + min_width=min_width, + max_width=max_width, + padding=padding, + collapse_padding=collapse_padding, + ) + except Exception as e: + from rich.console import Console + from rich.traceback import Traceback + + console = Console() + console.print(f"[bold red]Error executing query:[/bold red] {e!s}") + if debug_flag: + console.print(Traceback.from_exception(type(e), e, e.__traceback__)) + + async def pq(self, query: str, **kwargs): + """Print query result using rich""" + await self.print_query_result(query, **kwargs) + class NebulaBaseExecutor: @abstractmethod diff --git a/src/nebulagraph_python/client/nebula_client.py b/src/nebulagraph_python/client/nebula_client.py index 91e9ba7b..87514be5 100644 --- a/src/nebulagraph_python/client/nebula_client.py +++ b/src/nebulagraph_python/client/nebula_client.py @@ -23,6 +23,7 @@ import grpc from nebulagraph_python._error_code import ErrorCode +from nebulagraph_python.client.address_utils import parse_hosts from nebulagraph_python.client._connection import GrpcConnection, AsyncConnection, ConnectionConfig from nebulagraph_python.client.auth_result import AuthResult from nebulagraph_python.client.base_executor import NebulaBaseExecutor, NebulaBaseAsyncExecutor @@ -79,7 +80,7 @@ def __init__( ssl_param: SSL parameters auth_options: Additional authentication options """ - self.servers: List[HostAddress] = self._validate_address(addresses) + self.servers: List[HostAddress] = parse_hosts(addresses) self.user_name: str = user_name self.password: Optional[str] = password self.auth_options: Dict[str, object] = auth_options or {} @@ -103,6 +104,44 @@ def __init__( self._init_client() + def _init_client(self) -> None: + """Initialize the client connection""" + auth_result: Optional[AuthResult] = None + + # Create connection config + config = ConnectionConfig.from_defaults( + hosts=self.servers, + ssl_param=self.enable_tls or self.ssl_param, + connect_timeout=self.connect_timeout_mills / 1000.0, + request_timeout=self.request_timeout_mills / 1000.0, + ) + if self.ssl_param: + config.ssl_param = self.ssl_param + + self.connection = GrpcConnection() + + try_connect_times = len(self.servers) + random.shuffle(self.servers) + + while try_connect_times > 0: + try_connect_times -= 1 + try: + self.connection.open(self.servers[try_connect_times], self) + auth_result = self.connection.authenticate( + self.user_name, self.auth_options + ) + self.session_id = auth_result.get_session_id() + self.version = auth_result.get_version() + self.create_time = int(time.time() * 1000) + break + except AuthenticatingError as e: + logger.error(f"create NebulaClient failed: {e}") + raise + except Exception as e: + if try_connect_times == 0: + logger.error(f"create NebulaClient failed: {e}") + raise + def execute( self, statement: str, @@ -139,7 +178,7 @@ def get_create_time(self) -> int: def get_host(self) -> str: """Get the connected host address""" if self.connection: - return str(self.connection.get_server_address()) + return str(self.connection.server_addr) return "" def get_connect_timeout_mills(self) -> int: @@ -188,47 +227,6 @@ def _check_closed(self) -> None: if self.is_closed: raise RuntimeError("The NebulaClient already closed.") - def _init_client(self) -> None: - """Initialize the client connection""" - auth_result: Optional[AuthResult] = None - self.connection = GrpcConnection() - - try_connect_times = len(self.servers) - random.shuffle(self.servers) - - while try_connect_times > 0: - try_connect_times -= 1 - try: - self.connection.open(self.servers[try_connect_times], self) - auth_result = self.connection.authenticate( - self.user_name, self.auth_options - ) - self.session_id = auth_result.get_session_id() - self.version = auth_result.get_version() - self.create_time = int(time.time() * 1000) - break - except AuthenticatingError as e: - logger.error(f"create NebulaClient failed: {e}") - raise - except Exception as e: - if try_connect_times == 0: - logger.error(f"create NebulaClient failed: {e}") - raise - - @staticmethod - def _validate_address(addresses: str) -> List[HostAddress]: - """Validate and parse addresses""" - result = [] - if isinstance(addresses, str): - for addr in addresses.split(","): - addr = addr.strip() - if ":" in addr: - host, port = addr.rsplit(":", 1) - result.append(HostAddress(host, int(port))) - else: - raise ValueError(f"Invalid address format: {addr}") - return result - class AsyncNebulaClient(NebulaBaseAsyncExecutor): """Async client to connect to NebulaGraph, matching Java NebulaClient with async support""" @@ -261,7 +259,9 @@ def __init__( ssl_param: SSL parameters auth_options: Additional authentication options """ - self.servers: List[HostAddress] = self._validate_address(addresses) + # Parse addresses using centralized address parser from address_utils + from nebulagraph_python.client.address_utils import parse_hosts + self.servers: List[HostAddress] = parse_hosts(addresses) self.user_name: str = user_name self.password: Optional[str] = password self.auth_options: Dict[str, object] = auth_options or {} @@ -405,17 +405,3 @@ async def _init_client(self) -> None: if try_connect_times == 0: logger.error(f"create AsyncNebulaClient failed: {e}") raise - - @staticmethod - def _validate_address(addresses: str) -> List[HostAddress]: - """Validate and parse addresses""" - result = [] - if isinstance(addresses, str): - for addr in addresses.split(","): - addr = addr.strip() - if ":" in addr: - host, port = addr.rsplit(":", 1) - result.append(HostAddress(host, int(port))) - else: - raise ValueError(f"Invalid address format: {addr}") - return result diff --git a/src/nebulagraph_python/client/nebula_pool.py b/src/nebulagraph_python/client/nebula_pool.py index 337bfa48..9f3f47bc 100644 --- a/src/nebulagraph_python/client/nebula_pool.py +++ b/src/nebulagraph_python/client/nebula_pool.py @@ -20,6 +20,7 @@ from dataclasses import dataclass, field from typing import Dict, List, Optional, TYPE_CHECKING +from nebulagraph_python.client.address_utils import parse_hosts from nebulagraph_python.client.client_pool_factory import ClientPoolFactory from nebulagraph_python.client.constants import ( DEFAULT_BLOCK_WHEN_EXHAUSTED, @@ -119,8 +120,7 @@ def __init__(self, config: NebulaPoolConfig): def _init_pool(self) -> None: """Initialize the connection pool""" - # Parse addresses - addresses = self._parse_addresses(self.config.addresses) + addresses = parse_hosts(self.config.addresses) # Create load balancer config class LoadBalancerConfig: @@ -175,18 +175,7 @@ def __init__(self, pool_config: NebulaPoolConfig, addrs: List[HostAddress]): except Exception as e: logger.warning(f"Failed to create initial client: {e}") - @staticmethod - def _parse_addresses(addresses: str) -> List[HostAddress]: - """Parse address string to HostAddress list""" - result = [] - for addr in addresses.split(","): - addr = addr.strip() - if ":" in addr: - host, port = addr.rsplit(":", 1) - result.append(HostAddress(host, int(port))) - else: - raise ValueError(f"Invalid address format: {addr}") - return result + def get_client(self) -> NebulaClient: """Get a client from the pool""" diff --git a/src/nebulagraph_python/data.py b/src/nebulagraph_python/data.py index 13065fd2..93af9598 100644 --- a/src/nebulagraph_python/data.py +++ b/src/nebulagraph_python/data.py @@ -26,6 +26,10 @@ class HostAddress: port: int def __str__(self): + """Return string representation with proper IPv6 formatting""" + # IPv6 addresses contain colons, so wrap them in brackets + if ":" in self.host and not self.host.startswith("["): + return f"[{self.host}]:{self.port}" return f"{self.host}:{self.port}" def __hash__(self): diff --git a/tests/test_connection.py b/tests/test_connection.py index 6ede33d0..aac76ae5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -22,8 +22,8 @@ AsyncConnection, Connection, ConnectionConfig, - _parse_hosts, ) +from nebulagraph_python.client.address_utils import parse_hosts from nebulagraph_python.data import HostAddress, SSLParam from nebulagraph_python.error import ( AuthenticatingError, @@ -35,18 +35,18 @@ class TestParseHosts: - """Test cases for _parse_hosts function""" + """Test cases for parse_hosts function""" def test_parse_single_string_host(self): """Test parsing a single host string""" - hosts = _parse_hosts("127.0.0.1:9669") + hosts = parse_hosts("127.0.0.1:9669") assert len(hosts) == 1 assert hosts[0].host == "127.0.0.1" assert hosts[0].port == 9669 def test_parse_multiple_string_hosts(self): """Test parsing multiple host strings""" - hosts = _parse_hosts("127.0.0.1:9669,127.0.0.2:9669") + hosts = parse_hosts("127.0.0.1:9669,127.0.0.2:9669") assert len(hosts) == 2 assert hosts[0].host == "127.0.0.1" assert hosts[0].port == 9669 @@ -55,7 +55,7 @@ def test_parse_multiple_string_hosts(self): def test_parse_host_address_objects(self): """Test parsing HostAddress objects""" - hosts = _parse_hosts([HostAddress("127.0.0.1", 9669), HostAddress("127.0.0.2", 9670)]) + hosts = parse_hosts([HostAddress("127.0.0.1", 9669), HostAddress("127.0.0.2", 9670)]) assert len(hosts) == 2 assert hosts[0].host == "127.0.0.1" assert hosts[0].port == 9669 @@ -64,7 +64,7 @@ def test_parse_host_address_objects(self): def test_parse_mixed_hosts(self): """Test parsing mixed host formats""" - hosts = _parse_hosts(["127.0.0.1:9669", HostAddress("127.0.0.2", 9670)]) + hosts = parse_hosts(["127.0.0.1:9669", HostAddress("127.0.0.2", 9670)]) assert len(hosts) == 2 assert hosts[0].host == "127.0.0.1" assert hosts[0].port == 9669 diff --git a/tests/test_ipv6_address_parsing.py b/tests/test_ipv6_address_parsing.py new file mode 100644 index 00000000..3f0ced1c --- /dev/null +++ b/tests/test_ipv6_address_parsing.py @@ -0,0 +1,145 @@ +# Copyright 2025 vesoft-inc +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nebulagraph_python.client.address_utils import parse_address, parse_hosts +from nebulagraph_python.data import HostAddress + + +class TestIPv6AddressParsing: + """Test IPv6 address parsing functionality""" + + def test_parse_ipv4_address(self): + """Test parsing IPv4 addresses""" + result = parse_address("127.0.0.1:9669") + assert result.host == "127.0.0.1" + assert result.port == 9669 + + def test_parse_hostname_with_port(self): + """Test parsing hostname with port""" + result = parse_address("localhost:9669") + assert result.host == "localhost" + assert result.port == 9669 + + def test_parse_ipv6_with_brackets(self): + """Test parsing IPv6 address with brackets""" + result = parse_address("[2001:db8::1]:9669") + assert result.host == "2001:db8::1" + assert result.port == 9669 + + def test_parse_ipv6_full_with_brackets(self): + """Test parsing full IPv6 address with brackets""" + result = parse_address("[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:9669") + assert result.host == "2001:0db8:85a3:0000:0000:8a2e:0370:7334" + assert result.port == 9669 + + def test_parse_ipv6_localhost_with_brackets(self): + """Test parsing IPv6 localhost with brackets""" + result = parse_address("[::1]:9669") + assert result.host == "::1" + assert result.port == 9669 + + def test_parse_ipv6_all_zeros_with_brackets(self): + """Test parsing IPv6 all zeros with brackets""" + result = parse_address("[::]:9669") + assert result.host == "::" + assert result.port == 9669 + + def test_parse_ipv6_compressed_with_brackets(self): + """Test parsing compressed IPv6 address with brackets""" + result = parse_address("[fe80::1]:9669") + assert result.host == "fe80::1" + assert result.port == 9669 + + def test_parse_ipv6_mixed_compressed_with_brackets(self): + """Test parsing mixed compressed IPv6 address with brackets""" + result = parse_address("[2001:db8::1234:5678]:9669") + assert result.host == "2001:db8::1234:5678" + assert result.port == 9669 + + def test_parse_ipv6_missing_closing_bracket(self): + """Test that missing closing bracket raises ValueError""" + with pytest.raises(ValueError, match="Invalid IPv6 address format"): + parse_address("[2001:db8::1") + + def test_parse_ipv6_missing_port(self): + """Test that missing port after bracket raises ValueError""" + with pytest.raises(ValueError, match="Port number missing"): + parse_address("[2001:db8::1]") + + def test_parse_invalid_address_format(self): + """Test that invalid address format raises ValueError""" + with pytest.raises(ValueError, match="Invalid address format"): + parse_address("invalid") + + def test_parse_invalid_port(self): + """Test that invalid port raises ValueError""" + with pytest.raises(ValueError, match="Invalid port number"): + parse_address("127.0.0.1:abc") + + def test_parse_ipv6_wrong_separator_after_bracket(self): + """Test that wrong separator after bracket raises ValueError""" + with pytest.raises(ValueError, match="expected ':' after"): + parse_address("[2001:db8::1]-9669") + + def test_parse_hosts_string(self): + """Test parsing multiple hosts from string""" + result = parse_hosts("127.0.0.1:9669,[2001:db8::1]:9669,localhost:9770") + assert len(result) == 3 + assert result[0].host == "127.0.0.1" + assert result[0].port == 9669 + assert result[1].host == "2001:db8::1" + assert result[1].port == 9669 + assert result[2].host == "localhost" + assert result[2].port == 9770 + + def test_parse_hosts_list(self): + """Test parsing hosts from list""" + result = parse_hosts(["127.0.0.1:9669", "[::1]:9669"]) + assert len(result) == 2 + assert result[0].host == "127.0.0.1" + assert result[0].port == 9669 + assert result[1].host == "::1" + assert result[1].port == 9669 + + def test_parse_hosts_with_hostaddress_objects(self): + """Test parsing hosts with HostAddress objects""" + host_addr = HostAddress("127.0.0.1", 9669) + result = parse_hosts([host_addr, "[2001:db8::1]:9669"]) + assert len(result) == 2 + assert result[0].host == "127.0.0.1" + assert result[0].port == 9669 + assert result[1].host == "2001:db8::1" + assert result[1].port == 9669 + + def test_parse_address_with_whitespace(self): + """Test that whitespace is stripped from addresses""" + result = parse_address(" 127.0.0.1:9669 ") + assert result.host == "127.0.0.1" + assert result.port == 9669 + + def test_parse_hosts_with_whitespace(self): + """Test that whitespace is stripped from multiple hosts""" + result = parse_hosts(" 127.0.0.1:9669 , [2001:db8::1]:9669 , localhost:9770 ") + assert len(result) == 3 + assert result[0].host == "127.0.0.1" + assert result[1].host == "2001:db8::1" + assert result[2].host == "localhost" + + def test_parse_ipv6_with_zone_index(self): + """Test parsing IPv6 address with zone index""" + result = parse_address("[fe80::1%eth0]:9669") + assert result.host == "fe80::1%eth0" + assert result.port == 9669 \ No newline at end of file