-
Notifications
You must be signed in to change notification settings - Fork 32
iris.x: Device-side communication + .x APIs.
#296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR introduces iris.x, a new module providing device-side tile-level primitives for fine-grained collective operations. Unlike iris.ccl which handles full tensors with internal tiling, iris.x provides composable functions that users can call from their own kernels to manage tile iteration themselves.
Key Changes:
- New
iris.xmodule with tile-level communication primitives (all-reduce, all-gather, all-to-all, reduce-scatter) - Fused GEMM+Communication operations requiring tritonBLAS (gemm_all_reduce, gemm_all_gather, etc.)
- Comprehensive test suite for new primitives in
tests/x/ - CI/CD modernization with unified workflow replacing 3 separate workflows
- Documentation updates and benchmark enhancements
Reviewed changes
Copilot reviewed 33 out of 33 changed files in this pull request and generated 14 comments.
Show a summary per file
| File | Description |
|---|---|
iris/x/__init__.py |
Module initialization exposing all tile-level primitives with optional GEMM operations |
iris/x/all_reduce.py |
Five all-reduce variants (atomic, one-shot, two-shot, spinlock, ring) for different use cases |
iris/x/all_gather.py |
Tile-level all-gather primitive for gathering data from all ranks |
iris/x/all_to_all.py |
Tile-level all-to-all primitive for bidirectional data exchange |
iris/x/reduce_scatter.py |
Tile-level reduce-scatter that reduces and scatters to assigned ranks |
iris/x/gemm_all_reduce.py |
Fused GEMM + all-reduce using tritonBLAS stages |
iris/x/gemm_all_gather.py |
Fused GEMM + all-gather combining computation and communication |
iris/x/gemm_reduce_scatter.py |
Fused GEMM + reduce-scatter for column-parallel workloads |
iris/x/all_gather_gemm.py |
Fused all-gather + GEMM for tensor-parallel workloads |
iris/x/common.py |
Shared utilities for tile indexing and offset computation |
tests/x/test_*.py |
Comprehensive test suite validating all primitives against PyTorch references |
.github/workflows/iris-tests.yml |
New unified test workflow supporting multiple test directories and install methods |
.github/scripts/run_tests.sh |
Updated test runner with tritonBLAS installation for iris.x tests |
tests/ccl/test_all_reduce.py |
Modified to add explicit preamble calls for better test isolation |
pyproject.toml |
Added optional gemm dependency group for tritonBLAS |
docs/reference/examples.md |
Updated documentation with new example references |
benchmark/ccl/all_to_all/benchmark.py |
Added RCCL comparison benchmarking option |
|
@neoblizz we should be able to use |
Resolved conflicts by accepting main's changes for: - .gitignore - benchmark/ccl/*.py files - docker/Dockerfile - iris/ccl/*.py files
…eContext Refactor all tile-based collective operations and fused GEMM operators to use new object-oriented API, dramatically simplifying function signatures and improving code readability. Changes: - Collectives: all_gather, all_reduce (4 variants), reduce_scatter, all_to_all - Fused ops: all_gather_gemm, gemm_all_gather, gemm_all_reduce, gemm_reduce_scatter - Replace verbose parameter lists with OOP objects (Tile, TensorView, DeviceContext) - Add tl.constexpr annotations to all GEMM kernel parameters - Fix iris.load/atomic_add call signatures for correct argument ordering - Net reduction: -50 lines of code across 8 files
Update all test kernels to use new OOP API (Tile, TensorView, DeviceContext) and fix critical tile iteration bug causing test failures at scale. Changes: - Rename all test kernels to test_x_*_kernel pattern (avoids pytest warnings) - Update kernel calls to use OOP objects instead of verbose parameters - Fix tile iteration stride: use tl.num_programs(0) instead of 1 to prevent multiple CUs from processing the same tiles (fixes race conditions) - Fix all_to_all PyTorch reference to use .contiguous() chunks
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
Copilot reviewed 24 out of 24 changed files in this pull request and generated 10 comments.
| bias_vector = tl.load( | ||
| bias_ptr + row_indices * stride_bias, mask=row_indices < M, other=0.0 | ||
| ) # Load Bias vector | ||
| acc = add_vector(acc, bias_vector, QUANTIZED=False) # Add bias vector to output accumulator |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The add_vector function is being called with bias_vector as the second argument, but the bias needs to be properly broadcasted. The shape appears incorrect - it should be bias_vector[:, None] to broadcast properly across the second dimension, matching the pattern used in gemm_reduce_scatter.py line 179 and gemm_all_gather.py line 205.
| acc = add_vector(acc, bias_vector, QUANTIZED=False) # Add bias vector to output accumulator | |
| acc = add_vector(acc, bias_vector[:, None], QUANTIZED=False) # Add bias vector to output accumulator, broadcast across N dimension |
iris/x/gemm_all_reduce.py
Outdated
| tl.assume(stride_cn > 0) | ||
|
|
||
| # Determine accumulator dtype based on output type | ||
| acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32 |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The accumulator dtype logic appears inverted. It should be tl.float32 for most types and tl.int32 only for int8 output. The condition should be: acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32. This matches the pattern in gemm_reduce_scatter.py line 120.
| acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32 | |
| acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 |
iris/x/all_gather_gemm.py
Outdated
| tl.assume(stride_ag_n > 0) | ||
|
|
||
| # Determine accumulator dtype based on output type | ||
| acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32 |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The accumulator dtype logic appears inverted. It should be tl.float32 for most types and tl.int32 only for int8 output. The condition should be: acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32. This is inconsistent with the correct logic in gemm_all_gather.py line 113 and gemm_reduce_scatter.py line 120.
| acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32 | |
| acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 |
iris/x/gemm_reduce_scatter.py
Outdated
| tl.assume(stride_cn > 0) | ||
|
|
||
| # Determine accumulator dtype based on output type | ||
| acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32 |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The accumulator dtype logic appears inverted. It should be tl.float32 for most types and tl.int32 only for int8 output. The condition should be: acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32. This is inconsistent with the correct logic in gemm_all_gather.py line 113.
| acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32 | |
| acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32 |
| # Compute chunk for this step | ||
| chunk_id = (ctx.rank - step - 1) % ctx.world_size | ||
|
|
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable chunk_id is not used.
| # Compute chunk for this step | |
| chunk_id = (ctx.rank - step - 1) % ctx.world_size |
| tl.store(dst_tile_ptr, result, mask=mask) | ||
|
|
||
| for step in range(ctx.world_size - 1): | ||
| send_rank = (ctx.rank + step) % ctx.world_size |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable send_rank is not used.
tests/x/test_all_gather_gemm.py
Outdated
| # Launch all_gather_gemm kernel | ||
| num_pid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M | ||
| num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N | ||
| total_tiles = num_pid_m * num_pid_n |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable total_tiles is not used.
| total_tiles = num_pid_m * num_pid_n |
| ): | ||
| """Kernel that iterates over tiles assigned to this rank and calls reduce_scatter for each.""" | ||
| pid = tl.program_id(0) | ||
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable num_pid_m is not used.
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
| start_tile = cur_rank * tiles_per_rank | ||
| end_tile = start_tile + tiles_per_rank | ||
|
|
||
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable num_pid_m is not used.
| num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) |
|
|
||
| # Ring reduce-scatter phase | ||
| for step in range(ctx.world_size - 1): | ||
| send_rank = (ctx.rank - step) % ctx.world_size |
Copilot
AI
Jan 28, 2026
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assignment to 'send_rank' is unnecessary as it is redefined before this value is used.
mawad-amd
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. We need to move DeviceContext to iris.py but that can happen in a different PR. APIs looking good to me. Thanks!
| # Try to clone and install in editable mode (more reliable) | ||
| if [ ! -d \"/tmp/tritonBLAS\" ]; then | ||
| echo \"Cloning tritonBLAS repository...\" | ||
| cd /tmp && git clone https://github.com/ROCm/tritonBLAS.git 2>&1 | tail -3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we just add tritonBLAS as an optional dependency that way people can run this?
| WORKDIR $TRITON_PATH | ||
| RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH | ||
| RUN git checkout 715f6b1d442601436bf8d462db6ff8e17aec8cfb | ||
| RUN git checkout bcbcabdd0cff6539c7168299075992b2a23ff38e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You will need to update the apptainer file too for the runner.
| tile: Tile object with position and dimensions. | ||
| src_view: TensorView for input tensor. | ||
| dst_view: TensorView for output tensor. | ||
| locks_ptr: Pointer to locks array (one lock per tile). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Future work: it Would be great if we can get rid of this locks_ptr argument.
Introduces
iris.xAPIs.Submission Checklist