diff --git a/README.md b/README.md index c4ed90814..bc39bd5a3 100644 --- a/README.md +++ b/README.md @@ -11,7 +11,7 @@ There are two broad categories of ANN index: Graph-based indexes tend to be simpler to implement and faster, but more importantly they can be constructed and updated incrementally. This makes them a much better fit for a general-purpose index than partitioning approaches that only work on static datasets that are completely specified up front. That is why all the major commercial vector indexes use graph approaches. JVector is a graph index that merges the DiskANN and HNSW family trees. -JVector borrows the hierarchical structure from HNSW, and uses Vamana (the algorithm behind DiskANN) within each layer. +JVector borrows the hierarchical structure from HNSW, and uses Vamana (the algorithm behind DiskANN) within each layer. ## JVector Architecture diff --git a/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationMutableVectorBenchmark.java b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationMutableVectorBenchmark.java new file mode 100644 index 000000000..f3def2cab --- /dev/null +++ b/benchmarks-jmh/src/main/java/io/github/jbellis/jvector/bench/PQDistanceCalculationMutableVectorBenchmark.java @@ -0,0 +1,144 @@ +/* + * Copyright DataStax, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.github.jbellis.jvector.bench; + +import io.github.jbellis.jvector.graph.ListRandomAccessVectorValues; +import io.github.jbellis.jvector.graph.RandomAccessVectorValues; +import io.github.jbellis.jvector.graph.similarity.BuildScoreProvider; +import io.github.jbellis.jvector.graph.similarity.ScoreFunction; +import io.github.jbellis.jvector.graph.similarity.SearchScoreProvider; +import io.github.jbellis.jvector.quantization.MutablePQVectors; +import io.github.jbellis.jvector.quantization.PQVectors; +import io.github.jbellis.jvector.quantization.ProductQuantization; +import io.github.jbellis.jvector.vector.VectorSimilarityFunction; +import io.github.jbellis.jvector.vector.VectorizationProvider; +import io.github.jbellis.jvector.vector.types.VectorFloat; +import io.github.jbellis.jvector.vector.types.VectorTypeSupport; +import org.openjdk.jmh.annotations.*; +import org.openjdk.jmh.infra.Blackhole; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +/** + * Benchmark that compares the distance calculation of mutable Product Quantized vectors vs full precision vectors. + */ +@BenchmarkMode(Mode.AverageTime) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Thread) +@Fork(value = 1, jvmArgsAppend = {"--add-modules=jdk.incubator.vector", "--enable-preview", "-Djvector.experimental.enable_native_vectorization=false"}) +@Warmup(iterations = 2) +@Measurement(iterations = 3) +@Threads(1) +public class PQDistanceCalculationMutableVectorBenchmark { + private static final Logger log = LoggerFactory.getLogger(PQDistanceCalculationMutableVectorBenchmark.class); + private static final VectorTypeSupport VECTOR_TYPE_SUPPORT = VectorizationProvider.getInstance().getVectorTypeSupport(); + + private List> vectors; + private PQVectors pqVectors; + private List> queryVectors; + private ProductQuantization pq; + private BuildScoreProvider buildScoreProvider; + + @Param({"1536"}) + private int dimension; + + @Param({"10000"}) + private int vectorCount; + + @Param({"100"}) + private int queryCount; + + @Param({ "16","32", "64","96", "192"}) + private int M; // Number of subspaces for PQ + + @Param + private VectorSimilarityFunction vsf; + + @Setup + public void setup() throws IOException { + log.info("Creating dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount); + + // Create random vectors + vectors = new ArrayList<>(vectorCount); + for (int i = 0; i < vectorCount; i++) { + vectors.add(createRandomVector(dimension)); + } + + // Create query vectors + queryVectors = new ArrayList<>(queryCount); + for (int i = 0; i < queryCount; i++) { + queryVectors.add(createRandomVector(dimension)); + } + + RandomAccessVectorValues ravv = new ListRandomAccessVectorValues(vectors, dimension); + // Create Mutable PQ vectors + pq = ProductQuantization.compute(ravv, M, 256, true); + pqVectors = new MutablePQVectors(pq); + // build the index vector-at-a-time (on disk) + for (int ordinal = 0; ordinal < vectors.size(); ordinal++) + { + VectorFloat v = vectors.get(ordinal); + // compress the new vector and add it to the PQVectors + ((MutablePQVectors)pqVectors).encodeAndSet(ordinal, v); + } + buildScoreProvider = BuildScoreProvider.pqBuildScoreProvider(vsf, pqVectors); + log.info("Created dataset with dimension: {}, vector count: {}, query count: {}", dimension, vectorCount, queryCount); + } + + @Benchmark + public void scoreCalculation(Blackhole blackhole) { + float totalSimilarity = 0; + + for (VectorFloat query : queryVectors) { + + ScoreFunction.ApproximateScoreFunction asf = pqVectors.scoreFunctionFor(query, vsf); + for (int i = 0; i < vectorCount; i++) { + float similarity = asf.similarityTo(i); + totalSimilarity += similarity; + } + } + + blackhole.consume(totalSimilarity); + } + + @Benchmark + public void diversityCalculation(Blackhole blackhole) { + float totalSimilarity = 0; + + for (int q = 0; q < queryCount; q++) { + for (int i = 0; i < vectorCount; i++) { + final ScoreFunction sf = buildScoreProvider.diversityProviderFor(i).scoreFunction(); + float similarity = sf.similarityTo(q); + totalSimilarity += similarity; + } + } + + blackhole.consume(totalSimilarity); + } + + private VectorFloat createRandomVector(int dimension) { + VectorFloat vector = VECTOR_TYPE_SUPPORT.createFloatVector(dimension); + for (int i = 0; i < dimension; i++) { + vector.set(i, (float) Math.random()); + } + return vector; + } +} diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java index 538632da0..9bd0a0295 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/quantization/PQVectors.java @@ -228,33 +228,16 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, var encodedChunk = getChunk(node2); var encodedOffset = getOffsetInChunk(node2); // compute the dot product of the query and the codebook centroids corresponding to the encoded points - float dp = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - int centroidOffset = pq.subvectorSizesAndOffsets[m][1]; - dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); - } + float dp = VectorUtil.pqScoreDotProduct(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); // scale to [0, 1] return (1 + dp) / 2; }; case COSINE: - float norm1 = VectorUtil.dotProduct(centeredQuery, centeredQuery); return (node2) -> { var encodedChunk = getChunk(node2); var encodedOffset = getOffsetInChunk(node2); - // compute the dot product of the query and the codebook centroids corresponding to the encoded points - float sum = 0; - float norm2 = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - int centroidOffset = pq.subvectorSizesAndOffsets[m][1]; - var codebookOffset = centroidIndex * centroidLength; - sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength); - norm2 += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], codebookOffset, centroidLength); - } - float cosine = sum / (float) Math.sqrt(norm1 * norm2); + // compute the cosine of the query and the codebook centroids corresponding to the encoded points + float cosine = VectorUtil.pqScoreCosine(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); // scale to [0, 1] return (1 + cosine) / 2; }; @@ -263,13 +246,7 @@ public ScoreFunction.ApproximateScoreFunction scoreFunctionFor(VectorFloat q, var encodedChunk = getChunk(node2); var encodedOffset = getOffsetInChunk(node2); // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points - float sum = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - int centroidOffset = pq.subvectorSizesAndOffsets[m][1]; - sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); - } + float sum = VectorUtil.pqScoreEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); // scale to [0, 1] return 1 / (1 + sum); }; @@ -290,40 +267,16 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve var node2Chunk = getChunk(node2); var node2Offset = getOffsetInChunk(node2); // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points - float dp = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); - int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - dp += VectorUtil.dotProduct(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength); - } + float dp = VectorUtil.pqScoreDotProduct(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); // scale to [0, 1] return (1 + dp) / 2; }; case COSINE: - float norm1 = 0.0f; - for (int m1 = 0; m1 < subspaceCount; m1++) { - int centroidIndex = Byte.toUnsignedInt(node1Chunk.get(m1 + node1Offset)); - int centroidLength = pq.subvectorSizesAndOffsets[m1][0]; - var codebookOffset = centroidIndex * centroidLength; - norm1 += VectorUtil.dotProduct(pq.codebooks[m1], codebookOffset, pq.codebooks[m1], codebookOffset, centroidLength); - } - final float norm1final = norm1; return (node2) -> { var node2Chunk = getChunk(node2); var node2Offset = getOffsetInChunk(node2); // compute the dot product of the query and the codebook centroids corresponding to the encoded points - float sum = 0; - float norm2 = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); - int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - int codebookOffset = centroidIndex2 * centroidLength; - sum += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], centroidIndex1 * centroidLength, centroidLength); - norm2 += VectorUtil.dotProduct(pq.codebooks[m], codebookOffset, pq.codebooks[m], codebookOffset, centroidLength); - } - float cosine = sum / (float) Math.sqrt(norm1final * norm2); + float cosine = VectorUtil.pqScoreCosine(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); // scale to [0, 1] return (1 + cosine) / 2; }; @@ -332,13 +285,7 @@ public ScoreFunction.ApproximateScoreFunction diversityFunctionFor(int node1, Ve var node2Chunk = getChunk(node2); var node2Offset = getOffsetInChunk(node2); // compute the euclidean distance between the query and the codebook centroids corresponding to the encoded points - float sum = 0; - for (int m = 0; m < subspaceCount; m++) { - int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); - int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); - int centroidLength = pq.subvectorSizesAndOffsets[m][0]; - sum += VectorUtil.squareL2Distance(pq.codebooks[m], centroidIndex1 * centroidLength, pq.codebooks[m], centroidIndex2 * centroidLength, centroidLength); - } + float sum = VectorUtil.pqScoreEuclidean(pq.codebooks, pq.subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); // scale to [0, 1] return 1 / (1 + sum); }; diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java index 5843dc5f6..44eae8130 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/DefaultVectorUtilSupport.java @@ -547,4 +547,90 @@ public float nvqUniformLoss(VectorFloat vector, float minValue, float maxValu return squaredSum; } + @Override + public float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float dp = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + dp += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength); + } + return dp; +} + + + @Override + public float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + sum += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength); + aMagnitude += dotProduct(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex1 * centroidLength, centroidLength); + bMagnitude += dotProduct(codebooks[m], centroidIndex2 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength); + } + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + @Override + public float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + + sum += squareDistance(codebooks[m], centroidIndex1 * centroidLength, codebooks[m], centroidIndex2 * centroidLength, centroidLength); + } + return sum; + + } + + @Override + public float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float dp = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int centroidOffset = subvectorSizesAndOffsets[m][1]; + dp += dotProduct(codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); + } + return dp; + } + + @Override + public float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0; + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int centroidOffset = subvectorSizesAndOffsets[m][1]; + var codebookOffset = centroidIndex * centroidLength; + sum += dotProduct(codebooks[m], codebookOffset, centeredQuery, centroidOffset, centroidLength); + aMagnitude += dotProduct(codebooks[m], codebookOffset, codebooks[m], codebookOffset, centroidLength); + bMagnitude += dotProduct(centeredQuery, centroidOffset, centeredQuery, centroidOffset, centroidLength); + } + float cosine = sum / (float) Math.sqrt(aMagnitude * bMagnitude); + return cosine; + } + + @Override + public float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int centroidOffset = subvectorSizesAndOffsets[m][1]; + sum += squareDistance(codebooks[m], centroidIndex * centroidLength, centeredQuery, centroidOffset, centroidLength); + } + return sum; + } + } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java index 83cb5885b..48189dcd0 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtil.java @@ -235,4 +235,29 @@ public static float nvqLoss(VectorFloat vector, float growthRate, float midpo public static float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits) { return impl.nvqUniformLoss(vector, minValue, maxValue, nBits); } + + public static float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + return impl.pqScoreDotProduct(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + public static float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + return impl.pqScoreCosine(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + public static float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + return impl.pqScoreEuclidean(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + public static float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + return impl.pqScoreDotProduct(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + + public static float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + return impl.pqScoreCosine(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + + public static float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + return impl.pqScoreEuclidean(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + } diff --git a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java index d8223ab12..6de2ed445 100644 --- a/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java +++ b/jvector-base/src/main/java/io/github/jbellis/jvector/vector/VectorUtilSupport.java @@ -235,4 +235,78 @@ default float pqDecodedCosineSimilarity(ByteSequence encoded, int encodedOffs */ float nvqUniformLoss(VectorFloat vector, float minValue, float maxValue, int nBits); + /** + * Calculates the dotproduct for an array of codebooks, uses diversityFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param node1Chunk centroid vector for node1's subvectors + * @param node1Offset offset into ByteSequence of node1 + * @param node2Chunk centroid vector for node2's subvectors + * @param node2Offset offset into ByteSequence of node2 + * @param subspaceCount the number of PQ subspaces + * @return the dot product + */ + float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount); + + /** + * Calculates cosine for an array of codebooks, uses diversityFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param node1Chunk centroid vector for node1's subvectors + * @param node1Offset offset into ByteSequence of node1 + * @param node2Chunk centroid vector for node2's subvectors + * @param node2Offset offset into ByteSequence of node2 + * @param subspaceCount the number of PQ subspaces + * @return the cosine value + */ + float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount); + + /** + * Calculates the Euclidean distance for an array of codebooks, uses diversityFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param node1Chunk centroid vector for node1's subvectors + * @param node1Offset offset into ByteSequence of node1 + * @param node2Chunk centroid vector for node2's subvectors + * @param node2Offset offset into ByteSequence of node2 + * @param subspaceCount the number of PQ subspaces + * @return the Euclidean distance + */ + float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount); + + /** + * Overloaded function to calculate the dotproduct for an array of codebooks, uses scoreFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param encodedChunk centroid vector for encoded point + * @param encodedOffset offset into ByteSequence of encoded vector + * @param centeredQuery query + * @param subspaceCount the number of PQ subspaces + * @return the dotproduct + */ + float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount); + + /** + * Overloaded function to calculate cosine for an array of codebooks, uses scoreFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param encodedChunk centroid vector for encoded point + * @param encodedOffset offset into ByteSequence of encoded vector + * @param centeredQuery query + * @param subspaceCount the number of PQ subspaces + * @return the cosine value + */ + float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery,int subspaceCount); + + /** + * Overloaded function to calculate the Euclidean distance for an array of codebooks, uses scoreFunction. + * @param codebooks array of codebooks + * @param subvectorSizesAndOffsets contains dimensions and size of codebooks + * @param encodedChunk centroid vector for encoded point + * @param encodedOffset offset into ByteSequence of encoded vector + * @param centeredQuery query + * @param subspaceCount the number of PQ subspaces + * @return the Euclidean distance + */ + float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount); } diff --git a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java index dc18c4bf2..2a65859ac 100644 --- a/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java +++ b/jvector-twenty/src/main/java/io/github/jbellis/jvector/vector/PanamaVectorUtilSupport.java @@ -1543,4 +1543,1102 @@ public void calculatePartialSums(VectorFloat codebook, int codebookIndex, int public float pqDecodedCosineSimilarity(ByteSequence encoded, int clusterCount, VectorFloat partialSums, VectorFloat aMagnitude, float bMagnitude) { return pqDecodedCosineSimilarity(encoded, 0, encoded.length(), clusterCount, partialSums, aMagnitude, bMagnitude); } + + /*-------------------- Score functions--------------------*/ + //adding SPECIES_64 & SPECIES_128 for completeness, will it get there? + /** + * Computes the Euclidean distance between two PQ-encoded vectors + */ + float pqScoreEuclideanPreferred(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - codebooks[m].get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + /** + * Computes the Euclidean distance between two PQ-encoded vectors, automatically selecting the best vector species. + * @param codebooks Array of codebook vectors for each subspace. + * @param subvectorSizesAndOffsets Array containing sizes and offsets for each subvector. + * @param node1Chunk Byte sequence representing the first PQ-encoded vector. + * @param node1Offset Offset in the first vector. + * @param node2Chunk Byte sequence representing the second PQ-encoded vector. + * @param node2Offset Offset in the second vector. + * @param subspaceCount Number of subspaces. + * @return Euclidean distance between the two PQ vectors. + */ + @Override + public float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreEuclideanPreferred(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreEuclidean256( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreEuclidean128( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + return pqScoreEuclidean64( codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + /** + * Computes the Euclidean distance between a PQ-encoded vector and a centered query vector + */ + float pqScoreEuclideanPreferred(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength) ; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - centeredQuery.get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength) ; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - centeredQuery.get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength) ; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - centeredQuery.get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreEuclidean64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength) ; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2); + var diff = a.sub(b); + sum = diff.fma(diff,sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2 + i); + var diff = a.sub(b); + sum = diff.fma(diff, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + var diff = codebooks[m].get(length1 + i) - centeredQuery.get(length2 + i); + res += MathUtil.square(diff); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + /** + * Overloaded function which computes the Euclidean distance between PQ-encoded vector and centered query vector + * @param codebooks Array of codebook vectors for each subspace. + * @param subvectorSizesAndOffsets Array containing sizes and offsets for each subvector. + * @param encodedChunk Byte sequence representing the PQ-encoded vector. + * @param encodedOffset Offset in the encoded vector. + * @param centeredQuery Centered query vector. + * @param subspaceCount Number of subspaces. + * @return Euclidean distance between the PQ vector and the query. + */ + @Override + public float pqScoreEuclidean(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreEuclideanPreferred(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreEuclidean256(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreEuclidean128(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + return pqScoreEuclidean64(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + + /** + * Computes the dot product score between two PQ-encoded vectors + */ + float pqScoreDotProductPreferred(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + /** + * Computes the dot product score between two PQ-encoded vectors + * @param codebooks Array of codebook vectors for each subspace. + * @param subvectorSizesAndOffsets Array containing sizes and offsets for each subvector. + * @param node1Chunk Byte sequence representing the first PQ-encoded vector. + * @param node1Offset Offset in the first vector. + * @param node2Chunk Byte sequence representing the second PQ-encoded vector. + * @param node2Offset Offset in the second vector. + * @param subspaceCount Number of subspaces. + * @return Dot product of the two PQ vectors. + */ + @Override + public float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreDotProductPreferred(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreDotProduct256(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreDotProduct128(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + return pqScoreDotProduct64(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + /** + * Computes the dot product score between PQ-encoded vector and centered query vector + */ + float pqScoreDotProductPreferred(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + float pqScoreDotProduct64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float res = 0; + FloatVector sum = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2); + sum = a.fma(b, sum); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2 + i); + sum = a.fma(b, sum); + } + // Process the tail + for (; i < centroidLength ; ++i) { + res += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + } + } + } + res += sum.reduceLanes(VectorOperators.ADD); + return res; + } + + /** + * Overloaded function which computes the dot product between PQ-encoded vector and centered query vector + * @param codebooks Array of codebook vectors for each subspace. + * @param subvectorSizesAndOffsets Array containing sizes and offsets for each subvector. + * @param encodedChunk Byte sequence representing the PQ-encoded vector. + * @param encodedOffset Offset in the encoded vector. + * @param centeredQuery Centered query vector. + * @param subspaceCount Number of subspaces. + * @return Dot product between the PQ vector and the query. + */ + @Override + public float pqScoreDotProduct(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreDotProductPreferred(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreDotProduct256(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreDotProduct128(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + return pqScoreDotProduct64(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + + /** + * Computes the cosine similarity between two PQ-encoded vectors + */ + float pqScoreCosinePreferred(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += codebooks[m].get(length2 + i) * codebooks[m].get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_256); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += codebooks[m].get(length2 + i) * codebooks[m].get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_128); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_128); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += codebooks[m].get(length2 + i) * codebooks[m].get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_64); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_64); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex1 = Byte.toUnsignedInt(node1Chunk.get(m + node1Offset)); + int centroidIndex2 = Byte.toUnsignedInt(node2Chunk.get(m + node2Offset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex1 * centroidLength; + int length2 = centroidIndex2 * centroidLength; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * codebooks[m].get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += codebooks[m].get(length2 + i) * codebooks[m].get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + /** + * Computes the cosine similarity between two PQ-encoded vectors + * @param codebooks Array of codebook vectors for each subspace. + * @param subvectorSizesAndOffsets Array containing sizes and offsets for each subvector. + * @param node1Chunk Byte sequence representing the first PQ-encoded vector. + * @param node1Offset Offset in the first vector. + * @param node2Chunk Byte sequence representing the second PQ-encoded vector. + * @param node2Offset Offset in the second vector. + * @param subspaceCount Number of subspaces. + * @return cosine similarity between two PQ vectors. + */ + @Override + public float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence node1Chunk, int node1Offset, ByteSequence node2Chunk, int node2Offset, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreCosinePreferred(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreCosine256(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreCosine128(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + return pqScoreCosine64(codebooks, subvectorSizesAndOffsets, node1Chunk, node1Offset, node2Chunk, node2Offset, subspaceCount); + } + + /** + * Computes the cosine similarity between PQ-encoded vector and centered query vector + */ + float pqScoreCosinePreferred(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_PREFERRED); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_PREFERRED.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_PREFERRED.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_PREFERRED, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_PREFERRED, centeredQuery, length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += centeredQuery.get(length2 + i) * centeredQuery.get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine256(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_256); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_256); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_256.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_256.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_256, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_256, centeredQuery, length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += centeredQuery.get(length2 + i) * centeredQuery.get(length2 + i); + } + } + } + + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine128(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_128); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_128); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_128); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_128.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_128.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_128, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_128, centeredQuery, length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += centeredQuery.get(length2 + i) * centeredQuery.get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + float pqScoreCosine64(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + float sum = 0; + float aMagnitude = 0; + float bMagnitude = 0 ; + FloatVector vSum = FloatVector.zero(FloatVector.SPECIES_64); + FloatVector vaMagnitude = FloatVector.zero(FloatVector.SPECIES_64); + FloatVector vbMagnitude = FloatVector.zero(FloatVector.SPECIES_64); + + for (int m = 0; m < subspaceCount; m++) { + int centroidIndex = Byte.toUnsignedInt(encodedChunk.get(m + encodedOffset)); + int centroidLength = subvectorSizesAndOffsets[m][0]; + final int vectorizedLength = FloatVector.SPECIES_64.loopBound(centroidLength); + int length1 = centroidIndex * centroidLength; + int length2 = subvectorSizesAndOffsets[m][1]; + + if (centroidLength == FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + else { + int i = 0; + for (; i < vectorizedLength; i += FloatVector.SPECIES_64.length()) { + FloatVector a = fromVectorFloat(FloatVector.SPECIES_64, codebooks[m], length1 + i); + FloatVector b = fromVectorFloat(FloatVector.SPECIES_64, centeredQuery, length2 + i); + vSum = a.fma(b, vSum); + vaMagnitude = a.fma(a, vaMagnitude); + vbMagnitude = b.fma(b, vbMagnitude); + } + // Process the tail + for (; i < centroidLength ; ++i) { + sum += codebooks[m].get(length1 + i) * centeredQuery.get(length2 + i); + aMagnitude += codebooks[m].get(length1 + i) * codebooks[m].get(length1 + i); + bMagnitude += centeredQuery.get(length2 + i) * centeredQuery.get(length2 + i); + } + } + } + sum += vSum.reduceLanes(VectorOperators.ADD); + aMagnitude += vaMagnitude.reduceLanes(VectorOperators.ADD); + bMagnitude += vbMagnitude.reduceLanes(VectorOperators.ADD); + return (float)(sum / Math.sqrt(aMagnitude * bMagnitude)); + } + + /** + * Overloaded function which computes the cosine similarity between PQ-encoded vector and centered query vector + * @param codebooks Array of codebook vectors for each subspace. + * @param subvectorSizesAndOffsets Array containing sizes and offsets for each subvector. + * @param encodedChunk Byte sequence representing the PQ-encoded vector. + * @param encodedOffset Offset in the encoded vector. + * @param centeredQuery Centered query vector. + * @param subspaceCount Number of subspaces. + * @return Cosine similarity between the PQ vector and the query. + */ + @Override + public float pqScoreCosine(VectorFloat[] codebooks, int[][] subvectorSizesAndOffsets, ByteSequence encodedChunk, int encodedOffset, VectorFloat centeredQuery, int subspaceCount) { + //Since centroid length can vary, picking the first entry in the array which is the largest one + if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_PREFERRED.length()) { + return pqScoreCosinePreferred(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if(subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_256.length()) { + return pqScoreCosine256(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + else if (subvectorSizesAndOffsets[0][0] >= FloatVector.SPECIES_128.length()) { + return pqScoreCosine128(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } + return pqScoreCosine64(codebooks, subvectorSizesAndOffsets, encodedChunk, encodedOffset, centeredQuery, subspaceCount); + } } \ No newline at end of file