Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,11 @@ public void update(String key, String field, Object value) {
+ "Delete and re-create the entry instead.");
}

@Override
public String toString() {
return String.format("LangCacheSemanticCache(name='%s', ttl=%s)", name, ttl);
}

/** Builder for LangCacheSemanticCache. */
public static class Builder {
private String name = "langcache";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,12 @@ public void resetStatistics() {
missCount.set(0);
}

@Override
public String toString() {
return String.format(
"SemanticCache(name='%s', distance_threshold=%s, ttl=%s)", name, distanceThreshold, ttl);
}

/** Builder for SemanticCache. */
public static class Builder {
private String name;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@
import static com.redis.vl.extensions.Constants.*;

import com.github.f4b6a3.ulid.UlidCreator;
import com.redis.vl.index.SearchIndex;
import com.redis.vl.query.CountQuery;
import com.redis.vl.query.Filter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

/**
* Base class for message history implementations.
Expand All @@ -16,8 +19,7 @@
*/
public abstract class BaseMessageHistory {

/** Valid role values for message filtering. */
private static final Set<String> VALID_ROLES = Set.of("system", "user", "llm", "tool");
private static final Logger logger = Logger.getLogger(BaseMessageHistory.class.getName());

protected final String name;
protected final String sessionTag;
Expand All @@ -34,6 +36,48 @@ protected BaseMessageHistory(String name, String sessionTag) {
this.sessionTag = (sessionTag != null) ? sessionTag : UlidCreator.getUlid().toString();
}

/**
* Get the underlying search index.
*
* @return The search index used by this message history
*/
protected abstract SearchIndex getSearchIndex();

/**
* Get the default session filter.
*
* @return The filter for the default session tag
*/
protected abstract Filter getDefaultSessionFilter();

/**
* Count the number of messages in the conversation history.
*
* <p>Matches Python count() from base_history.py
*
* @return The number of messages for the default session
*/
public long count() {
return count(null);
}

/**
* Count the number of messages in the conversation history for a specific session.
*
* <p>Matches Python count(session_tag=...) from base_history.py
*
* @param sessionTag The session tag to count messages for (null uses default session)
* @return The number of messages
*/
public long count(String sessionTag) {
Filter sessionFilter =
(sessionTag != null)
? Filter.tag(SESSION_FIELD_NAME, sessionTag)
: getDefaultSessionFilter();
CountQuery query = new CountQuery(sessionFilter);
return getSearchIndex().count(query);
}

/** Clears the chat message history. */
public abstract void clear();

Expand Down Expand Up @@ -143,10 +187,7 @@ protected List<String> validateRoles(Object role) {
// Handle single role string
if (role instanceof String) {
String roleStr = (String) role;
if (!VALID_ROLES.contains(roleStr)) {
throw new IllegalArgumentException(
String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES));
}
validateSingleRole(roleStr);
return List.of(roleStr);
}

Expand All @@ -166,10 +207,7 @@ protected List<String> validateRoles(Object role) {
"role list must contain only strings, found: " + r.getClass().getSimpleName());
}
String roleStr = (String) r;
if (!VALID_ROLES.contains(roleStr)) {
throw new IllegalArgumentException(
String.format("Invalid role '%s'. Valid roles are: %s", roleStr, VALID_ROLES));
}
validateSingleRole(roleStr);
validatedRoles.add(roleStr);
}

Expand All @@ -179,6 +217,28 @@ protected List<String> validateRoles(Object role) {
throw new IllegalArgumentException("role must be a String, List<String>, or null");
}

/**
* Validate a single role string using ChatRole enum.
*
* @param roleStr The role string to validate
* @throws IllegalArgumentException if the role is not valid
*/
private void validateSingleRole(String roleStr) {
if (ChatRole.isDeprecatedRole(roleStr)) {
logger.warning(
String.format(
"Role '%s' is a deprecated value. Update to valid roles: %s.",
roleStr, java.util.Arrays.toString(ChatRole.values())));
return;
}
if (!ChatRole.isValidRole(roleStr)) {
throw new IllegalArgumentException(
String.format(
"Invalid role '%s'. Valid roles are: %s",
roleStr, java.util.Arrays.toString(ChatRole.values())));
}
}

public String getName() {
return name;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.redis.vl.extensions.messagehistory;

import java.util.Set;

/**
* Enumeration of valid chat message roles.
*
* <p>Ported from Python: redisvl/extensions/message_history/schema.py (commit 23ecc77)
*
* <p>This enum serves as the single source of truth for valid roles in chat message history. The
* deprecated "llm" role is accepted for backward compatibility but is not a member of this enum.
*/
public enum ChatRole {
USER("user"),
ASSISTANT("assistant"),
SYSTEM("system"),
TOOL("tool");

/** Deprecated role values that are accepted for backward compatibility. */
private static final Set<String> DEPRECATED_ROLES = Set.of("llm");

private final String value;

ChatRole(String value) {
this.value = value;
}

/**
* Get the string value of this role.
*
* @return The role string value
*/
public String getValue() {
return value;
}

/**
* Coerce a string to a ChatRole enum value.
*
* @param role The role string
* @return The matching ChatRole, or null if not a standard role
*/
public static ChatRole fromString(String role) {
if (role == null || role.isEmpty()) {
return null;
}
for (ChatRole chatRole : values()) {
if (chatRole.value.equals(role)) {
return chatRole;
}
}
return null;
}

/**
* Check if a string is a valid role (including deprecated roles).
*
* @param role The role string to check
* @return true if the role is valid (either standard or deprecated)
*/
public static boolean isValidRole(String role) {
return fromString(role) != null || isDeprecatedRole(role);
}

/**
* Check if a string is a deprecated role value.
*
* @param role The role string to check
* @return true if the role is deprecated
*/
public static boolean isDeprecatedRole(String role) {
return DEPRECATED_ROLES.contains(role);
}

@Override
public String toString() {
return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ public MessageHistory(String name, String sessionTag, String prefix, UnifiedJedi
this.defaultSessionFilter = Filter.tag(SESSION_FIELD_NAME, this.sessionTag);
}

@Override
protected SearchIndex getSearchIndex() {
return index;
}

@Override
protected Filter getDefaultSessionFilter() {
return defaultSessionFilter;
}

@Override
public void clear() {
index.clear();
Expand Down Expand Up @@ -269,4 +279,9 @@ public void addMessages(List<Map<String, String>> messages) {
public SearchIndex getIndex() {
return index;
}

@Override
public String toString() {
return String.format("MessageHistory(name='%s', session_tag='%s')", name, sessionTag);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ public SearchIndex getIndex() {
return index;
}

@Override
protected SearchIndex getSearchIndex() {
return index;
}

@Override
protected Filter getDefaultSessionFilter() {
return defaultSessionFilter;
}

@Override
public void clear() {
index.clear();
Expand Down Expand Up @@ -490,6 +500,13 @@ private Filter combineWithRoleFilter(Filter sessionFilter, List<String> roles) {
}
}

@Override
public String toString() {
return String.format(
"SemanticMessageHistory(name='%s', session_tag='%s', distance_threshold=%s)",
name, sessionTag, distanceThreshold);
}

/**
* Format messages with metadata deserialization support.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,11 @@ public SemanticRouter build() {
}
}

@Override
public String toString() {
return String.format("SemanticRouter(name='%s', routes=%d)", name, routes.size());
}

/**
* Get the list of route names. Ported from Python: route_names property (line 187)
*
Expand Down
7 changes: 7 additions & 0 deletions core/src/main/java/com/redis/vl/index/SearchIndex.java
Original file line number Diff line number Diff line change
Expand Up @@ -1726,6 +1726,13 @@ private List<Map<String, Object>> processHybridResult(HybridResult result) {
return processed;
}

@Override
public String toString() {
return String.format(
"SearchIndex(name='%s', prefix='%s', storage_type='%s')",
getName(), getPrefix(), getStorageType());
}

/**
* Execute multiple search queries in batch
*
Expand Down
17 changes: 7 additions & 10 deletions core/src/main/java/com/redis/vl/query/MultiVectorQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,6 @@
@Getter
public final class MultiVectorQuery extends AggregationQuery {

/** Distance threshold for VECTOR_RANGE (hardcoded at 2.0 to include all eligible documents) */
private static final double DISTANCE_THRESHOLD = 2.0;

private final List<Vector> vectors;
private final Filter filterExpression;
private final List<String> returnFields;
Expand Down Expand Up @@ -111,8 +108,8 @@ public static Builder builder() {
/**
* Build the Redis query string for multi-vector search.
*
* <p>Format: {@code @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} |
* @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}}
* <p>Format: {@code @field1:[VECTOR_RANGE max_dist $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} AND
* @field2:[VECTOR_RANGE max_dist $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}}
*
* @return Query string
*/
Expand All @@ -123,8 +120,8 @@ public String toQueryString() {
/**
* Build the Redis query string for multi-vector search.
*
* <p>Format: {@code @field1:[VECTOR_RANGE 2.0 $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} |
* @field2:[VECTOR_RANGE 2.0 $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}}
* <p>Format: {@code @field1:[VECTOR_RANGE max_dist $vector_0]=>{$YIELD_DISTANCE_AS: distance_0} AND
* @field2:[VECTOR_RANGE max_dist $vector_1]=>{$YIELD_DISTANCE_AS: distance_1}}
*
* @return Query string
*/
Expand All @@ -136,12 +133,12 @@ public String buildQueryString() {
Vector v = vectors.get(i);
String rangeQuery =
String.format(
"@%s:[VECTOR_RANGE %.1f $vector_%d]=>{$YIELD_DISTANCE_AS: distance_%d}",
v.getFieldName(), DISTANCE_THRESHOLD, i, i);
"@%s:[VECTOR_RANGE %s $vector_%d]=>{$YIELD_DISTANCE_AS: distance_%d}",
v.getFieldName(), v.getMaxDistance(), i, i);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Scientific notation in query string for small maxDistance

Medium Severity

Using %s to format the double maxDistance into the Redis query string calls Double.toString(), which produces scientific notation (e.g., "5.0E-4") for values below 0.001. Since the validation allows any value in [0.0, 2.0], a user setting maxDistance(0.0005) would generate a VECTOR_RANGE 5.0E-4 query fragment that Redis likely cannot parse, causing a query failure. The previous code used %.1f which always produced plain decimal output. A format like BigDecimal.valueOf(...).toPlainString() would preserve precision without risking scientific notation.

Fix in Cursor Fix in Web

rangeQueries.add(rangeQuery);
}

String baseQuery = String.join(" | ", rangeQueries);
String baseQuery = String.join(" AND ", rangeQueries);

// Add filter expression if present
if (filterExpression != null) {
Expand Down
Loading
Loading