From a2c7917dd3b4eb90f5bf38e615a226d081cd541e Mon Sep 17 00:00:00 2001 From: AntonOresten Date: Sat, 21 Feb 2026 21:40:20 +0100 Subject: [PATCH] Insert trailing ones instead of leading ones when broadcasting --- examples/layernorm.jl | 4 ++-- src/compiler/intrinsics/core.jl | 8 ++++---- test/execution/broadcast.jl | 21 +++++++++++++++++++++ 3 files changed, 27 insertions(+), 6 deletions(-) diff --git a/examples/layernorm.jl b/examples/layernorm.jl index 1daf05e..3f7af6f 100644 --- a/examples/layernorm.jl +++ b/examples/layernorm.jl @@ -45,7 +45,7 @@ function layer_norm_fwd(X::ct.TileArray{Float32, 2}, W::ct.TileArray{Float32, 1} while j <= num_tiles tx = ct.load(X, (bid_m, j), (1, TILE_N); padding_mode=ct.PaddingMode.Zero) # Mask for valid elements - mask = ct.broadcast_to(((j - Int32(1)) * Int32(TILE_N) .+ ct.arange((TILE_N,), Int32)) .<= N, (1, TILE_N)) + mask = reshape(((j - Int32(1)) * Int32(TILE_N) .+ ct.arange((TILE_N,), Int32)) .<= N, (1, TILE_N)) centered_tx = ifelse.(mask, tx .- mean, 0.0f0) var = var .+ (centered_tx .^ 2.0f0) j += Int32(1) @@ -96,7 +96,7 @@ bid_m and j are 1-indexed (block ID and tile index). indices = ct.arange((TILE_N,), Int32) offset = (j - Int32(1)) * Int32(TILE_N) global_indices = offset .+ indices - mask = ct.broadcast_to(global_indices .<= N, (1, TILE_N)) + mask = reshape(global_indices .<= N, (1, TILE_N)) xhat_masked = ifelse.(mask, xhat, 0.0f0) wdy_masked = ifelse.(mask, wdy, 0.0f0) diff --git a/src/compiler/intrinsics/core.jl b/src/compiler/intrinsics/core.jl index c18b7e1..dde2348 100644 --- a/src/compiler/intrinsics/core.jl +++ b/src/compiler/intrinsics/core.jl @@ -63,7 +63,7 @@ end """ broadcast_tile_to_shape!(cb, tt, tv::CGVal, target_shape::Vector{Int}, dtype::TypeId) -> Value -Broadcast a tile to a target shape by inserting ReshapeOp (for leading 1s) and BroadcastOp. +Broadcast a tile to a target shape by inserting ReshapeOp (for trailing 1s) and BroadcastOp. Returns the value after broadcasting, or the original value if shapes already match. """ function broadcast_tile_to_shape!(cb::CodeBuilder, tt::TypeTable, tv::CGVal, @@ -78,11 +78,11 @@ function broadcast_tile_to_shape!(cb::CodeBuilder, tt::TypeTable, tv::CGVal, current_val = tv.v current_shape = src_shape - # Step 1: Add leading 1s via ReshapeOp if needed (dimension mismatch) + # Step 1: Add trailing 1s via ReshapeOp if needed (dimension mismatch) + # Follows Julia convention: (n,) pads to (n, 1) — first dimension aligns. if length(current_shape) < length(target_shape) - # Prepend 1s to match target ndim n_extra = length(target_shape) - length(current_shape) - new_shape = vcat(fill(1, n_extra), current_shape) + new_shape = vcat(current_shape, fill(1, n_extra)) reshaped_type = tile_type!(tt, dtype, new_shape) current_val = encode_ReshapeOp!(cb, reshaped_type, current_val) current_shape = new_shape diff --git a/test/execution/broadcast.jl b/test/execution/broadcast.jl index 94dbbf8..00055ae 100644 --- a/test/execution/broadcast.jl +++ b/test/execution/broadcast.jl @@ -191,6 +191,27 @@ end end end +@testset "1D-to-2D broadcast: (64,) .+ (64, 128)" begin + function broadcast_1d_2d_kernel(a::ct.TileArray{Float32,1}, b::ct.TileArray{Float32,2}, + c::ct.TileArray{Float32,2}) + col = ct.load(a, 1, (64,)) # 1D: (64,) + tile = ct.load(b, (1, 1), (64, 128)) # 2D: (64, 128) + result = col .+ tile # broadcast (64,) → (64, 1) → (64, 128) + ct.store(c, (1, 1), result) + return + end + + m, n = 64, 128 + a = CUDA.rand(Float32, m) + b = CUDA.rand(Float32, m, n) + c = CUDA.zeros(Float32, m, n) + + ct.launch(broadcast_1d_2d_kernel, 1, a, b, c) + + expected = Array(a) .+ Array(b) # Julia: (64,) broadcasts along dim 1 + @test Array(c) ≈ expected +end + end @testset "comparison operations" begin