diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java index ba537ed06..6a0ba0f52 100644 --- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java +++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/util/AccuracyMetrics.java @@ -17,11 +17,9 @@ package io.github.jbellis.jvector.example.util; import io.github.jbellis.jvector.graph.SearchResult; - -import java.util.Arrays; +import java.util.HashSet; import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.IntStream; +import java.util.Set; /** * Computes accuracy metrics, such as recall and mean average precision. @@ -41,43 +39,44 @@ public static double recallFromSearchResults(List> gt, L if (gt.size() != retrieved.size()) { throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements"); } - Long correctCount = IntStream.range(0, gt.size()) - .mapToObj(i -> topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved)) - .reduce(0L, Long::sum); + + long correctCount = 0; + for (int i = 0; i < gt.size(); i++) { + correctCount += topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved); + } + return (double) correctCount / (kGT * gt.size()); } - private static long topKCorrect(List gt, List retrieved, int kGT, int kRetrieved) { + private static long topKCorrect(List gt, SearchResult retrieved, int kGT, int kRetrieved) { + // Exception validation + var nodes = retrieved.getNodes(); if (kGT > kRetrieved) { throw new IllegalArgumentException("kGT: " + kGT + " > kRetrieved: " + kRetrieved); } if (kGT > gt.size()) { throw new IllegalArgumentException("kGT: " + kGT + " > Gt size: " + gt.size()); } - if (kRetrieved > retrieved.size()) { - throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + retrieved.size()); + if (kRetrieved > nodes.length) { + throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + nodes.length); } - var gtView = crop(gt, kGT); - var retrievedView = crop(retrieved, kRetrieved); - - if (gtView.size() > retrieved.size()) { - return gtView.stream().filter(retrievedView::contains).count(); - } else { - return retrievedView.stream().filter(gtView::contains).count(); + // Build HashSet with explicit capacity to avoid rehashing. + // Load factor is 0.75, so sized to kGT / 0.75. + Set gtSet = new HashSet<>((int) (kGT / 0.75f) + 1); + for (int i = 0; i < kGT; i++) { + gtSet.add(gt.get(i)); } - } - private static long topKCorrect(List gt, SearchResult retrieved, int kGT, int kRetrieved) { - var temp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node) - .boxed() - .collect(Collectors.toList()); - return topKCorrect(gt, temp, kGT, kRetrieved); - } + // Manual primitive loop for speed (no Stream setup). + int hits = 0; + for (int i = 0; i < kRetrieved; i++) { + if (gtSet.contains(nodes[i].node)) { + hits++; + } + } - private static List crop(List list, int k) { - int count = Math.min(list.size(), k); - return list.subList(0, count); + return hits; } /** @@ -89,33 +88,34 @@ private static List crop(List list, int k) { * @return the average precision */ public static double averagePrecisionAtK(List gt, SearchResult retrieved, int k) { - var retrievedTemp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node) - .boxed() - .collect(Collectors.toList()); - + var nodes = retrieved.getNodes(); if (k > gt.size()) { throw new IllegalArgumentException("k: " + k + " > Gt size: " + gt.size()); } - if (k > retrievedTemp.size()) { - throw new IllegalArgumentException("k: " + k + " > retrieved size: " + retrievedTemp.size()); + if (k > nodes.length) { + throw new IllegalArgumentException("k: " + k + " > retrieved size: " + nodes.length); } - var gtView = crop(gt, k); - var retrievedView = crop(retrievedTemp, k); + // Sized hashset used for performance. + Set gtSet = new HashSet<>((int) (k / 0.75f) + 1); + for (int i = 0; i < k; i++) { + gtSet.add(gt.get(i)); + } - double score = 0.; - int num_hits = 0; - int i = 0; + // Handles potential duplicates in O(1). + Set seen = new HashSet<>((int) (k / 0.75f) + 1); - for (var p : retrievedView) { - if (gtView.contains(p) && !retrievedView.subList(0, i).contains(p)) { - num_hits += 1; - score += num_hits / (i + 1.0); + double score = 0.; + int hits = 0; + for (int i = 0; i < k; i++) { + int p = nodes[i].node; + if (gtSet.contains(p) && seen.add(p)) { + hits++; + score += (double) hits / (i + 1); } - i++; } - return score / gtView.size(); + return score / k; } /** @@ -130,10 +130,12 @@ public static double meanAveragePrecisionAtK(List> gt, L if (gt.size() != retrieved.size()) { throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements"); } - Double apk = IntStream.range(0, gt.size()) - .mapToObj(i -> averagePrecisionAtK(gt.get(i), retrieved.get(i), k)) - .reduce(0., Double::sum); - return apk / gt.size(); - } + double totalAp = 0; + for (int i = 0; i < gt.size(); i++) { + totalAp += averagePrecisionAtK(gt.get(i), retrieved.get(i), k); + } + + return totalAp / gt.size(); + } }