Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions python/python/knowledge_graph/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from .service import LanceKnowledgeGraph, create_default_service
from .store import LanceGraphStore
from .webservice import create_app
from lance_graph import VectorSearch, DistanceMetric

TableMapping = Mapping[str, pa.Table]

Expand Down Expand Up @@ -66,6 +67,23 @@ def run(
)
return query.execute(sources)

def run_with_vector_rerank(
self,
statement: str,
vector_search: "VectorSearch",
*,
datasets: Optional[TableMapping] = None,
) -> pa.Table:
"""Execute a Cypher statement and rerank results by vector similarity."""

query = CypherQuery(statement).with_config(self.config)
sources: Dict[str, pa.Table] = dict(self._tables)
if datasets:
sources.update(
{name: _ensure_table(name, table) for name, table in datasets.items()}
)
return query.execute_with_vector_rerank(sources, vector_search)

def tables(self) -> Dict[str, pa.Table]:
"""Return a shallow copy of the registered datasets."""
return dict(self._tables)
Expand Down Expand Up @@ -129,4 +147,6 @@ def build(self) -> KnowledgeGraph:
"preview_extraction",
"HeuristicExtractor",
"LLMExtractor",
"VectorSearch",
"DistanceMetric",
]
108 changes: 106 additions & 2 deletions python/python/knowledge_graph/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@

from __future__ import annotations

from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional

import pyarrow as pa
import yaml
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from pydantic import BaseModel, Field

from .config import KnowledgeGraphConfig
from .service import LanceKnowledgeGraph
Expand All @@ -23,6 +23,45 @@ class QueryResponse(BaseModel):
row_count: int


class VectorQueryRequest(BaseModel):
"""Request body for vector-reranked Cypher queries.

Supply either ``vector`` (raw floats) or ``query_text`` (auto-embedded via OpenAI).
"""

query: str = Field(..., description="Cypher statement to execute.")
column: str = Field(..., description="Name of the vector column to search.")

# Choose one: pass the vector directly or pass the text and let the server automatically embed it
vector: Optional[List[float]] = Field(
None,
description="Query vector (float list). Mutually exclusive with query_text.",
)
query_text: Optional[str] = Field(
None, description="Text to embed as query vector. Requires OpenAI API key."
)

metric: Literal["cosine", "l2", "dot"] = Field(
"cosine", description="Distance metric: cosine | l2 | dot."
)
top_k: int = Field(10, ge=1, le=10000, description="Number of nearest neighbours.")
include_distance: bool = Field(
True, description="Include _distance column in results."
)
embedding_model: str = Field(
"text-embedding-3-small",
description="OpenAI embedding model (only used when query_text is provided).",
)


class VectorQueryResponse(BaseModel):
rows: List[Dict[str, Any]]
row_count: int
column: str
metric: str
top_k: int


class DatasetUpsertRequest(BaseModel):
records: List[Dict[str, Any]]
merge: bool = True
Expand Down Expand Up @@ -98,6 +137,71 @@ async def get_schema() -> Dict[str, Any]:
payload = yaml.safe_load(handle) or {}
return {"path": str(schema_path), "schema": payload}

@self.router.post("/query/vector", response_model=VectorQueryResponse)
async def execute_vector_query(
request: VectorQueryRequest,
) -> VectorQueryResponse:
"""Execute a Cypher query with vector similarity reranking.

Supply ``vector`` (raw floats) or ``query_text`` (auto-embedded).
"""
if request.vector is None and request.query_text is None:
raise HTTPException(
status_code=400,
detail="Either 'vector' or 'query_text' must be provided.",
)
if request.vector is not None and request.query_text is not None:
raise HTTPException(
status_code=400,
detail="Provide only one of 'vector' or 'query_text', not both.",
)

service = self._get_service()

