Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/tirx/transform/lower_device_kernel_launch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,17 @@ class DeviceInfoCollector : public StmtVisitor {
return extent.value();
}

void VisitStmt_(const BindNode* op) final {
// Track Bind definitions so that thread_extent values and
// dyn_shmem_size expressions that reference locally-bound
// variables (e.g. CSE variables) can be inlined back to
// expressions over function parameters. Substitute earlier
// bindings into the value to handle chains (cse_v2 = f(cse_v1)).
PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : op->value;
bind_map_.Set(op->var, value);
StmtVisitor::VisitStmt_(op);
}

void VisitStmt_(const AttrStmtNode* op) final {
if (op->attr_key == attr::thread_extent) {
IterVar iv = Downcast<IterVar>(op->node);
Expand All @@ -113,7 +124,10 @@ class DeviceInfoCollector : public StmtVisitor {
if (!defined_thread.count(iv.get())) {
defined_thread.insert(iv.get());
info_.launch_params.push_back(iv->thread_tag);
thread_extent.Set(iv->thread_tag, op->value);
// Inline any locally-bound variables (e.g. from CSE) so
// that the extent is expressible in terms of function params.
PrimExpr value = bind_map_.size() ? Substitute(op->value, bind_map_) : op->value;
thread_extent.Set(iv->thread_tag, value);
}
}

Expand All @@ -133,6 +147,10 @@ class DeviceInfoCollector : public StmtVisitor {
}
dyn_size *= op->buffer->dtype.bytes();

// Inline any locally-bound variables (e.g. from CSE).
if (bind_map_.size()) {
dyn_size = Substitute(dyn_size, bind_map_);
}
dyn_shmem_size = dyn_size;
}
StmtVisitor::VisitStmt_(op);
Expand All @@ -146,6 +164,8 @@ class DeviceInfoCollector : public StmtVisitor {
ffi::Map<ffi::String, PrimExpr> thread_extent;
// The amount of dynamic shared memory used
ffi::Optional<PrimExpr> dyn_shmem_size{std::nullopt};
// Accumulated Bind definitions for inlining into extent/size expressions.
ffi::Map<Var, PrimExpr> bind_map_;
};

class ReturnRemover : public StmtExprMutator {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,57 @@ def kernel(A_data: T.handle("float32")):
tvm.ir.assert_structural_equal(After, Expected)


def test_bind_before_thread_extent():
"""DeviceInfoCollector inlines Bind-defined variables in thread extents.

When CSE (or another pass) inserts Bind statements before
thread_extent AttrStmts, the extent value may reference a
locally-bound variable instead of function parameters.
LowerDeviceKernelLaunch must inline these bindings so that the
launch argument is expressible in terms of the caller's arguments.
"""

@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer(16, "float32"), n: T.int32):
T.func_attr({"target": T.target("llvm")})
Before.kernel(A.data, n)

@T.prim_func
def kernel(A_data: T.handle("float32"), n: T.int32):
T.func_attr({"target": T.target("cuda"), "global_symbol": "kernel"})
A = T.decl_buffer(16, dtype="float32", data=A_data)
v: T.int32 = n + 1
i = T.launch_thread("threadIdx.x", v)
A[i] = 0.0

@I.ir_module
class Expected:
@T.prim_func
def main(A: T.Buffer(16, "float32"), n: T.int32):
T.func_attr({"target": T.target("llvm")})
T.call_packed("kernel", A.data, n, n + 1)

@T.prim_func
def kernel(A_data: T.handle("float32"), n: T.int32):
T.func_attr(
{
"target": T.target("cuda"),
"calling_conv": 2,
"tirx.kernel_launch_params": ["threadIdx.x"],
"global_symbol": "kernel",
"tirx.is_global_func": True,
}
)
A = T.decl_buffer(16, dtype="float32", data=A_data)
v: T.int32 = n + 1
i = T.launch_thread("threadIdx.x", v)
A[i] = 0.0

After = tvm.tirx.transform.LowerDeviceKernelLaunch()(Before)
tvm.ir.assert_structural_equal(After, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading