Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 44 additions & 27 deletions runtime/mcp/knowledge_base/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,19 +64,32 @@ async def tool_embed(text: str) -> dict:


def _is_postgres() -> bool:
"""检测 backend 是否 PostgreSQL。
"""检测 backend 是否 PostgreSQL 且可达

L2-B 加固: db_url 配 postgres 但 psycopg/asyncpg 未装时, get_engine 调
create_engine 会 ImportError。此处 try/except 兜底, 视为非 postgres,
走 sqlite fallback 分支 (本文件 L83+ 已有 sqlite 分支)。

同时检查 DB 是否可达 (短超时), 不可达则走 sqlite fallback。
"""
try:
return get_engine().dialect.name == "postgresql"
except (ImportError, ModuleNotFoundError):
# psycopg / asyncpg 未装 → 视为非 postgres, 走 sqlite fallback
return False
eng = get_engine()
if eng.dialect.name != "postgresql":
return False
# 轻量连通性检查: 用短超时尝试连接, 避免长时间阻塞
import psycopg

conn = psycopg.connect(
host=eng.url.host or "localhost",
port=eng.url.port or 5432,
dbname=eng.url.database,
user=eng.url.username,
password=eng.url.password or "",
connect_timeout=3,
)
conn.close()
return True
except Exception:
# 其他异常 (连接失败 / dialect 加载失败 等) 也 fallback
return False


Expand Down Expand Up @@ -126,27 +139,31 @@ async def tool_search_similar(text: str, top_k: int = 5, source_type: str = "cas
vec = await _embed(text)
if not _is_postgres():
return {"error": "search_similar requires Postgres + pgvector (sqlite fallback only supports indexing)"}
with get_engine().connect() as c:
rows = c.execute(
sql_text(
"SELECT id, source_id, model, payload, 1 - (embedding <=> CAST(:v AS vector)) AS similarity "
"FROM embeddings WHERE source_type = :t "
"ORDER BY embedding <=> CAST(:v AS vector) LIMIT :k"
),
{"v": _vec_literal(vec), "t": source_type, "k": top_k},
).mappings().all()
return {
"count": len(rows),
"results": [
{
"id": r["id"],
"source_id": r["source_id"],
"similarity": float(r["similarity"]),
"preview": (r["payload"] or {}).get("text", "")[:200],
}
for r in rows
],
}
try:
with get_engine().connect() as c:
rows = c.execute(
sql_text(
"SELECT id, source_id, model, payload, 1 - (embedding <=> CAST(:v AS vector)) AS similarity "
"FROM embeddings WHERE source_type = :t "
"ORDER BY embedding <=> CAST(:v AS vector) LIMIT :k"
),
{"v": _vec_literal(vec), "t": source_type, "k": top_k},
).mappings().all()
return {
"count": len(rows),
"results": [
{
"id": r["id"],
"source_id": r["source_id"],
"similarity": float(r["similarity"]),
"preview": (r["payload"] or {}).get("text", "")[:200],
}
for r in rows
],
}
except Exception as e:
logger.warning("search_similar DB error: {}", e)
return {"error": f"search_similar unavailable: {e}"}


def build_server():
Expand Down
11 changes: 10 additions & 1 deletion runtime/orchestrator/adapters/scripts.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,16 @@ def run_script(script_filename: str, args: list[str] | None = None, *, timeout:
cmd = [sys.executable, str(script_path), *(args or [])]
logger.info("running script: {}", " ".join(cmd))
start = time.monotonic()
proc = subprocess.run(cmd, capture_output=True, text=True, timeout=timeout, check=False)
proc = subprocess.run(
cmd,
capture_output=True,
text=True,
encoding="utf-8",
errors="replace",
timeout=timeout,
check=False,
env={**__import__("os").environ, "PYTHONIOENCODING": "utf-8"},
)
dur_ms = int((time.monotonic() - start) * 1000)
return ScriptResult(
script=script_filename,
Expand Down
7 changes: 6 additions & 1 deletion runtime/storage/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@
def get_engine():
global _engine
if _engine is None:
_engine = create_engine(get_settings().db_url, pool_pre_ping=True, future=True)
url = get_settings().db_url
kw: dict = {"pool_pre_ping": True, "future": True}
# Short connect timeout for psycopg to avoid hanging on unavailable DB
if url.startswith("postgresql"):
kw["connect_args"] = {"connect_timeout": 5}
_engine = create_engine(url, **kw)
return _engine


Expand Down
Loading