Skip to content

Commit 6316f68

Browse files
authored
Merge pull request #31 from pangpang20/master
GaussDB Compatibility Improvements
2 parents 03883dc + 392bab1 commit 6316f68

25 files changed

Lines changed: 800 additions & 68 deletions

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ __pycache__/
99
/gaussdb_binary/
1010
.vscode
1111
.venv
12+
myenv
13+
activate_dev.ps1
1214
.coverage
1315
htmlcov
1416
.idea

gaussdb/gaussdb/_connection_info.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
from . import pq
1313
from ._tz import get_tzinfo
14-
from .conninfo import make_conninfo
1514

1615

1716
class ConnectionInfo:
@@ -72,26 +71,74 @@ def get_parameters(self) -> dict[str, str]:
7271
either from the connection string and parameters passed to
7372
`~Connection.connect()` or from environment variables. The password
7473
is never returned (you can read it using the `password` attribute).
74+
75+
Note:
76+
GaussDB does not support PGconn.info attribute, uses fallback method.
7577
"""
7678
pyenc = self.encoding
7779

78-
# Get the known defaults to avoid reporting them
79-
defaults = {
80-
i.keyword: i.compiled
81-
for i in pq.Conninfo.get_defaults()
82-
if i.compiled is not None
83-
}
84-
# Not returned by the libq. Bug? Bet we're using SSH.
85-
defaults.setdefault(b"channel_binding", b"prefer")
86-
defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()
87-
88-
return {
89-
i.keyword.decode(pyenc): i.val.decode(pyenc)
90-
for i in self.pgconn.info
91-
if i.val is not None
92-
and i.keyword != b"password"
93-
and i.val != defaults.get(i.keyword)
94-
}
80+
# Check if info attribute is supported (GaussDB does not support)
81+
try:
82+
info = self.pgconn.info
83+
if info is None:
84+
return self._get_parameters_fallback()
85+
except (AttributeError, NotImplementedError):
86+
return self._get_parameters_fallback()
87+
88+
# PostgreSQL normal path
89+
try:
90+
# Get the known defaults to avoid reporting them
91+
defaults = {
92+
i.keyword: i.compiled
93+
for i in pq.Conninfo.get_defaults()
94+
if i.compiled is not None
95+
}
96+
# Not returned by the libq. Bug? Bet we're using SSH.
97+
defaults.setdefault(b"channel_binding", b"prefer")
98+
defaults[b"passfile"] = str(Path.home() / ".pgpass").encode()
99+
100+
return {
101+
i.keyword.decode(pyenc): i.val.decode(pyenc)
102+
for i in info
103+
if i.val is not None
104+
and i.keyword != b"password"
105+
and i.val != defaults.get(i.keyword)
106+
}
107+
except Exception:
108+
# Use fallback on error
109+
return self._get_parameters_fallback()
110+
111+
def _get_parameters_fallback(self) -> dict[str, str]:
112+
"""Fallback method for getting connection parameters.
113+
114+
When PGconn.info is not available (e.g., GaussDB),
115+
retrieve basic connection information from other sources.
116+
"""
117+
params = {}
118+
119+
# Get available information from pgconn attributes
120+
if self.pgconn.host:
121+
params["host"] = self.pgconn.host.decode(self.encoding, errors="replace")
122+
123+
if self.pgconn.port:
124+
params["port"] = self.pgconn.port.decode(self.encoding, errors="replace")
125+
126+
if self.pgconn.db:
127+
params["dbname"] = self.pgconn.db.decode(self.encoding, errors="replace")
128+
129+
if self.pgconn.user:
130+
params["user"] = self.pgconn.user.decode(self.encoding, errors="replace")
131+
132+
# Get other available parameters
133+
try:
134+
if hasattr(self.pgconn, "options") and self.pgconn.options:
135+
params["options"] = self.pgconn.options.decode(
136+
self.encoding, errors="replace"
137+
)
138+
except Exception:
139+
pass
140+
141+
return params
95142

96143
@property
97144
def dsn(self) -> str:
@@ -103,7 +150,25 @@ def dsn(self) -> str:
103150
password is never returned (you can read it using the `password`
104151
attribute).
105152
"""
106-
return make_conninfo(**self.get_parameters())
153+
try:
154+
params = self.get_parameters()
155+
except Exception:
156+
params = self._get_parameters_fallback()
157+
158+
if not params:
159+
return ""
160+
161+
# Build DSN string
162+
parts = []
163+
for key, value in params.items():
164+
if key == "password":
165+
continue # Do not include password
166+
# Escape values
167+
if " " in value or "=" in value or "'" in value:
168+
value = "'" + value.replace("'", "\\'") + "'"
169+
parts.append(f"{key}={value}")
170+
171+
return " ".join(parts)
107172

