Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 116 additions & 0 deletions src/compiler/intrinsics/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,119 @@ efunc(::typeof(Intrinsics.atomic_add), effects::CC.Effects) =
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add), args)
emit_atomic_rmw!(ctx, args, AtomicADD)
end

# ============================================================================
# Tile-indexed atomic operations
# These take pre-computed pointer tiles, value tiles, and masks.
# Used by the public API for tile-indexed atomic operations.
# ============================================================================

# Shared codegen helper for tile-indexed atomic RMW operations
function emit_atomic_rmw_tile!(ctx::CGCtx, args::AbstractVector, mode::AtomicRMWMode)
cb = ctx.cb
tt = ctx.tt

# args: (ptr_tile, val, mask, memory_order, memory_scope)
ptr_tv = emit_value!(ctx, args[1])
ptr_tv === nothing && throw(IRError("tile-indexed atomic RMW requires ptr_tile"))
val_tv = emit_value!(ctx, args[2])
val_tv === nothing && throw(IRError("tile-indexed atomic RMW requires value"))
mask_tv = emit_value!(ctx, args[3])
mask_tv === nothing && throw(IRError("tile-indexed atomic RMW requires mask"))

memory_order = @something get_constant(ctx, args[4]) throw(IRError("tile-indexed atomic RMW requires constant memory_order"))
memory_scope = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic RMW requires constant memory_scope"))

shape = val_tv.shape
elem_type = eltype(val_tv.jltype)

dtype = julia_to_tile_dtype!(tt, elem_type)
result_tile_type = tile_type!(tt, dtype, collect(shape))
token_type = Token(tt)

# Auto-promote integer ADD to float ADD for floating-point types
actual_mode = mode
if mode == AtomicADD && elem_type <: AbstractFloat
actual_mode = AtomicADDF
end

mem_ordering = memory_order_to_semantics(memory_order)
mem_scope = memory_scope_to_scope(memory_scope)

old_val, new_token = encode_AtomicRMWPtrOp!(cb, result_tile_type, token_type,
ptr_tv.v, val_tv.v, actual_mode;
mask=mask_tv.v,
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
ctx.token = new_token

CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
end

# Tile-indexed atomic exchange
@intrinsic atomic_xchg_tile(ptr_tile, val, mask, memory_order, memory_scope)
function tfunc(𝕃, ::typeof(Intrinsics.atomic_xchg_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...)
CC.widenconst(val)
end
efunc(::typeof(Intrinsics.atomic_xchg_tile), effects::CC.Effects) =
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_xchg_tile), args)
emit_atomic_rmw_tile!(ctx, args, AtomicXCHG)
end

# Tile-indexed atomic addition
@intrinsic atomic_add_tile(ptr_tile, val, mask, memory_order, memory_scope)
function tfunc(𝕃, ::typeof(Intrinsics.atomic_add_tile), @nospecialize(ptrs), @nospecialize(val), @nospecialize args...)
CC.widenconst(val)
end
efunc(::typeof(Intrinsics.atomic_add_tile), effects::CC.Effects) =
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_add_tile), args)
emit_atomic_rmw_tile!(ctx, args, AtomicADD)
end

# Tile-indexed atomic compare-and-swap
@intrinsic atomic_cas_tile(ptr_tile, expected, desired, mask, memory_order, memory_scope)
function tfunc(𝕃, ::typeof(Intrinsics.atomic_cas_tile), @nospecialize(ptrs), @nospecialize(expected), @nospecialize args...)
CC.widenconst(expected)
end
efunc(::typeof(Intrinsics.atomic_cas_tile), effects::CC.Effects) =
CC.Effects(effects; effect_free=CC.ALWAYS_FALSE)
function emit_intrinsic!(ctx::CGCtx, ::typeof(Intrinsics.atomic_cas_tile), args)
cb = ctx.cb
tt = ctx.tt

