diff --git a/jvector-examples/pom.xml b/jvector-examples/pom.xml
index ae8c77d6d..452bbc62d 100644
--- a/jvector-examples/pom.xml
+++ b/jvector-examples/pom.xml
@@ -147,6 +147,11 @@
7.3.0
test
+
+ io.nosqlbench
+ datatools-vectordata
+ 0.1.22
+
org.junit.jupiter
junit-jupiter-api
diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoader.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoader.java
index d280fbf91..dd42dd42b 100644
--- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoader.java
+++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoader.java
@@ -38,4 +38,6 @@ public interface DataSetLoader {
* @return a {@link DataSet}, if found
*/
Optional loadDataSet(String dataSetName);
+
+ String getName();
}
diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderHDF5.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderHDF5.java
index 072a9b764..01893df1d 100644
--- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderHDF5.java
+++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderHDF5.java
@@ -41,17 +41,63 @@
* This dataset loader will get and load hdf5 files from ann-benchmarks.
*/
public class DataSetLoaderHDF5 implements DataSetLoader {
+ private static final org.slf4j.Logger logger = org.slf4j.LoggerFactory.getLogger(DataSetLoaderHDF5.class);
public static final Path HDF5_DIR = Path.of("hdf5");
private static final VectorTypeSupport vectorTypeSupport = VectorizationProvider.getInstance().getVectorTypeSupport();
public static final String HDF5_EXTN = ".hdf5";
+ public static final String NAME = "HDF5";
+ public String getName() {
+ return NAME;
+ }
+
+ private static final java.util.Set KNOWN_DATASETS = java.util.Set.of(
+ "deep-image-96-angular",
+ "fashion-mnist-784-euclidean",
+ "gist-960-euclidean",
+ "glove-25-angular",
+ "glove-50-angular",
+ "glove-100-angular",
+ "glove-200-angular",
+ "kosarak-jaccard",
+ "mnist-784-euclidean",
+ "movielens10m-jaccard",
+ "nytimes-256-angular",
+ "sift-128-euclidean",
+ "lastfm-64-dot",
+ "coco-i2i-512-angular",
+ "coco-t2i-512-angular"
+ );
+
+
/**
* {@inheritDoc}
*/
public Optional loadDataSet(String datasetName) {
+
+ // HDF5 loader does not support profiles
+ if (datasetName.contains(":")) {
+ logger.trace("Dataset '{}' has a profile, which is not supported by the HDF5 loader.", datasetName);
+ return Optional.empty();
+ }
+
+ // If not local, only download if it's explicitly known to be on ann-benchmarks.com
+ if (!KNOWN_DATASETS.contains(datasetName)) {
+ logger.trace("Dataset '{}' not in known list, skipping HDF5 download.", datasetName);
+ return Optional.empty();
+ }
+
+ // If it exists locally, we're good
+ var dsFilePath = HDF5_DIR.resolve(datasetName + HDF5_EXTN);
+ if (Files.exists(dsFilePath)) {
+ logger.trace("Dataset '{}' already downloaded.", datasetName);
+ return Optional.of(readHdf5Data(dsFilePath));
+ }
+
return maybeDownloadHdf5(datasetName).map(this::readHdf5Data);
}
+
private DataSet readHdf5Data(Path path) {
// infer the similarity
@@ -114,16 +160,12 @@ else if (filename.toString().contains("-euclidean")) {
}
private Optional maybeDownloadHdf5(String datasetName) {
-
- var dsFilePath = HDF5_DIR.resolve(datasetName+HDF5_EXTN);
-
- if (Files.exists(dsFilePath)) {
- return Optional.of(dsFilePath);
- }
+ var dsFilePath = HDF5_DIR.resolve(datasetName + HDF5_EXTN);
// Download from https://ann-benchmarks.com/datasetName
var url = "https://ann-benchmarks.com/" + datasetName + HDF5_EXTN;
- System.out.println("Downloading: " + url);
+ logger.info("Downloading: {}", url);
+
HttpURLConnection connection;
while (true) {
@@ -139,7 +181,7 @@ private Optional maybeDownloadHdf5(String datasetName) {
}
if (responseCode == HttpURLConnection.HTTP_MOVED_PERM || responseCode == HttpURLConnection.HTTP_MOVED_TEMP) {
String newUrl = connection.getHeaderField("Location");
- System.out.println("Redirect detected to URL: " + newUrl);
+ logger.info("Redirect detected to URL: {}", newUrl);
url = newUrl;
} else {
break;
diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderMFD.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderMFD.java
index 7381f0c35..74cc65ea5 100644
--- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderMFD.java
+++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderMFD.java
@@ -51,10 +51,21 @@ public class DataSetLoaderMFD implements DataSetLoader {
private static final String bucketName = "astra-vector";
private static final List bucketNames = List.of(bucketName, infraBucketName);
+ public static final String NAME = "MFD";
+ public String getName() {
+ return NAME;
+ }
+
/**
* {@inheritDoc}
*/
public Optional loadDataSet(String fileName) {
+
+ if (fileName.contains(":")) {
+ logger.trace("Dataset {} with profile is not supported by MFD loader", fileName);
+ return Optional.empty();
+ }
+
return maybeDownloadFvecs(fileName).map(MultiFileDatasource::load);
}
diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderVectordata.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderVectordata.java
new file mode 100644
index 000000000..18cb19909
--- /dev/null
+++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSetLoaderVectordata.java
@@ -0,0 +1,292 @@
+/*
+ * Copyright DataStax, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package io.github.jbellis.jvector.example.benchmarks.datasets;
+
+import io.github.jbellis.jvector.graph.RandomAccessVectorValues;
+import io.github.jbellis.jvector.vector.VectorSimilarityFunction;
+import io.github.jbellis.jvector.vector.VectorizationProvider;
+import io.github.jbellis.jvector.vector.types.VectorFloat;
+import io.github.jbellis.jvector.vector.types.VectorTypeSupport;
+import io.nosqlbench.nbdatatools.api.concurrent.ProgressIndicatingFuture;
+import io.nosqlbench.vectordata.VectorTestData;
+import io.nosqlbench.vectordata.discovery.ProfileSelector;
+import io.nosqlbench.vectordata.discovery.TestDataGroup;
+import io.nosqlbench.vectordata.discovery.TestDataSources;
+import io.nosqlbench.vectordata.discovery.vector.VectorTestDataView;
+import io.nosqlbench.vectordata.downloader.Catalog;
+import io.nosqlbench.vectordata.downloader.DatasetEntry;
+import io.nosqlbench.vectordata.downloader.DatasetProfileSpec;
+import io.nosqlbench.vectordata.spec.datasets.types.DistanceFunction;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.IOException;
+import java.util.AbstractList;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.CompletableFuture;
+import java.util.function.LongFunction;
+import java.util.stream.Collectors;
+
+import static io.github.jbellis.jvector.vector.VectorSimilarityFunction.COSINE;
+
+/**
+ * A DataSetLoader that uses the io.nosqlbench.datatools-vectordata library to load datasets.
+ */
+public class DataSetLoaderVectordata implements DataSetLoader {
+ private static final Logger logger = LoggerFactory.getLogger(DataSetLoaderVectordata.class);
+ private static final VectorTypeSupport vts = VectorizationProvider.getInstance().getVectorTypeSupport();
+ private final boolean prebuffer;
+
+ public DataSetLoaderVectordata() {
+ this(false);
+ }
+
+ public DataSetLoaderVectordata(boolean prebuffer) {
+ this.prebuffer = prebuffer;
+ }
+
+ @Override
+ public String getName() {
+ return "VECTORDATA";
+ }
+
+ @Override
+ public Optional loadDataSet(String dataSetName) {
+ try {
+ logger.info("Attempting to load dataset '{}' via vectordata", dataSetName);
+
+ DatasetProfileSpec spec = DatasetProfileSpec.parse(dataSetName);
+
+ TestDataSources tds1 = VectorTestData.catalogs();
+ TestDataSources tds2 = tds1.configure();
+ Catalog c1 = tds2.catalog();
+ Optional entryOpt = c1.findExact(spec.dataset());
+
+ VectorTestDataView view;
+ if (entryOpt.isPresent()) {
+ DatasetEntry entry = entryOpt.get();
+ logger.info("Found dataset '{}' in catalog. URL: {}", spec.dataset(), entry.url());
+ ProfileSelector selector = entry.select();
+ view = spec.profile().map(selector::profile).orElseGet(selector::profile);
+ } else {
+ // Fallback to local load
+ logger.debug("Dataset '{}' not found in catalog, attempting local load", spec.dataset());
+ try {
+ TestDataGroup group = VectorTestData.load(spec.dataset());
+ view = spec.profile().map(group::profile).orElseGet(group::getDefaultProfile);
+ } catch (Exception e) {
+ logger.debug("Local load failed for '{}'", spec.dataset());
+ return Optional.empty();
+ }
+ }
+
+ if (view == null) {
+ return Optional.empty();
+ }
+
+ if (prebuffer) {
+ logger.info("Prebuffering dataset '{}'...", dataSetName);
+ CompletableFuture f = view.prebuffer();
+ if (f instanceof ProgressIndicatingFuture) {
+ System.out.println("blocking until prebuffer completes, with progress reporting...");
+ ((ProgressIndicatingFuture>) f).monitorProgress(System.out, 5000);
+ } else {
+ System.out.println("blocking until prebuffer completes...");
+ }
+ f.get();
+ // Block until data is ready/cached
+ }
+
+ return Optional.of(new VectordataDataSet(dataSetName, mapDistanceFunction(view.getDistanceFunction()), view));
+ } catch (Exception e) {
+ logger.error("Error loading dataset '{}' via vectordata", dataSetName, e);
+ return Optional.empty();
+ }
+ }
+
+ private static class VectordataRavv implements RandomAccessVectorValues {
+ private final int dimension;
+ private final int size;
+ private final LongFunction getter;
+
+ public VectordataRavv(int dimension, int size, LongFunction getter) {
+ this.dimension = dimension;
+ this.size = size;
+ this.getter = getter;
+ }
+
+ @Override
+ public int size() {
+ return size;
+ }
+
+ @Override
+ public int dimension() {
+ return dimension;
+ }
+
+ @Override
+ public VectorFloat> getVector(int nodeId) {
+ return vts.createFloatVector(getter.apply((long) nodeId));
+ }
+
+ @Override
+ public boolean isValueShared() {
+ return false;
+ }
+
+ @Override
+ public RandomAccessVectorValues copy() {
+ return this;
+ }
+ }
+
+ private static class VectordataDataSet implements DataSet {
+ private final String name;
+ private final VectorSimilarityFunction vsf;
+ private final int dimension;
+ private final RandomAccessVectorValues baseRavv;
+ private final List> baseVectors;
+ private final List> queryVectors;
+ private final List extends List> groundTruth;
+
+ public VectordataDataSet(String name, VectorSimilarityFunction vsf, VectorTestDataView view) {
+ this.name = name;
+ this.vsf = vsf;
+
+ var bv = view.getBaseVectors().orElseThrow(() -> new RuntimeException("Base vectors missing in dataset " + name));
+ int bSize = (int) bv.getCount();
+ float[] firstVector = bv.get(0L);
+ this.dimension = firstVector.length;
+
+ this.baseRavv = new VectordataRavv(dimension, bSize, bv::get);
+ this.baseVectors = new AbstractList>() {
+ @Override
+ public VectorFloat> get(int index) {
+ return vts.createFloatVector(bv.get((long) index));
+ }
+
+ @Override
+ public int size() {
+ return bSize;
+ }
+ };
+
+ var qv = view.getQueryVectors().orElseThrow(() -> new RuntimeException("Query vectors missing in dataset " + name));
+ int qSize = (int) qv.getCount();
+ this.queryVectors = new AbstractList>() {
+ @Override
+ public VectorFloat> get(int index) {
+ return vts.createFloatVector(qv.get((long) index));
+ }
+
+ @Override
+ public int size() {
+ return qSize;
+ }
+ };
+
+ var niOpt = view.getNeighborIndices();
+ if (niOpt.isPresent()) {
+ var ni = niOpt.get();
+ int niSize = (int) ni.getCount();
+ this.groundTruth = new AbstractList>() {
+ @Override
+ public List get(int index) {
+ int[] indices = ni.get((long) index);
+ return new AbstractList() {
+ @Override
+ public Integer get(int i) {
+ return indices[i];
+ }
+
+ @Override
+ public int size() {
+ return indices.length;
+ }
+ };
+ }
+
+ @Override
+ public int size() {
+ return niSize;
+ }
+ };
+ } else {
+ logger.warn("Ground truth missing in dataset {}, recall metrics will not be available", name);
+ this.groundTruth = Collections.nCopies(queryVectors.size(), Collections.emptyList());
+ }
+
+ System.out.format("%n%s: %d base and %d query vectors loaded via Vectordata, dimensions %d%n",
+ name, baseVectors.size(), queryVectors.size(), dimension);
+ }
+
+ @Override
+ public int getDimension() {
+ return dimension;
+ }
+
+ @Override
+ public RandomAccessVectorValues getBaseRavv() {
+ return baseRavv;
+ }
+
+ @Override
+ public String getName() {
+ return name;
+ }
+
+ @Override
+ public VectorSimilarityFunction getSimilarityFunction() {
+ return vsf;
+ }
+
+ @Override
+ public List> getBaseVectors() {
+ return baseVectors;
+ }
+
+ @Override
+ public List> getQueryVectors() {
+ return queryVectors;
+ }
+
+ @Override
+ public List extends List> getGroundTruth() {
+ return groundTruth;
+ }
+ }
+
+ private VectorSimilarityFunction mapDistanceFunction(DistanceFunction df) {
+ if (df == null) return COSINE;
+ switch (df) {
+ case COSINE:
+ return COSINE;
+ case DOT_PRODUCT:
+ return VectorSimilarityFunction.DOT_PRODUCT;
+ case EUCLIDEAN:
+ case L2:
+ return VectorSimilarityFunction.EUCLIDEAN;
+ default:
+ logger.warn("Unknown distance function {}, defaulting to COSINE", df);
+ return COSINE;
+ }
+ }
+}
diff --git a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSets.java b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSets.java
index da27c8f2c..28c9412fc 100644
--- a/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSets.java
+++ b/jvector-examples/src/main/java/io/github/jbellis/jvector/example/benchmarks/datasets/DataSets.java
@@ -20,10 +20,8 @@
import org.slf4j.LoggerFactory;
import java.security.InvalidParameterException;
-import java.util.ArrayList;
-import java.util.Collection;
-import java.util.List;
-import java.util.Optional;
+import java.util.*;
+import java.util.stream.Collectors;
public class DataSets {
private static final Logger logger = LoggerFactory.getLogger(DataSets.class);
@@ -31,6 +29,7 @@ public class DataSets {
public static final List defaultLoaders = new ArrayList<>() {{
add(new DataSetLoaderHDF5());
add(new DataSetLoaderMFD());
+ add(new DataSetLoaderVectordata(true));
}};
public static Optional loadDataSet(String dataSetName) {
@@ -44,14 +43,14 @@ public static Optional loadDataSet(String dataSetName, Collection dataSetLoaded = loader.loadDataSet(dataSetName);
if (dataSetLoaded.isPresent()) {
- logger.info("dataset [{}] found with loader [{}]", dataSetName, loader.getClass().getSimpleName());
+ logger.info("dataset [{}] found with loader [{}]", dataSetName, loader.getName());
return dataSetLoaded;
}
}
- logger.warn("Unable to find dataset [{}] with any dataset loader.", dataSetName);
+ logger.warn("Unable to find dataset [{}] with any dataset loader. Loaders tried:{}", dataSetName, loaders.stream().map(DataSetLoader::getName).collect(Collectors.joining(",")));
return Optional.empty();
}
}