-
Notifications
You must be signed in to change notification settings - Fork 3
Speed up find_pairs by using a Numba-optimised kd-tree for searching M to build S #104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
mdales
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall at a first pass.
methods/matching/find_pairs.py
Outdated
| k_set = pd.read_parquet(k_parquet_filename) | ||
| k_subset = k_set.sample( | ||
| frac=0.1, | ||
| frac=1, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this just the same as k_subset = k_set?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes; I didn't clean this up yet as I wasn't sure if we definitely wanted to change to 100% of K instead of 10%.
methods/matching/find_pairs.py
Outdated
|
|
||
| # Find categories in K | ||
| hard_match_category_columns = [k[hard_match_columns].to_numpy() for _, k in k_set.iterrows()] | ||
| hard_match_categories = {k.tobytes(): k for k in hard_match_category_columns} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this needs some comment here about what's happening - I had to work through this with real data to figure out the trick that's going on here to get unique columns. Given you don't use the keys ever again, I'd rather you called values here, rather than in make_s_set_mask, as again that'd make it a bit more obvious you're using this to find unique sets of columns. (assuming I understand what's happening here).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair point, I'll tidy this
methods/matching/find_pairs.py
Outdated
| return s_include, k_miss | ||
|
|
||
| @jit(nopython=True, fastmath=True, error_model="numpy") | ||
| def make_s_set_mask_numba( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very happy to delete this version.
methods/matching/find_pairs.py
Outdated
| k_subset_dist_hard = np.ascontiguousarray(k_subset[hard_match_columns].to_numpy()).astype(np.int32) | ||
|
|
||
| # Methodology 6.5.5: S should be 10 times the size of K, in order to achieve this for every | ||
| # pixel in the subsample (which is 10% the size of K) we select 100 pixels. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment needs updating
| if value >= low[d]: | ||
| queue.append(self.lefts[pos]) | ||
| return count | ||
| def members_sample(self, point: np.ndarray, count: int, rng: np.random.Generator): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have to confess, due to lack of comments, I only skim reviewed this to try and work out what count was achieving, and then gave up. Which is fine at the prototype stage, but before we merge this some comments to API/algorithm would be useful, as this is quite nuanced I think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've added some docstrings and comments, hopefully that covers what is needed but please. do come back to me on anything else.
|
Thanks for reviewing this @mdales, and you're right it could do with some more comments in the gnarly bits and generally clearing up a bit. I'll do that as soon as I can and bounce it back to you (probably after Easter unfortunately). |
|
@mdales I think I've fixed the stuff you've reviewed and improved the comments on the other parts. |
mdales
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a couple of things it'd be nice to tidy up.
| random_state=rng | ||
| ).reset_index() | ||
| # TODO: This assumes the methodolgy is being updated to 100% of K | ||
| k_subset = k_set |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just collapse this change throughout, and when this is merged we bump versions of both the code and the methodology.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will do when I merge
| rand = rand_state[0] + rand_state[3] | ||
| t = rand_state[1] << 17 | ||
| rand_state[2] ^= rand_state[0] | ||
| rand_state[3] ^= rand_state[1] | ||
| rand_state[1] ^= rand_state[2] | ||
| rand_state[0] ^= rand_state[3] | ||
| rand_state[2] ^= t | ||
| rand_state[3] = (rand_state[3] >> 45) | (rand_state[3] << 19) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sad that we have to do this, but I see it's because of performance reasons. Can we at least pull out this code so that we're not baking into the algorithm the encryption method? The methodology does not require this particular algorithm, just we've chosen to use it for performance reasons.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure I quite understand: just pull it out into a function (and hope Numba inlines it), with a comment saying it is a random number but no specific algorithm is needed, and this was just chosen for speed? Or something else? (And we can only do this I suspect if Numba does inline it correctly, and we'll be passing the state around so I'm not sure it'll be particularly clearer)
This has none of the memory sharing optimisations to make this easier to run in parallel (or possible for larger projects). It also has none of the optimisations we talked about to split M into 100 or so subsets and pick from the randomly.
However, it does seem to pick reasonable pairs and have better SMD number than the current version.
It's also worth nothing this code moves to using 100% of K instead of 10% - is that still what we want? It was part of the original motivation for this change.
I'm happy to talk it through with anyone or do any further testing you want to suggest to make sure it is robust before I add the memory sharing optimisation (which requires a fair bit of restructuring to thread everything through but shouldn't change the output).