# args: (ptr_tile, expected, desired, mask, memory_order, memory_scope)
ptr_tv = emit_value!(ctx, args[1])
ptr_tv === nothing && throw(IRError("tile-indexed atomic CAS requires ptr_tile"))
expected_tv = emit_value!(ctx, args[2])
expected_tv === nothing && throw(IRError("tile-indexed atomic CAS requires expected value"))
desired_tv = emit_value!(ctx, args[3])
desired_tv === nothing && throw(IRError("tile-indexed atomic CAS requires desired value"))
mask_tv = emit_value!(ctx, args[4])
mask_tv === nothing && throw(IRError("tile-indexed atomic CAS requires mask"))

memory_order = @something get_constant(ctx, args[5]) throw(IRError("tile-indexed atomic CAS requires constant memory_order"))
memory_scope = @something get_constant(ctx, args[6]) throw(IRError("tile-indexed atomic CAS requires constant memory_scope"))

shape = expected_tv.shape
elem_type = eltype(expected_tv.jltype)

dtype = julia_to_tile_dtype!(tt, elem_type)
result_tile_type = tile_type!(tt, dtype, collect(shape))
token_type = Token(tt)

mem_ordering = memory_order_to_semantics(memory_order)
mem_scope = memory_scope_to_scope(memory_scope)

old_val, new_token = encode_AtomicCASPtrOp!(cb, result_tile_type, token_type,
ptr_tv.v, expected_tv.v, desired_tv.v;
mask=mask_tv.v,
token=ctx.token,
memory_ordering=mem_ordering,
memory_scope=mem_scope)
ctx.token = new_token

CGVal(old_val, result_tile_type, Tile{elem_type, Tuple{shape...}}, collect(shape))
end
185 changes: 185 additions & 0 deletions src/language/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,188 @@ old_val = ct.atomic_add(counters, idx, Int32(1))
memory_scope::Int=MemScope.Device) where {T}
Intrinsics.atomic_add(array, index - One(), val, memory_order, memory_scope)
end

# ============================================================================
# Tile-indexed atomic operations
# These accept Tile indices to perform atomic operations on multiple elements.
# ============================================================================

# --- Pointer/mask helper (N-dimensional) ---

@inline function _atomic_ptrs_mask(array::TileArray{T, N},
indices::NTuple{N, Tile{<:Integer}}) where {T, N}
# Convert each index to 0-indexed
indices_0 = ntuple(Val(N)) do d
indices[d] .- one(eltype(indices[d]))
end

# Broadcast all index tiles to a common shape
S = reduce(broadcast_shape, ntuple(d -> size(indices[d]), Val(N)))

# Broadcast and convert to Int32
indices_i32 = ntuple(Val(N)) do d
convert(Tile{Int32}, broadcast_to(indices_0[d], S))
end

# Linear index: sum(idx[d] * stride[d])
linear_idx = reduce(.+, ntuple(Val(N)) do d
indices_i32[d] .* broadcast_to(Tile(array.strides[d]), S)
end)

ptr_tile = Intrinsics.offset(array.ptr, linear_idx)

# Bounds mask: 0 <= idx[d] < size[d] for all d
zero_bc = broadcast_to(Tile(Int32(0)), S)
mask = reduce(.&, ntuple(Val(N)) do d
(indices_i32[d] .>= zero_bc) .& (indices_i32[d] .< broadcast_to(Tile(size(array, d)), S))
end)

(ptr_tile, mask, S)
end

# 1D convenience: single Tile -> 1-tuple
@inline function _atomic_ptrs_mask(array::TileArray{T, 1}, indices::Tile{<:Integer}) where {T}
_atomic_ptrs_mask(array, (indices,))
end

# --- RMW operations (atomic_add, atomic_xchg) ---

const _ATOMIC_RMW_OPS = (
(:add, :atomic_add_tile),
(:xchg, :atomic_xchg_tile),
)

for (op, intrinsic) in _ATOMIC_RMW_OPS
fname = Symbol(:atomic_, op)

# N-D with scalar value
@eval @inline function $fname(array::TileArray{T, N},
indices::NTuple{N, Tile{<:Integer}}, val::T;
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T, N}
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
val_tile = broadcast_to(Tile(val), S)
Intrinsics.$intrinsic(ptr_tile, val_tile, mask, memory_order, memory_scope)
end

