Skip to content

Conversation

@jonahsamost
Copy link

We optimize this kernel with a sort of checkpointing (see CHECKPOINT_INTERVAL in the code). It reduces intermediate buffer writes/reads by 4x. The backward recomputes values between checkpoints instead of reading them. The forward only writes every CHECKPOINT_INTERVAL amount of times. Also use fast math intrinsics where we can. End up with ~1.5x forward and ~1.2x backward speedups (about ~1.3x overall).

Note these benchmarks were done on an L4. The results may be different depending on your GPU's roofline characteristics. Speedups are better on memory bandwidth constrained GPUs

Forward Benchmarks

Config Original Optimized Speedup
B=512, T=64, H=256 1.131ms 0.731ms 1.55x
B=512, T=64, H=384 1.722ms 1.117ms 1.54x
B=512, T=64, H=512 2.180ms 1.429ms 1.53x
B=512, T=96, H=256 1.693ms 1.091ms 1.55x
B=768, T=64, H=256 1.673ms 1.094ms 1.53x
B=768, T=64, H=512 3.281ms 2.128ms 1.54x
B=1024, T=64, H=256 2.199ms 1.432ms 1.54x
B=1024, T=64, H=512 4.349ms 2.828ms 1.54x
B=1024, T=96, H=384 4.931ms 3.238ms 1.52x
B=1024, T=128, H=256 4.394ms 2.857ms 1.54x
B=1536, T=64, H=512 6.542ms 4.237ms 1.54x
B=2048, T=64, H=256 4.390ms 2.840ms 1.55x
B=2048, T=64, H=512 8.708ms 5.636ms 1.55x
B=2048, T=96, H=512 13.110ms 8.444ms 1.55x
B=512, T=128, H=512 4.380ms 2.836ms 1.54x
B=1024, T=91, H=384 4.646ms 3.057ms 1.52x
B=1536, T=77, H=512 7.801ms 5.076ms 1.54x

Backward Benchmarks

Config Original Checkpointed Speedup
B=512, T=64, H=256 1.978ms 1.621ms 1.22x
B=512, T=64, H=384 2.946ms 2.416ms 1.22x
B=512, T=64, H=512 3.836ms 3.171ms 1.21x
B=512, T=96, H=256 2.890ms 2.420ms 1.19x
B=768, T=64, H=256 2.890ms 2.387ms 1.21x
B=768, T=64, H=512 5.778ms 4.724ms 1.22x
B=1024, T=64, H=256 3.846ms 3.161ms 1.22x
B=1024, T=64, H=512 7.653ms 6.306ms 1.21x
B=1024, T=96, H=384 8.679ms 7.130ms 1.22x
B=1024, T=128, H=256 7.700ms 6.307ms 1.22x
B=1536, T=64, H=512 11.500ms 9.388ms 1.22x
B=2048, T=64, H=256 7.661ms 6.274ms 1.22x
B=2048, T=64, H=512 15.188ms 12.506ms 1.21x
B=2048, T=96, H=512 22.803ms 18.730ms 1.22x
B=512, T=128, H=512 7.640ms 6.278ms 1.22x
B=1024, T=91, H=384 8.176ms 6.721ms 1.22x
B=1536, T=77, H=512 13.715ms 11.275ms 1.22x

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant