A toy implementation of Tensor Parallel (TP) + Sequence Parallel (SP) with communication-computation overlap for educational purposes.
Companion blog post: Tensor Parallel + Sequence Parallel — A Deep Dive
Based on the original implementation by @xiabingquan.
# Run unit tests (needs >= 4 GPUs)
python -m tensor_parallel_toy test
# Run profiling (no TP / TP / TP+overlap, needs >= 4 GPUs)
python -m tensor_parallel_toy profile
# Or via shell scripts
bash run_tests.sh
bash run_profile.sh| File | Description |
|---|---|
config.py |
ModelConfig dataclass |
initialize.py |
Distributed init, global seed, CPU weight init |
parallel_linear.py |
ColumnParallelLinear, RowParallelLinear, overlap variants, comm helpers |
model.py |
RMSNorm, Attention, MLP, TransformerBlock, Transformer |
utils.py |
Shared distributed testing utilities |
test_tp.py |
Unit tests (TP vs no-TP, overlap vs no-overlap) |
profile_memory.py |
Training loop profiling (loss, grad norm, step time, peak memory) |
__init__.py |
Package exports |
__main__.py |
CLI entry point |
- TP Linear: Custom
autograd.Functionper layer. Forward saves the AllGathered input for backward reuse.use_overlapflag switches between basic and overlap autograd Functions within the same Module. - Overlap AG+GEMM: Ring P2P exchange (
dist.batch_isend_irecv) with pipelined GEMM on dual CUDA streams. - Overlap GEMM+RS: Chunked GEMM with per-chunk
dist.reduce(dst=i)(notreduce_scatter, which would scatter at the wrong granularity). - RMSNorm grads: Replicated weights need
AllReduceafter backward since each rank only seess/npositions. - Weight init: Global
torch.manual_seed→ all ranks draw from the same RNG in the same order → sharded weights are consistent with the full model.
- Python >= 3.8
- PyTorch >= 2.1
- Multiple CUDA GPUs (>= 4 for profiling)