# N-D with tile value
@eval @inline function $fname(array::TileArray{T, N},
indices::NTuple{N, Tile{<:Integer}}, val::Tile{T};
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T, N}
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
val_bc = broadcast_to(val, S)
Intrinsics.$intrinsic(ptr_tile, val_bc, mask, memory_order, memory_scope)
end

# 1D convenience: single Tile index
@eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{<:Integer}, val::T;
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
$fname(array, (indices,), val; memory_order, memory_scope)
end

@eval @inline function $fname(array::TileArray{T, 1}, indices::Tile{<:Integer}, val::Tile{T};
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
$fname(array, (indices,), val; memory_order, memory_scope)
end

end

# --- CAS operations (separate due to different signature) ---

# N-D with scalar expected/desired
@inline function atomic_cas(array::TileArray{T, N},
indices::NTuple{N, Tile{<:Integer}},
expected::T, desired::T;
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T, N}
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
expected_tile = broadcast_to(Tile(expected), S)
desired_tile = broadcast_to(Tile(desired), S)
Intrinsics.atomic_cas_tile(ptr_tile, expected_tile, desired_tile, mask,
memory_order, memory_scope)
end

# N-D with tile expected/desired
@inline function atomic_cas(array::TileArray{T, N},
indices::NTuple{N, Tile{<:Integer}},
expected::Tile{T}, desired::Tile{T};
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T, N}
ptr_tile, mask, S = _atomic_ptrs_mask(array, indices)
expected_bc = broadcast_to(expected, S)
desired_bc = broadcast_to(desired, S)
Intrinsics.atomic_cas_tile(ptr_tile, expected_bc, desired_bc, mask,
memory_order, memory_scope)
end

# 1D convenience: single Tile index
@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer},
expected::T, desired::T;
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope)
end

@inline function atomic_cas(array::TileArray{T, 1}, indices::Tile{<:Integer},
expected::Tile{T}, desired::Tile{T};
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
atomic_cas(array, (indices,), expected, desired; memory_order, memory_scope)
end

# ============================================================================
# Tile-space atomic operations
# These accept tile-space integer indices (like store) to atomically operate
# on contiguous tile-shaped blocks of an array.
# ============================================================================

# --- Pointer/mask helper for tile-space indexing ---

@inline function _tile_space_ptrs_mask(array::TileArray{T, N},
index::NTuple{N, Integer},
::Val{Shape}) where {T, N, Shape}
# Build per-dimension element index tiles (1-indexed)
# For dim d: arange [1..Shape[d]], reshaped for N-D broadcasting, plus base offset
idx_tiles = ntuple(Val(N)) do d
bcast_shape = ntuple(i -> i == d ? Shape[d] : 1, Val(N))
base = Int32((index[d] - 1) * Shape[d])
reshape(arange((Shape[d],), Int32), bcast_shape) .+ Tile(base)
end

# 0-indexed linear offset: sum((idx[d] - 1) * stride[d])
linear_idx = reduce(.+, ntuple(Val(N)) do d
(idx_tiles[d] .- Tile(Int32(1))) .* Tile(array.strides[d])
end)

ptr_tile = Intrinsics.offset(array.ptr, linear_idx)

# Bounds mask: 1 <= idx[d] <= size(array, d) for all d
mask = reduce(.&, ntuple(Val(N)) do d
(idx_tiles[d] .>= Tile(Int32(1))) .& (idx_tiles[d] .<= Tile(size(array, d)))
end)

(ptr_tile, mask)
end

# --- Tile-space atomic_add ---

# N-D tuple index + tile value (like store)
@inline function atomic_add(array::TileArray{T, N},
index::NTuple{N, Integer}, tile::Tile{T};
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T, N}
reshaped = _reshape_to_rank(tile, Val(N))
ptr_tile, mask = _tile_space_ptrs_mask(array, index, Val(size(reshaped)))
Intrinsics.atomic_add_tile(ptr_tile, reshaped, mask, memory_order, memory_scope)
end

# 1D convenience (scalar index)
@inline function atomic_add(array::TileArray{T, 1},
index::Integer, tile::Tile{T};
memory_order::Int=MemoryOrder.AcqRel,
memory_scope::Int=MemScope.Device) where {T}
atomic_add(array, (index,), tile; memory_order, memory_scope)
end
Loading
Loading