diff --git a/tidb_vector/integrations/vector_client.py b/tidb_vector/integrations/vector_client.py index 174c2a9..cd6b846 100644 --- a/tidb_vector/integrations/vector_client.py +++ b/tidb_vector/integrations/vector_client.py @@ -4,7 +4,7 @@ import logging import enum import uuid -from typing import Type, Tuple, Any, Dict, Generator, Iterable, List, Optional +from typing import Sequence, Type, Tuple, Any, Dict, Generator, Iterable, List, Optional import sqlalchemy from sqlalchemy.orm import Session, declarative_base @@ -73,6 +73,7 @@ class QueryResult: document: str metadata: dict distance: float + embedding: Sequence[float] class TiDBVectorClient: @@ -303,6 +304,7 @@ def query( metadata=doc.meta, id=doc.id, distance=doc.distance, + embedding=doc.embedding, ) for doc in relevant_docs ] @@ -326,6 +328,7 @@ def _vector_search( self._table_model.id, self._table_model.meta, self._table_model.document, + self._table_model.embedding, self.distance_strategy(query_embedding).label("distance"), ) .filter(filter_by) @@ -342,6 +345,7 @@ def _vector_search( self._table_model.id, self._table_model.meta, self._table_model.document, + self._table_model.embedding, self.distance_strategy(query_embedding).label("distance"), ) .order_by(sqlalchemy.asc("distance")) @@ -354,6 +358,7 @@ def _vector_search( subquery.c.id, subquery.c.meta, subquery.c.document, + subquery.c.embedding, subquery.c.distance, ) .filter(filter_by)