From b731542a77b09ed8d58acd09a94d46ed367e5b6e Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 6 Feb 2026 15:16:36 -0800 Subject: [PATCH 1/6] wip: adding max_distance via switch to use AND vector range query --- redisvl/query/aggregate.py | 22 +++++++--- tests/integration/test_aggregation.py | 38 +++++++++++++++++ .../integration/test_redis_cluster_support.py | 2 + tests/integration/test_search_index.py | 1 + tests/unit/test_aggregation_types.py | 41 ++++++++++++++++--- 5 files changed, 94 insertions(+), 10 deletions(-) diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index 1f1da2ac..5f877976 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -24,6 +24,7 @@ class Vector(BaseModel): field_name: str dtype: str = "float32" weight: float = 1.0 + max_distance: float = 2.0 @field_validator("dtype") @classmethod @@ -36,6 +37,15 @@ def validate_dtype(cls, dtype: str) -> str: ) return dtype + @field_validator("max_distance") + @classmethod + def validate_max_distance(cls, max_distance: float) -> float: + if not isinstance(max_distance, float) or isinstance(max_distance, int): + raise ValueError("max_distance must be a value between 0.0 and 2.0") + if max_distance < 0.0 or max_distance > 2.0: + raise ValueError("max_distance must be a value between 0.0 and 2.0") + return max_distance + @model_validator(mode="after") def validate_vector(self) -> Self: """If the vector passed in is an array of float convert it to a byte string.""" @@ -361,21 +371,23 @@ def _build_query_string(self) -> str: # base KNN query range_queries = [] - for i, (vector, field) in enumerate( - [(v.vector, v.field_name) for v in self._vectors] + for i, (vector, field, max_dist) in enumerate( + [(v.vector, v.field_name, v.max_distance) for v in self._vectors] ): range_queries.append( - f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}" + f"@{field}:[VECTOR_RANGE {max_dist} $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}" ) - range_query = " | ".join(range_queries) + range_query = " AND ".join(range_queries) filter_expression = self._filter_expression if isinstance(self._filter_expression, FilterExpression): filter_expression = str(self._filter_expression) if filter_expression: - return f"({range_query}) AND ({filter_expression})" + return ( + f"({range_query}) AND ({filter_expression})" + ) else: return f"{range_query}" diff --git a/tests/integration/test_aggregation.py b/tests/integration/test_aggregation.py index 1ee7a2a1..ca66b78f 100644 --- a/tests/integration/test_aggregation.py +++ b/tests/integration/test_aggregation.py @@ -675,6 +675,44 @@ def test_multivector_query_datatypes(index): ) # allow for small floating point error +###### +@pytest.mark.paramatrize(distances_and_results, [[( , ) 5],[( , ) 5], [( , ) 5], [( , ) 5]]) +@pytest.mark.paramatrize(distances, [ ] ) +def test_multivector_query_max_distances(index): + skip_if_redis_version_below(index.client, "7.2.0") + + vector_vals = [[0.1, 0.2, 0.5], [1.2, 0.3, -0.4, 0.7, 0.2]] + vector_fields = ["user_embedding", "image_embedding"] + distances = [1.0947, 0.19] + distances = [2.0, 1.0019] + return_fields = [ + "distance_0", + "distance_1", + "score_0", + "score_1", + "user_embedding", + "image_embedding", + ] + + vectors = [] + for vector, field, distance in zip(vector_vals, vector_fields, distances): + vectors.append(Vector(vector=vector, field_name=field, max_distance=distance)) + + multi_query = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + print(multi_query) #### + results = index.query(multi_query) + + # verify we're filtering vectors based on max_distances + for i in range(len(results)): + print(results[i]) + assert float(results[i][f"distance_0"]) <= distances[0] + assert float(results[i][f"distance_1"]) <= distances[1] + assert False + +###### def test_multivector_query_mixed_index(index): # test that we can do multi vector queries on indices with both a 'flat' and 'hnsw' index skip_if_redis_version_below(index.client, "7.2.0") diff --git a/tests/integration/test_redis_cluster_support.py b/tests/integration/test_redis_cluster_support.py index 80b82420..0d18dea3 100644 --- a/tests/integration/test_redis_cluster_support.py +++ b/tests/integration/test_redis_cluster_support.py @@ -89,6 +89,7 @@ def test_search_index_cluster_info(redis_cluster_url): finally: index.delete(drop=True) + @pytest.mark.requires_cluster @pytest.mark.asyncio async def test_async_search_index_cluster_info(redis_cluster_url): @@ -110,6 +111,7 @@ async def test_async_search_index_cluster_info(redis_cluster_url): await index.delete(drop=True) await client.aclose() + @pytest.mark.requires_cluster @pytest.mark.asyncio async def test_async_search_index_client(redis_cluster_url): diff --git a/tests/integration/test_search_index.py b/tests/integration/test_search_index.py index ae64a229..ebfedbe7 100644 --- a/tests/integration/test_search_index.py +++ b/tests/integration/test_search_index.py @@ -304,6 +304,7 @@ def test_search_index_delete(index): assert not index.exists() assert index.name not in convert_bytes(index.client.execute_command("FT._LIST")) + @pytest.mark.parametrize("num_docs", [0, 1, 5, 10, 2042]) def test_search_index_clear(index, num_docs): index.create(overwrite=True, drop=True) diff --git a/tests/unit/test_aggregation_types.py b/tests/unit/test_aggregation_types.py index 503674f9..87780e46 100644 --- a/tests/unit/test_aggregation_types.py +++ b/tests/unit/test_aggregation_types.py @@ -315,6 +315,7 @@ def test_multi_vector_query(): assert multivector_query._vectors[0].field_name == "field_1" assert multivector_query._vectors[0].weight == 1.0 assert multivector_query._vectors[0].dtype == "float32" + assert multivector_query._vectors[0].max_distance == 2.0 assert multivector_query._filter_expression == None assert multivector_query._num_results == 10 assert multivector_query._loadfields == [] @@ -325,10 +326,21 @@ def test_multi_vector_query(): vector_field_names = ["field_1", "field_2", "field_3", "field_4"] weights = [0.2, 0.5, 0.6, 0.1] dtypes = ["float32", "float32", "float32", "float32"] + distances = [2.0, 1.5, 0.4, 0.01] args = [] - for vec, field, weight, dtype in zip(vectors, vector_field_names, weights, dtypes): - args.append(Vector(vector=vec, field_name=field, weight=weight, dtype=dtype)) + for vec, field, weight, dtype, distance in zip( + vectors, vector_field_names, weights, dtypes, distances + ): + args.append( + Vector( + vector=vec, + field_name=field, + weight=weight, + dtype=dtype, + max_distance=distance, + ) + ) multivector_query = MultiVectorQuery(vectors=args) @@ -358,16 +370,28 @@ def test_multi_vector_query_string(): field_2 = "image embedding" weight_1 = 0.2 weight_2 = 0.7 + max_distance_1 = 0.7 + max_distance_2 = 1.8 multi_vector_query = MultiVectorQuery( vectors=[ - Vector(vector=sample_vector_2, field_name=field_1, weight=weight_1), - Vector(vector=sample_vector_3, field_name=field_2, weight=weight_2), + Vector( + vector=sample_vector_2, + field_name=field_1, + weight=weight_1, + max_distance=max_distance_1, + ), + Vector( + vector=sample_vector_3, + field_name=field_2, + weight=weight_2, + max_distance=max_distance_2, + ), ] ) assert ( str(multi_vector_query) - == f"@{field_1}:[VECTOR_RANGE 2.0 $vector_0]=>{{$YIELD_DISTANCE_AS: distance_0}} | @{field_2}:[VECTOR_RANGE 2.0 $vector_1]=>{{$YIELD_DISTANCE_AS: distance_1}} SCORER TFIDF DIALECT 2 APPLY (2 - @distance_0)/2 AS score_0 APPLY (2 - @distance_1)/2 AS score_1 APPLY @score_0 * {weight_1} + @score_1 * {weight_2} AS combined_score SORTBY 2 @combined_score DESC MAX 10" + == f"@{field_1}:[VECTOR_RANGE {max_distance_1} $vector_0]=>{{$YIELD_DISTANCE_AS: distance_0}} | @{field_2}:[VECTOR_RANGE {max_distance_2} $vector_1]=>{{$YIELD_DISTANCE_AS: distance_1}} SCORER TFIDF DIALECT 2 APPLY (2 - @distance_0)/2 AS score_0 APPLY (2 - @distance_1)/2 AS score_1 APPLY @score_0 * {weight_1} + @score_1 * {weight_2} AS combined_score SORTBY 2 @combined_score DESC MAX 10" ) @@ -411,6 +435,13 @@ def test_vector_object_validation(): vec = Vector(vector=sample_vector, field_name="text embedding", dtype=dtype) assert isinstance(vec, Vector) + # max_distance is bounded to [0, 2.0] + for distance in [-0.1, 2.001, 35, -float("inf"), +float("inf")]: + with pytest.raises(ValueError): + vec = Vector( + vector=sample_vector, field_name="text embedding", max_distance=distance + ) + def test_vector_object_handles_byte_conversion(): # test that passing an array of floats gets converted to bytes From 45898dfdcc169842729ddefa84e6daef1f1348c0 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Thu, 19 Feb 2026 14:34:02 -0800 Subject: [PATCH 2/6] adds max_distance as optional multivector query setting --- redisvl/query/aggregate.py | 4 +-- tests/integration/test_aggregation.py | 43 +++++++++++++++++++-------- tests/unit/test_aggregation_types.py | 2 +- 3 files changed, 33 insertions(+), 16 deletions(-) diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index 5f877976..f1d5fb5f 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -385,9 +385,7 @@ def _build_query_string(self) -> str: filter_expression = str(self._filter_expression) if filter_expression: - return ( - f"({range_query}) AND ({filter_expression})" - ) + return f"({range_query}) AND ({filter_expression})" else: return f"{range_query}" diff --git a/tests/integration/test_aggregation.py b/tests/integration/test_aggregation.py index ca66b78f..fe26839b 100644 --- a/tests/integration/test_aggregation.py +++ b/tests/integration/test_aggregation.py @@ -675,16 +675,24 @@ def test_multivector_query_datatypes(index): ) # allow for small floating point error -###### -@pytest.mark.paramatrize(distances_and_results, [[( , ) 5],[( , ) 5], [( , ) 5], [( , ) 5]]) -@pytest.mark.paramatrize(distances, [ ] ) -def test_multivector_query_max_distances(index): +# paramatrized format is ((max_distance_1, max_distance_2), expected_num_results) +@pytest.mark.parametrize( + "distances_and_results", + [ + [(0.2, 0.2), 0], + [(0.9, 0.2), 0], + [(0.35, 0.5), 1], + [(0.2, 0.9), 2], + [(0.3, 1.0), 3], + [(1.3, 1.9), 6], + ], +) +def test_multivector_query_max_distances(index, distances_and_results): skip_if_redis_version_below(index.client, "7.2.0") vector_vals = [[0.1, 0.2, 0.5], [1.2, 0.3, -0.4, 0.7, 0.2]] vector_fields = ["user_embedding", "image_embedding"] - distances = [1.0947, 0.19] - distances = [2.0, 1.0019] + distances, num_results = distances_and_results return_fields = [ "distance_0", "distance_1", @@ -701,18 +709,29 @@ def test_multivector_query_max_distances(index): multi_query = MultiVectorQuery( vectors=vectors, return_fields=return_fields, + num_results=10, ) - print(multi_query) #### results = index.query(multi_query) + # verify we get the right number of total results + assert len(results) == num_results + # verify we're filtering vectors based on max_distances for i in range(len(results)): - print(results[i]) - assert float(results[i][f"distance_0"]) <= distances[0] - assert float(results[i][f"distance_1"]) <= distances[1] - assert False + assert float(results[i]["distance_0"]) <= distances[0] + assert float(results[i]["distance_1"]) <= distances[1] + + # check we're indeed filtering on both distances and not just the lesser of the two + if results: + first_distances = [float(result["distance_0"]) for result in results] + second_distances = [float(result["distance_1"]) for result in results] + + # this test only applies for our specific test case values + assert (max(first_distances) > distances[1]) or ( + max(second_distances) > distances[0] + ) + -###### def test_multivector_query_mixed_index(index): # test that we can do multi vector queries on indices with both a 'flat' and 'hnsw' index skip_if_redis_version_below(index.client, "7.2.0") diff --git a/tests/unit/test_aggregation_types.py b/tests/unit/test_aggregation_types.py index 87780e46..c85cedea 100644 --- a/tests/unit/test_aggregation_types.py +++ b/tests/unit/test_aggregation_types.py @@ -391,7 +391,7 @@ def test_multi_vector_query_string(): assert ( str(multi_vector_query) - == f"@{field_1}:[VECTOR_RANGE {max_distance_1} $vector_0]=>{{$YIELD_DISTANCE_AS: distance_0}} | @{field_2}:[VECTOR_RANGE {max_distance_2} $vector_1]=>{{$YIELD_DISTANCE_AS: distance_1}} SCORER TFIDF DIALECT 2 APPLY (2 - @distance_0)/2 AS score_0 APPLY (2 - @distance_1)/2 AS score_1 APPLY @score_0 * {weight_1} + @score_1 * {weight_2} AS combined_score SORTBY 2 @combined_score DESC MAX 10" + == f"@{field_1}:[VECTOR_RANGE {max_distance_1} $vector_0]=>{{$YIELD_DISTANCE_AS: distance_0}} AND @{field_2}:[VECTOR_RANGE {max_distance_2} $vector_1]=>{{$YIELD_DISTANCE_AS: distance_1}} SCORER TFIDF DIALECT 2 APPLY (2 - @distance_0)/2 AS score_0 APPLY (2 - @distance_1)/2 AS score_1 APPLY @score_0 * {weight_1} + @score_1 * {weight_2} AS combined_score SORTBY 2 @combined_score DESC MAX 10" ) From 6d239f9677885e727ca4ad40ff3c85916c7dc6f8 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek <165097110+justin-cechmanek@users.noreply.github.com> Date: Thu, 19 Feb 2026 15:00:32 -0800 Subject: [PATCH 3/6] simplify max_distance validation Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- redisvl/query/aggregate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index f1d5fb5f..9c1ef2cd 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -40,7 +40,7 @@ def validate_dtype(cls, dtype: str) -> str: @field_validator("max_distance") @classmethod def validate_max_distance(cls, max_distance: float) -> float: - if not isinstance(max_distance, float) or isinstance(max_distance, int): + if not isinstance(max_distance, (float, int)): raise ValueError("max_distance must be a value between 0.0 and 2.0") if max_distance < 0.0 or max_distance > 2.0: raise ValueError("max_distance must be a value between 0.0 and 2.0") From 3624f427010a4253aec7b6078eea7d51687ccc57 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Thu, 19 Feb 2026 15:12:52 -0800 Subject: [PATCH 4/6] adds docstring to Vector data class --- redisvl/query/aggregate.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index 9c1ef2cd..bc0ec91f 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -18,6 +18,13 @@ class Vector(BaseModel): """ Simple object containing the necessary arguments to perform a multi vector query. + + Args: + vector: The vector values as a list of floats or bytes + field_name: The name of the vector field to search + dtype: The data type of the vector (default: "float32") + weight: The weight for this vector in the combined score (default: 1.0) + max_distance: The maximum distance for vector range search (default: 2.0, range: [0.0, 2.0]) """ vector: Union[List[float], bytes] From 618c31ded7bb9128535cec0f8b74eac39fa85685 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek <165097110+justin-cechmanek@users.noreply.github.com> Date: Fri, 20 Feb 2026 17:17:04 -0800 Subject: [PATCH 5/6] simplify max_distance validation Co-authored-by: Vishal Bala --- redisvl/query/aggregate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index bc0ec91f..3b429072 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -31,7 +31,7 @@ class Vector(BaseModel): field_name: str dtype: str = "float32" weight: float = 1.0 - max_distance: float = 2.0 + max_distance: float = Field(default=2.0, ge=0.0, le=2.0) @field_validator("dtype") @classmethod From f5b47a3579b7d355ee5190ef520095c9b2bb9925 Mon Sep 17 00:00:00 2001 From: Justin Cechmanek Date: Fri, 20 Feb 2026 17:23:02 -0800 Subject: [PATCH 6/6] imports pydantic Field --- redisvl/query/aggregate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index 3b429072..d852961a 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -1,7 +1,7 @@ import warnings from typing import Any, Dict, List, Optional, Set, Union -from pydantic import BaseModel, field_validator, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from redis.commands.search.aggregation import AggregateRequest, Desc from typing_extensions import Self