From 899d7cd82e0efad782df874940200d44c474879e Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Mon, 16 Mar 2026 14:27:11 -0700 Subject: [PATCH 1/6] feat(query): make max_distance configurable per-vector in MultiVectorQuery Add maxDistance field to Vector class (default 2.0, validated 0.0-2.0) replacing the hardcoded DISTANCE_THRESHOLD constant. Each vector in a multi-vector query can now specify its own distance threshold for VECTOR_RANGE. Also changes the range query join operator from "|" to "AND" to match upstream Python behavior. Port of redis-vl-python commit 61f0f41. --- .../com/redis/vl/query/MultiVectorQuery.java | 15 +- .../main/java/com/redis/vl/query/Vector.java | 26 +++- .../MultiVectorQueryMaxDistanceTest.java | 144 ++++++++++++++++++ .../redis/vl/query/MultiVectorQueryTest.java | 2 +- 4 files changed, 175 insertions(+), 12 deletions(-) create mode 100644 core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java diff --git a/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java index e778995..f61ffac 100644 --- a/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java +++ b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java @@ -69,9 +69,6 @@ @Getter public final class MultiVectorQuery extends AggregationQuery { - /** Distance threshold for VECTOR_RANGE (hardcoded at 2.0 to include all eligible documents) */ - private static final double DISTANCE_THRESHOLD = 2.0; - private final List vectors; private final Filter filterExpression; private final List returnFields; @@ -111,8 +108,8 @@ public static Builder builder() { /** * Build the Redis query string for multi-vector search. * - *

Format: {@code @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} | - * @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}} + *

Format: {@code @field1:[VECTOR_RANGE max_dist $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} AND + * @field2:[VECTOR_RANGE max_dist $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}} * * @return Query string */ @@ -123,8 +120,8 @@ public String toQueryString() { /** * Build the Redis query string for multi-vector search. * - *

Format: {@code @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} | - * @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}} + *

Format: {@code @field1:[VECTOR_RANGE max_dist $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} AND + * @field2:[VECTOR_RANGE max_dist $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}} * * @return Query string */ @@ -137,11 +134,11 @@ public String buildQueryString() { String rangeQuery = String.format( "@%s:[VECTOR_RANGE %.1f $vector_%d]=>{$YIELD_DISTANCE_AS: distance_%d}", - v.getFieldName(), DISTANCE_THRESHOLD, i, i); + v.getFieldName(), v.getMaxDistance(), i, i); rangeQueries.add(rangeQuery); } - String baseQuery = String.join(" | ", rangeQueries); + String baseQuery = String.join(" AND ", rangeQueries); // Add filter expression if present if (filterExpression != null) { diff --git a/core/src/main/java/com/redis/vl/query/Vector.java b/core/src/main/java/com/redis/vl/query/Vector.java index 599201e..9304546 100644 --- a/core/src/main/java/com/redis/vl/query/Vector.java +++ b/core/src/main/java/com/redis/vl/query/Vector.java @@ -57,6 +57,7 @@ public final class Vector { private final String fieldName; private final String dtype; private final double weight; + private final double maxDistance; private Vector(Builder builder) { // Validate before modifying state @@ -74,11 +75,16 @@ private Vector(Builder builder) { if (builder.weight <= 0) { throw new IllegalArgumentException("Weight must be positive, got " + builder.weight); } + if (builder.maxDistance < 0.0 || builder.maxDistance > 2.0) { + throw new IllegalArgumentException( + "max_distance must be a value between 0.0 and 2.0, got " + builder.maxDistance); + } this.vector = Arrays.copyOf(builder.vector, builder.vector.length); this.fieldName = builder.fieldName.trim(); this.dtype = builder.dtype; this.weight = builder.weight; + this.maxDistance = builder.maxDistance; } /** @@ -105,6 +111,7 @@ public static class Builder { private String fieldName; private String dtype = "float32"; // Default from Python private double weight = 1.0; // Default from Python + private double maxDistance = 2.0; // Default from Python Builder() {} @@ -156,6 +163,19 @@ public Builder weight(double weight) { return this; } + /** + * Set the maximum distance threshold for this vector in multi-vector range queries. + * + *

Must be between 0.0 and 2.0 (inclusive). Default is 2.0. + * + * @param maxDistance Maximum distance threshold + * @return This builder + */ + public Builder maxDistance(double maxDistance) { + this.maxDistance = maxDistance; + return this; + } + /** * Build the Vector instance. * @@ -174,6 +194,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Vector vector1 = (Vector) o; return Double.compare(vector1.weight, weight) == 0 + && Double.compare(vector1.maxDistance, maxDistance) == 0 && Arrays.equals(vector, vector1.vector) && fieldName.equals(vector1.fieldName) && dtype.equals(vector1.dtype); @@ -185,13 +206,14 @@ public int hashCode() { result = 31 * result + fieldName.hashCode(); result = 31 * result + dtype.hashCode(); result = 31 * result + Double.hashCode(weight); + result = 31 * result + Double.hashCode(maxDistance); return result; } @Override public String toString() { return String.format( - "Vector[fieldName=%s, dtype=%s, weight=%.2f, dimensions=%d]", - fieldName, dtype, weight, vector.length); + "Vector[fieldName=%s, dtype=%s, weight=%.2f, maxDistance=%.1f, dimensions=%d]", + fieldName, dtype, weight, maxDistance, vector.length); } } diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java new file mode 100644 index 0000000..477d5e6 --- /dev/null +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java @@ -0,0 +1,144 @@ +package com.redis.vl.query; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for configurable max_distance in MultiVectorQuery and Vector. + * + *

Ported from Python: commit 61f0f41 - Expose max_distance as an optional setting in multi + * vector queries + * + *

Python reference: tests/unit/test_aggregation_types.py + */ +@DisplayName("MultiVectorQuery max_distance Tests") +class MultiVectorQueryMaxDistanceTest { + + private static final float[] SAMPLE_VECTOR = {0.1f, 0.2f, 0.3f}; + private static final float[] SAMPLE_VECTOR_2 = {0.4f, 0.5f}; + + @Test + @DisplayName("Vector: Should default max_distance to 2.0") + void testVectorDefaultMaxDistance() { + Vector vector = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").build(); + + assertThat(vector.getMaxDistance()).isEqualTo(2.0); + } + + @Test + @DisplayName("Vector: Should accept custom max_distance") + void testVectorCustomMaxDistance() { + Vector vector = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(0.5).build(); + + assertThat(vector.getMaxDistance()).isEqualTo(0.5); + } + + @Test + @DisplayName("Vector: Should reject max_distance below 0.0") + void testVectorMaxDistanceBelowZero() { + assertThatThrownBy( + () -> + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(-0.1).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("max_distance must be a value between 0.0 and 2.0"); + } + + @Test + @DisplayName("Vector: Should reject max_distance above 2.0") + void testVectorMaxDistanceAboveTwo() { + assertThatThrownBy( + () -> + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(2.1).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("max_distance must be a value between 0.0 and 2.0"); + } + + @Test + @DisplayName("Vector: Should accept max_distance at boundary 0.0") + void testVectorMaxDistanceAtZero() { + Vector vector = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(0.0).build(); + + assertThat(vector.getMaxDistance()).isEqualTo(0.0); + } + + @Test + @DisplayName("Vector: Should accept max_distance at boundary 2.0") + void testVectorMaxDistanceAtTwo() { + Vector vector = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(2.0).build(); + + assertThat(vector.getMaxDistance()).isEqualTo(2.0); + } + + @Test + @DisplayName("MultiVectorQuery: Should use per-vector max_distance in query string") + void testQueryStringUsesPerVectorMaxDistance() { + Vector vector1 = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").maxDistance(0.5).build(); + + Vector vector2 = + Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").maxDistance(1.0).build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String queryString = query.toQueryString(); + + // Each vector should use its own max_distance + assertThat(queryString).contains("@field_1:[VECTOR_RANGE 0.5 $vector_0]"); + assertThat(queryString).contains("@field_2:[VECTOR_RANGE 1.0 $vector_1]"); + } + + @Test + @DisplayName("MultiVectorQuery: Should default to 2.0 when max_distance not set") + void testQueryStringDefaultMaxDistance() { + Vector vector1 = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vector(vector1).build(); + + String queryString = query.toQueryString(); + + assertThat(queryString).contains("VECTOR_RANGE 2.0"); + } + + @Test + @DisplayName("MultiVectorQuery: Should use AND join instead of pipe") + void testQueryStringUsesAndJoin() { + Vector vector1 = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").maxDistance(0.5).build(); + + Vector vector2 = + Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").maxDistance(1.0).build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String queryString = query.toQueryString(); + + // Python changed from " | " to " AND " in commit 61f0f41 + assertThat(queryString).doesNotContain(" | "); + assertThat(queryString).contains(" AND "); + } + + @Test + @DisplayName("MultiVectorQuery: Query string with mixed default and custom max_distance") + void testQueryStringMixedMaxDistance() { + Vector vector1 = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build(); // default + + Vector vector2 = + Vector.builder() + .vector(SAMPLE_VECTOR_2) + .fieldName("field_2") + .maxDistance(0.3) + .build(); // custom + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String queryString = query.toQueryString(); + + assertThat(queryString).contains("@field_1:[VECTOR_RANGE 2.0 $vector_0]"); + assertThat(queryString).contains("@field_2:[VECTOR_RANGE 0.3 $vector_1]"); + } +} diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java index 6e5ffbb..add1004 100644 --- a/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java @@ -257,7 +257,7 @@ void testMultiVectorQueryString() { .contains("{$YIELD_DISTANCE_AS: distance_0}") .contains(String.format("@%s:[VECTOR_RANGE 2.0 $vector_1]", field2)) .contains("{$YIELD_DISTANCE_AS: distance_1}") - .contains(" | "); + .contains(" AND "); } @Test From 09968fd8c301bfdff975c6f95a556b33c460a535 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Mon, 16 Mar 2026 14:28:12 -0700 Subject: [PATCH 2/6] feat(messagehistory): add ChatRole enum and count() method Add ChatRole enum (USER, ASSISTANT, SYSTEM, TOOL) as single source of truth for valid message roles. The deprecated "llm" role is still accepted with a warning for backward compatibility. Update BaseMessageHistory.validateRoles() to use ChatRole. Add count() and count(sessionTag) methods to MessageHistory and SemanticMessageHistory using CountQuery, allowing callers to get the number of stored messages without fetching them all. Port of redis-vl-python commits 23ecc77 and e7301a2. --- .../messagehistory/BaseMessageHistory.java | 56 ++++++++++--- .../extensions/messagehistory/ChatRole.java | 79 +++++++++++++++++ .../messagehistory/MessageHistory.java | 19 +++++ .../SemanticMessageHistory.java | 21 +++++ .../messagehistory/ChatRoleTest.java | 82 ++++++++++++++++++ .../MessageHistoryCountTest.java | 84 +++++++++++++++++++ .../messagehistory/RoleFilteringTest.java | 12 ++- 7 files changed, 341 insertions(+), 12 deletions(-) create mode 100644 core/src/main/java/com/redis/vl/extensions/messagehistory/ChatRole.java create mode 100644 core/src/test/java/com/redis/vl/extensions/messagehistory/ChatRoleTest.java create mode 100644 core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountTest.java diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java index 194674e..add4521 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java @@ -7,7 +7,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; +import java.util.logging.Logger; /** * Base class for message history implementations. @@ -16,8 +16,7 @@ */ public abstract class BaseMessageHistory { - /** Valid role values for message filtering. */ - private static final Set VALID_ROLES = Set.of("system", "user", "llm", "tool"); + private static final Logger logger = Logger.getLogger(BaseMessageHistory.class.getName()); protected final String name; protected final String sessionTag; @@ -34,6 +33,25 @@ protected BaseMessageHistory(String name, String sessionTag) { this.sessionTag = (sessionTag != null) ? sessionTag : UlidCreator.getUlid().toString(); } + /** + * Count the number of messages in the conversation history. + * + *

Matches Python count() from base_history.py + * + * @return The number of messages for the default session + */ + public abstract long count(); + + /** + * Count the number of messages in the conversation history for a specific session. + * + *

Matches Python count(session_tag=...) from base_history.py + * + * @param sessionTag The session tag to count messages for (null uses default session) + * @return The number of messages + */ + public abstract long count(String sessionTag); + /** Clears the chat message history. */ public abstract void clear(); @@ -143,10 +161,7 @@ protected List validateRoles(Object role) { // Handle single role string if (role instanceof String) { String roleStr = (String) role; - if (!VALID_ROLES.contains(roleStr)) { - throw new IllegalArgumentException( - String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES)); - } + validateSingleRole(roleStr); return List.of(roleStr); } @@ -166,10 +181,7 @@ protected List validateRoles(Object role) { "role list must contain only strings, found: " + r.getClass().getSimpleName()); } String roleStr = (String) r; - if (!VALID_ROLES.contains(roleStr)) { - throw new IllegalArgumentException( - String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES)); - } + validateSingleRole(roleStr); validatedRoles.add(roleStr); } @@ -179,6 +191,28 @@ protected List validateRoles(Object role) { throw new IllegalArgumentException("role must be a String, List, or null"); } + /** + * Validate a single role string using ChatRole enum. + * + * @param roleStr The role string to validate + * @throws IllegalArgumentException if the role is not valid + */ + private void validateSingleRole(String roleStr) { + if (ChatRole.isDeprecatedRole(roleStr)) { + logger.warning( + String.format( + "Role '%s' is a deprecated value. Update to valid roles: %s.", + roleStr, java.util.Arrays.toString(ChatRole.values()))); + return; + } + if (!ChatRole.isValidRole(roleStr)) { + throw new IllegalArgumentException( + String.format( + "Invalid role '%s'. Valid roles are: %s", + roleStr, java.util.Arrays.toString(ChatRole.values()))); + } + } + public String getName() { return name; } diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/ChatRole.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/ChatRole.java new file mode 100644 index 0000000..59351f2 --- /dev/null +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/ChatRole.java @@ -0,0 +1,79 @@ +package com.redis.vl.extensions.messagehistory; + +import java.util.Set; + +/** + * Enumeration of valid chat message roles. + * + *

Ported from Python: redisvl/extensions/message_history/schema.py (commit 23ecc77) + * + *

This enum serves as the single source of truth for valid roles in chat message history. The + * deprecated "llm" role is accepted for backward compatibility but is not a member of this enum. + */ +public enum ChatRole { + USER("user"), + ASSISTANT("assistant"), + SYSTEM("system"), + TOOL("tool"); + + /** Deprecated role values that are accepted for backward compatibility. */ + private static final Set DEPRECATED_ROLES = Set.of("llm"); + + private final String value; + + ChatRole(String value) { + this.value = value; + } + + /** + * Get the string value of this role. + * + * @return The role string value + */ + public String getValue() { + return value; + } + + /** + * Coerce a string to a ChatRole enum value. + * + * @param role The role string + * @return The matching ChatRole, or null if not a standard role + */ + public static ChatRole fromString(String role) { + if (role == null || role.isEmpty()) { + return null; + } + for (ChatRole chatRole : values()) { + if (chatRole.value.equals(role)) { + return chatRole; + } + } + return null; + } + + /** + * Check if a string is a valid role (including deprecated roles). + * + * @param role The role string to check + * @return true if the role is valid (either standard or deprecated) + */ + public static boolean isValidRole(String role) { + return fromString(role) != null || isDeprecatedRole(role); + } + + /** + * Check if a string is a deprecated role value. + * + * @param role The role string to check + * @return true if the role is deprecated + */ + public static boolean isDeprecatedRole(String role) { + return DEPRECATED_ROLES.contains(role); + } + + @Override + public String toString() { + return value; + } +} diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java index a88750e..4a495a1 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java @@ -3,6 +3,7 @@ import static com.redis.vl.extensions.Constants.*; import com.redis.vl.index.SearchIndex; +import com.redis.vl.query.CountQuery; import com.redis.vl.query.Filter; import com.redis.vl.query.FilterQuery; import com.redis.vl.schema.IndexSchema; @@ -64,6 +65,19 @@ public MessageHistory(String name, String sessionTag, String prefix, UnifiedJedi this.defaultSessionFilter = Filter.tag(SESSION_FIELD_NAME, this.sessionTag); } + @Override + public long count() { + return count(null); + } + + @Override + public long count(String sessionTag) { + Filter sessionFilter = + (sessionTag != null) ? Filter.tag(SESSION_FIELD_NAME, sessionTag) : defaultSessionFilter; + CountQuery query = new CountQuery(sessionFilter); + return index.count(query); + } + @Override public void clear() { index.clear(); @@ -269,4 +283,9 @@ public void addMessages(List> messages) { public SearchIndex getIndex() { return index; } + + @Override + public String toString() { + return String.format("MessageHistory(name='%s', session_tag='%s')", name, sessionTag); + } } diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java index b75a3f1..4c7c918 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java @@ -3,6 +3,7 @@ import static com.redis.vl.extensions.Constants.*; import com.redis.vl.index.SearchIndex; +import com.redis.vl.query.CountQuery; import com.redis.vl.query.Filter; import com.redis.vl.query.FilterQuery; import com.redis.vl.query.VectorQuery; @@ -152,6 +153,19 @@ public SearchIndex getIndex() { return index; } + @Override + public long count() { + return count(null); + } + + @Override + public long count(String sessionTag) { + Filter sessionFilter = + (sessionTag != null) ? Filter.tag(SESSION_FIELD_NAME, sessionTag) : defaultSessionFilter; + CountQuery query = new CountQuery(sessionFilter); + return index.count(query); + } + @Override public void clear() { index.clear(); @@ -490,6 +504,13 @@ private Filter combineWithRoleFilter(Filter sessionFilter, List roles) { } } + @Override + public String toString() { + return String.format( + "SemanticMessageHistory(name='%s', session_tag='%s', distance_threshold=%s)", + name, sessionTag, distanceThreshold); + } + /** * Format messages with metadata deserialization support. * diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/ChatRoleTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/ChatRoleTest.java new file mode 100644 index 0000000..124372a --- /dev/null +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/ChatRoleTest.java @@ -0,0 +1,82 @@ +package com.redis.vl.extensions.messagehistory; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for ChatRole enum. + * + *

Ported from Python: commit 23ecc77 - Add ChatRole enum for message history role validation + * + *

Python reference: tests/unit/test_message_history_schema.py + */ +@DisplayName("ChatRole Enum Tests") +class ChatRoleTest { + + @Test + @DisplayName("Should have all required role values") + void testAllRolesExist() { + assertThat(ChatRole.values()).hasSize(4); + assertThat(ChatRole.USER.getValue()).isEqualTo("user"); + assertThat(ChatRole.ASSISTANT.getValue()).isEqualTo("assistant"); + assertThat(ChatRole.SYSTEM.getValue()).isEqualTo("system"); + assertThat(ChatRole.TOOL.getValue()).isEqualTo("tool"); + } + + @Test + @DisplayName("Should coerce valid string to ChatRole") + void testCoerceValidString() { + assertThat(ChatRole.fromString("user")).isEqualTo(ChatRole.USER); + assertThat(ChatRole.fromString("assistant")).isEqualTo(ChatRole.ASSISTANT); + assertThat(ChatRole.fromString("system")).isEqualTo(ChatRole.SYSTEM); + assertThat(ChatRole.fromString("tool")).isEqualTo(ChatRole.TOOL); + } + + @Test + @DisplayName("Should accept deprecated 'llm' role with warning") + void testDeprecatedLlmRole() { + // In Python, 'llm' is accepted with a deprecation warning + // In Java, we accept it and map it to the string value + assertThat(ChatRole.fromString("llm")).isNull(); + assertThat(ChatRole.isDeprecatedRole("llm")).isTrue(); + } + + @Test + @DisplayName("Should return null for unrecognized role") + void testUnrecognizedRole() { + assertThat(ChatRole.fromString("potato")).isNull(); + assertThat(ChatRole.fromString("admin")).isNull(); + assertThat(ChatRole.fromString("")).isNull(); + } + + @Test + @DisplayName("Should be case-sensitive") + void testCaseSensitive() { + assertThat(ChatRole.fromString("User")).isNull(); + assertThat(ChatRole.fromString("SYSTEM")).isNull(); + assertThat(ChatRole.fromString("TOOL")).isNull(); + } + + @Test + @DisplayName("Should check validity including deprecated roles") + void testIsValidRole() { + assertThat(ChatRole.isValidRole("user")).isTrue(); + assertThat(ChatRole.isValidRole("assistant")).isTrue(); + assertThat(ChatRole.isValidRole("system")).isTrue(); + assertThat(ChatRole.isValidRole("tool")).isTrue(); + assertThat(ChatRole.isValidRole("llm")).isTrue(); // deprecated but valid + assertThat(ChatRole.isValidRole("potato")).isFalse(); + assertThat(ChatRole.isValidRole("User")).isFalse(); + } + + @Test + @DisplayName("toString should return the string value") + void testToString() { + assertThat(ChatRole.USER.toString()).isEqualTo("user"); + assertThat(ChatRole.ASSISTANT.toString()).isEqualTo("assistant"); + assertThat(ChatRole.SYSTEM.toString()).isEqualTo("system"); + assertThat(ChatRole.TOOL.toString()).isEqualTo("tool"); + } +} diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountTest.java new file mode 100644 index 0000000..5c1c138 --- /dev/null +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountTest.java @@ -0,0 +1,84 @@ +package com.redis.vl.extensions.messagehistory; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.search.FTSearchParams; +import redis.clients.jedis.search.SearchResult; + +/** + * Unit tests for count() method in MessageHistory. Ported from Python: + * tests/integration/test_message_history.py (test_standard_count, test_semantic_count) + */ +class MessageHistoryCountTest { + + private UnifiedJedis mockJedis; + + @SuppressWarnings("unchecked") + @BeforeEach + void setUp() { + mockJedis = mock(UnifiedJedis.class); + // Mock ftCreate to not fail + when(mockJedis.ftCreate(anyString(), any(), any(Iterable.class))).thenReturn("OK"); + } + + @Test + void testCountReturnsZeroWhenEmpty() { + // Mock ftSearch for count query to return 0 results + SearchResult emptyResult = mock(SearchResult.class); + when(emptyResult.getTotalResults()).thenReturn(0L); + when(mockJedis.ftSearch(anyString(), anyString(), any(FTSearchParams.class))) + .thenReturn(emptyResult); + + MessageHistory history = new MessageHistory("test_count", mockJedis); + assertEquals(0, history.count()); + } + + @Test + void testCountWithSessionTag() { + SearchResult result = mock(SearchResult.class); + when(result.getTotalResults()).thenReturn(4L); + when(mockJedis.ftSearch(anyString(), anyString(), any(FTSearchParams.class))) + .thenReturn(result); + + MessageHistory history = new MessageHistory("test_count", "session1", null, mockJedis); + assertEquals(4, history.count("session1")); + } + + @Test + void testCountDefaultsToInstanceSessionTag() { + SearchResult result = mock(SearchResult.class); + when(result.getTotalResults()).thenReturn(2L); + when(mockJedis.ftSearch(anyString(), anyString(), any(FTSearchParams.class))) + .thenReturn(result); + + MessageHistory history = new MessageHistory("test_count", "my-session", null, mockJedis); + + // count() with no args should use instance session tag + assertEquals(2, history.count()); + } + + @Test + void testCountUsesCountQuery() { + SearchResult result = mock(SearchResult.class); + when(result.getTotalResults()).thenReturn(3L); + when(mockJedis.ftSearch(anyString(), anyString(), any(FTSearchParams.class))) + .thenReturn(result); + + MessageHistory history = new MessageHistory("test_count", "session1", null, mockJedis); + long count = history.count(); + + // Verify ftSearch was called with a filter containing the session tag + verify(mockJedis, atLeastOnce()) + .ftSearch(anyString(), contains("session1"), any(FTSearchParams.class)); + } + + private static String contains(String substring) { + return argThat(arg -> arg != null && arg.contains(substring)); + } +} diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java index 6026b7f..7f8115c 100644 --- a/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java @@ -62,6 +62,16 @@ public void addMessages(List> messages, String ses @Override public void addMessage(java.util.Map message, String sessionTag) {} + + @Override + public long count() { + return 0; + } + + @Override + public long count(String sessionTag) { + return 0; + } } @Test @@ -124,7 +134,7 @@ void testInvalidSingleRole() { .hasMessageContaining("Invalid role 'invalid_role'") .hasMessageContaining("system") .hasMessageContaining("user") - .hasMessageContaining("llm") + .hasMessageContaining("assistant") .hasMessageContaining("tool"); } From 7b4966d33c9c79b709e2c4ac6b506b899ab3f4ab Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Mon, 16 Mar 2026 14:28:54 -0700 Subject: [PATCH 3/6] feat: add toString() to core public classes Add human-readable toString() implementations to SearchIndex, SemanticCache, LangCacheSemanticCache, SemanticRouter, MessageHistory, and SemanticMessageHistory. Each displays key identifying fields (name, prefix, storage type, thresholds, etc.) without exposing sensitive data. Port of redis-vl-python commit e285fed. --- .../cache/LangCacheSemanticCache.java | 5 + .../vl/extensions/cache/SemanticCache.java | 6 + .../vl/extensions/router/SemanticRouter.java | 5 + .../java/com/redis/vl/index/SearchIndex.java | 7 + .../test/java/com/redis/vl/ToStringTest.java | 138 ++++++++++++++++++ 5 files changed, 161 insertions(+) create mode 100644 core/src/test/java/com/redis/vl/ToStringTest.java diff --git a/core/src/main/java/com/redis/vl/extensions/cache/LangCacheSemanticCache.java b/core/src/main/java/com/redis/vl/extensions/cache/LangCacheSemanticCache.java index bb7b802..fdf1124 100644 --- a/core/src/main/java/com/redis/vl/extensions/cache/LangCacheSemanticCache.java +++ b/core/src/main/java/com/redis/vl/extensions/cache/LangCacheSemanticCache.java @@ -596,6 +596,11 @@ public void update(String key, String field, Object value) { + "Delete and re-create the entry instead."); } + @Override + public String toString() { + return String.format("LangCacheSemanticCache(name='%s', ttl=%s)", name, ttl); + } + /** Builder for LangCacheSemanticCache. */ public static class Builder { private String name = "langcache"; diff --git a/core/src/main/java/com/redis/vl/extensions/cache/SemanticCache.java b/core/src/main/java/com/redis/vl/extensions/cache/SemanticCache.java index 1b843fb..abffc37 100644 --- a/core/src/main/java/com/redis/vl/extensions/cache/SemanticCache.java +++ b/core/src/main/java/com/redis/vl/extensions/cache/SemanticCache.java @@ -489,6 +489,12 @@ public void resetStatistics() { missCount.set(0); } + @Override + public String toString() { + return String.format( + "SemanticCache(name='%s', distance_threshold=%s, ttl=%s)", name, distanceThreshold, ttl); + } + /** Builder for SemanticCache. */ public static class Builder { private String name; diff --git a/core/src/main/java/com/redis/vl/extensions/router/SemanticRouter.java b/core/src/main/java/com/redis/vl/extensions/router/SemanticRouter.java index 60001c0..80334a9 100644 --- a/core/src/main/java/com/redis/vl/extensions/router/SemanticRouter.java +++ b/core/src/main/java/com/redis/vl/extensions/router/SemanticRouter.java @@ -148,6 +148,11 @@ public SemanticRouter build() { } } + @Override + public String toString() { + return String.format("SemanticRouter(name='%s', routes=%d)", name, routes.size()); + } + /** * Get the list of route names. Ported from Python: route_names property (line 187) * diff --git a/core/src/main/java/com/redis/vl/index/SearchIndex.java b/core/src/main/java/com/redis/vl/index/SearchIndex.java index 48d6a02..d9594b3 100644 --- a/core/src/main/java/com/redis/vl/index/SearchIndex.java +++ b/core/src/main/java/com/redis/vl/index/SearchIndex.java @@ -1726,6 +1726,13 @@ private List> processHybridResult(HybridResult result) { return processed; } + @Override + public String toString() { + return String.format( + "SearchIndex(name='%s', prefix='%s', storage_type='%s')", + getName(), getPrefix(), getStorageType()); + } + /** * Execute multiple search queries in batch * diff --git a/core/src/test/java/com/redis/vl/ToStringTest.java b/core/src/test/java/com/redis/vl/ToStringTest.java new file mode 100644 index 0000000..69627ce --- /dev/null +++ b/core/src/test/java/com/redis/vl/ToStringTest.java @@ -0,0 +1,138 @@ +package com.redis.vl; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +import com.redis.vl.extensions.cache.LangCacheSemanticCache; +import com.redis.vl.extensions.messagehistory.MessageHistory; +import com.redis.vl.extensions.messagehistory.SemanticMessageHistory; +import com.redis.vl.extensions.router.SemanticRouter; +import com.redis.vl.index.SearchIndex; +import com.redis.vl.schema.IndexSchema; +import com.redis.vl.schema.TagField; +import com.redis.vl.utils.vectorize.BaseVectorizer; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import redis.clients.jedis.UnifiedJedis; + +/** + * Unit tests for toString() methods on core public classes. + * + *

Ported from Python: commit e285fed - Add __repr__ methods + * + *

Python reference: tests/unit/test_repr.py + */ +@DisplayName("toString() Tests") +class ToStringTest { + + @Test + @DisplayName("SearchIndex: toString should include name, prefix, and storage_type") + void testSearchIndexToString() { + IndexSchema schema = + IndexSchema.builder() + .name("test_index") + .prefix("test_prefix") + .storageType(IndexSchema.StorageType.HASH) + .field(new TagField("tag1")) + .build(); + + SearchIndex index = new SearchIndex(schema); + + String repr = index.toString(); + + assertThat(repr).contains("SearchIndex"); + assertThat(repr).contains("test_index"); + assertThat(repr).contains("test_prefix"); + assertThat(repr).contains("HASH"); + } + + @Test + @DisplayName("SearchIndex: toString with JSON storage") + void testSearchIndexToStringJson() { + IndexSchema schema = + IndexSchema.builder() + .name("json_index") + .prefix("json_prefix") + .storageType(IndexSchema.StorageType.JSON) + .field(new TagField("tag1")) + .build(); + + SearchIndex index = new SearchIndex(schema); + + String repr = index.toString(); + + assertThat(repr).contains("json_index"); + assertThat(repr).contains("JSON"); + } + + @SuppressWarnings("unchecked") + @Test + @DisplayName("MessageHistory: toString should include name and session_tag") + void testMessageHistoryToString() { + UnifiedJedis mockJedis = mock(UnifiedJedis.class); + when(mockJedis.ftCreate(anyString(), any(), any(Iterable.class))).thenReturn("OK"); + + MessageHistory history = new MessageHistory("chat_history", "session123", null, mockJedis); + + String repr = history.toString(); + + assertThat(repr).contains("MessageHistory"); + assertThat(repr).contains("chat_history"); + assertThat(repr).contains("session123"); + } + + @SuppressWarnings("unchecked") + @Test + @DisplayName( + "SemanticMessageHistory: toString should include name, session_tag, and distance_threshold") + void testSemanticMessageHistoryToString() { + UnifiedJedis mockJedis = mock(UnifiedJedis.class); + when(mockJedis.ftCreate(anyString(), any(), any(Iterable.class))).thenReturn("OK"); + + BaseVectorizer mockVectorizer = mock(BaseVectorizer.class); + when(mockVectorizer.getDimensions()).thenReturn(384); + when(mockVectorizer.getDataType()).thenReturn("float32"); + + SemanticMessageHistory history = + new SemanticMessageHistory("semantic_chat", "session456", null, mockVectorizer, mockJedis); + + String repr = history.toString(); + + assertThat(repr).contains("SemanticMessageHistory"); + assertThat(repr).contains("semantic_chat"); + assertThat(repr).contains("session456"); + assertThat(repr).contains("0.3"); // default distance threshold + } + + @Test + @DisplayName("SemanticRouter: toString should include name and route count") + void testSemanticRouterToString() { + SemanticRouter router = new SemanticRouter("my_router"); + + String repr = router.toString(); + + assertThat(repr).contains("SemanticRouter"); + assertThat(repr).contains("my_router"); + assertThat(repr).contains("0"); // no routes + } + + @Test + @DisplayName("LangCacheSemanticCache: toString should include name and ttl") + void testLangCacheSemanticCacheToString() { + LangCacheSemanticCache cache = + new LangCacheSemanticCache.Builder() + .name("my_cache") + .cacheId("cache-123") + .apiKey("test-key") + .ttl(3600) + .build(); + + String repr = cache.toString(); + + assertThat(repr).contains("LangCacheSemanticCache"); + assertThat(repr).contains("my_cache"); + assertThat(repr).contains("3600"); + } +} From d1221bb06ce617c1558922c5feb18263ec7f83cd Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Mon, 16 Mar 2026 14:55:45 -0700 Subject: [PATCH 4/6] test: add integration tests for count() and max_distance features Add integration tests matching Python redis-vl test coverage: - MessageHistoryCountIntegrationTest: count() for both standard and semantic message history (store, count, clear, session scoping) - MultiVectorQueryMaxDistanceIntegrationTest: per-vector max_distance filtering with parametrized thresholds, monotonicity, and distance verification against real Redis - Fix MultiVectorQueryIntegrationTest assertion for AND join --- .../MessageHistoryCountIntegrationTest.java | 140 +++++++ .../MultiVectorQueryIntegrationTest.java | 2 +- ...VectorQueryMaxDistanceIntegrationTest.java | 379 ++++++++++++++++++ 3 files changed, 520 insertions(+), 1 deletion(-) create mode 100644 core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountIntegrationTest.java create mode 100644 core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceIntegrationTest.java diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountIntegrationTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountIntegrationTest.java new file mode 100644 index 0000000..52167c4 --- /dev/null +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountIntegrationTest.java @@ -0,0 +1,140 @@ +package com.redis.vl.extensions.messagehistory; + +import static org.junit.jupiter.api.Assertions.*; + +import com.redis.vl.BaseIntegrationTest; +import org.junit.jupiter.api.*; + +/** + * Integration tests for count() method in MessageHistory and SemanticMessageHistory. + * + *

Ported from Python: tests/integration/test_message_history.py (test_standard_count, + * test_semantic_count) + */ +@Tag("integration") +@DisplayName("MessageHistory count() Integration Tests") +class MessageHistoryCountIntegrationTest extends BaseIntegrationTest { + + @Nested + @DisplayName("Standard MessageHistory count") + class StandardCountTests { + + private MessageHistory history; + + @BeforeEach + void setUp() { + history = new MessageHistory("test_standard_count", unifiedJedis); + history.clear(); + } + + @AfterEach + void tearDown() { + if (history != null) { + history.clear(); + history.delete(); + } + } + + @Test + @DisplayName("count returns 0 when empty") + void testCountReturnsZeroWhenEmpty() { + assertEquals(0, history.count()); + } + + @Test + @DisplayName("count returns 2 after storing one prompt/response pair, 0 after clear") + void testStandardCount() { + history.store("some prompt", "some response"); + assertEquals(2, history.count()); + + history.clear(); + assertEquals(0, history.count()); + } + + @Test + @DisplayName("count with explicit session tag") + void testCountWithSessionTag() { + history.store("prompt 1", "response 1", "session_a"); + history.store("prompt 2", "response 2", "session_b"); + + assertEquals(2, history.count("session_a")); + assertEquals(2, history.count("session_b")); + } + + @Test + @DisplayName("count defaults to instance session tag") + void testCountDefaultsToInstanceSession() { + MessageHistory sessionHistory = + new MessageHistory("test_count_session", "my-session", null, unifiedJedis); + try { + sessionHistory.store("prompt", "response"); + assertEquals(2, sessionHistory.count()); + assertEquals(2, sessionHistory.count("my-session")); + } finally { + sessionHistory.clear(); + sessionHistory.delete(); + } + } + + @Test + @DisplayName("count reflects multiple stores") + void testCountMultipleStores() { + history.store("first prompt", "first response"); + assertEquals(2, history.count()); + + history.store("second prompt", "second response"); + assertEquals(4, history.count()); + } + } + + @Nested + @DisplayName("Semantic MessageHistory count") + class SemanticCountTests { + + private SemanticMessageHistory history; + + @SuppressWarnings("unchecked") + @BeforeEach + void setUp() { + com.redis.vl.utils.vectorize.BaseVectorizer vectorizer = + new com.redis.vl.utils.vectorize.MockVectorizer("mock-model", 768); + history = + new SemanticMessageHistory( + "test_semantic_count", null, null, vectorizer, 0.3, unifiedJedis, true); + } + + @AfterEach + void tearDown() { + if (history != null) { + history.clear(); + history.delete(); + } + } + + @Test + @DisplayName("count returns 0 when empty") + void testSemanticCountReturnsZeroWhenEmpty() { + assertEquals(0, history.count()); + } + + @Test + @DisplayName("count returns 2 after storing one prompt/response pair, 0 after clear") + void testSemanticCount() { + history.store("first prompt", "first response"); + assertEquals(2, history.count()); + + history.clear(); + assertEquals(0, history.count()); + } + + @Test + @DisplayName("count with explicit session tag") + void testSemanticCountWithSessionTag() { + history.store("prompt 1", "response 1", "session_x"); + history.store("prompt 2", "response 2", "session_y"); + + assertEquals(2, history.count("session_x")); + assertEquals(2, history.count("session_y")); + } + } +} diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java index 7dc4f9b..cd4e2a1 100644 --- a/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java @@ -164,7 +164,7 @@ void testMultipleVectorsQuery() { assertThat(queryString) .contains("@text_embedding:[VECTOR_RANGE 2.0 $vector_0]") .contains("@image_embedding:[VECTOR_RANGE 2.0 $vector_1]") - .contains(" | "); + .contains(" AND "); // Verify scoring (@ prefix is correct for FT.AGGREGATE APPLY expressions) String formula = query.getScoringFormula(); diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceIntegrationTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceIntegrationTest.java new file mode 100644 index 0000000..79c3ba8 --- /dev/null +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceIntegrationTest.java @@ -0,0 +1,379 @@ +package com.redis.vl.query; + +import static org.assertj.core.api.Assertions.*; + +import com.redis.vl.BaseIntegrationTest; +import com.redis.vl.index.SearchIndex; +import com.redis.vl.schema.*; +import java.util.*; +import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +/** + * Integration tests for per-vector max_distance in MultiVectorQuery. + * + *

Ported from Python: tests/integration/test_aggregation.py + * (test_multivector_query_max_distances) + * + *

Verifies that each vector's max_distance threshold is independently applied when querying + * against real Redis with FT.AGGREGATE VECTOR_RANGE. + */ +@Tag("integration") +@DisplayName("MultiVectorQuery max_distance Integration Tests") +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +class MultiVectorQueryMaxDistanceIntegrationTest extends BaseIntegrationTest { + + private static final String INDEX_NAME = "multi_vec_maxdist_test_idx"; + private static SearchIndex searchIndex; + + @BeforeAll + static void setupIndex() throws InterruptedException { + // Clean up any existing index + try { + unifiedJedis.ftDropIndex(INDEX_NAME); + } catch (Exception e) { + // Ignore if index doesn't exist + } + + // Create schema with two vector fields (matching Python test structure) + // user_embedding: 3 dimensions, COSINE + // image_embedding: 5 dimensions, COSINE + IndexSchema schema = + IndexSchema.builder() + .name(INDEX_NAME) + .prefix("mvmd:") + .field(TextField.builder().name("title").build()) + .field( + VectorField.builder() + .name("user_embedding") + .dimensions(3) + .distanceMetric(VectorField.DistanceMetric.COSINE) + .build()) + .field( + VectorField.builder() + .name("image_embedding") + .dimensions(5) + .distanceMetric(VectorField.DistanceMetric.COSINE) + .build()) + .build(); + + searchIndex = new SearchIndex(schema, unifiedJedis); + searchIndex.create(); + + // Insert test documents with varying vector values to produce a range of distances. + // For COSINE distance: distance = 1 - cosine_similarity, range [0, 2]. + // We create 6 documents with progressively different vectors so that queries + // produce different result counts at different distance thresholds. + List> docs = new ArrayList<>(); + float[][] userVecs = { + {0.1f, 0.2f, 0.5f}, // very similar to query + {0.15f, 0.25f, 0.45f}, // very similar + {0.3f, 0.1f, 0.4f}, // somewhat similar + {0.5f, 0.5f, 0.1f}, // moderately different + {0.9f, 0.1f, 0.1f}, // quite different + {-0.5f, -0.3f, 0.2f} // very different + }; + + float[][] imageVecs = { + {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}, // very similar to query + {1.0f, 0.4f, -0.3f, 0.6f, 0.3f}, // very similar + {0.5f, 0.8f, 0.1f, 0.3f, 0.5f}, // moderately different + {0.1f, 0.1f, 0.9f, 0.1f, 0.1f}, // quite different + {-0.3f, 0.7f, 0.5f, -0.2f, 0.4f}, // very different + {-0.8f, -0.2f, 0.6f, -0.5f, -0.1f} // very different + }; + + for (int i = 0; i < 6; i++) { + Map doc = new HashMap<>(); + doc.put("id", String.valueOf(i + 1)); + doc.put("title", "Document " + (i + 1)); + doc.put("user_embedding", userVecs[i]); + doc.put("image_embedding", imageVecs[i]); + docs.add(doc); + } + + searchIndex.load(docs, "id"); + + // Wait for indexing + Thread.sleep(200); + } + + @AfterAll + static void cleanupIndex() { + if (searchIndex != null) { + try { + searchIndex.drop(); + } catch (Exception e) { + // Ignore + } + } + } + + @Test + @Order(1) + @DisplayName("Should return all documents with max_distance=2.0 (default)") + void testDefaultMaxDistanceReturnsAll() { + Vector userVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .build(); // default max_distance=2.0 + + Vector imageVec = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .build(); // default max_distance=2.0 + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vectors(userVec, imageVec) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + List> results = searchIndex.query(query); + + // With max_distance=2.0 on both vectors, all 6 documents should match + assertThat(results).hasSize(6); + } + + @Test + @Order(2) + @DisplayName("Should return fewer documents with tight max_distance") + void testTightMaxDistanceFilters() { + Vector userVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(0.05) // very tight - only very similar vectors + .build(); + + Vector imageVec = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(0.05) // very tight + .build(); + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vectors(userVec, imageVec) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + List> results = searchIndex.query(query); + + // Very tight thresholds should return fewer results than default + assertThat(results.size()).isLessThan(6); + } + + @Test + @Order(3) + @DisplayName("Should filter independently per vector field") + void testIndependentPerVectorFiltering() { + // Use a tight threshold on user_embedding but loose on image_embedding + Vector userVecTight = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(0.01) // very tight + .build(); + + Vector imageVecLoose = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(2.0) // wide open + .build(); + + MultiVectorQuery tightUserQuery = + MultiVectorQuery.builder() + .vectors(userVecTight, imageVecLoose) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + // Now flip: loose on user, tight on image + Vector userVecLoose = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(2.0) // wide open + .build(); + + Vector imageVecTight = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(0.01) // very tight + .build(); + + MultiVectorQuery tightImageQuery = + MultiVectorQuery.builder() + .vectors(userVecLoose, imageVecTight) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + List> tightUserResults = searchIndex.query(tightUserQuery); + List> tightImageResults = searchIndex.query(tightImageQuery); + + // Both should return fewer than 6 (the default max) + // The two queries should potentially return different counts because they + // filter different vector fields tightly + assertThat(tightUserResults.size()).isLessThanOrEqualTo(6); + assertThat(tightImageResults.size()).isLessThanOrEqualTo(6); + } + + @Test + @Order(4) + @DisplayName("Returned distances should respect max_distance thresholds") + void testReturnedDistancesWithinThreshold() { + double userMaxDist = 0.5; + double imageMaxDist = 0.5; + + Vector userVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(userMaxDist) + .build(); + + Vector imageVec = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(imageMaxDist) + .build(); + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vectors(userVec, imageVec) + .returnFields("title", "distance_0", "distance_1", "score_0", "score_1") + .numResults(10) + .build(); + + List> results = searchIndex.query(query); + + // Every returned document should have distances within the thresholds + for (Map result : results) { + Object dist0 = result.get("distance_0"); + Object dist1 = result.get("distance_1"); + if (dist0 != null) { + assertThat(Double.parseDouble(dist0.toString())) + .as("distance_0 for %s", result.get("title")) + .isLessThanOrEqualTo(userMaxDist); + } + if (dist1 != null) { + assertThat(Double.parseDouble(dist1.toString())) + .as("distance_1 for %s", result.get("title")) + .isLessThanOrEqualTo(imageMaxDist); + } + } + } + + @ParameterizedTest(name = "max_distance({0}, {1}) should return results") + @CsvSource({ + "2.0, 2.0", // widest - should return all + "0.5, 0.5", // moderate + "0.1, 0.1", // tight + }) + @Order(5) + @DisplayName("Parametrized: tighter thresholds should return fewer or equal results") + void testTighterThresholdsReturnFewerResults(double maxDist1, double maxDist2) { + Vector userVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(maxDist1) + .build(); + + Vector imageVec = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(maxDist2) + .build(); + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vectors(userVec, imageVec) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + List> results = searchIndex.query(query); + + // Results count should be non-negative + assertThat(results.size()).isGreaterThanOrEqualTo(0); + + // All returned results should have distances within thresholds + for (Map result : results) { + Object dist0 = result.get("distance_0"); + Object dist1 = result.get("distance_1"); + if (dist0 != null) { + assertThat(Double.parseDouble(dist0.toString())).isLessThanOrEqualTo(maxDist1); + } + if (dist1 != null) { + assertThat(Double.parseDouble(dist1.toString())).isLessThanOrEqualTo(maxDist2); + } + } + } + + @Test + @Order(6) + @DisplayName("Monotonicity: wider thresholds return >= results than narrower") + void testMonotonicity() { + // Narrow thresholds + Vector userNarrow = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(0.1) + .build(); + Vector imageNarrow = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(0.1) + .build(); + + MultiVectorQuery narrowQuery = + MultiVectorQuery.builder() + .vectors(userNarrow, imageNarrow) + .returnFields("title") + .numResults(10) + .build(); + + // Wide thresholds + Vector userWide = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(1.5) + .build(); + Vector imageWide = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(1.5) + .build(); + + MultiVectorQuery wideQuery = + MultiVectorQuery.builder() + .vectors(userWide, imageWide) + .returnFields("title") + .numResults(10) + .build(); + + List> narrowResults = searchIndex.query(narrowQuery); + List> wideResults = searchIndex.query(wideQuery); + + assertThat(wideResults.size()) + .as("Wider thresholds should return >= results than narrower") + .isGreaterThanOrEqualTo(narrowResults.size()); + } +} From 975b6b409060b966014225681541d3784a66d74c Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Mon, 16 Mar 2026 15:11:17 -0700 Subject: [PATCH 5/6] refactor(messagehistory): deduplicate count() into BaseMessageHistory Move identical count()/count(sessionTag) implementations from MessageHistory and SemanticMessageHistory into the base class. Subclasses now provide getSearchIndex() and getDefaultSessionFilter() accessors instead of duplicating the query logic. --- .../messagehistory/BaseMessageHistory.java | 30 +++++++++++++++++-- .../messagehistory/MessageHistory.java | 12 +++----- .../SemanticMessageHistory.java | 12 +++----- .../messagehistory/RoleFilteringTest.java | 10 ++++--- 4 files changed, 42 insertions(+), 22 deletions(-) diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java index add4521..163a854 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java @@ -3,6 +3,9 @@ import static com.redis.vl.extensions.Constants.*; import com.github.f4b6a3.ulid.UlidCreator; +import com.redis.vl.index.SearchIndex; +import com.redis.vl.query.CountQuery; +import com.redis.vl.query.Filter; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -33,6 +36,20 @@ protected BaseMessageHistory(String name, String sessionTag) { this.sessionTag = (sessionTag != null) ? sessionTag : UlidCreator.getUlid().toString(); } + /** + * Get the underlying search index. + * + * @return The search index used by this message history + */ + protected abstract SearchIndex getSearchIndex(); + + /** + * Get the default session filter. + * + * @return The filter for the default session tag + */ + protected abstract Filter getDefaultSessionFilter(); + /** * Count the number of messages in the conversation history. * @@ -40,7 +57,9 @@ protected BaseMessageHistory(String name, String sessionTag) { * * @return The number of messages for the default session */ - public abstract long count(); + public long count() { + return count(null); + } /** * Count the number of messages in the conversation history for a specific session. @@ -50,7 +69,14 @@ protected BaseMessageHistory(String name, String sessionTag) { * @param sessionTag The session tag to count messages for (null uses default session) * @return The number of messages */ - public abstract long count(String sessionTag); + public long count(String sessionTag) { + Filter sessionFilter = + (sessionTag != null) + ? Filter.tag(SESSION_FIELD_NAME, sessionTag) + : getDefaultSessionFilter(); + CountQuery query = new CountQuery(sessionFilter); + return getSearchIndex().count(query); + } /** Clears the chat message history. */ public abstract void clear(); diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java index 4a495a1..a834bf5 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java @@ -3,7 +3,6 @@ import static com.redis.vl.extensions.Constants.*; import com.redis.vl.index.SearchIndex; -import com.redis.vl.query.CountQuery; import com.redis.vl.query.Filter; import com.redis.vl.query.FilterQuery; import com.redis.vl.schema.IndexSchema; @@ -66,16 +65,13 @@ public MessageHistory(String name, String sessionTag, String prefix, UnifiedJedi } @Override - public long count() { - return count(null); + protected SearchIndex getSearchIndex() { + return index; } @Override - public long count(String sessionTag) { - Filter sessionFilter = - (sessionTag != null) ? Filter.tag(SESSION_FIELD_NAME, sessionTag) : defaultSessionFilter; - CountQuery query = new CountQuery(sessionFilter); - return index.count(query); + protected Filter getDefaultSessionFilter() { + return defaultSessionFilter; } @Override diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java index 4c7c918..ad050e3 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java @@ -3,7 +3,6 @@ import static com.redis.vl.extensions.Constants.*; import com.redis.vl.index.SearchIndex; -import com.redis.vl.query.CountQuery; import com.redis.vl.query.Filter; import com.redis.vl.query.FilterQuery; import com.redis.vl.query.VectorQuery; @@ -154,16 +153,13 @@ public SearchIndex getIndex() { } @Override - public long count() { - return count(null); + protected SearchIndex getSearchIndex() { + return index; } @Override - public long count(String sessionTag) { - Filter sessionFilter = - (sessionTag != null) ? Filter.tag(SESSION_FIELD_NAME, sessionTag) : defaultSessionFilter; - CountQuery query = new CountQuery(sessionFilter); - return index.count(query); + protected Filter getDefaultSessionFilter() { + return defaultSessionFilter; } @Override diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java index 7f8115c..fafb7dc 100644 --- a/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java @@ -2,6 +2,8 @@ import static org.assertj.core.api.Assertions.*; +import com.redis.vl.index.SearchIndex; +import com.redis.vl.query.Filter; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -64,13 +66,13 @@ public void addMessages(List> messages, String ses public void addMessage(java.util.Map message, String sessionTag) {} @Override - public long count() { - return 0; + protected SearchIndex getSearchIndex() { + return null; } @Override - public long count(String sessionTag) { - return 0; + protected Filter getDefaultSessionFilter() { + return null; } } From b983a9c2b3d7be4ce3258e50447edfd5d67e3f63 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Mon, 16 Mar 2026 15:35:14 -0700 Subject: [PATCH 6/6] fix(query): preserve full maxDistance precision in VECTOR_RANGE query Replace %.1f format specifier with %s to avoid truncating user-provided maxDistance values. Previously, maxDistance(0.01) would produce "VECTOR_RANGE 0.0" (matching nothing) and maxDistance(0.05) would produce "VECTOR_RANGE 0.1" (doubling the intended radius). Add unit test verifying sub-decimal precision is preserved. --- .../com/redis/vl/query/MultiVectorQuery.java | 2 +- .../query/MultiVectorQueryMaxDistanceTest.java | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java index f61ffac..9a1bbc2 100644 --- a/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java +++ b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java @@ -133,7 +133,7 @@ public String buildQueryString() { Vector v = vectors.get(i); String rangeQuery = String.format( - "@%s:[VECTOR_RANGE %.1f $vector_%d]=>{$YIELD_DISTANCE_AS: distance_%d}", + "@%s:[VECTOR_RANGE %s $vector_%d]=>{$YIELD_DISTANCE_AS: distance_%d}", v.getFieldName(), v.getMaxDistance(), i, i); rangeQueries.add(rangeQuery); } diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java index 477d5e6..ccf6796 100644 --- a/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java @@ -141,4 +141,22 @@ void testQueryStringMixedMaxDistance() { assertThat(queryString).contains("@field_1:[VECTOR_RANGE 2.0 $vector_0]"); assertThat(queryString).contains("@field_2:[VECTOR_RANGE 0.3 $vector_1]"); } + + @Test + @DisplayName("MultiVectorQuery: Should preserve full precision in query string") + void testQueryStringPreservesFullPrecision() { + Vector vector1 = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").maxDistance(0.01).build(); + + Vector vector2 = + Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").maxDistance(0.05).build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String queryString = query.toQueryString(); + + // These would have been truncated to 0.0 and 0.1 with %.1f format + assertThat(queryString).contains("VECTOR_RANGE 0.01"); + assertThat(queryString).contains("VECTOR_RANGE 0.05"); + } }