diff --git a/core/src/main/java/com/redis/vl/extensions/cache/LangCacheSemanticCache.java b/core/src/main/java/com/redis/vl/extensions/cache/LangCacheSemanticCache.java index bb7b802..fdf1124 100644 --- a/core/src/main/java/com/redis/vl/extensions/cache/LangCacheSemanticCache.java +++ b/core/src/main/java/com/redis/vl/extensions/cache/LangCacheSemanticCache.java @@ -596,6 +596,11 @@ public void update(String key, String field, Object value) { + "Delete and re-create the entry instead."); } + @Override + public String toString() { + return String.format("LangCacheSemanticCache(name='%s', ttl=%s)", name, ttl); + } + /** Builder for LangCacheSemanticCache. */ public static class Builder { private String name = "langcache"; diff --git a/core/src/main/java/com/redis/vl/extensions/cache/SemanticCache.java b/core/src/main/java/com/redis/vl/extensions/cache/SemanticCache.java index 1b843fb..abffc37 100644 --- a/core/src/main/java/com/redis/vl/extensions/cache/SemanticCache.java +++ b/core/src/main/java/com/redis/vl/extensions/cache/SemanticCache.java @@ -489,6 +489,12 @@ public void resetStatistics() { missCount.set(0); } + @Override + public String toString() { + return String.format( + "SemanticCache(name='%s', distance_threshold=%s, ttl=%s)", name, distanceThreshold, ttl); + } + /** Builder for SemanticCache. */ public static class Builder { private String name; diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java index 194674e..163a854 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/BaseMessageHistory.java @@ -3,11 +3,14 @@ import static com.redis.vl.extensions.Constants.*; import com.github.f4b6a3.ulid.UlidCreator; +import com.redis.vl.index.SearchIndex; +import com.redis.vl.query.CountQuery; +import com.redis.vl.query.Filter; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; +import java.util.logging.Logger; /** * Base class for message history implementations. @@ -16,8 +19,7 @@ */ public abstract class BaseMessageHistory { - /** Valid role values for message filtering. */ - private static final Set VALID_ROLES = Set.of("system", "user", "llm", "tool"); + private static final Logger logger = Logger.getLogger(BaseMessageHistory.class.getName()); protected final String name; protected final String sessionTag; @@ -34,6 +36,48 @@ protected BaseMessageHistory(String name, String sessionTag) { this.sessionTag = (sessionTag != null) ? sessionTag : UlidCreator.getUlid().toString(); } + /** + * Get the underlying search index. + * + * @return The search index used by this message history + */ + protected abstract SearchIndex getSearchIndex(); + + /** + * Get the default session filter. + * + * @return The filter for the default session tag + */ + protected abstract Filter getDefaultSessionFilter(); + + /** + * Count the number of messages in the conversation history. + * + *

Matches Python count() from base_history.py + * + * @return The number of messages for the default session + */ + public long count() { + return count(null); + } + + /** + * Count the number of messages in the conversation history for a specific session. + * + *

Matches Python count(session_tag=...) from base_history.py + * + * @param sessionTag The session tag to count messages for (null uses default session) + * @return The number of messages + */ + public long count(String sessionTag) { + Filter sessionFilter = + (sessionTag != null) + ? Filter.tag(SESSION_FIELD_NAME, sessionTag) + : getDefaultSessionFilter(); + CountQuery query = new CountQuery(sessionFilter); + return getSearchIndex().count(query); + } + /** Clears the chat message history. */ public abstract void clear(); @@ -143,10 +187,7 @@ protected List validateRoles(Object role) { // Handle single role string if (role instanceof String) { String roleStr = (String) role; - if (!VALID_ROLES.contains(roleStr)) { - throw new IllegalArgumentException( - String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES)); - } + validateSingleRole(roleStr); return List.of(roleStr); } @@ -166,10 +207,7 @@ protected List validateRoles(Object role) { "role list must contain only strings, found: " + r.getClass().getSimpleName()); } String roleStr = (String) r; - if (!VALID_ROLES.contains(roleStr)) { - throw new IllegalArgumentException( - String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES)); - } + validateSingleRole(roleStr); validatedRoles.add(roleStr); } @@ -179,6 +217,28 @@ protected List validateRoles(Object role) { throw new IllegalArgumentException("role must be a String, List, or null"); } + /** + * Validate a single role string using ChatRole enum. + * + * @param roleStr The role string to validate + * @throws IllegalArgumentException if the role is not valid + */ + private void validateSingleRole(String roleStr) { + if (ChatRole.isDeprecatedRole(roleStr)) { + logger.warning( + String.format( + "Role '%s' is a deprecated value. Update to valid roles: %s.", + roleStr, java.util.Arrays.toString(ChatRole.values()))); + return; + } + if (!ChatRole.isValidRole(roleStr)) { + throw new IllegalArgumentException( + String.format( + "Invalid role '%s'. Valid roles are: %s", + roleStr, java.util.Arrays.toString(ChatRole.values()))); + } + } + public String getName() { return name; } diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/ChatRole.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/ChatRole.java new file mode 100644 index 0000000..59351f2 --- /dev/null +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/ChatRole.java @@ -0,0 +1,79 @@ +package com.redis.vl.extensions.messagehistory; + +import java.util.Set; + +/** + * Enumeration of valid chat message roles. + * + *

Ported from Python: redisvl/extensions/message_history/schema.py (commit 23ecc77) + * + *

This enum serves as the single source of truth for valid roles in chat message history. The + * deprecated "llm" role is accepted for backward compatibility but is not a member of this enum. + */ +public enum ChatRole { + USER("user"), + ASSISTANT("assistant"), + SYSTEM("system"), + TOOL("tool"); + + /** Deprecated role values that are accepted for backward compatibility. */ + private static final Set DEPRECATED_ROLES = Set.of("llm"); + + private final String value; + + ChatRole(String value) { + this.value = value; + } + + /** + * Get the string value of this role. + * + * @return The role string value + */ + public String getValue() { + return value; + } + + /** + * Coerce a string to a ChatRole enum value. + * + * @param role The role string + * @return The matching ChatRole, or null if not a standard role + */ + public static ChatRole fromString(String role) { + if (role == null || role.isEmpty()) { + return null; + } + for (ChatRole chatRole : values()) { + if (chatRole.value.equals(role)) { + return chatRole; + } + } + return null; + } + + /** + * Check if a string is a valid role (including deprecated roles). + * + * @param role The role string to check + * @return true if the role is valid (either standard or deprecated) + */ + public static boolean isValidRole(String role) { + return fromString(role) != null || isDeprecatedRole(role); + } + + /** + * Check if a string is a deprecated role value. + * + * @param role The role string to check + * @return true if the role is deprecated + */ + public static boolean isDeprecatedRole(String role) { + return DEPRECATED_ROLES.contains(role); + } + + @Override + public String toString() { + return value; + } +} diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java index a88750e..a834bf5 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/MessageHistory.java @@ -64,6 +64,16 @@ public MessageHistory(String name, String sessionTag, String prefix, UnifiedJedi this.defaultSessionFilter = Filter.tag(SESSION_FIELD_NAME, this.sessionTag); } + @Override + protected SearchIndex getSearchIndex() { + return index; + } + + @Override + protected Filter getDefaultSessionFilter() { + return defaultSessionFilter; + } + @Override public void clear() { index.clear(); @@ -269,4 +279,9 @@ public void addMessages(List> messages) { public SearchIndex getIndex() { return index; } + + @Override + public String toString() { + return String.format("MessageHistory(name='%s', session_tag='%s')", name, sessionTag); + } } diff --git a/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java b/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java index b75a3f1..ad050e3 100644 --- a/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java +++ b/core/src/main/java/com/redis/vl/extensions/messagehistory/SemanticMessageHistory.java @@ -152,6 +152,16 @@ public SearchIndex getIndex() { return index; } + @Override + protected SearchIndex getSearchIndex() { + return index; + } + + @Override + protected Filter getDefaultSessionFilter() { + return defaultSessionFilter; + } + @Override public void clear() { index.clear(); @@ -490,6 +500,13 @@ private Filter combineWithRoleFilter(Filter sessionFilter, List roles) { } } + @Override + public String toString() { + return String.format( + "SemanticMessageHistory(name='%s', session_tag='%s', distance_threshold=%s)", + name, sessionTag, distanceThreshold); + } + /** * Format messages with metadata deserialization support. * diff --git a/core/src/main/java/com/redis/vl/extensions/router/SemanticRouter.java b/core/src/main/java/com/redis/vl/extensions/router/SemanticRouter.java index 60001c0..80334a9 100644 --- a/core/src/main/java/com/redis/vl/extensions/router/SemanticRouter.java +++ b/core/src/main/java/com/redis/vl/extensions/router/SemanticRouter.java @@ -148,6 +148,11 @@ public SemanticRouter build() { } } + @Override + public String toString() { + return String.format("SemanticRouter(name='%s', routes=%d)", name, routes.size()); + } + /** * Get the list of route names. Ported from Python: route_names property (line 187) * diff --git a/core/src/main/java/com/redis/vl/index/SearchIndex.java b/core/src/main/java/com/redis/vl/index/SearchIndex.java index 48d6a02..d9594b3 100644 --- a/core/src/main/java/com/redis/vl/index/SearchIndex.java +++ b/core/src/main/java/com/redis/vl/index/SearchIndex.java @@ -1726,6 +1726,13 @@ private List> processHybridResult(HybridResult result) { return processed; } + @Override + public String toString() { + return String.format( + "SearchIndex(name='%s', prefix='%s', storage_type='%s')", + getName(), getPrefix(), getStorageType()); + } + /** * Execute multiple search queries in batch * diff --git a/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java index e778995..9a1bbc2 100644 --- a/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java +++ b/core/src/main/java/com/redis/vl/query/MultiVectorQuery.java @@ -69,9 +69,6 @@ @Getter public final class MultiVectorQuery extends AggregationQuery { - /** Distance threshold for VECTOR_RANGE (hardcoded at 2.0 to include all eligible documents) */ - private static final double DISTANCE_THRESHOLD = 2.0; - private final List vectors; private final Filter filterExpression; private final List returnFields; @@ -111,8 +108,8 @@ public static Builder builder() { /** * Build the Redis query string for multi-vector search. * - *

Format: {@code @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} | - * @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}} + *

Format: {@code @field1:[VECTOR_RANGE max_dist $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} AND + * @field2:[VECTOR_RANGE max_dist $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}} * * @return Query string */ @@ -123,8 +120,8 @@ public String toQueryString() { /** * Build the Redis query string for multi-vector search. * - *

Format: {@code @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} | - * @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}} + *

Format: {@code @field1:[VECTOR_RANGE max_dist $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} AND + * @field2:[VECTOR_RANGE max_dist $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}} * * @return Query string */ @@ -136,12 +133,12 @@ public String buildQueryString() { Vector v = vectors.get(i); String rangeQuery = String.format( - "@%s:[VECTOR_RANGE %.1f $vector_%d]=>{$YIELD_DISTANCE_AS: distance_%d}", - v.getFieldName(), DISTANCE_THRESHOLD, i, i); + "@%s:[VECTOR_RANGE %s $vector_%d]=>{$YIELD_DISTANCE_AS: distance_%d}", + v.getFieldName(), v.getMaxDistance(), i, i); rangeQueries.add(rangeQuery); } - String baseQuery = String.join(" | ", rangeQueries); + String baseQuery = String.join(" AND ", rangeQueries); // Add filter expression if present if (filterExpression != null) { diff --git a/core/src/main/java/com/redis/vl/query/Vector.java b/core/src/main/java/com/redis/vl/query/Vector.java index 599201e..9304546 100644 --- a/core/src/main/java/com/redis/vl/query/Vector.java +++ b/core/src/main/java/com/redis/vl/query/Vector.java @@ -57,6 +57,7 @@ public final class Vector { private final String fieldName; private final String dtype; private final double weight; + private final double maxDistance; private Vector(Builder builder) { // Validate before modifying state @@ -74,11 +75,16 @@ private Vector(Builder builder) { if (builder.weight <= 0) { throw new IllegalArgumentException("Weight must be positive, got " + builder.weight); } + if (builder.maxDistance < 0.0 || builder.maxDistance > 2.0) { + throw new IllegalArgumentException( + "max_distance must be a value between 0.0 and 2.0, got " + builder.maxDistance); + } this.vector = Arrays.copyOf(builder.vector, builder.vector.length); this.fieldName = builder.fieldName.trim(); this.dtype = builder.dtype; this.weight = builder.weight; + this.maxDistance = builder.maxDistance; } /** @@ -105,6 +111,7 @@ public static class Builder { private String fieldName; private String dtype = "float32"; // Default from Python private double weight = 1.0; // Default from Python + private double maxDistance = 2.0; // Default from Python Builder() {} @@ -156,6 +163,19 @@ public Builder weight(double weight) { return this; } + /** + * Set the maximum distance threshold for this vector in multi-vector range queries. + * + *

Must be between 0.0 and 2.0 (inclusive). Default is 2.0. + * + * @param maxDistance Maximum distance threshold + * @return This builder + */ + public Builder maxDistance(double maxDistance) { + this.maxDistance = maxDistance; + return this; + } + /** * Build the Vector instance. * @@ -174,6 +194,7 @@ public boolean equals(Object o) { if (o == null || getClass() != o.getClass()) return false; Vector vector1 = (Vector) o; return Double.compare(vector1.weight, weight) == 0 + && Double.compare(vector1.maxDistance, maxDistance) == 0 && Arrays.equals(vector, vector1.vector) && fieldName.equals(vector1.fieldName) && dtype.equals(vector1.dtype); @@ -185,13 +206,14 @@ public int hashCode() { result = 31 * result + fieldName.hashCode(); result = 31 * result + dtype.hashCode(); result = 31 * result + Double.hashCode(weight); + result = 31 * result + Double.hashCode(maxDistance); return result; } @Override public String toString() { return String.format( - "Vector[fieldName=%s, dtype=%s, weight=%.2f, dimensions=%d]", - fieldName, dtype, weight, vector.length); + "Vector[fieldName=%s, dtype=%s, weight=%.2f, maxDistance=%.1f, dimensions=%d]", + fieldName, dtype, weight, maxDistance, vector.length); } } diff --git a/core/src/test/java/com/redis/vl/ToStringTest.java b/core/src/test/java/com/redis/vl/ToStringTest.java new file mode 100644 index 0000000..69627ce --- /dev/null +++ b/core/src/test/java/com/redis/vl/ToStringTest.java @@ -0,0 +1,138 @@ +package com.redis.vl; + +import static org.assertj.core.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +import com.redis.vl.extensions.cache.LangCacheSemanticCache; +import com.redis.vl.extensions.messagehistory.MessageHistory; +import com.redis.vl.extensions.messagehistory.SemanticMessageHistory; +import com.redis.vl.extensions.router.SemanticRouter; +import com.redis.vl.index.SearchIndex; +import com.redis.vl.schema.IndexSchema; +import com.redis.vl.schema.TagField; +import com.redis.vl.utils.vectorize.BaseVectorizer; +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; +import redis.clients.jedis.UnifiedJedis; + +/** + * Unit tests for toString() methods on core public classes. + * + *

Ported from Python: commit e285fed - Add __repr__ methods + * + *

Python reference: tests/unit/test_repr.py + */ +@DisplayName("toString() Tests") +class ToStringTest { + + @Test + @DisplayName("SearchIndex: toString should include name, prefix, and storage_type") + void testSearchIndexToString() { + IndexSchema schema = + IndexSchema.builder() + .name("test_index") + .prefix("test_prefix") + .storageType(IndexSchema.StorageType.HASH) + .field(new TagField("tag1")) + .build(); + + SearchIndex index = new SearchIndex(schema); + + String repr = index.toString(); + + assertThat(repr).contains("SearchIndex"); + assertThat(repr).contains("test_index"); + assertThat(repr).contains("test_prefix"); + assertThat(repr).contains("HASH"); + } + + @Test + @DisplayName("SearchIndex: toString with JSON storage") + void testSearchIndexToStringJson() { + IndexSchema schema = + IndexSchema.builder() + .name("json_index") + .prefix("json_prefix") + .storageType(IndexSchema.StorageType.JSON) + .field(new TagField("tag1")) + .build(); + + SearchIndex index = new SearchIndex(schema); + + String repr = index.toString(); + + assertThat(repr).contains("json_index"); + assertThat(repr).contains("JSON"); + } + + @SuppressWarnings("unchecked") + @Test + @DisplayName("MessageHistory: toString should include name and session_tag") + void testMessageHistoryToString() { + UnifiedJedis mockJedis = mock(UnifiedJedis.class); + when(mockJedis.ftCreate(anyString(), any(), any(Iterable.class))).thenReturn("OK"); + + MessageHistory history = new MessageHistory("chat_history", "session123", null, mockJedis); + + String repr = history.toString(); + + assertThat(repr).contains("MessageHistory"); + assertThat(repr).contains("chat_history"); + assertThat(repr).contains("session123"); + } + + @SuppressWarnings("unchecked") + @Test + @DisplayName( + "SemanticMessageHistory: toString should include name, session_tag, and distance_threshold") + void testSemanticMessageHistoryToString() { + UnifiedJedis mockJedis = mock(UnifiedJedis.class); + when(mockJedis.ftCreate(anyString(), any(), any(Iterable.class))).thenReturn("OK"); + + BaseVectorizer mockVectorizer = mock(BaseVectorizer.class); + when(mockVectorizer.getDimensions()).thenReturn(384); + when(mockVectorizer.getDataType()).thenReturn("float32"); + + SemanticMessageHistory history = + new SemanticMessageHistory("semantic_chat", "session456", null, mockVectorizer, mockJedis); + + String repr = history.toString(); + + assertThat(repr).contains("SemanticMessageHistory"); + assertThat(repr).contains("semantic_chat"); + assertThat(repr).contains("session456"); + assertThat(repr).contains("0.3"); // default distance threshold + } + + @Test + @DisplayName("SemanticRouter: toString should include name and route count") + void testSemanticRouterToString() { + SemanticRouter router = new SemanticRouter("my_router"); + + String repr = router.toString(); + + assertThat(repr).contains("SemanticRouter"); + assertThat(repr).contains("my_router"); + assertThat(repr).contains("0"); // no routes + } + + @Test + @DisplayName("LangCacheSemanticCache: toString should include name and ttl") + void testLangCacheSemanticCacheToString() { + LangCacheSemanticCache cache = + new LangCacheSemanticCache.Builder() + .name("my_cache") + .cacheId("cache-123") + .apiKey("test-key") + .ttl(3600) + .build(); + + String repr = cache.toString(); + + assertThat(repr).contains("LangCacheSemanticCache"); + assertThat(repr).contains("my_cache"); + assertThat(repr).contains("3600"); + } +} diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/ChatRoleTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/ChatRoleTest.java new file mode 100644 index 0000000..124372a --- /dev/null +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/ChatRoleTest.java @@ -0,0 +1,82 @@ +package com.redis.vl.extensions.messagehistory; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for ChatRole enum. + * + *

Ported from Python: commit 23ecc77 - Add ChatRole enum for message history role validation + * + *

Python reference: tests/unit/test_message_history_schema.py + */ +@DisplayName("ChatRole Enum Tests") +class ChatRoleTest { + + @Test + @DisplayName("Should have all required role values") + void testAllRolesExist() { + assertThat(ChatRole.values()).hasSize(4); + assertThat(ChatRole.USER.getValue()).isEqualTo("user"); + assertThat(ChatRole.ASSISTANT.getValue()).isEqualTo("assistant"); + assertThat(ChatRole.SYSTEM.getValue()).isEqualTo("system"); + assertThat(ChatRole.TOOL.getValue()).isEqualTo("tool"); + } + + @Test + @DisplayName("Should coerce valid string to ChatRole") + void testCoerceValidString() { + assertThat(ChatRole.fromString("user")).isEqualTo(ChatRole.USER); + assertThat(ChatRole.fromString("assistant")).isEqualTo(ChatRole.ASSISTANT); + assertThat(ChatRole.fromString("system")).isEqualTo(ChatRole.SYSTEM); + assertThat(ChatRole.fromString("tool")).isEqualTo(ChatRole.TOOL); + } + + @Test + @DisplayName("Should accept deprecated 'llm' role with warning") + void testDeprecatedLlmRole() { + // In Python, 'llm' is accepted with a deprecation warning + // In Java, we accept it and map it to the string value + assertThat(ChatRole.fromString("llm")).isNull(); + assertThat(ChatRole.isDeprecatedRole("llm")).isTrue(); + } + + @Test + @DisplayName("Should return null for unrecognized role") + void testUnrecognizedRole() { + assertThat(ChatRole.fromString("potato")).isNull(); + assertThat(ChatRole.fromString("admin")).isNull(); + assertThat(ChatRole.fromString("")).isNull(); + } + + @Test + @DisplayName("Should be case-sensitive") + void testCaseSensitive() { + assertThat(ChatRole.fromString("User")).isNull(); + assertThat(ChatRole.fromString("SYSTEM")).isNull(); + assertThat(ChatRole.fromString("TOOL")).isNull(); + } + + @Test + @DisplayName("Should check validity including deprecated roles") + void testIsValidRole() { + assertThat(ChatRole.isValidRole("user")).isTrue(); + assertThat(ChatRole.isValidRole("assistant")).isTrue(); + assertThat(ChatRole.isValidRole("system")).isTrue(); + assertThat(ChatRole.isValidRole("tool")).isTrue(); + assertThat(ChatRole.isValidRole("llm")).isTrue(); // deprecated but valid + assertThat(ChatRole.isValidRole("potato")).isFalse(); + assertThat(ChatRole.isValidRole("User")).isFalse(); + } + + @Test + @DisplayName("toString should return the string value") + void testToString() { + assertThat(ChatRole.USER.toString()).isEqualTo("user"); + assertThat(ChatRole.ASSISTANT.toString()).isEqualTo("assistant"); + assertThat(ChatRole.SYSTEM.toString()).isEqualTo("system"); + assertThat(ChatRole.TOOL.toString()).isEqualTo("tool"); + } +} diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountIntegrationTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountIntegrationTest.java new file mode 100644 index 0000000..52167c4 --- /dev/null +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountIntegrationTest.java @@ -0,0 +1,140 @@ +package com.redis.vl.extensions.messagehistory; + +import static org.junit.jupiter.api.Assertions.*; + +import com.redis.vl.BaseIntegrationTest; +import org.junit.jupiter.api.*; + +/** + * Integration tests for count() method in MessageHistory and SemanticMessageHistory. + * + *

Ported from Python: tests/integration/test_message_history.py (test_standard_count, + * test_semantic_count) + */ +@Tag("integration") +@DisplayName("MessageHistory count() Integration Tests") +class MessageHistoryCountIntegrationTest extends BaseIntegrationTest { + + @Nested + @DisplayName("Standard MessageHistory count") + class StandardCountTests { + + private MessageHistory history; + + @BeforeEach + void setUp() { + history = new MessageHistory("test_standard_count", unifiedJedis); + history.clear(); + } + + @AfterEach + void tearDown() { + if (history != null) { + history.clear(); + history.delete(); + } + } + + @Test + @DisplayName("count returns 0 when empty") + void testCountReturnsZeroWhenEmpty() { + assertEquals(0, history.count()); + } + + @Test + @DisplayName("count returns 2 after storing one prompt/response pair, 0 after clear") + void testStandardCount() { + history.store("some prompt", "some response"); + assertEquals(2, history.count()); + + history.clear(); + assertEquals(0, history.count()); + } + + @Test + @DisplayName("count with explicit session tag") + void testCountWithSessionTag() { + history.store("prompt 1", "response 1", "session_a"); + history.store("prompt 2", "response 2", "session_b"); + + assertEquals(2, history.count("session_a")); + assertEquals(2, history.count("session_b")); + } + + @Test + @DisplayName("count defaults to instance session tag") + void testCountDefaultsToInstanceSession() { + MessageHistory sessionHistory = + new MessageHistory("test_count_session", "my-session", null, unifiedJedis); + try { + sessionHistory.store("prompt", "response"); + assertEquals(2, sessionHistory.count()); + assertEquals(2, sessionHistory.count("my-session")); + } finally { + sessionHistory.clear(); + sessionHistory.delete(); + } + } + + @Test + @DisplayName("count reflects multiple stores") + void testCountMultipleStores() { + history.store("first prompt", "first response"); + assertEquals(2, history.count()); + + history.store("second prompt", "second response"); + assertEquals(4, history.count()); + } + } + + @Nested + @DisplayName("Semantic MessageHistory count") + class SemanticCountTests { + + private SemanticMessageHistory history; + + @SuppressWarnings("unchecked") + @BeforeEach + void setUp() { + com.redis.vl.utils.vectorize.BaseVectorizer vectorizer = + new com.redis.vl.utils.vectorize.MockVectorizer("mock-model", 768); + history = + new SemanticMessageHistory( + "test_semantic_count", null, null, vectorizer, 0.3, unifiedJedis, true); + } + + @AfterEach + void tearDown() { + if (history != null) { + history.clear(); + history.delete(); + } + } + + @Test + @DisplayName("count returns 0 when empty") + void testSemanticCountReturnsZeroWhenEmpty() { + assertEquals(0, history.count()); + } + + @Test + @DisplayName("count returns 2 after storing one prompt/response pair, 0 after clear") + void testSemanticCount() { + history.store("first prompt", "first response"); + assertEquals(2, history.count()); + + history.clear(); + assertEquals(0, history.count()); + } + + @Test + @DisplayName("count with explicit session tag") + void testSemanticCountWithSessionTag() { + history.store("prompt 1", "response 1", "session_x"); + history.store("prompt 2", "response 2", "session_y"); + + assertEquals(2, history.count("session_x")); + assertEquals(2, history.count("session_y")); + } + } +} diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountTest.java new file mode 100644 index 0000000..5c1c138 --- /dev/null +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/MessageHistoryCountTest.java @@ -0,0 +1,84 @@ +package com.redis.vl.extensions.messagehistory; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.*; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import redis.clients.jedis.UnifiedJedis; +import redis.clients.jedis.search.FTSearchParams; +import redis.clients.jedis.search.SearchResult; + +/** + * Unit tests for count() method in MessageHistory. Ported from Python: + * tests/integration/test_message_history.py (test_standard_count, test_semantic_count) + */ +class MessageHistoryCountTest { + + private UnifiedJedis mockJedis; + + @SuppressWarnings("unchecked") + @BeforeEach + void setUp() { + mockJedis = mock(UnifiedJedis.class); + // Mock ftCreate to not fail + when(mockJedis.ftCreate(anyString(), any(), any(Iterable.class))).thenReturn("OK"); + } + + @Test + void testCountReturnsZeroWhenEmpty() { + // Mock ftSearch for count query to return 0 results + SearchResult emptyResult = mock(SearchResult.class); + when(emptyResult.getTotalResults()).thenReturn(0L); + when(mockJedis.ftSearch(anyString(), anyString(), any(FTSearchParams.class))) + .thenReturn(emptyResult); + + MessageHistory history = new MessageHistory("test_count", mockJedis); + assertEquals(0, history.count()); + } + + @Test + void testCountWithSessionTag() { + SearchResult result = mock(SearchResult.class); + when(result.getTotalResults()).thenReturn(4L); + when(mockJedis.ftSearch(anyString(), anyString(), any(FTSearchParams.class))) + .thenReturn(result); + + MessageHistory history = new MessageHistory("test_count", "session1", null, mockJedis); + assertEquals(4, history.count("session1")); + } + + @Test + void testCountDefaultsToInstanceSessionTag() { + SearchResult result = mock(SearchResult.class); + when(result.getTotalResults()).thenReturn(2L); + when(mockJedis.ftSearch(anyString(), anyString(), any(FTSearchParams.class))) + .thenReturn(result); + + MessageHistory history = new MessageHistory("test_count", "my-session", null, mockJedis); + + // count() with no args should use instance session tag + assertEquals(2, history.count()); + } + + @Test + void testCountUsesCountQuery() { + SearchResult result = mock(SearchResult.class); + when(result.getTotalResults()).thenReturn(3L); + when(mockJedis.ftSearch(anyString(), anyString(), any(FTSearchParams.class))) + .thenReturn(result); + + MessageHistory history = new MessageHistory("test_count", "session1", null, mockJedis); + long count = history.count(); + + // Verify ftSearch was called with a filter containing the session tag + verify(mockJedis, atLeastOnce()) + .ftSearch(anyString(), contains("session1"), any(FTSearchParams.class)); + } + + private static String contains(String substring) { + return argThat(arg -> arg != null && arg.contains(substring)); + } +} diff --git a/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java index 6026b7f..fafb7dc 100644 --- a/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java +++ b/core/src/test/java/com/redis/vl/extensions/messagehistory/RoleFilteringTest.java @@ -2,6 +2,8 @@ import static org.assertj.core.api.Assertions.*; +import com.redis.vl.index.SearchIndex; +import com.redis.vl.query.Filter; import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -62,6 +64,16 @@ public void addMessages(List> messages, String ses @Override public void addMessage(java.util.Map message, String sessionTag) {} + + @Override + protected SearchIndex getSearchIndex() { + return null; + } + + @Override + protected Filter getDefaultSessionFilter() { + return null; + } } @Test @@ -124,7 +136,7 @@ void testInvalidSingleRole() { .hasMessageContaining("Invalid role 'invalid_role'") .hasMessageContaining("system") .hasMessageContaining("user") - .hasMessageContaining("llm") + .hasMessageContaining("assistant") .hasMessageContaining("tool"); } diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java index 7dc4f9b..cd4e2a1 100644 --- a/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryIntegrationTest.java @@ -164,7 +164,7 @@ void testMultipleVectorsQuery() { assertThat(queryString) .contains("@text_embedding:[VECTOR_RANGE 2.0 $vector_0]") .contains("@image_embedding:[VECTOR_RANGE 2.0 $vector_1]") - .contains(" | "); + .contains(" AND "); // Verify scoring (@ prefix is correct for FT.AGGREGATE APPLY expressions) String formula = query.getScoringFormula(); diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceIntegrationTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceIntegrationTest.java new file mode 100644 index 0000000..79c3ba8 --- /dev/null +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceIntegrationTest.java @@ -0,0 +1,379 @@ +package com.redis.vl.query; + +import static org.assertj.core.api.Assertions.*; + +import com.redis.vl.BaseIntegrationTest; +import com.redis.vl.index.SearchIndex; +import com.redis.vl.schema.*; +import java.util.*; +import org.junit.jupiter.api.*; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +/** + * Integration tests for per-vector max_distance in MultiVectorQuery. + * + *

Ported from Python: tests/integration/test_aggregation.py + * (test_multivector_query_max_distances) + * + *

Verifies that each vector's max_distance threshold is independently applied when querying + * against real Redis with FT.AGGREGATE VECTOR_RANGE. + */ +@Tag("integration") +@DisplayName("MultiVectorQuery max_distance Integration Tests") +@TestMethodOrder(MethodOrderer.OrderAnnotation.class) +class MultiVectorQueryMaxDistanceIntegrationTest extends BaseIntegrationTest { + + private static final String INDEX_NAME = "multi_vec_maxdist_test_idx"; + private static SearchIndex searchIndex; + + @BeforeAll + static void setupIndex() throws InterruptedException { + // Clean up any existing index + try { + unifiedJedis.ftDropIndex(INDEX_NAME); + } catch (Exception e) { + // Ignore if index doesn't exist + } + + // Create schema with two vector fields (matching Python test structure) + // user_embedding: 3 dimensions, COSINE + // image_embedding: 5 dimensions, COSINE + IndexSchema schema = + IndexSchema.builder() + .name(INDEX_NAME) + .prefix("mvmd:") + .field(TextField.builder().name("title").build()) + .field( + VectorField.builder() + .name("user_embedding") + .dimensions(3) + .distanceMetric(VectorField.DistanceMetric.COSINE) + .build()) + .field( + VectorField.builder() + .name("image_embedding") + .dimensions(5) + .distanceMetric(VectorField.DistanceMetric.COSINE) + .build()) + .build(); + + searchIndex = new SearchIndex(schema, unifiedJedis); + searchIndex.create(); + + // Insert test documents with varying vector values to produce a range of distances. + // For COSINE distance: distance = 1 - cosine_similarity, range [0, 2]. + // We create 6 documents with progressively different vectors so that queries + // produce different result counts at different distance thresholds. + List> docs = new ArrayList<>(); + float[][] userVecs = { + {0.1f, 0.2f, 0.5f}, // very similar to query + {0.15f, 0.25f, 0.45f}, // very similar + {0.3f, 0.1f, 0.4f}, // somewhat similar + {0.5f, 0.5f, 0.1f}, // moderately different + {0.9f, 0.1f, 0.1f}, // quite different + {-0.5f, -0.3f, 0.2f} // very different + }; + + float[][] imageVecs = { + {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}, // very similar to query + {1.0f, 0.4f, -0.3f, 0.6f, 0.3f}, // very similar + {0.5f, 0.8f, 0.1f, 0.3f, 0.5f}, // moderately different + {0.1f, 0.1f, 0.9f, 0.1f, 0.1f}, // quite different + {-0.3f, 0.7f, 0.5f, -0.2f, 0.4f}, // very different + {-0.8f, -0.2f, 0.6f, -0.5f, -0.1f} // very different + }; + + for (int i = 0; i < 6; i++) { + Map doc = new HashMap<>(); + doc.put("id", String.valueOf(i + 1)); + doc.put("title", "Document " + (i + 1)); + doc.put("user_embedding", userVecs[i]); + doc.put("image_embedding", imageVecs[i]); + docs.add(doc); + } + + searchIndex.load(docs, "id"); + + // Wait for indexing + Thread.sleep(200); + } + + @AfterAll + static void cleanupIndex() { + if (searchIndex != null) { + try { + searchIndex.drop(); + } catch (Exception e) { + // Ignore + } + } + } + + @Test + @Order(1) + @DisplayName("Should return all documents with max_distance=2.0 (default)") + void testDefaultMaxDistanceReturnsAll() { + Vector userVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .build(); // default max_distance=2.0 + + Vector imageVec = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .build(); // default max_distance=2.0 + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vectors(userVec, imageVec) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + List> results = searchIndex.query(query); + + // With max_distance=2.0 on both vectors, all 6 documents should match + assertThat(results).hasSize(6); + } + + @Test + @Order(2) + @DisplayName("Should return fewer documents with tight max_distance") + void testTightMaxDistanceFilters() { + Vector userVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(0.05) // very tight - only very similar vectors + .build(); + + Vector imageVec = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(0.05) // very tight + .build(); + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vectors(userVec, imageVec) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + List> results = searchIndex.query(query); + + // Very tight thresholds should return fewer results than default + assertThat(results.size()).isLessThan(6); + } + + @Test + @Order(3) + @DisplayName("Should filter independently per vector field") + void testIndependentPerVectorFiltering() { + // Use a tight threshold on user_embedding but loose on image_embedding + Vector userVecTight = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(0.01) // very tight + .build(); + + Vector imageVecLoose = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(2.0) // wide open + .build(); + + MultiVectorQuery tightUserQuery = + MultiVectorQuery.builder() + .vectors(userVecTight, imageVecLoose) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + // Now flip: loose on user, tight on image + Vector userVecLoose = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(2.0) // wide open + .build(); + + Vector imageVecTight = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(0.01) // very tight + .build(); + + MultiVectorQuery tightImageQuery = + MultiVectorQuery.builder() + .vectors(userVecLoose, imageVecTight) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + List> tightUserResults = searchIndex.query(tightUserQuery); + List> tightImageResults = searchIndex.query(tightImageQuery); + + // Both should return fewer than 6 (the default max) + // The two queries should potentially return different counts because they + // filter different vector fields tightly + assertThat(tightUserResults.size()).isLessThanOrEqualTo(6); + assertThat(tightImageResults.size()).isLessThanOrEqualTo(6); + } + + @Test + @Order(4) + @DisplayName("Returned distances should respect max_distance thresholds") + void testReturnedDistancesWithinThreshold() { + double userMaxDist = 0.5; + double imageMaxDist = 0.5; + + Vector userVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(userMaxDist) + .build(); + + Vector imageVec = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(imageMaxDist) + .build(); + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vectors(userVec, imageVec) + .returnFields("title", "distance_0", "distance_1", "score_0", "score_1") + .numResults(10) + .build(); + + List> results = searchIndex.query(query); + + // Every returned document should have distances within the thresholds + for (Map result : results) { + Object dist0 = result.get("distance_0"); + Object dist1 = result.get("distance_1"); + if (dist0 != null) { + assertThat(Double.parseDouble(dist0.toString())) + .as("distance_0 for %s", result.get("title")) + .isLessThanOrEqualTo(userMaxDist); + } + if (dist1 != null) { + assertThat(Double.parseDouble(dist1.toString())) + .as("distance_1 for %s", result.get("title")) + .isLessThanOrEqualTo(imageMaxDist); + } + } + } + + @ParameterizedTest(name = "max_distance({0}, {1}) should return results") + @CsvSource({ + "2.0, 2.0", // widest - should return all + "0.5, 0.5", // moderate + "0.1, 0.1", // tight + }) + @Order(5) + @DisplayName("Parametrized: tighter thresholds should return fewer or equal results") + void testTighterThresholdsReturnFewerResults(double maxDist1, double maxDist2) { + Vector userVec = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(maxDist1) + .build(); + + Vector imageVec = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(maxDist2) + .build(); + + MultiVectorQuery query = + MultiVectorQuery.builder() + .vectors(userVec, imageVec) + .returnFields("title", "distance_0", "distance_1") + .numResults(10) + .build(); + + List> results = searchIndex.query(query); + + // Results count should be non-negative + assertThat(results.size()).isGreaterThanOrEqualTo(0); + + // All returned results should have distances within thresholds + for (Map result : results) { + Object dist0 = result.get("distance_0"); + Object dist1 = result.get("distance_1"); + if (dist0 != null) { + assertThat(Double.parseDouble(dist0.toString())).isLessThanOrEqualTo(maxDist1); + } + if (dist1 != null) { + assertThat(Double.parseDouble(dist1.toString())).isLessThanOrEqualTo(maxDist2); + } + } + } + + @Test + @Order(6) + @DisplayName("Monotonicity: wider thresholds return >= results than narrower") + void testMonotonicity() { + // Narrow thresholds + Vector userNarrow = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(0.1) + .build(); + Vector imageNarrow = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(0.1) + .build(); + + MultiVectorQuery narrowQuery = + MultiVectorQuery.builder() + .vectors(userNarrow, imageNarrow) + .returnFields("title") + .numResults(10) + .build(); + + // Wide thresholds + Vector userWide = + Vector.builder() + .vector(new float[] {0.1f, 0.2f, 0.5f}) + .fieldName("user_embedding") + .maxDistance(1.5) + .build(); + Vector imageWide = + Vector.builder() + .vector(new float[] {1.2f, 0.3f, -0.4f, 0.7f, 0.2f}) + .fieldName("image_embedding") + .maxDistance(1.5) + .build(); + + MultiVectorQuery wideQuery = + MultiVectorQuery.builder() + .vectors(userWide, imageWide) + .returnFields("title") + .numResults(10) + .build(); + + List> narrowResults = searchIndex.query(narrowQuery); + List> wideResults = searchIndex.query(wideQuery); + + assertThat(wideResults.size()) + .as("Wider thresholds should return >= results than narrower") + .isGreaterThanOrEqualTo(narrowResults.size()); + } +} diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java new file mode 100644 index 0000000..ccf6796 --- /dev/null +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryMaxDistanceTest.java @@ -0,0 +1,162 @@ +package com.redis.vl.query; + +import static org.assertj.core.api.Assertions.*; + +import org.junit.jupiter.api.DisplayName; +import org.junit.jupiter.api.Test; + +/** + * Unit tests for configurable max_distance in MultiVectorQuery and Vector. + * + *

Ported from Python: commit 61f0f41 - Expose max_distance as an optional setting in multi + * vector queries + * + *

Python reference: tests/unit/test_aggregation_types.py + */ +@DisplayName("MultiVectorQuery max_distance Tests") +class MultiVectorQueryMaxDistanceTest { + + private static final float[] SAMPLE_VECTOR = {0.1f, 0.2f, 0.3f}; + private static final float[] SAMPLE_VECTOR_2 = {0.4f, 0.5f}; + + @Test + @DisplayName("Vector: Should default max_distance to 2.0") + void testVectorDefaultMaxDistance() { + Vector vector = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").build(); + + assertThat(vector.getMaxDistance()).isEqualTo(2.0); + } + + @Test + @DisplayName("Vector: Should accept custom max_distance") + void testVectorCustomMaxDistance() { + Vector vector = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(0.5).build(); + + assertThat(vector.getMaxDistance()).isEqualTo(0.5); + } + + @Test + @DisplayName("Vector: Should reject max_distance below 0.0") + void testVectorMaxDistanceBelowZero() { + assertThatThrownBy( + () -> + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(-0.1).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("max_distance must be a value between 0.0 and 2.0"); + } + + @Test + @DisplayName("Vector: Should reject max_distance above 2.0") + void testVectorMaxDistanceAboveTwo() { + assertThatThrownBy( + () -> + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(2.1).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("max_distance must be a value between 0.0 and 2.0"); + } + + @Test + @DisplayName("Vector: Should accept max_distance at boundary 0.0") + void testVectorMaxDistanceAtZero() { + Vector vector = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(0.0).build(); + + assertThat(vector.getMaxDistance()).isEqualTo(0.0); + } + + @Test + @DisplayName("Vector: Should accept max_distance at boundary 2.0") + void testVectorMaxDistanceAtTwo() { + Vector vector = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field").maxDistance(2.0).build(); + + assertThat(vector.getMaxDistance()).isEqualTo(2.0); + } + + @Test + @DisplayName("MultiVectorQuery: Should use per-vector max_distance in query string") + void testQueryStringUsesPerVectorMaxDistance() { + Vector vector1 = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").maxDistance(0.5).build(); + + Vector vector2 = + Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").maxDistance(1.0).build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String queryString = query.toQueryString(); + + // Each vector should use its own max_distance + assertThat(queryString).contains("@field_1:[VECTOR_RANGE 0.5 $vector_0]"); + assertThat(queryString).contains("@field_2:[VECTOR_RANGE 1.0 $vector_1]"); + } + + @Test + @DisplayName("MultiVectorQuery: Should default to 2.0 when max_distance not set") + void testQueryStringDefaultMaxDistance() { + Vector vector1 = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vector(vector1).build(); + + String queryString = query.toQueryString(); + + assertThat(queryString).contains("VECTOR_RANGE 2.0"); + } + + @Test + @DisplayName("MultiVectorQuery: Should use AND join instead of pipe") + void testQueryStringUsesAndJoin() { + Vector vector1 = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").maxDistance(0.5).build(); + + Vector vector2 = + Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").maxDistance(1.0).build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String queryString = query.toQueryString(); + + // Python changed from " | " to " AND " in commit 61f0f41 + assertThat(queryString).doesNotContain(" | "); + assertThat(queryString).contains(" AND "); + } + + @Test + @DisplayName("MultiVectorQuery: Query string with mixed default and custom max_distance") + void testQueryStringMixedMaxDistance() { + Vector vector1 = Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").build(); // default + + Vector vector2 = + Vector.builder() + .vector(SAMPLE_VECTOR_2) + .fieldName("field_2") + .maxDistance(0.3) + .build(); // custom + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String queryString = query.toQueryString(); + + assertThat(queryString).contains("@field_1:[VECTOR_RANGE 2.0 $vector_0]"); + assertThat(queryString).contains("@field_2:[VECTOR_RANGE 0.3 $vector_1]"); + } + + @Test + @DisplayName("MultiVectorQuery: Should preserve full precision in query string") + void testQueryStringPreservesFullPrecision() { + Vector vector1 = + Vector.builder().vector(SAMPLE_VECTOR).fieldName("field_1").maxDistance(0.01).build(); + + Vector vector2 = + Vector.builder().vector(SAMPLE_VECTOR_2).fieldName("field_2").maxDistance(0.05).build(); + + MultiVectorQuery query = MultiVectorQuery.builder().vectors(vector1, vector2).build(); + + String queryString = query.toQueryString(); + + // These would have been truncated to 0.0 and 0.1 with %.1f format + assertThat(queryString).contains("VECTOR_RANGE 0.01"); + assertThat(queryString).contains("VECTOR_RANGE 0.05"); + } +} diff --git a/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java b/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java index 6e5ffbb..add1004 100644 --- a/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java +++ b/core/src/test/java/com/redis/vl/query/MultiVectorQueryTest.java @@ -257,7 +257,7 @@ void testMultiVectorQueryString() { .contains("{$YIELD_DISTANCE_AS: distance_0}") .contains(String.format("@%s:[VECTOR_RANGE 2.0 $vector_1]", field2)) .contains("{$YIELD_DISTANCE_AS: distance_1}") - .contains(" | "); + .contains(" AND "); } @Test