From 6feb92717f4992166e1b34f9188e0e9df82e3ab8 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Tue, 12 Mar 2024 14:11:05 +0000 Subject: [PATCH 1/3] Speed up find_pairs by using a Numba-optimised kd-tree for searching M to build S --- methods/matching/find_pairs.py | 85 ++++++-- methods/utils/kd_tree.py | 354 +++++++++++++++++++++++++++++++++ test/test_find_pairs.py | 21 +- test/test_kd_tree.py | 176 ++++++++++++++++ 4 files changed, 620 insertions(+), 16 deletions(-) create mode 100644 methods/utils/kd_tree.py create mode 100644 test/test_kd_tree.py diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index 7f17782..21d631d 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -3,11 +3,13 @@ import logging from functools import partial from multiprocessing import Pool, cpu_count, set_start_method +from typing import Any from numba import jit # type: ignore import numpy as np import pandas as pd from methods.common.luc import luc_matching_columns +from methods.utils.kd_tree import make_kdrangetree, make_rumba_tree REPEAT_MATCH_FINDING = 100 DEFAULT_DISTANCE = 10000000.0 @@ -38,7 +40,7 @@ def find_match_iteration( # Methodology 6.5.7: For a 10% sample of K k_set = pd.read_parquet(k_parquet_filename) k_subset = k_set.sample( - frac=0.1, + frac=1, random_state=rng ).reset_index() @@ -76,23 +78,25 @@ def find_match_iteration( hard_match_columns = ['country', 'ecoregion', luc10, luc5, luc0] assert len(hard_match_columns) == HARD_COLUMN_COUNT - # similar to the above, make the hard match columns contiguous float32 numpy arrays - m_dist_hard = np.ascontiguousarray(m_set[hard_match_columns].to_numpy()).astype(np.int32) - k_subset_dist_hard = np.ascontiguousarray(k_subset[hard_match_columns].to_numpy()).astype(np.int32) - # Methodology 6.5.5: S should be 10 times the size of K, in order to achieve this for every # pixel in the subsample (which is 10% the size of K) we select 100 pixels. - required = 100 + required = 10 + + # Find categories in K + hard_match_category_columns = [k[hard_match_columns].to_numpy() for _, k in k_set.iterrows()] + hard_match_categories = {k.tobytes(): k for k in hard_match_category_columns} logging.info("Running make_s_set_mask... required: %d", required) - starting_positions = rng.integers(0, int(m_dist_thresholded.shape[0]), int(k_subset_dist_thresholded.shape[0])) + s_set_mask_true, no_potentials = make_s_set_mask( + rng, + k_set, + m_set, m_dist_thresholded, k_subset_dist_thresholded, - m_dist_hard, - k_subset_dist_hard, - starting_positions, - required + hard_match_columns, + required, + hard_match_categories ) logging.info("Done make_s_set_mask. s_set_mask.shape: %a", {s_set_mask_true.shape}) @@ -173,8 +177,65 @@ def find_match_iteration( logging.info("Finished find match iteration") -@jit(nopython=True, fastmath=True, error_model="numpy") def make_s_set_mask( + rng: np.random.Generator, + k_set: pd.DataFrame, + m_set: pd.DataFrame, + m_dist_thresholded: np.ndarray, + k_subset_dist_thresholded: np.ndarray, + hard_match_columns: list, + required: int, + hard_match_categories: dict[Any, np.ndarray] + ): + s_set_mask_true = np.zeros(m_set.shape[0], dtype=np.bool_) + no_potentials = np.zeros(k_set.shape[0], dtype=np.bool_) + + # Split K and M into those categories and create masks + for values in hard_match_categories.values(): + k_selector = np.all(k_set[hard_match_columns] == values, axis=1) + m_selector = np.all(m_set[hard_match_columns] == values, axis=1) + logging.info(" category: %a |K|: %d |M|: %d", values, k_selector.sum(), m_selector.sum()) + # Make masks for each of those pairs + key_s_set_mask_true, key_no_potentials = make_s_set_mask_rumba_inner( + m_dist_thresholded[m_selector], + k_subset_dist_thresholded[k_selector], + required, + rng + ) + # Merge into one s_set_mask_true + s_set_mask_true[m_selector] = key_s_set_mask_true + # Merge into no_potentials + no_potentials[k_selector] = key_no_potentials + return s_set_mask_true,no_potentials + +def make_s_set_mask_rumba_inner( + m_dist_thresholded: np.ndarray, + k_set_dist_thresholded: np.ndarray, + required: int, + rng: np.random.Generator +): + k_size = k_set_dist_thresholded.shape[0] + m_size = m_dist_thresholded.shape[0] + + s_include = np.zeros(m_size, dtype=np.bool_) + k_miss = np.zeros(k_size, dtype=np.bool_) + + m_tree = make_kdrangetree(m_dist_thresholded, np.ones(m_dist_thresholded.shape[1])) + + rumba_tree = make_rumba_tree(m_tree, m_dist_thresholded) + + for k in range(k_size): + k_row = k_set_dist_thresholded[k] + possible_s = rumba_tree.members_sample(k_row, required, rng) + if len(possible_s) == 0: + k_miss[k] = True + else: + s_include[possible_s] = True + + return s_include, k_miss + +@jit(nopython=True, fastmath=True, error_model="numpy") +def make_s_set_mask_numba( m_dist_thresholded: np.ndarray, k_subset_dist_thresholded: np.ndarray, m_dist_hard: np.ndarray, diff --git a/methods/utils/kd_tree.py b/methods/utils/kd_tree.py new file mode 100644 index 0000000..beb926e --- /dev/null +++ b/methods/utils/kd_tree.py @@ -0,0 +1,354 @@ +import math + +import numpy as np +from numba import float32, int32 # type: ignore +from numba.experimental import jitclass # type: ignore + + +class KDTree: + def __init__(self): + pass + + def contains(self, _range) -> bool: + raise NotImplementedError() + + def depth(self) -> int: + return 1 + + def size(self) -> int: + return 1 + + def count(self) -> int: + return 0 + + def members(self, _range) -> np.ndarray: + raise NotImplementedError() + + def dump(self, _space: str): + raise NotImplementedError() + +class KDLeaf(KDTree): + def __init__(self, point, index): + self.point = point + self.index = index + def contains(self, range) -> bool: + return np.all(range[0] <= self.point) & np.all(range[1] >= self.point) # type: ignore + def members(self, range): + if self.contains(range): + return np.array([self.index]) + return np.empty(0, dtype=np.int_) + def dump(self, space: str): + print(space, f"point {self.point}") + def count(self): + return 1 + +class KDList(KDTree): + def __init__(self, points, indexes): + self.points = points + self.indexes = indexes + def contains(self, range) -> bool: + return np.any(np.all(range[0] <= self.points, axis=1) & np.all(range[1] >= self.points, axis=1)) # type: ignore + def members(self, range): + return self.indexes[np.all(range[0] <= self.points, axis=1) & np.all(range[1] >= self.points, axis=1)] + def dump(self, space: str): + print(space, f"points {self.points}") + def count(self): + return len(self.points) + +class KDSplit(KDTree): + def __init__(self, d: int, value: float, left: KDTree, right: KDTree): + self.d = d + self.value = value + self.left = left + self.right = right + def contains(self, range) -> bool: + l = self.value - range[0, self.d] # Amount on left side + r = range[1, self.d] - self.value # Amount on right side + # Either l or r must be positive, or both + # Pick the biggest first + if l >= r: + if self.left.contains(range): + return True + # Visit the rest if it is inside + if r >= 0: + if self.right.contains(range): + return True + else: + if self.right.contains(range): + return True + # Visit the rest if it is inside + if l >= 0: + if self.left.contains(range): + return True + return False + + def members(self, range): + result = None + if self.value >= range[0, self.d]: + result = self.left.members(range) + if range[1, self.d] >= self.value: + rights = self.right.members(range) + if result is None: + result = rights + else: + result = np.append(result, rights, axis=0) + return result if result is not None else np.empty(0, dtype=np.int_) + + def size(self) -> int: + return 1 + self.left.size() + self.right.size() + + def depth(self) -> int: + return 1 + max(self.left.depth(), self.right.depth()) + + def dump(self, space: str): + print(space, f"split d{self.d} at {self.value}") + print(space + " <") + self.left.dump(space + "\t") + print(space + " >") + self.right.dump(space + "\t") + def count(self): + return self.left.count() + self.right.count() + + +class KDRangeTree: + def __init__(self, tree, widths): + self.tree = tree + self.widths = widths + def contains(self, point) -> bool: + return self.tree.contains(np.array([point - self.widths, point + self.widths])) + def members(self, point) -> np.ndarray: + return self.tree.members(np.array([point - self.widths, point + self.widths])) + def dump(self, space: str): + self.tree.dump(space) + def size(self): + return self.tree.size() + def depth(self): + return self.tree.depth() + def count(self): + return self.tree.count() + +def make_kdrangetree(points, widths): + def make_kdtree_internal(points, indexes): + if len(points) == 1: + return KDLeaf(points[0], indexes[0]) + if len(points) < 30: + return KDList(points, indexes) + # Find split in dimension with most bins + dimensions = points.shape[1] + bins: float = None # type: ignore + chosen_d_min = 0 + chosen_d_max = 0 + chosen_d = 0 + for d in range(dimensions): + d_max = np.max(points[:, d]) + d_min = np.min(points[:, d]) + d_range = d_max - d_min + d_bins = d_range / widths[d] + if bins is None or d_bins > bins: + bins = d_bins + chosen_d = d + chosen_d_max = d_max + chosen_d_min = d_min + + if bins < 1.3: + print(bins) + print(chosen_d, chosen_d_min, chosen_d_max) + # No split is very worthwhile, so dump points + return KDList(points, indexes) + + split_at = np.median(points[:, chosen_d]) + # Avoid degenerate cases + if split_at == chosen_d_max or split_at == chosen_d_min: + split_at = (chosen_d_max + chosen_d_min) / 2 + + left_side = points[:, chosen_d] <= split_at + right_side = ~left_side + lefts = points[left_side] + rights = points[right_side] + lefts_indexes = indexes[left_side] + rights_indexes = indexes[right_side] + return KDSplit(chosen_d, split_at, make_kdtree_internal(lefts, lefts_indexes), make_kdtree_internal(rights, rights_indexes)) + indexes = np.arange(len(points)) + return KDRangeTree(make_kdtree_internal(points, indexes), widths) + +@jitclass([('ds', int32[:]), ('values', float32[:]), ('items', int32[:]), ('lefts', int32[:]), ('rights', int32[:]), ('rows', float32[:, :]), ('dimensions', int32), ('widths', float32[:])]) +class RumbaTree: + def __init__(self, ds: np.ndarray, values: np.ndarray, items: np.ndarray, lefts: np.ndarray, rights: np.ndarray, rows: np.ndarray, dimensions: int, widths: np.ndarray): + self.ds = ds + self.values = values + self.items = items + self.lefts = lefts + self.rights = rights + self.rows = rows + self.dimensions = dimensions + self.widths = widths + def members(self, point: np.ndarray): + low = point - self.widths + high = point + self.widths + queue = [0] + finds = [] + while len(queue) > 0: + pos = queue.pop() + d = self.ds[pos] + value = self.values[pos] + if math.isnan(value): + i = d + item = self.items[i] + while item != -1: + # Check item + found = True + for d in range(self.dimensions): + value = self.rows[item, d] + if value < low[d]: + found = False + break + if value > high[d]: + found = False + break + if found: + finds.append(item) + i += 1 + item = self.items[i] + else: + if value <= high[d]: + queue.append(self.rights[pos]) + if value >= low[d]: + queue.append(self.lefts[pos]) + return finds + def count_members(self, point: np.ndarray): + low = point - self.widths + high = point + self.widths + queue = [0] + count = 0 + while len(queue) > 0: + pos = queue.pop() + d = self.ds[pos] + value = self.values[pos] + if math.isnan(value): + i = d + item = self.items[i] + while item != -1: + # Check item + found = True + for d in range(self.dimensions): + value = self.rows[item, d] + if value < low[d]: + found = False + break + if value > high[d]: + found = False + break + if found: + count += 1 + i += 1 + item = self.items[i] + else: + if value <= high[d]: + queue.append(self.rights[pos]) + if value >= low[d]: + queue.append(self.lefts[pos]) + return count + def members_sample(self, point: np.ndarray, count: int, rng: np.random.Generator): + low = point - self.widths + high = point + self.widths + queue = [0] + finds: list[int] = [] + found_count = 0 + rand_state = rng.integers(0, 0xFFFF_FFFF_FFFF_FFFF, 4, np.uint64, True) + while len(queue) > 0: + pos = queue.pop() + d = self.ds[pos] + value = self.values[pos] + if math.isnan(value): + i = d + item: int = self.items[i] # type: ignore + while item != -1: + # Check item + found = True + for d in range(self.dimensions): + value = self.rows[item, d] + if value < low[d]: + found = False + break + if value > high[d]: + found = False + break + if found: + found_count += 1 + if len(finds) < count: + finds.append(item) + else: + # Replace a random item in finds based on on-line search probability + # Source: https://prng.di.unimi.it/xoshiro256plus.c + rand = rand_state[0] + rand_state[3] + t = rand_state[1] << 17 + rand_state[2] ^= rand_state[0] + rand_state[3] ^= rand_state[1] + rand_state[1] ^= rand_state[2] + rand_state[0] ^= rand_state[3] + rand_state[2] ^= t + rand_state[3] = (rand_state[3] >> 45) | (rand_state[3] << 19) + + pos = rand % found_count + + if pos < count: + finds[pos] = item + i += 1 + item = self.items[i] # type: ignore + else: + if value <= high[d]: + queue.append(self.rights[pos]) + if value >= low[d]: + queue.append(self.lefts[pos]) + return finds + +NAN = float('nan') +def make_rumba_tree(tree: KDRangeTree, rows: np.ndarray): + ds: list[int] = [] + values = [] + items: list[int] = [] + lefts = [] + rights = [] + widths = None + def recurse(node): + nonlocal widths + if isinstance(node, KDSplit): + pos = len(ds) + ds.append(node.d) + values.append(node.value) + lefts.append(pos + 1) # Next node we output will be left + rights.append(0xDEADBEEF) # Put placeholder here... + recurse(node.left) + rights[pos] = len(ds) # ..and fixup right to be the next node we output + recurse(node.right) + elif isinstance(node, KDList): + values.append(NAN) + ds.append(len(items)) + lefts.append(-1) # Specific invalid values for debugging an errors in tree build + rights.append(-2) + for item in node.indexes: + items.append(item) + items.append(-1) + elif isinstance(node, KDLeaf): + values.append(NAN) + ds.append(len(items)) + lefts.append(-3) + rights.append(-4) + items.append(node.index) + items.append(-1) + elif isinstance(node, KDRangeTree): + widths = node.widths + recurse(node.tree) + recurse(tree) + if widths is None: + raise ValueError(f"Expected KDRangeTree, got {tree}") + return RumbaTree( + np.array(ds, dtype=np.int32), + np.array(values, dtype=np.float32), + np.array(items, dtype=np.int32), + np.array(lefts, dtype=np.int32), + np.array(rights, dtype=np.int32), + np.ascontiguousarray(rows, dtype=np.float32), + rows.shape[1], + np.ascontiguousarray(widths, dtype=np.float32), + ) diff --git a/test/test_find_pairs.py b/test/test_find_pairs.py index ebb8427..fb75bf8 100644 --- a/test/test_find_pairs.py +++ b/test/test_find_pairs.py @@ -1,5 +1,6 @@ import numpy as np from scipy.spatial.distance import mahalanobis +import pandas as pd from methods.matching import find_pairs @@ -54,16 +55,28 @@ def test_make_s_set_mask(): [0.5, 0.6], ], dtype=np.float32) - starting_positions = np.array([1, 2, 3, 4]) + rng = np.random.default_rng(42) + k_set = np.array([k_dist_hard, k_dist_thresholded]) + k_set = np.moveaxis(k_set, 0, 1).reshape(-1, k_dist_hard.shape[1] + k_dist_thresholded.shape[1]) + k_set = pd.DataFrame(k_set) + + m_set = np.array([m_dist_hard, m_dist_thresholded]) + m_set = np.moveaxis(m_set, 0, 1).reshape(-1, m_dist_hard.shape[1] + m_dist_thresholded.shape[1]) + m_set = pd.DataFrame(m_set) + + hard_match_columns = list(range(k_dist_hard.shape[1])) + hard_match_categories = {k.tobytes(): k for k in k_dist_hard} # calculate using make_s_set_mask s_subset_mask, misses = find_pairs.make_s_set_mask( + rng, + k_set, + m_set, m_dist_thresholded, k_dist_thresholded, - m_dist_hard, - k_dist_hard, - starting_positions, + hard_match_columns, 2, + hard_match_categories, ) assert (s_subset_mask == [True, False, False, False, False]).all() diff --git a/test/test_kd_tree.py b/test/test_kd_tree.py new file mode 100644 index 0000000..e9e263a --- /dev/null +++ b/test/test_kd_tree.py @@ -0,0 +1,176 @@ +import math +from time import time +import numpy as np +import pandas as pd + +from methods.common.luc import luc_matching_columns +from methods.utils.kd_tree import KDRangeTree, make_kdrangetree, make_rumba_tree + +ALLOWED_VARIATION = np.array([ + 200, + 2.5, + 10, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, + 0.1, +]) + +def test_kd_tree_matches_as_expected(): + def build_rects(items): + rects = [] + for item in items: + lefts = [] + rights = [] + for dimension, value in enumerate(item): + width = ALLOWED_VARIATION[dimension] + if width < 0: + fraction = -width + width = value * fraction + lefts.append(value - width) + rights.append(value + width) + rects.append([lefts, rights]) + return np.array(rects) + + expected_fraction = 1 / 100 # This proportion of pixels we end up matching + + def build_kdranged_tree_for_k(k_rows) -> KDRangeTree: + return make_kdrangetree(np.array([( + row.elevation, + row.slope, + row.access, + row["cpc0_u"], + row["cpc0_d"], + row["cpc5_u"], + row["cpc5_d"], + row["cpc10_u"], + row["cpc10_d"], + ) for row in k_rows + ]), ALLOWED_VARIATION) + + + luc0, luc5, luc10 = luc_matching_columns(2012) + source_pixels = pd.read_parquet("./test/data/1201-k.parquet") + + # Split source_pixels into classes + source_rows = [] + for _, row in source_pixels.iterrows(): + key = (int(row.ecoregion) << 16) | (int(row[luc0]) << 10) | (int(row[luc5]) << 5) | (int(row[luc10])) + if key != 1967137: + continue + source_rows.append(row) + + source = np.array([ + [ + row.elevation, + row.slope, + row.access, + row["cpc0_u"], + row["cpc0_d"], + row["cpc5_u"], + row["cpc5_d"], + row["cpc10_u"], + row["cpc10_d"], + ] for row in source_rows + ]) + + # Invent an array of values that matches the expected_fraction + length = 10000 + np.random.seed(42) + + ranges = np.transpose(np.array([ + np.min(source, axis=0), + np.max(source, axis=0) + ])) + + # Safe ranges (exclude 10% of outliers) + safe_ranges = np.transpose(np.array([ + np.quantile(source, 0.05, axis=0), + np.quantile(source, 0.95, axis=0) + ])) + + # Need to put an estimate here of how much of the area inside those 90% bounds is actually filled + filled_fraction = 0.775 + + # Proportion of values that should fall inside each dimension + inside_fraction = expected_fraction * math.pow(1 / filled_fraction, len(ranges)) + inside_length = math.ceil(length * inside_fraction) + inside_values = np.random.uniform(safe_ranges[:, 0], safe_ranges[:, 1], (inside_length, len(ranges))) + + widths = ranges[:, 1] - ranges[:, 0] + range_extension = 100 * widths # Width extension makes it very unlikely a random value will be inside + outside_ranges = np.transpose([ranges[:, 0] - range_extension, ranges[:, 1] + range_extension]) + + outside_length = length - inside_length + outside_values = np.random.uniform(outside_ranges[:, 0], outside_ranges[:, 1], (outside_length, len(ranges))) + + test_values = np.append(inside_values, outside_values, axis=0) + + def do_np_matching(): + source_rects = build_rects(source) + found = 0 + for i in range(length): + pos = np.all((test_values[i] >= source_rects[:, 0]) & (test_values[i] <= source_rects[:, 1]), axis=1) + found += 1 if np.any(pos) else 0 + return found + + def speed_of(what, func): + expected_finds = 946 + start = time() + value = func() + end = time() + assert value == expected_finds, f"Got wrong value {value} for method {what}, expected {expected_finds}" + print(what, ": ", (end - start) / length, "per call") + + print("making tree... (this will take a few seconds)") + start = time() + kd_tree = build_kdranged_tree_for_k(source_rows) + print("build time", time() - start) + print("tree depth", kd_tree.depth()) + print("tree size", kd_tree.size()) + + def do_kdrange_tree_matching(): + found = 0 + for i in range(length): + found += 1 if len(kd_tree.members(test_values[i])) > 0 else 0 + return found + + rumba_tree = make_rumba_tree(kd_tree, source) + + def do_rumba_tree_matching(): + found = 0 + for i in range(length): + found += 1 if len(rumba_tree.members(test_values[i])) > 0 else 0 + return found + + test_np_matching = False # This is slow but a useful check so I don't want to delete it + if test_np_matching: + speed_of("NP matching", do_np_matching) + speed_of("KD Tree matching", do_kdrange_tree_matching) + speed_of("Rumba matching", do_rumba_tree_matching) + +def test_rumba_tree_sampling(): + """Check that the rumba tree members_sample function returns a uniform random sample. + + Actually only tests the mean converges to the middle index over a series of runs. + """ + data = np.arange(3000).reshape((-1, 3)) + + # Build a tree + centre = np.array([1500, 1500, 1500]) + + kd_tree = make_kdrangetree(data, centre) + rumba_tree = make_rumba_tree(kd_tree, data) + + assert 1000 == rumba_tree.count_members(centre) + + means = [] + for seed in range(100): + for i in range(100): + sample = rumba_tree.members_sample(centre, i + 1, np.random.default_rng(100 * seed + i)) + means.append(np.mean(sample)) + mean = np.mean(np.array(means)) + mean_difference = abs(mean - 500) + assert mean_difference < 1 From dc69ba1fac2eec9052eeb338d022b1d1589eaac1 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Fri, 15 Mar 2024 15:37:26 +0000 Subject: [PATCH 2/3] Removed debug print statements accidentally left in --- methods/utils/kd_tree.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/methods/utils/kd_tree.py b/methods/utils/kd_tree.py index beb926e..2678ac7 100644 --- a/methods/utils/kd_tree.py +++ b/methods/utils/kd_tree.py @@ -151,8 +151,6 @@ def make_kdtree_internal(points, indexes): chosen_d_min = d_min if bins < 1.3: - print(bins) - print(chosen_d, chosen_d_min, chosen_d_max) # No split is very worthwhile, so dump points return KDList(points, indexes) From cd63f60144ed6d55308d2e4b19d3de643dd240d0 Mon Sep 17 00:00:00 2001 From: Robin Message Date: Tue, 16 Apr 2024 16:09:38 +0100 Subject: [PATCH 3/3] Fixes from code review and better comments on RumbaTree --- methods/matching/find_pairs.py | 77 +++++---------------------- methods/utils/kd_tree.py | 97 ++++++++++++++++++++++++++++++++-- test/test_find_pairs.py | 2 +- 3 files changed, 107 insertions(+), 69 deletions(-) diff --git a/methods/matching/find_pairs.py b/methods/matching/find_pairs.py index 21d631d..0461f23 100644 --- a/methods/matching/find_pairs.py +++ b/methods/matching/find_pairs.py @@ -3,7 +3,6 @@ import logging from functools import partial from multiprocessing import Pool, cpu_count, set_start_method -from typing import Any from numba import jit # type: ignore import numpy as np import pandas as pd @@ -39,10 +38,8 @@ def find_match_iteration( # Methodology 6.5.7: For a 10% sample of K k_set = pd.read_parquet(k_parquet_filename) - k_subset = k_set.sample( - frac=1, - random_state=rng - ).reset_index() + # TODO: This assumes the methodolgy is being updated to 100% of K + k_subset = k_set logging.info("Loading M from %s", m_parquet_filename) m_set = pd.read_parquet(m_parquet_filename) @@ -78,13 +75,16 @@ def find_match_iteration( hard_match_columns = ['country', 'ecoregion', luc10, luc5, luc0] assert len(hard_match_columns) == HARD_COLUMN_COUNT - # Methodology 6.5.5: S should be 10 times the size of K, in order to achieve this for every - # pixel in the subsample (which is 10% the size of K) we select 100 pixels. + # Methodology 6.5.5: S should be 10 times the size of K, in order to achieve this + # we select 10 pixels for each K. + # TODO: This assumes the methodolgy is being updated to 100% of K required = 10 - # Find categories in K - hard_match_category_columns = [k[hard_match_columns].to_numpy() for _, k in k_set.iterrows()] - hard_match_categories = {k.tobytes(): k for k in hard_match_category_columns} + # Find the unique categories in K + hard_match_category_values = [k[hard_match_columns].to_numpy() for _, k in k_set.iterrows()] + # Use a dictionary comprehension to find the unique values for the category columns + # and then convert that into a list + hard_match_categories = list({k.tobytes(): k for k in hard_match_category_values}.values()) logging.info("Running make_s_set_mask... required: %d", required) @@ -185,13 +185,13 @@ def make_s_set_mask( k_subset_dist_thresholded: np.ndarray, hard_match_columns: list, required: int, - hard_match_categories: dict[Any, np.ndarray] + hard_match_categories: list[np.ndarray] ): s_set_mask_true = np.zeros(m_set.shape[0], dtype=np.bool_) no_potentials = np.zeros(k_set.shape[0], dtype=np.bool_) # Split K and M into those categories and create masks - for values in hard_match_categories.values(): + for values in hard_match_categories: k_selector = np.all(k_set[hard_match_columns] == values, axis=1) m_selector = np.all(m_set[hard_match_columns] == values, axis=1) logging.info(" category: %a |K|: %d |M|: %d", values, k_selector.sum(), m_selector.sum()) @@ -234,59 +234,6 @@ def make_s_set_mask_rumba_inner( return s_include, k_miss -@jit(nopython=True, fastmath=True, error_model="numpy") -def make_s_set_mask_numba( - m_dist_thresholded: np.ndarray, - k_subset_dist_thresholded: np.ndarray, - m_dist_hard: np.ndarray, - k_subset_dist_hard: np.ndarray, - starting_positions: np.ndarray, - required: int -): - m_size = m_dist_thresholded.shape[0] - k_size = k_subset_dist_thresholded.shape[0] - - s_include = np.zeros(m_size, dtype=np.bool_) - k_miss = np.zeros(k_size, dtype=np.bool_) - - for k in range(k_size): - matches = 0 - k_row = k_subset_dist_thresholded[k, :] - k_hard = k_subset_dist_hard[k] - - for index in range(m_size): - m_index = (index + starting_positions[k]) % m_size - - m_row = m_dist_thresholded[m_index, :] - m_hard = m_dist_hard[m_index] - - should_include = True - - # check that every element of m_hard matches k_hard - hard_equals = True - for j in range(m_hard.shape[0]): - if m_hard[j] != k_hard[j]: - hard_equals = False - - if not hard_equals: - should_include = False - else: - for j in range(m_row.shape[0]): - if abs(m_row[j] - k_row[j]) > 1.0: - should_include = False - - if should_include: - s_include[m_index] = True - matches += 1 - - # Don't find any more M's - if matches == required: - break - - k_miss[k] = matches == 0 - - return s_include, k_miss - # Function which returns a boolean array indicating whether all values in a row are true @jit(nopython=True, fastmath=True, error_model="numpy") def rows_all_true(rows: np.ndarray): diff --git a/methods/utils/kd_tree.py b/methods/utils/kd_tree.py index 2678ac7..2182886 100644 --- a/methods/utils/kd_tree.py +++ b/methods/utils/kd_tree.py @@ -6,28 +6,42 @@ class KDTree: + """ + A k-d tree represents points in a K-dimensional space. + + We expect to do range searches to find points that match a range on all k dimensions. + """ def __init__(self): pass - def contains(self, _range) -> bool: + def contains(self, range) -> bool: + """Does the tree contain a point in range?""" raise NotImplementedError() def depth(self) -> int: + """The height of the deepest node in the tree.""" return 1 def size(self) -> int: + """The number of nodes in the tree.""" return 1 def count(self) -> int: + """The number of points in the tree.""" return 0 - def members(self, _range) -> np.ndarray: + def members(self, range) -> np.ndarray: + """Return a list of all members in range.""" raise NotImplementedError() def dump(self, _space: str): + """Return a string representation of the tree for debugging.""" raise NotImplementedError() class KDLeaf(KDTree): + """ + A leaf repesents a single point in the tree. + """ def __init__(self, point, index): self.point = point self.index = index @@ -43,6 +57,11 @@ def count(self): return 1 class KDList(KDTree): + """ + A list node repesents a list of points in the tree. + + This is an optimisation for when linear search becomes quicker than walking the tree. + """ def __init__(self, points, indexes): self.points = points self.indexes = indexes @@ -56,6 +75,13 @@ def count(self): return len(self.points) class KDSplit(KDTree): + """ + A split node represents am axis-aligned binary split in the tree. + + The tree splits into tree disjoint subtrees at value in dimension d, called left and right. + All the functions that take a range must consider the need to look into both left and right + if the range intersects the split point. + """ def __init__(self, d: int, value: float, left: KDTree, right: KDTree): self.d = d self.value = value @@ -111,6 +137,8 @@ def count(self): class KDRangeTree: + """Wrap up a KDTree with a fixed width range for queries. + """ def __init__(self, tree, widths): self.tree = tree self.widths = widths @@ -128,6 +156,12 @@ def count(self): return self.tree.count() def make_kdrangetree(points, widths): + """Make a KDRangeTree containing points that is queried with ranges of width width. + + We recursively split up the data by the most-discriminating dimension. + If no dimension splits the data up very much, or there are less than a cut-off number of + points, we output a list node instead of a split. + """ def make_kdtree_internal(points, indexes): if len(points) == 1: return KDLeaf(points[0], indexes[0]) @@ -171,6 +205,11 @@ def make_kdtree_internal(points, indexes): @jitclass([('ds', int32[:]), ('values', float32[:]), ('items', int32[:]), ('lefts', int32[:]), ('rights', int32[:]), ('rows', float32[:, :]), ('dimensions', int32), ('widths', float32[:])]) class RumbaTree: + """A RumbaTree is a KDRangeTree which is optimised with Numba. + + Instead of pointers, the various members of the tree are serialised into arrays to make + traversal in Numba easy. + """ def __init__(self, ds: np.ndarray, values: np.ndarray, items: np.ndarray, lefts: np.ndarray, rights: np.ndarray, rows: np.ndarray, dimensions: int, widths: np.ndarray): self.ds = ds self.values = values @@ -181,6 +220,14 @@ def __init__(self, ds: np.ndarray, values: np.ndarray, items: np.ndarray, lefts: self.dimensions = dimensions self.widths = widths def members(self, point: np.ndarray): + """ + Return all the items that are within widths of point. + + We use a queue to start at the top of the tree and, for each node, decide if the left or + right need to be processed, adding them to the queue if so. + For leaf nodes (marked with a value of NaN), we instead process the list of items at that + node and return all of the items that are within widths of point. + """ low = point - self.widths high = point + self.widths queue = [0] @@ -214,6 +261,11 @@ def members(self, point: np.ndarray): queue.append(self.lefts[pos]) return finds def count_members(self, point: np.ndarray): + """ + Return a count of items that are within widths of point. + + Mostly just for debugging. + """ low = point - self.widths high = point + self.widths queue = [0] @@ -247,11 +299,23 @@ def count_members(self, point: np.ndarray): queue.append(self.lefts[pos]) return count def members_sample(self, point: np.ndarray, count: int, rng: np.random.Generator): + """ + Return up to count items that are within widths of point, selected at random. + + As with members, we use a queue to start at the top of the tree and, for each node, decide + if the left or right need to be processed, adding them to the queue if so. + + However, for leaf nodes, instead of taking all items that match, we use algorithm R for + resoivoir sampling to select up to count items at random. + https://en.wikipedia.org/wiki/Reservoir_sampling#Simple:_Algorithm_R + """ low = point - self.widths high = point + self.widths queue = [0] finds: list[int] = [] + # We need to track how many items have been found for the sampling. found_count = 0 + # rng is noticable slow, so we seed 256-bit inline xoshiro generator here. rand_state = rng.integers(0, 0xFFFF_FFFF_FFFF_FFFF, 4, np.uint64, True) while len(queue) > 0: pos = queue.pop() @@ -272,11 +336,29 @@ def members_sample(self, point: np.ndarray, count: int, rng: np.random.Generator found = False break if found: + # Keep track of how many items are found for the probabilities to work out. found_count += 1 + # For each item we find, decide here whether to include it in the output. if len(finds) < count: + # If we haven't found count items yet, we provisionaly take this item. finds.append(item) else: - # Replace a random item in finds based on on-line search probability + # Replace a random item in finds based on algorithm R. + + # When we find the N+1th item, by induction, we assume we have selected + # count of the previous N items with 1/N probability. + # + # We want to select this item with count/N+1 probability + # and we want it to replace one of the previously selected + # items with a uniform 1/count probability. + # We generate a uniform random number, pos, from [0, N+1) + # and select this item if that number is less than count + # (which corresponds to our count/N+1 probability of selection). + # The position to replace the item is then given by pos, + # given we now know it is a uniform random number in [0, count) + # (which again matches our desired distribution). + + # Generate a random number # Source: https://prng.di.unimi.it/xoshiro256plus.c rand = rand_state[0] + rand_state[3] t = rand_state[1] << 17 @@ -287,8 +369,10 @@ def members_sample(self, point: np.ndarray, count: int, rng: np.random.Generator rand_state[2] ^= t rand_state[3] = (rand_state[3] >> 45) | (rand_state[3] << 19) + # Work out where this item should fall. pos = rand % found_count + # Replace an existing item if we should include this item. if pos < count: finds[pos] = item i += 1 @@ -302,6 +386,13 @@ def members_sample(self, point: np.ndarray, count: int, rng: np.random.Generator NAN = float('nan') def make_rumba_tree(tree: KDRangeTree, rows: np.ndarray): + """ + Make a RumbaTree from a KDRangeTree. + + This just involved walking the tree and flattening it out into arrays. + The code is slightly fiddly because we don't know how big a node is until we have walked it, + so we output a placeholder for the right node and update it after walking the left node. + """ ds: list[int] = [] values = [] items: list[int] = [] diff --git a/test/test_find_pairs.py b/test/test_find_pairs.py index fb75bf8..e8e2fb2 100644 --- a/test/test_find_pairs.py +++ b/test/test_find_pairs.py @@ -65,7 +65,7 @@ def test_make_s_set_mask(): m_set = pd.DataFrame(m_set) hard_match_columns = list(range(k_dist_hard.shape[1])) - hard_match_categories = {k.tobytes(): k for k in k_dist_hard} + hard_match_categories = list({k.tobytes(): k for k in k_dist_hard}.values()) # calculate using make_s_set_mask s_subset_mask, misses = find_pairs.make_s_set_mask(