Fused scan kernel optimization #465
Open
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
We optimize this kernel with a sort of checkpointing (see CHECKPOINT_INTERVAL in the code). It reduces intermediate buffer writes/reads by 4x. The backward recomputes values between checkpoints instead of reading them. The forward only writes every CHECKPOINT_INTERVAL amount of times. Also use fast math intrinsics where we can. End up with ~1.5x forward and ~1.2x backward speedups (about ~1.3x overall).
Note these benchmarks were done on an L4. The results may be different depending on your GPU's roofline characteristics. Speedups are better on memory bandwidth constrained GPUs
Forward Benchmarks
Backward Benchmarks