Skip to content

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Dec 9, 2025

Introduces iris.x APIs.

Submission Checklist

@github-actions github-actions bot added in-progress We are working on it iris Iris project issue labels Dec 9, 2025
@neoblizz neoblizz changed the base branch from main to muhosama/ccl-more December 9, 2025 20:32
@neoblizz neoblizz requested a review from Copilot December 9, 2025 20:33
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces iris.x, a new module providing device-side tile-level primitives for fine-grained collective operations. Unlike iris.ccl which handles full tensors with internal tiling, iris.x provides composable functions that users can call from their own kernels to manage tile iteration themselves.

Key Changes:

  • New iris.x module with tile-level communication primitives (all-reduce, all-gather, all-to-all, reduce-scatter)
  • Fused GEMM+Communication operations requiring tritonBLAS (gemm_all_reduce, gemm_all_gather, etc.)
  • Comprehensive test suite for new primitives in tests/x/
  • CI/CD modernization with unified workflow replacing 3 separate workflows
  • Documentation updates and benchmark enhancements

Reviewed changes

Copilot reviewed 33 out of 33 changed files in this pull request and generated 14 comments.

Show a summary per file
File Description
iris/x/__init__.py Module initialization exposing all tile-level primitives with optional GEMM operations
iris/x/all_reduce.py Five all-reduce variants (atomic, one-shot, two-shot, spinlock, ring) for different use cases
iris/x/all_gather.py Tile-level all-gather primitive for gathering data from all ranks
iris/x/all_to_all.py Tile-level all-to-all primitive for bidirectional data exchange
iris/x/reduce_scatter.py Tile-level reduce-scatter that reduces and scatters to assigned ranks
iris/x/gemm_all_reduce.py Fused GEMM + all-reduce using tritonBLAS stages
iris/x/gemm_all_gather.py Fused GEMM + all-gather combining computation and communication
iris/x/gemm_reduce_scatter.py Fused GEMM + reduce-scatter for column-parallel workloads
iris/x/all_gather_gemm.py Fused all-gather + GEMM for tensor-parallel workloads
iris/x/common.py Shared utilities for tile indexing and offset computation
tests/x/test_*.py Comprehensive test suite validating all primitives against PyTorch references
.github/workflows/iris-tests.yml New unified test workflow supporting multiple test directories and install methods
.github/scripts/run_tests.sh Updated test runner with tritonBLAS installation for iris.x tests
tests/ccl/test_all_reduce.py Modified to add explicit preamble calls for better test isolation
pyproject.toml Added optional gemm dependency group for tritonBLAS
docs/reference/examples.md Updated documentation with new example references
benchmark/ccl/all_to_all/benchmark.py Added RCCL comparison benchmarking option

@mawad-amd
Copy link
Collaborator

@neoblizz we should be able to use aggregate to cleanup the APIs for device-side APIs. See https://godbolt.org/z/hY3oWfW1x