try:
if request.query_text is not None:
# Text: service internally calls EmbeddingGenerator
result = service.query_by_text(
request.query,
request.query_text,
request.column,
top_k=request.top_k,
metric=request.metric,
include_distance=request.include_distance,
embedding_model=request.embedding_model,
)
else:
# Vector: Constructing VectorSearch directly
from lance_graph import DistanceMetric, VectorSearch

_metric_map = {
"cosine": DistanceMetric.Cosine,
"l2": DistanceMetric.L2,
"dot": DistanceMetric.Dot,
}
vs = (
VectorSearch(request.column)
.query_vector(request.vector)
.metric(_metric_map[request.metric])
.top_k(request.top_k)
.include_distance(request.include_distance)
)
result = service.run_with_vector_rerank(request.query, vs)

except RuntimeError as exc:
raise HTTPException(status_code=500, detail=str(exc)) from exc
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc

rows = result.to_pylist()
return VectorQueryResponse(
rows=rows,
row_count=len(rows),
column=request.column,
metric=request.metric,
top_k=request.top_k,
)

def close(self) -> None:
"""Release retained resources."""
self._service = None
Expand Down
85 changes: 84 additions & 1 deletion python/python/knowledge_graph/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from typing import TYPE_CHECKING, Iterable, Mapping, MutableMapping, Optional

from lance_graph import CypherQuery, GraphConfig
from lance_graph import CypherQuery, GraphConfig, VectorSearch, DistanceMetric

from .config import KnowledgeGraphConfig, build_default_graph_config
from .store import LanceGraphStore
Expand Down Expand Up @@ -141,6 +141,89 @@ def query(
"""Alias for :meth:`run` to match the semantic service naming."""
return self.run(statement, datasets=datasets)

def run_with_vector_rerank(
self,
statement: str,
vector_search: "VectorSearch",
*,
datasets: Optional[Mapping[str, "pa.Table"]] = None,
) -> "pa.Table":
"""Execute a Cypher statement and rerank results by vector similarity.

Parameters
----------
statement:
Cypher query string.
vector_search:
A configured ``VectorSearch`` instance (column, vector, metric, top_k).
datasets:
Optional override tables injected on top of persisted datasets.
"""
query = CypherQuery(statement).with_config(self._config)

referenced_tables = set(query.node_labels()) | set(query.relationship_types())
base_tables: MutableMapping[str, "pa.Table"] = dict(
self._store.load_tables(referenced_tables)
)
if datasets:
base_tables.update(datasets)
return query.execute_with_vector_rerank(base_tables, vector_search)

def query_by_text(
self,
statement: str,
query_text: str,
column: str,
*,
top_k: int = 10,
metric: str = "cosine",
include_distance: bool = True,
embedding_model: str = "text-embedding-3-small",
datasets: Optional[Mapping[str, "pa.Table"]] = None,
) -> "pa.Table":
"""Convenience method: embed ``query_text`` then call run_with_vector_rerank.

Parameters
----------
statement:
Cypher query string.
query_text:
Natural-language text to embed as the query vector.
column:
Name of the vector column in the dataset.
top_k:
Number of nearest neighbours to return.
metric:
Distance metric: "cosine", "l2", or "dot".
include_distance:
Whether to include the ``_distance`` column in results.
embedding_model:
OpenAI embedding model name.
datasets:
Optional override tables.
"""
from .embeddings import EmbeddingGenerator

_metric_map = {
"cosine": DistanceMetric.Cosine,
"l2": DistanceMetric.L2,
"dot": DistanceMetric.Dot,
}
rust_metric = _metric_map.get(metric.lower(), DistanceMetric.Cosine)

vector = EmbeddingGenerator(model=embedding_model).embed_one(query_text)
if vector is None:
raise RuntimeError(f"Failed to generate embedding for text: {query_text!r}")

vs = (
VectorSearch(column)
.query_vector(vector)
.metric(rust_metric)
.top_k(top_k)
.include_distance(include_distance)
)
return self.run_with_vector_rerank(statement, vs, datasets=datasets)


def create_default_service(
config: Optional[KnowledgeGraphConfig] = None,
Expand Down