diff --git a/src/tirx/transform/lower_device_kernel_launch.cc b/src/tirx/transform/lower_device_kernel_launch.cc index 3ff4cf17c585..fea8d458b963 100644 --- a/src/tirx/transform/lower_device_kernel_launch.cc +++ b/src/tirx/transform/lower_device_kernel_launch.cc @@ -104,6 +104,17 @@ class DeviceInfoCollector : public StmtVisitor { return extent.value(); } + void VisitStmt_(const BindNode* op) final { + // Track Bind definitions so that thread_extent values and + // dyn_shmem_size expressions that reference locally-bound + // variables (e.g. CSE variables) can be inlined back to + // expressions over function parameters. Substitute earlier + // bindings into the value to handle chains (cse_v2 = f(cse_v1)). + PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : op->value; + bind_map_.Set(op->var, value); + StmtVisitor::VisitStmt_(op); + } + void VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::thread_extent) { IterVar iv = Downcast(op->node); @@ -113,7 +124,10 @@ class DeviceInfoCollector : public StmtVisitor { if (!defined_thread.count(iv.get())) { defined_thread.insert(iv.get()); info_.launch_params.push_back(iv->thread_tag); - thread_extent.Set(iv->thread_tag, op->value); + // Inline any locally-bound variables (e.g. from CSE) so + // that the extent is expressible in terms of function params. + PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : op->value; + thread_extent.Set(iv->thread_tag, value); } } @@ -133,6 +147,10 @@ class DeviceInfoCollector : public StmtVisitor { } dyn_size *= op->buffer->dtype.bytes(); + // Inline any locally-bound variables (e.g. from CSE). + if (bind_map_.size()) { + dyn_size = Substitute(dyn_size, bind_map_); + } dyn_shmem_size = dyn_size; } StmtVisitor::VisitStmt_(op); @@ -146,6 +164,8 @@ class DeviceInfoCollector : public StmtVisitor { ffi::Map thread_extent; // The amount of dynamic shared memory used ffi::Optional dyn_shmem_size{std::nullopt}; + // Accumulated Bind definitions for inlining into extent/size expressions. + ffi::Map bind_map_; }; class ReturnRemover : public StmtExprMutator { diff --git a/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py index 6d77d7e87164..3dab487ab59f 100644 --- a/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py +++ b/tests/python/tirx-transform/test_tir_transform_device_kernel_launch.py @@ -223,5 +223,57 @@ def kernel(A_data: T.handle("float32")): tvm.ir.assert_structural_equal(After, Expected) +def test_bind_before_thread_extent(): + """DeviceInfoCollector inlines Bind-defined variables in thread extents. + + When CSE (or another pass) inserts Bind statements before + thread_extent AttrStmts, the extent value may reference a + locally-bound variable instead of function parameters. + LowerDeviceKernelLaunch must inline these bindings so that the + launch argument is expressible in terms of the caller's arguments. + """ + + @I.ir_module + class Before: + @T.prim_func + def main(A: T.Buffer(16, "float32"), n: T.int32): + T.func_attr({"target": T.target("llvm")}) + Before.kernel(A.data, n) + + @T.prim_func + def kernel(A_data: T.handle("float32"), n: T.int32): + T.func_attr({"target": T.target("cuda"), "global_symbol": "kernel"}) + A = T.decl_buffer(16, dtype="float32", data=A_data) + v: T.int32 = n + 1 + i = T.launch_thread("threadIdx.x", v) + A[i] = 0.0 + + @I.ir_module + class Expected: + @T.prim_func + def main(A: T.Buffer(16, "float32"), n: T.int32): + T.func_attr({"target": T.target("llvm")}) + T.call_packed("kernel", A.data, n, n + 1) + + @T.prim_func + def kernel(A_data: T.handle("float32"), n: T.int32): + T.func_attr( + { + "target": T.target("cuda"), + "calling_conv": 2, + "tirx.kernel_launch_params": ["threadIdx.x"], + "global_symbol": "kernel", + "tirx.is_global_func": True, + } + ) + A = T.decl_buffer(16, dtype="float32", data=A_data) + v: T.int32 = n + 1 + i = T.launch_thread("threadIdx.x", v) + A[i] = 0.0 + + After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before) + tvm.ir.assert_structural_equal(After, Expected) + + if __name__ == "__main__": tvm.testing.main()