Skip to content

Commit c2825a1

Browse files
Create approx_nearest_neighbours.py
1 parent 7530a41 commit c2825a1

1 file changed

Lines changed: 119 additions & 0 deletions

File tree

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
Approximate Nearest Neighbor (ANN) Search
3+
https://en.wikipedia.org/wiki/Nearest_neighbor_search#Approximate_nearest_neighbor
4+
5+
ANN search finds "close enough" vectors instead of the exact nearest neighbor,
6+
which makes it much faster for large datasets.
7+
8+
This implementation uses a simple **random projection hashing** method.
9+
Steps:
10+
1. Generate random hyperplanes to hash vectors into buckets.
11+
2. Place dataset vectors into buckets.
12+
3. For a query vector, look into its bucket (and maybe nearby buckets).
13+
4. Return the approximate nearest neighbor from those candidates.
14+
15+
Each result contains:
16+
1. The nearest (approximate) vector.
17+
2. Its distance from the query vector.
18+
"""
19+
20+
from __future__ import annotations
21+
22+
import math
23+
from collections import defaultdict
24+
import numpy as np
25+
26+
27+
def euclidean(input_a: np.ndarray, input_b: np.ndarray) -> float:
28+
"""
29+
Calculates Euclidean distance between two vectors.
30+
>>> euclidean(np.array([0]), np.array([1]))
31+
1.0
32+
>>> euclidean(np.array([1, 2]), np.array([1, 5]))
33+
3.0
34+
"""
35+
return math.sqrt(sum(pow(a - b, 2) for a, b in zip(input_a, input_b)))
36+
37+
38+
class ANN:
39+
"""
40+
Approximate Nearest Neighbor using random projection hashing.
41+
"""
42+
43+
def __init__(self, dataset: np.ndarray, n_planes: int = 5, seed: int = 42) -> None:
44+
"""
45+
:param dataset: ndarray of shape (n_samples, n_features)
46+
:param n_planes: number of random hyperplanes for hashing
47+
:param seed: random seed for reproducibility
48+
"""
49+
self.dataset = dataset
50+
self.n_planes = n_planes
51+
rng = np.random.default_rng(seed)
52+
self.planes = rng.standard_normal((n_planes, dataset.shape[1]))
53+
self.buckets: dict[str, list[np.ndarray]] = defaultdict(list)
54+
self._build_index()
55+
56+
def _hash_vector(self, vec: np.ndarray) -> str:
57+
"""
58+
Hash a vector based on which side of each hyperplane it falls on.
59+
Returns a bit string.
60+
61+
>>> dataset = np.array([[1, 2]])
62+
>>> ann = ANN(dataset, n_planes=2, seed=0)
63+
>>> h = ann._hash_vector(np.array([1, 2]))
64+
>>> isinstance(h, str)
65+
True
66+
>>> len(h) == ann.n_planes
67+
True
68+
"""
69+
signs = (vec @ self.planes.T) >= 0
70+
return "".join(["1" if s else "0" for s in signs])
71+
72+
def _build_index(self) -> None:
73+
"""
74+
Build hash buckets for all dataset vectors.
75+
76+
>>> dataset = np.array([[0, 0], [1, 1]])
77+
>>> ann = ANN(dataset, n_planes=2, seed=0)
78+
>>> all(isinstance(k, str) for k in ann.buckets.keys())
79+
True
80+
>>> sum(len(v) for v in ann.buckets.values()) == len(dataset)
81+
True
82+
"""
83+
for vec in self.dataset:
84+
h = self._hash_vector(vec)
85+
self.buckets[h].append(vec)
86+
87+
def query(self, query_vectors: np.ndarray) -> list[list[list[float] | float]]:
88+
"""
89+
Find approximate nearest neighbor for query vector(s).
90+
:param query_vectors: ndarray of shape (m, n_features)
91+
:return: list of [nearest_vector, distance]
92+
93+
>>> dataset = np.array([[0, 0], [1, 1], [2, 2], [10, 10]])
94+
>>> ann = ANN(dataset, n_planes=4, seed=0)
95+
>>> ann.query(np.array([[0, 1]])) # doctest: +NORMALIZE_WHITESPACE
96+
[[[0, 0], 1.0]]
97+
"""
98+
results = []
99+
for vec in query_vectors:
100+
h = self._hash_vector(vec)
101+
candidates = self.buckets[h]
102+
103+
if not candidates: # fallback: search entire dataset
104+
candidates = self.dataset
105+
106+
# Approximate NN search among candidates
107+
best_vec = candidates[0]
108+
best_dist = euclidean(vec, best_vec)
109+
for cand in candidates[1:]:
110+
d = euclidean(vec, cand)
111+
if d < best_dist:
112+
best_vec, best_dist = cand, d
113+
results.append([best_vec.tolist(), best_dist])
114+
return results
115+
116+
117+
if __name__ == "__main__":
118+
import doctest
119+
doctest.testmod()

0 commit comments

Comments
 (0)