Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/layernorm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ function layer_norm_fwd(X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1}
while j <= num_tiles
tx = ct.load(X, (bid_m, j), (1, TILE_N); padding_mode=ct.PaddingMode.Zero)
# Mask for valid elements
mask = ct.broadcast_to(((j - Int32(1)) * Int32(TILE_N) .+ ct.arange((TILE_N,), Int32)) .<= N, (1, TILE_N))
mask = reshape(((j - Int32(1)) * Int32(TILE_N) .+ ct.arange((TILE_N,), Int32)) .<= N, (1, TILE_N))
centered_tx = ifelse.(mask, tx .- mean, 0.0f0)
var = var .+ (centered_tx .^ 2.0f0)
j += Int32(1)
Expand Down Expand Up @@ -96,7 +96,7 @@ bid_m and j are 1-indexed (block ID and tile index).
indices = ct.arange((TILE_N,), Int32)
offset = (j - Int32(1)) * Int32(TILE_N)
global_indices = offset .+ indices
mask = ct.broadcast_to(global_indices .<= N, (1, TILE_N))
mask = reshape(global_indices .<= N, (1, TILE_N))

xhat_masked = ifelse.(mask, xhat, 0.0f0)
wdy_masked = ifelse.(mask, wdy, 0.0f0)
Expand Down
8 changes: 4 additions & 4 deletions src/compiler/intrinsics/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ end
"""
broadcast_tile_to_shape!(cb, tt, tv::CGVal, target_shape::Vector{Int}, dtype::TypeId) -> Value

Broadcast a tile to a target shape by inserting ReshapeOp (for leading 1s) and BroadcastOp.
Broadcast a tile to a target shape by inserting ReshapeOp (for trailing 1s) and BroadcastOp.
Returns the value after broadcasting, or the original value if shapes already match.
"""
function broadcast_tile_to_shape!(cb::CodeBuilder, tt::TypeTable, tv::CGVal,
Expand All @@ -78,11 +78,11 @@ function broadcast_tile_to_shape!(cb::CodeBuilder, tt::TypeTable, tv::CGVal,
current_val = tv.v
current_shape = src_shape

# Step 1: Add leading 1s via ReshapeOp if needed (dimension mismatch)
# Step 1: Add trailing 1s via ReshapeOp if needed (dimension mismatch)
# Follows Julia convention: (n,) pads to (n, 1) — first dimension aligns.
if length(current_shape) < length(target_shape)
# Prepend 1s to match target ndim
n_extra = length(target_shape) - length(current_shape)
new_shape = vcat(fill(1, n_extra), current_shape)
new_shape = vcat(current_shape, fill(1, n_extra))
reshaped_type = tile_type!(tt, dtype, new_shape)
current_val = encode_ReshapeOp!(cb, reshaped_type, current_val)
current_shape = new_shape
Expand Down
21 changes: 21 additions & 0 deletions test/execution/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,27 @@ end
end
end

@testset "1D-to-2D broadcast: (64,) .+ (64, 128)" begin
function broadcast_1d_2d_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,2},
c::ct.TileArray{Float32,2})
col = ct.load(a, 1, (64,)) # 1D: (64,)
tile = ct.load(b, (1, 1), (64, 128)) # 2D: (64, 128)
result = col .+ tile # broadcast (64,) → (64, 1) → (64, 128)
ct.store(c, (1, 1), result)
return
end

m, n = 64, 128
a = CUDA.rand(Float32, m)
b = CUDA.rand(Float32, m, n)
c = CUDA.zeros(Float32, m, n)

ct.launch(broadcast_1d_2d_kernel, 1, a, b, c)

expected = Array(a) .+ Array(b) # Julia: (64,) broadcasts along dim 1
@test Array(c) ≈ expected
end

end

@testset "comparison operations" begin
Expand Down