108173
@property
109174
def status(self) -> pq.ConnStatus:

gaussdb/gaussdb/_oids.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,3 +122,55 @@
122122
YEAR_OID = 1038
123123

124124
# autogenerated: end
125+
126+
127+
# =====================================================
128+
# GaussDB OID 别名映射(PostgreSQL OID -> GaussDB 等效 OID 列表)
129+
# =====================================================
130+
131+
GAUSSDB_OID_ALIASES: dict[int, list[int]] = {
132+
# date 类型可能映射到多个 OID
133+
DATE_OID: [DATE_OID, SMALLDATETIME_OID],
134+
# timestamp 类型
135+
TIMESTAMP_OID: [TIMESTAMP_OID, SMALLDATETIME_OID],
136+
TIMESTAMPTZ_OID: [TIMESTAMPTZ_OID],
137+
# 其他类型保持一对一映射
138+
}
139+
140+
141+
def is_compatible_oid(expected_oid: int, actual_oid: int) -> bool:
142+
"""
143+
检查两个 OID 是否兼容
144+
145+
用于 GaussDB 场景下的类型比较,考虑 OID 别名。
146+
147+
Args:
148+
expected_oid: 期望的 OID
149+
actual_oid: 实际的 OID
150+
151+
Returns:
152+
是否兼容
153+
"""
154+
if expected_oid == actual_oid:
155+
return True
156+
157+
# 检查别名映射
158+
aliases = GAUSSDB_OID_ALIASES.get(expected_oid, [expected_oid])
159+
return actual_oid in aliases
160+
161+
162+
def get_oid_name(oid: int) -> str:
163+
"""
164+
获取 OID 对应的类型名称
165+
166+
Args:
167+
oid: 类型 OID
168+
169+
Returns:
170+
类型名称字符串
171+
"""
172+
# 反向查找 OID 常量名
173+
for name, value in globals().items():
174+
if name.endswith("_OID") and value == oid:
175+
return name.replace("_OID", "").lower()
176+
return f"oid_{oid}"

gaussdb/gaussdb/_py_transformer.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,33 @@
3838
PY_TEXT = PyFormat.TEXT
3939

4040

41+
def _is_empty_value(val: Any) -> bool:
42+
"""检测值是否为等效空值(GaussDB 兼容)"""
43+
if val is None:
44+
return True
45+
if isinstance(val, (bytes, str)) and len(val) == 0:
46+
return True
47+
if isinstance(val, (dict, list)) and len(val) == 0:
48+
return True
49+
return False
50+
51+
52+
def _normalize_empty_value(val: Any, normalize_to_none: bool = False) -> Any:
53+
"""
54+
将空值规范化处理
55+
56+
Args:
57+
val: 要处理的值
58+
normalize_to_none: 如果为 True,将空字符串/空字典等转为 None
59+
60+
Returns:
61+
规范化后的值
62+
"""
63+
if normalize_to_none and _is_empty_value(val):
64+
return None
65+
return val
66+
67+
4168
class Transformer(AdaptContext):
4269
"""
4370
An object that can adapt efficiently between Python and GaussDB.

gaussdb/gaussdb/_typeinfo.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from . import sql
1717
from .abc import AdaptContext, Query
1818
from .rows import dict_row
19+
from ._oids import GAUSSDB_OID_ALIASES
1920
from ._compat import TypeAlias, TypeVar
2021
from ._typemod import TypeModifier
2122
from ._encodings import conn_encoding
@@ -209,6 +210,56 @@ def get_precision(self, fmod: int) -> int | None:
209210
def get_scale(self, fmod: int) -> int | None:
210211
return self.typemod.get_scale(fmod)
211212

213+
@classmethod
214+
def fetch_runtime_oid(cls, conn: Any, typename: str) -> int | None:
215+
"""
216+
运行时获取类型 OID
217+
218+
从数据库查询正确的 OID,处理 GaussDB 差异。
219+
220+
Args:
221+
conn: 数据库连接
222+
typename: 类型名称
223+
224+
Returns:
225+
类型 OID,查询失败返回 None
226+
"""
227+
try:
228+
from .connection import Connection
229+
230+
if isinstance(conn, Connection):
231+
result = conn.execute(
232+
"SELECT oid FROM pg_type WHERE typname = %s", [typename]
233+
).fetchone()
234+
else:
235+
# AsyncConnection
236+
import asyncio
237+
238+
async def _fetch():
239+
result = await conn.execute(
240+
"SELECT oid FROM pg_type WHERE typname = %s", [typename]
241+
)
242+
return await result.fetchone()
243+
244+
result = asyncio.run(_fetch())
245+
246+
return result[0] if result else None
247+
except Exception:
248+
return None
249+
250+
@classmethod
251+
def get_compatible_oids(cls, base_oid: int) -> list[int]:
252+
"""
253+
获取兼容的 OID 列表
254+
255+
Args:
256+
base_oid: 基础 OID
257+
258+
Returns:
259+
包含基础 OID 及其别名的列表
260+
"""
261+
return GAUSSDB_OID_ALIASES.get(base_oid, [base_oid])
262+
212263

213264
class TypesRegistry:
214265
"""

