diff --git a/poetry.lock b/poetry.lock index 3267940..d522198 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4106,7 +4106,7 @@ version = "12.0.0" description = "Python Imaging Library (fork)" optional = false python-versions = ">=3.10" -groups = ["main", "metrics"] +groups = ["main"] files = [ {file = "pillow-12.0.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:3adfb466bbc544b926d50fe8f4a4e6abd8c6bffd28a26177594e6e9b2b76572b"}, {file = "pillow-12.0.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1ac11e8ea4f611c3c0147424eae514028b5e9077dd99ab91e1bd7bc33ff145e1"}, @@ -6059,30 +6059,31 @@ test = ["Cython", "array-api-strict (>=2.3.1)", "asv", "gmpy2", "hypothesis (>=6 [[package]] name = "sentence-transformers" -version = "3.4.1" -description = "State-of-the-Art Text Embeddings" +version = "5.2.0" +description = "Embeddings, Retrieval, and Reranking" optional = false -python-versions = ">=3.9" -groups = ["main", "metrics"] +python-versions = ">=3.10" +groups = ["metrics"] files = [ - {file = "sentence_transformers-3.4.1-py3-none-any.whl", hash = "sha256:e026dc6d56801fd83f74ad29a30263f401b4b522165c19386d8bc10dcca805da"}, - {file = "sentence_transformers-3.4.1.tar.gz", hash = "sha256:68daa57504ff548340e54ff117bd86c1d2f784b21e0fb2689cf3272b8937b24b"}, + {file = "sentence_transformers-5.2.0-py3-none-any.whl", hash = "sha256:aa57180f053687d29b08206766ae7db549be5074f61849def7b17bf0b8025ca2"}, + {file = "sentence_transformers-5.2.0.tar.gz", hash = "sha256:acaeb38717de689f3dab45d5e5a02ebe2f75960a4764ea35fea65f58a4d3019f"}, ] [package.dependencies] huggingface-hub = ">=0.20.0" -Pillow = "*" scikit-learn = "*" scipy = "*" torch = ">=1.11.0" tqdm = "*" -transformers = ">=4.41.0,<5.0.0" +transformers = ">=4.41.0,<6.0.0" +typing_extensions = ">=4.5.0" [package.extras] -dev = ["accelerate (>=0.20.3)", "datasets", "peft", "pre-commit", "pytest", "pytest-cov"] -onnx = ["optimum[onnxruntime] (>=1.23.1)"] -onnx-gpu = ["optimum[onnxruntime-gpu] (>=1.23.1)"] -openvino = ["optimum-intel[openvino] (>=1.20.0)"] +dev = ["Pillow", "accelerate (>=0.20.3)", "datasets", "peft", "pre-commit", "pytest", "pytest-cov"] +image = ["Pillow"] +onnx = ["optimum-onnx[onnxruntime]"] +onnx-gpu = ["optimum-onnx[onnxruntime-gpu]"] +openvino = ["optimum-intel[openvino]"] train = ["accelerate (>=0.20.3)", "datasets"] [[package]] @@ -7679,4 +7680,4 @@ cffi = ["cffi (>=1.17,<2.0) ; platform_python_implementation != \"PyPy\" and pyt [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.13" -content-hash = "4339b0051819bb83b7764e9f27a68c74c390063dd547a94312e7134947286c44" +content-hash = "af1b46c365d37a727c0d91a93ea91e36f302abb522392a7f41599f15893b2ee8" diff --git a/pyproject.toml b/pyproject.toml index 98b58a8..212fd7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,6 @@ python-multipart = "^0.0.18" qdrant-client = "^1.15.0" requests = "^2.32.4" scikit-learn = "^1.5.1" -sentence-transformers = "^3.4.1" sqlalchemy = "^2.0.35" transformers="^4.50.0" torch = {version = "^2.2.2+cpu", source = "pytorch_cpu"} diff --git a/src/app/services/search.py b/src/app/services/search.py index 5633838..cbdff10 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -11,8 +11,9 @@ from qdrant_client import models as qdrant_models from qdrant_client.http import exceptions as qdrant_exceptions from qdrant_client.http import models as http_models -from sentence_transformers import SentenceTransformer from sklearn.metrics.pairwise import cosine_similarity +import torch +from transformers import AutoModel, AutoTokenizer from src.app.models.collections import Collection from src.app.models.search import ( @@ -139,10 +140,14 @@ def _get_model(self, curr_model: str) -> dict: try: time_start = time.time() # TODO: path should be an env variable - model = SentenceTransformer(f"../models/embedding/{curr_model}/") + model_path = f"../models/embedding/{curr_model}/" + model = AutoModel.from_pretrained(model_path) + tokenizer = AutoTokenizer.from_pretrained(model_path) + model.eval() self.model[curr_model] = { - "max_seq_length": model.get_max_seq_length(), + "max_seq_length": tokenizer.model_max_length, "instance": model, + "tokenizer": tokenizer, } time_end = time.time() @@ -179,6 +184,15 @@ def _split_input_seq_len(self, seq_len: int, input: str) -> list[str]: return inputs + @log_time_and_error_sync + def _compute_embeddings(self, model, tokenizer, inputs: list[str]) -> np.ndarray: + with torch.no_grad(): + tokenized_inputs = tokenizer(inputs, padding=True, truncation=True, return_tensors='pt') + model_output = model(**tokenized_inputs) + embeddings = model_output[0][:, 0] + embeddings = torch.nn.functional.normalize(embeddings, dim=1).numpy() + return embeddings + @log_time_and_error_sync async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) @@ -188,11 +202,11 @@ async def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: seq_len = self.model[curr_model]["max_seq_length"] model = self.model[curr_model]["instance"] + tokenizer = self.model[curr_model]["tokenizer"] inputs = self._split_input_seq_len(seq_len, search_input) try: - embeddings = await run_in_threadpool(model.encode, inputs) - # embeddings = model.encode(sentences=inputs) + embeddings = await run_in_threadpool(self._compute_embeddings, model, tokenizer, inputs) embeddings = np.mean(embeddings, axis=0) except Exception as ex: logger.error("api_error=EMBED_ERROR model=%s", curr_model) @@ -210,8 +224,11 @@ async def simple_search_handler(self, qp: EnhancedSearchQuery): model = await run_in_threadpool( self._get_model, curr_model="granite-embedding-107m-multilingual" ) + model_instance = model["instance"] - embedding = await run_in_threadpool(model_instance.encode, qp.query) + tokenizer = model["tokenizer"] + embedding_input = qp.query if isinstance(qp.query, list) else [qp.query] + embedding = await run_in_threadpool(self._compute_embeddings, model_instance, tokenizer, embedding_input) result = await self.search( collection_info="collection_welearn_mul_granite-embedding-107m-multilingual", embedding=embedding,