Resolved conflicts by accepting main's changes for:
- .gitignore
- benchmark/ccl/*.py files
- docker/Dockerfile
- iris/ccl/*.py files
…eContext

Refactor all tile-based collective operations and fused GEMM operators to use
new object-oriented API, dramatically simplifying function signatures and
improving code readability.

Changes:
- Collectives: all_gather, all_reduce (4 variants), reduce_scatter, all_to_all
- Fused ops: all_gather_gemm, gemm_all_gather, gemm_all_reduce, gemm_reduce_scatter
- Replace verbose parameter lists with OOP objects (Tile, TensorView, DeviceContext)
- Add tl.constexpr annotations to all GEMM kernel parameters
- Fix iris.load/atomic_add call signatures for correct argument ordering
- Net reduction: -50 lines of code across 8 files
Update all test kernels to use new OOP API (Tile, TensorView, DeviceContext)
and fix critical tile iteration bug causing test failures at scale.

Changes:
- Rename all test kernels to test_x_*_kernel pattern (avoids pytest warnings)
- Update kernel calls to use OOP objects instead of verbose parameters
- Fix tile iteration stride: use tl.num_programs(0) instead of 1 to prevent
  multiple CUs from processing the same tiles (fixes race conditions)
- Fix all_to_all PyTorch reference to use .contiguous() chunks
@neoblizz neoblizz marked this pull request as ready for review January 28, 2026 20:57
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 10 comments.

bias_vector = tl.load(
bias_ptr + row_indices * stride_bias, mask=row_indices < M, other=0.0
) # Load Bias vector
acc = add_vector(acc, bias_vector, QUANTIZED=False) # Add bias vector to output accumulator
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The add_vector function is being called with bias_vector as the second argument, but the bias needs to be properly broadcasted. The shape appears incorrect - it should be bias_vector[:, None] to broadcast properly across the second dimension, matching the pattern used in gemm_reduce_scatter.py line 179 and gemm_all_gather.py line 205.

Suggested change
acc = add_vector(acc, bias_vector, QUANTIZED=False) # Add bias vector to output accumulator
acc = add_vector(acc, bias_vector[:, None], QUANTIZED=False) # Add bias vector to output accumulator, broadcast across N dimension

Copilot uses AI. Check for mistakes.
tl.assume(stride_cn > 0)

# Determine accumulator dtype based on output type
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accumulator dtype logic appears inverted. It should be tl.float32 for most types and tl.int32 only for int8 output. The condition should be: acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32. This matches the pattern in gemm_reduce_scatter.py line 120.

Suggested change
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32
acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32

Copilot uses AI. Check for mistakes.
tl.assume(stride_ag_n > 0)

# Determine accumulator dtype based on output type
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accumulator dtype logic appears inverted. It should be tl.float32 for most types and tl.int32 only for int8 output. The condition should be: acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32. This is inconsistent with the correct logic in gemm_all_gather.py line 113 and gemm_reduce_scatter.py line 120.

Suggested change
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32
acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32

Copilot uses AI. Check for mistakes.
tl.assume(stride_cn > 0)

# Determine accumulator dtype based on output type
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The accumulator dtype logic appears inverted. It should be tl.float32 for most types and tl.int32 only for int8 output. The condition should be: acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32. This is inconsistent with the correct logic in gemm_all_gather.py line 113.

Suggested change
acc_dtype = tl.int32 if C.type.element_ty != tl.int8 else tl.float32
acc_dtype = tl.float32 if C.type.element_ty != tl.int8 else tl.int32

Copilot uses AI. Check for mistakes.
Comment on lines +172 to +174
# Compute chunk for this step
chunk_id = (ctx.rank - step - 1) % ctx.world_size

Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable chunk_id is not used.

Suggested change
# Compute chunk for this step
chunk_id = (ctx.rank - step - 1) % ctx.world_size

Copilot uses AI. Check for mistakes.
tl.store(dst_tile_ptr, result, mask=mask)

for step in range(ctx.world_size - 1):
send_rank = (ctx.rank + step) % ctx.world_size
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable send_rank is not used.

Copilot uses AI. Check for mistakes.
# Launch all_gather_gemm kernel
num_pid_m = (M + BLOCK_SIZE_M - 1) // BLOCK_SIZE_M
num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
total_tiles = num_pid_m * num_pid_n
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable total_tiles is not used.

Suggested change
total_tiles = num_pid_m * num_pid_n

Copilot uses AI. Check for mistakes.
):
"""Kernel that iterates over tiles assigned to this rank and calls reduce_scatter for each."""
pid = tl.program_id(0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable num_pid_m is not used.

Suggested change
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)

Copilot uses AI. Check for mistakes.
start_tile = cur_rank * tiles_per_rank
end_tile = start_tile + tiles_per_rank

num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable num_pid_m is not used.

Suggested change
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)

Copilot uses AI. Check for mistakes.

# Ring reduce-scatter phase
for step in range(ctx.world_size - 1):
send_rank = (ctx.rank - step) % ctx.world_size
Copy link

Copilot AI Jan 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assignment to 'send_rank' is unnecessary as it is redefined before this value is used.

Copilot uses AI. Check for mistakes.
Copy link
Collaborator

@mawad-amd mawad-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. We need to move DeviceContext to iris.py but that can happen in a different PR. APIs looking good to me. Thanks!

# Try to clone and install in editable mode (more reliable)
if [ ! -d \"/tmp/tritonBLAS\" ]; then
echo \"Cloning tritonBLAS repository...\"
cd /tmp && git clone https://github.com/ROCm/tritonBLAS.git 2>&1 | tail -3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we just add tritonBLAS as an optional dependency that way people can run this?

WORKDIR $TRITON_PATH
RUN git clone https://github.com/triton-lang/triton.git $TRITON_PATH
RUN git checkout 715f6b1d442601436bf8d462db6ff8e17aec8cfb
RUN git checkout bcbcabdd0cff6539c7168299075992b2a23ff38e
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You will need to update the apptainer file too for the runner.

tile: Tile object with position and dimensions.
src_view: TensorView for input tensor.
dst_view: TensorView for output tensor.
locks_ptr: Pointer to locks array (one lock per tile).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Future work: it Would be great if we can get rid of this locks_ptr argument.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

in-progress We are working on it iris Iris project issue

Projects

None yet

3 participants