gaussdb/gaussdb/raw_cursor.py

Lines changed: 49 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from typing import TYPE_CHECKING
1010

11+
from . import errors as e
1112
from .abc import ConnectionType, Params, Query
1213
from .sql import Composable
1314
from .rows import Row
@@ -26,6 +27,19 @@
2627

2728

2829
class GaussDBRawQuery(GaussDBQuery):
30+
"""
31+
GaussDB raw query class.
32+
33+
Only supports positional placeholders ($1, $2, ...), not named placeholders.
34+
"""
35+
36+
# Query cache size
37+
_CACHE_SIZE = 128
38+
39+
def __init__(self, *args, **kwargs):
40+
super().__init__(*args, **kwargs)
41+
self._query_cache: dict[bytes, bytes] = {}
42+
2943
def convert(self, query: Query, vars: Params | None) -> None:
3044
if isinstance(query, str):
3145
bquery = query.encode(self._encoding)
@@ -34,14 +48,43 @@ def convert(self, query: Query, vars: Params | None) -> None:
3448
else:
3549
bquery = query
3650

37-
self.query = bquery
51+
# Try to get from cache
52+
if bquery in self._query_cache:
53+
self.query = self._query_cache[bquery]
54+
else:
55+
# Validate query doesn't contain named placeholders
56+
if b"%(" in bquery:
57+
raise e.ProgrammingError(
58+
"RawCursor does not support named placeholders (%(name)s). "
59+
"Use positional placeholders ($1, $2, ...) instead."
60+
)
61+
62+
self.query = bquery
63+
64+
# Cache result
65+
if len(self._query_cache) < self._CACHE_SIZE:
66+
self._query_cache[bquery] = bquery
67+
3868
self._want_formats = self._order = None
3969
self.dump(vars)
4070

4171
def dump(self, vars: Params | None) -> None:
72+
"""
73+
Serialize parameters.
74+
75+
Args:
76+
vars: Parameter sequence (must be sequence, not dict)
77+
78+
Raises:
79+
TypeError: If parameters are not a sequence
80+
"""
4281
if vars is not None:
4382
if not GaussDBQuery.is_params_sequence(vars):
44-
raise TypeError("raw queries require a sequence of parameters")
83+
raise TypeError(
84+
"RawCursor requires a sequence of parameters (tuple or list), "
85+
f"got {type(vars).__name__}. "
86+
"For named parameters, use regular Cursor instead."
87+
)
4588
self._want_formats = [PyFormat.AUTO] * len(vars)
4689

4790
self.params = self._tx.dump_sequence(vars, self._want_formats)
@@ -52,6 +95,10 @@ def dump(self, vars: Params | None) -> None:
5295
self.types = ()
5396
self.formats = None
5497

98+
def clear_cache(self) -> None:
99+
"""Clear query cache."""
100+
self._query_cache.clear()
101+
55102

56103
class RawCursorMixin(BaseCursor[ConnectionType, Row]):
57104
_query_cls = GaussDBRawQuery

gaussdb/gaussdb/types/array.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,3 +471,23 @@ def _load_binary(data: Buffer, tx: Transformer) -> list[Any]:
471471
out = [out[i : i + dim] for i in range(0, len(out), dim)]
472472

473473
return out
474+
475+
476+
def array_equals_unordered(arr1: list[Any], arr2: list[Any]) -> bool:
477+
"""
478+
Compare two arrays without considering element order.
479+
480+
Used for GaussDB compatibility scenarios where array element order may differ.
481+
"""
482+
if arr1 is None and arr2 is None:
483+
return True
484+
if arr1 is None or arr2 is None:
485+
return False
486+
if len(arr1) != len(arr2):
487+
return False
488+
489+
try:
490+
return sorted(arr1) == sorted(arr2)
491+
except TypeError:
492+
# Elements not sortable, fall back to set comparison
493+
return set(map(str, arr1)) == set(map(str, arr2))

0 commit comments

Comments
 (0)