Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,9 @@
package io.github.jbellis.jvector.example.util;

import io.github.jbellis.jvector.graph.SearchResult;

import java.util.Arrays;
import java.util.HashSet;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.Set;

/**
* Computes accuracy metrics, such as recall and mean average precision.
Expand All @@ -41,43 +39,44 @@ public static double recallFromSearchResults(List<? extends List<Integer>> gt, L
if (gt.size() != retrieved.size()) {
throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements");
}
Long correctCount = IntStream.range(0, gt.size())
.mapToObj(i -> topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved))
.reduce(0L, Long::sum);

long correctCount = 0;
for (int i = 0; i < gt.size(); i++) {
correctCount += topKCorrect(gt.get(i), retrieved.get(i), kGT, kRetrieved);
}

return (double) correctCount / (kGT * gt.size());
}

private static long topKCorrect(List<Integer> gt, List<Integer> retrieved, int kGT, int kRetrieved) {
private static long topKCorrect(List<Integer> gt, SearchResult retrieved, int kGT, int kRetrieved) {
// Exception validation
var nodes = retrieved.getNodes();
if (kGT > kRetrieved) {
throw new IllegalArgumentException("kGT: " + kGT + " > kRetrieved: " + kRetrieved);
}
if (kGT > gt.size()) {
throw new IllegalArgumentException("kGT: " + kGT + " > Gt size: " + gt.size());
}
if (kRetrieved > retrieved.size()) {
throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + retrieved.size());
if (kRetrieved > nodes.length) {
throw new IllegalArgumentException("kRetrieved: " + kRetrieved + " > retrieved size: " + nodes.length);
}

var gtView = crop(gt, kGT);
var retrievedView = crop(retrieved, kRetrieved);

if (gtView.size() > retrieved.size()) {
return gtView.stream().filter(retrievedView::contains).count();
} else {
return retrievedView.stream().filter(gtView::contains).count();
// Build HashSet with explicit capacity to avoid rehashing.
// Load factor is 0.75, so sized to kGT / 0.75.
Set<Integer> gtSet = new HashSet<>((int) (kGT / 0.75f) + 1);
for (int i = 0; i < kGT; i++) {
gtSet.add(gt.get(i));
}
}

private static long topKCorrect(List<Integer> gt, SearchResult retrieved, int kGT, int kRetrieved) {
var temp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node)
.boxed()
.collect(Collectors.toList());
return topKCorrect(gt, temp, kGT, kRetrieved);
}
// Manual primitive loop for speed (no Stream setup).
int hits = 0;
for (int i = 0; i < kRetrieved; i++) {
if (gtSet.contains(nodes[i].node)) {
hits++;
Comment on lines +74 to +75
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we should handle dupes in retrieved like we do in case of averagePrecisionAtK. That said, the earlier version of this function didn't handle duplicates either.

}
}

private static List<Integer> crop(List<Integer> list, int k) {
int count = Math.min(list.size(), k);
return list.subList(0, count);
return hits;
}

/**
Expand All @@ -89,33 +88,34 @@ private static List<Integer> crop(List<Integer> list, int k) {
* @return the average precision
*/
public static double averagePrecisionAtK(List<Integer> gt, SearchResult retrieved, int k) {
var retrievedTemp = Arrays.stream(retrieved.getNodes()).mapToInt(nodeScore -> nodeScore.node)
.boxed()
.collect(Collectors.toList());

var nodes = retrieved.getNodes();
if (k > gt.size()) {
throw new IllegalArgumentException("k: " + k + " > Gt size: " + gt.size());
}
if (k > retrievedTemp.size()) {
throw new IllegalArgumentException("k: " + k + " > retrieved size: " + retrievedTemp.size());
if (k > nodes.length) {
throw new IllegalArgumentException("k: " + k + " > retrieved size: " + nodes.length);
}

var gtView = crop(gt, k);
var retrievedView = crop(retrievedTemp, k);
// Sized hashset used for performance.
Set<Integer> gtSet = new HashSet<>((int) (k / 0.75f) + 1);
for (int i = 0; i < k; i++) {
gtSet.add(gt.get(i));
}

double score = 0.;
int num_hits = 0;
int i = 0;
// Handles potential duplicates in O(1).
Set<Integer> seen = new HashSet<>((int) (k / 0.75f) + 1);

for (var p : retrievedView) {
if (gtView.contains(p) && !retrievedView.subList(0, i).contains(p)) {
num_hits += 1;
score += num_hits / (i + 1.0);
double score = 0.;
int hits = 0;
for (int i = 0; i < k; i++) {
int p = nodes[i].node;
if (gtSet.contains(p) && seen.add(p)) {
hits++;
score += (double) hits / (i + 1);
}
i++;
}

return score / gtView.size();
return score / k;
}

/**
Expand All @@ -130,10 +130,12 @@ public static double meanAveragePrecisionAtK(List<? extends List<Integer>> gt, L
if (gt.size() != retrieved.size()) {
throw new IllegalArgumentException("Insufficient ground truth for the number of retrieved elements");
}
Double apk = IntStream.range(0, gt.size())
.mapToObj(i -> averagePrecisionAtK(gt.get(i), retrieved.get(i), k))
.reduce(0., Double::sum);
return apk / gt.size();
}

double totalAp = 0;
for (int i = 0; i < gt.size(); i++) {
totalAp += averagePrecisionAtK(gt.get(i), retrieved.get(i), k);
}

return totalAp / gt.size();
}
}
Loading