-
-
Notifications
You must be signed in to change notification settings - Fork 95
Optimize SoftmaxCrossEntropy (Dense & Sparse) using OpenMP and raw pointers #686
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Summary of ChangesHello @nazar-kl, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces significant performance enhancements to the Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request significantly optimizes the softmax_cross_entropy and sparse_softmax_cross_entropy functions by using raw pointers, fusing operations, and adding OpenMP parallelization. The performance gains shown in the benchmarks are impressive. The code is well-written and the logic for optimization is sound. I have identified a critical issue where a division by zero can occur if an empty batch is provided. Additionally, I've suggested further performance improvements by fusing the inner loops to perform calculations in a single pass, which would also enhance robustness.
| when compileOption("boundChecks"): | ||
| check_input_target(input, target) | ||
|
|
||
| let batch_size = input.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code does not handle the case where batch_size is 0. This will lead to a division-by-zero error at the end of the function (result /= T(batch_size)). You should add a check at the beginning to handle this edge case and return a sensible value, like 0.
let batch_size = input.shape[0]
if batch_size == 0: return 0.T
|
|
||
| # TODO: term rewriting macro for auto fusion | ||
|
|
||
| let batch_size = input.shape[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code does not handle the case where batch_size is 0. This will lead to a division-by-zero error at the end of the function (result /= T(batch_size)). You should add a check at the beginning to handle this edge case and return a sensible value, like 0.
let batch_size = input.shape[0]
if batch_size == 0: return 0.T
| var max_val = inp_ptr[row_inp_idx] | ||
| for j in 1 ..< features: | ||
| let val = inp_ptr[row_inp_idx + j * inp_s1] | ||
| if val > max_val: | ||
| max_val = val | ||
|
|
||
| var sum_exp: T = 0 | ||
| var row_dot: T = 0 | ||
|
|
||
| for j in 0 ..< features: | ||
| let val = inp_ptr[row_inp_idx + j * inp_s1] | ||
| let t_val = tgt_ptr[row_tgt_idx + j * tgt_s1] | ||
|
|
||
| sum_exp += exp(val - max_val) | ||
| row_dot += val * t_val |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are iterating over each row's features twice: once to find the maximum value, and a second time to compute sum_exp and row_dot. This can be optimized into a single pass by using a streaming algorithm to calculate the max value and sum_exp simultaneously. This improves cache usage and performance, and also makes the code robust for cases where features is 0.
This approach is already used elsewhere in the codebase (e.g., streaming_max_sumexp in p_logsumexp.nim).
var max_val = -Inf.T
var sum_exp: T = 0
var row_dot: T = 0
for j in 0 ..< features:
let val = inp_ptr[row_inp_idx + j * inp_s1]
let t_val = tgt_ptr[row_tgt_idx + j * tgt_s1]
row_dot += val * t_val
if val <= max_val:
sum_exp += exp(val - max_val)
else:
# Found new max, rescale previous sum_exp
sum_exp = sum_exp * exp(max_val - val) + 1.T
max_val = val
| var max_val = inp_ptr[row_inp_idx] | ||
| for j in 1 ..< features: | ||
| let val = inp_ptr[row_inp_idx + j * inp_s1] | ||
| if val > max_val: | ||
| max_val = val | ||
|
|
||
| var sum_exp: T = 0 | ||
| for j in 0 ..< features: | ||
| sum_exp += exp(inp_ptr[row_inp_idx + j * inp_s1] - max_val) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the dense version, you are iterating over each row's features twice. This can be optimized into a single pass using a streaming algorithm to calculate the max value and sum_exp simultaneously. This improves performance and makes the code more robust, especially for the features = 0 edge case.
This approach is already used elsewhere in the codebase (e.g., streaming_max_sumexp in p_logsumexp.nim).
var max_val = -Inf.T
var sum_exp: T = 0
for j in 0 ..< features:
let val = inp_ptr[row_inp_idx + j * inp_s1]
if val <= max_val:
sum_exp += exp(val - max_val)
else:
# Found new max, rescale previous sum_exp
sum_exp = sum_exp * exp(max_val - val) + 1.T
max_val = val
I replaced the high-level tensor operations with raw pointer access and manual strides to remove allocation overhead. Also fused the max, sum_exp and dot product steps into a single loop to improve cache locality, added OpenMP parallelization for the batch dimension.
Benchmarks on my CPU i5 3470 (batch 128, classes 1000):
Before:
After:
Verified with test (test_nnp_loss).