Skip to content

Add autotuning (experimental)#95

Draft
AntonOresten wants to merge 6 commits intoJuliaGPU:mainfrom
AntonOresten:autotune
Draft

Add autotuning (experimental)#95
AntonOresten wants to merge 6 commits intoJuliaGPU:mainfrom
AntonOresten:autotune

Conversation

@AntonOresten
Copy link
Contributor

No description provided.

@AntonOresten
Copy link
Contributor Author

AntonOresten commented Feb 24, 2026

Noticed a very significant slowdown when the order argument was specified in ct.load (#72). I'm testing a scoped inference cache and it seems to work:

using Base.ScopedValues: ScopedValue, with
const _SCOPED_INF_CACHE = ScopedValue{Vector{CC.InferenceResult}}()

function cuTileInterpreter(cache::CacheView; always_inline::Bool=true)
    method_table = get_method_table_view(cache.world)
    inf_cache = isassigned(_SCOPED_INF_CACHE) ? _SCOPED_INF_CACHE[] : Vector{CC.InferenceResult}()
	...
Click to see test script with output
# MWE: Shared inference cache for cuTileInterpreter
#
# Simulates the autotuning pattern: one MI, one CacheView, many
# const-seeded inference passes with different Constant values.

using cuTile
import cuTile: Constant, _SCOPED_INF_CACHE, cuTileInterpreter, cuTileMethodTable,
               CuTileResults, emit_ir
import Core.Compiler as CC
using CompilerCaching: CacheView, method_instance, typeinf!
using Base.ScopedValues: with

const ct = cuTile

# ── Two kernels: with and without `order` kwarg ──────────────────────
function kernel_order(A::ct.TileArray, B::ct.TileArray, TILE_M::Int, TILE_N::Int)
    bid = ct.bid(1)
    tile = ct.load(A, (bid, 1i32), (TILE_M, TILE_N);
                   padding_mode=ct.PaddingMode.Zero, order=(1, 2))
    ct.store(B, (bid, 1i32), tile)
    return
end

function kernel_no_order(A::ct.TileArray, B::ct.TileArray, TILE_M::Int, TILE_N::Int)
    bid = ct.bid(1)
    tile = ct.load(A, (bid, 1i32), (TILE_M, TILE_N);
                   padding_mode=ct.PaddingMode.Zero)
    ct.store(B, (bid, 1i32), tile)
    return
end

# ── Benchmark: mirrors the real autotuning pattern ───────────────────
function bench_inference(f, configs; use_scoped_cache::Bool)
    world = Base.get_world_counter()
    base_argtypes = Tuple{ct.TileArray{Float32, 2}, ct.TileArray{Float32, 2}, Int, Int}
    mi = something(
        method_instance(f, base_argtypes; world, method_table=cuTileMethodTable),
        method_instance(f, base_argtypes; world),
    )

    # One CacheView for all configs (same as real autotuning)
    opts = (sm_arch="sm_120", opt_level=3, num_ctas=nothing, occupancy=nothing)
    cache = CacheView{CuTileResults}((:mwe, opts, objectid(f), use_scoped_cache, time_ns()), world)

    function do_infer()
        # Generic inference once (establishes the CI)
        interp = cuTileInterpreter(cache)
        typeinf!(cache, interp, mi)

        # Const-seeded inference per config (the hot path)
        for (tile_m, tile_n) in configs
            const_argtypes = Any[
                CC.Const(f),
                ct.TileArray{Float32, 2},
                ct.TileArray{Float32, 2},
                CC.Const(tile_m),
                CC.Const(tile_n),
            ]
            interp2 = cuTileInterpreter(cache)
            typeinf!(cache, interp2, mi, const_argtypes)
        end
    end

    if use_scoped_cache
        with(_SCOPED_INF_CACHE => CC.InferenceResult[]) do
            do_infer()
        end
    else
        do_infer()
    end
end

# ── Config grid (match real autotuning scale) ────────────────────────
tile_ms = [16, 32, 64, 128, 256]
tile_ns = [16, 32, 64, 128, 256]
configs = [(m, n) for m in tile_ms for n in tile_ns]
println("Configs: $(length(configs))")

# ── Warmup ───────────────────────────────────────────────────────────
print("Warmup... ")
bench_inference(kernel_order, configs[1:1]; use_scoped_cache=false)
bench_inference(kernel_no_order, configs[1:1]; use_scoped_cache=false)
println("done")

for label in ("Run 1", "Run 2")
    println("\n──── $label ────")

    println("\n  ═══ WITH order=(1,2) ═══")
    t1 = @elapsed bench_inference(kernel_order, configs; use_scoped_cache=false)
    println("    Fresh:  $(round(t1, digits=3))s  ($(round(t1/length(configs)*1000, digits=1))ms/cfg)")
    t2 = @elapsed bench_inference(kernel_order, configs; use_scoped_cache=true)
    println("    Shared: $(round(t2, digits=3))s  ($(round(t2/length(configs)*1000, digits=1))ms/cfg)")
    println("    Speedup: $(round(t1/t2, digits=1))x")

    println("\n  ═══ WITHOUT order ═══")
    t3 = @elapsed bench_inference(kernel_no_order, configs; use_scoped_cache=false)
    println("    Fresh:  $(round(t3, digits=3))s  ($(round(t3/length(configs)*1000, digits=1))ms/cfg)")
    t4 = @elapsed bench_inference(kernel_no_order, configs; use_scoped_cache=true)
    println("    Shared: $(round(t4, digits=3))s  ($(round(t4/length(configs)*1000, digits=1))ms/cfg)")
    println("    Speedup: $(round(t3/t4, digits=1))x")

    println("\n  ═══ Summary ═══")
    println("    order overhead (fresh):  $(round(t1/t3, digits=2))x")
    println("    order overhead (shared): $(round(t2/t4, digits=2))x")
end
Configs: 25
Warmup... done

──── Run 1 ────

  ═══ WITH order=(1,2) ═══
    Fresh:  0.656s  (26.2ms/cfg)
    Shared: 0.31s  (12.4ms/cfg)
    Speedup: 2.1x

  ═══ WITHOUT order ═══
    Fresh:  0.322s  (12.9ms/cfg)
    Shared: 0.293s  (11.7ms/cfg)
    Speedup: 1.1x

  ═══ Summary ═══
    order overhead (fresh):  2.04x
    order overhead (shared): 1.06x

──── Run 2 ────

  ═══ WITH order=(1,2) ═══
    Fresh:  0.29s  (11.6ms/cfg)
    Shared: 0.292s  (11.7ms/cfg)
    Speedup: 1.0x

  ═══ WITHOUT order ═══
    Fresh:  0.287s  (11.5ms/cfg)
    Shared: 0.291s  (11.6ms/cfg)
    Speedup: 1.0x

  ═══ Summary ═══
    order overhead (fresh):  1.01x
    order overhead (shared): 1.0x

This matters a lot because it was making the codegen stage for a matmul 10x slower, even though the resulting Tile IR was identical.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant