Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions selector/methods/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,16 @@ def algorithm(self, x, max_size) -> Union[List, Iterable]:
# bv will serve as a mask to discard points within radius r of previously selected points
bv = np.zeros(n_samples)
candidates = list(range(n_samples))

# Initialize min_dists array to store minimum distances from points and selected points
# Initially consider all minimum distances as Infinity
min_dists = np.full(n_samples, np.inf)

# Calculate distances of points from all initially selected points
for idx in selected:
dists = np.linalg.norm(x - x[idx], axis=1)
min_dists = np.minimum(min_dists, dists)

# determine which points are within radius r of initial point
# note: workers=-1 uses all available processors/CPUs
index_remove = tree.query_ball_point(
Expand All @@ -384,20 +394,19 @@ def algorithm(self, x, max_size) -> Union[List, Iterable]:
except ValueError:
sublist = candidates.compressed()

# create a new kd-tree for nearest neighbor lookup with candidates
new_tree = spatial.KDTree(x[selected])
# query the kd-tree for nearest neighbors to selected samples
# note: workers=-1 uses all available processors/CPUs
search, _ = new_tree.query(x[sublist], eps=self.eps, p=self.p, workers=-1)
# identify the nearest neighbor with the largest distance from previously selected samples
best_idx = sublist[np.argmax(search)]
# Select and Append the candidate farthest from its nearest selected point
best_idx = sublist[np.argmax(min_dists[sublist])]
selected.append(best_idx)

count += 1
if count > max_size:
# do this if you have reached the maximum number of points selected
return selected

# Update min_dists array: calculate distances from newly selected point
new_dists = np.linalg.norm(x - x[best_idx], axis=1)
min_dists = np.minimum(min_dists, new_dists)

# eliminate all samples within radius r of the selected sample
index_remove = tree.query_ball_point(
x[best_idx], self.r, eps=self.eps, p=self.p, workers=-1
Expand Down
Loading