diff --git a/python/python/knowledge_graph/__init__.py b/python/python/knowledge_graph/__init__.py index c1669836..be49add4 100644 --- a/python/python/knowledge_graph/__init__.py +++ b/python/python/knowledge_graph/__init__.py @@ -25,6 +25,7 @@ from .service import LanceKnowledgeGraph, create_default_service from .store import LanceGraphStore from .webservice import create_app +from lance_graph import VectorSearch, DistanceMetric TableMapping = Mapping[str, pa.Table] @@ -66,6 +67,23 @@ def run( ) return query.execute(sources) + def run_with_vector_rerank( + self, + statement: str, + vector_search: "VectorSearch", + *, + datasets: Optional[TableMapping] = None, + ) -> pa.Table: + """Execute a Cypher statement and rerank results by vector similarity.""" + + query = CypherQuery(statement).with_config(self.config) + sources: Dict[str, pa.Table] = dict(self._tables) + if datasets: + sources.update( + {name: _ensure_table(name, table) for name, table in datasets.items()} + ) + return query.execute_with_vector_rerank(sources, vector_search) + def tables(self) -> Dict[str, pa.Table]: """Return a shallow copy of the registered datasets.""" return dict(self._tables) @@ -129,4 +147,6 @@ def build(self) -> KnowledgeGraph: "preview_extraction", "HeuristicExtractor", "LLMExtractor", + "VectorSearch", + "DistanceMetric", ] diff --git a/python/python/knowledge_graph/component.py b/python/python/knowledge_graph/component.py index ecc1460c..cf87051d 100644 --- a/python/python/knowledge_graph/component.py +++ b/python/python/knowledge_graph/component.py @@ -2,12 +2,12 @@ from __future__ import annotations -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional import pyarrow as pa import yaml from fastapi import APIRouter, HTTPException -from pydantic import BaseModel +from pydantic import BaseModel, Field from .config import KnowledgeGraphConfig from .service import LanceKnowledgeGraph @@ -23,6 +23,45 @@ class QueryResponse(BaseModel): row_count: int +class VectorQueryRequest(BaseModel): + """Request body for vector-reranked Cypher queries. + + Supply either ``vector`` (raw floats) or ``query_text`` (auto-embedded via OpenAI). + """ + + query: str = Field(..., description="Cypher statement to execute.") + column: str = Field(..., description="Name of the vector column to search.") + + # Choose one: pass the vector directly or pass the text and let the server automatically embed it + vector: Optional[List[float]] = Field( + None, + description="Query vector (float list). Mutually exclusive with query_text.", + ) + query_text: Optional[str] = Field( + None, description="Text to embed as query vector. Requires OpenAI API key." + ) + + metric: Literal["cosine", "l2", "dot"] = Field( + "cosine", description="Distance metric: cosine | l2 | dot." + ) + top_k: int = Field(10, ge=1, le=10000, description="Number of nearest neighbours.") + include_distance: bool = Field( + True, description="Include _distance column in results." + ) + embedding_model: str = Field( + "text-embedding-3-small", + description="OpenAI embedding model (only used when query_text is provided).", + ) + + +class VectorQueryResponse(BaseModel): + rows: List[Dict[str, Any]] + row_count: int + column: str + metric: str + top_k: int + + class DatasetUpsertRequest(BaseModel): records: List[Dict[str, Any]] merge: bool = True @@ -98,6 +137,71 @@ async def get_schema() -> Dict[str, Any]: payload = yaml.safe_load(handle) or {} return {"path": str(schema_path), "schema": payload} + @self.router.post("/query/vector", response_model=VectorQueryResponse) + async def execute_vector_query( + request: VectorQueryRequest, + ) -> VectorQueryResponse: + """Execute a Cypher query with vector similarity reranking. + + Supply ``vector`` (raw floats) or ``query_text`` (auto-embedded). + """ + if request.vector is None and request.query_text is None: + raise HTTPException( + status_code=400, + detail="Either 'vector' or 'query_text' must be provided.", + ) + if request.vector is not None and request.query_text is not None: + raise HTTPException( + status_code=400, + detail="Provide only one of 'vector' or 'query_text', not both.", + ) + + service = self._get_service() + + try: + if request.query_text is not None: + # Text: service internally calls EmbeddingGenerator + result = service.query_by_text( + request.query, + request.query_text, + request.column, + top_k=request.top_k, + metric=request.metric, + include_distance=request.include_distance, + embedding_model=request.embedding_model, + ) + else: + # Vector: Constructing VectorSearch directly + from lance_graph import DistanceMetric, VectorSearch + + _metric_map = { + "cosine": DistanceMetric.Cosine, + "l2": DistanceMetric.L2, + "dot": DistanceMetric.Dot, + } + vs = ( + VectorSearch(request.column) + .query_vector(request.vector) + .metric(_metric_map[request.metric]) + .top_k(request.top_k) + .include_distance(request.include_distance) + ) + result = service.run_with_vector_rerank(request.query, vs) + + except RuntimeError as exc: + raise HTTPException(status_code=500, detail=str(exc)) from exc + except ValueError as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + + rows = result.to_pylist() + return VectorQueryResponse( + rows=rows, + row_count=len(rows), + column=request.column, + metric=request.metric, + top_k=request.top_k, + ) + def close(self) -> None: """Release retained resources.""" self._service = None diff --git a/python/python/knowledge_graph/service.py b/python/python/knowledge_graph/service.py index a0e53f63..b78b0a19 100644 --- a/python/python/knowledge_graph/service.py +++ b/python/python/knowledge_graph/service.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Iterable, Mapping, MutableMapping, Optional -from lance_graph import CypherQuery, GraphConfig +from lance_graph import CypherQuery, GraphConfig, VectorSearch, DistanceMetric from .config import KnowledgeGraphConfig, build_default_graph_config from .store import LanceGraphStore @@ -141,6 +141,89 @@ def query( """Alias for :meth:`run` to match the semantic service naming.""" return self.run(statement, datasets=datasets) + def run_with_vector_rerank( + self, + statement: str, + vector_search: "VectorSearch", + *, + datasets: Optional[Mapping[str, "pa.Table"]] = None, + ) -> "pa.Table": + """Execute a Cypher statement and rerank results by vector similarity. + + Parameters + ---------- + statement: + Cypher query string. + vector_search: + A configured ``VectorSearch`` instance (column, vector, metric, top_k). + datasets: + Optional override tables injected on top of persisted datasets. + """ + query = CypherQuery(statement).with_config(self._config) + + referenced_tables = set(query.node_labels()) | set(query.relationship_types()) + base_tables: MutableMapping[str, "pa.Table"] = dict( + self._store.load_tables(referenced_tables) + ) + if datasets: + base_tables.update(datasets) + return query.execute_with_vector_rerank(base_tables, vector_search) + + def query_by_text( + self, + statement: str, + query_text: str, + column: str, + *, + top_k: int = 10, + metric: str = "cosine", + include_distance: bool = True, + embedding_model: str = "text-embedding-3-small", + datasets: Optional[Mapping[str, "pa.Table"]] = None, + ) -> "pa.Table": + """Convenience method: embed ``query_text`` then call run_with_vector_rerank. + + Parameters + ---------- + statement: + Cypher query string. + query_text: + Natural-language text to embed as the query vector. + column: + Name of the vector column in the dataset. + top_k: + Number of nearest neighbours to return. + metric: + Distance metric: "cosine", "l2", or "dot". + include_distance: + Whether to include the ``_distance`` column in results. + embedding_model: + OpenAI embedding model name. + datasets: + Optional override tables. + """ + from .embeddings import EmbeddingGenerator + + _metric_map = { + "cosine": DistanceMetric.Cosine, + "l2": DistanceMetric.L2, + "dot": DistanceMetric.Dot, + } + rust_metric = _metric_map.get(metric.lower(), DistanceMetric.Cosine) + + vector = EmbeddingGenerator(model=embedding_model).embed_one(query_text) + if vector is None: + raise RuntimeError(f"Failed to generate embedding for text: {query_text!r}") + + vs = ( + VectorSearch(column) + .query_vector(vector) + .metric(rust_metric) + .top_k(top_k) + .include_distance(include_distance) + ) + return self.run_with_vector_rerank(statement, vs, datasets=datasets) + def create_default_service( config: Optional[KnowledgeGraphConfig] = None,