From 5caddee7f64728c4ea9523ecded1f3110e6bf6fa Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 26 Mar 2026 10:40:07 +0000 Subject: [PATCH 1/6] [FFI] Bump tvm-ffi to 63224e3 and fix regressions Update tvm-ffi submodule from c85fd42 (#471) to 63224e3 (#512). Fix two regressions from the bump: - Add no-op __init__ to relax.TEPlaceholderOp to prevent duplicate parameter name error from auto-init wiring (#491) - Use query_imports=True in Executable.__call__ to find main in imported sub-modules --- 3rdparty/tvm-ffi | 2 +- python/tvm/relax/expr.py | 3 +++ python/tvm/runtime/executable.py | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index c85fd42df6ea..63224e3f1e46 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit c85fd42df6eae4ae0ec1aaa4ebb67ac859758cf5 +Subproject commit 63224e3f1e464cc62307223787926a48fc8df8c0 diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 7c7bcc2aeaa7..8dd52de1b9fa 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -1203,6 +1203,9 @@ def const( class TEPlaceholderOp(tvm.te.tensor.Operation): """The placeholder op that represents a relax expression.""" + def __init__(self): + pass + def te_tensor( value: Expr, tir_var_map: dict[tvm.tirx.Var, tvm.tirx.PrimExpr], name: str = "rxplaceholder" diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py index 212896ccb2b7..2323ed02ca35 100644 --- a/python/tvm/runtime/executable.py +++ b/python/tvm/runtime/executable.py @@ -41,7 +41,7 @@ def __getitem__(self, name: str) -> PackedFunc: def __call__(self, *args, **kwargs) -> Any: """Call the executable.""" - return self.jit().main(*args, **kwargs) + return self.jit().get_function("main", query_imports=True)(*args, **kwargs) def jit( self, From ed32081afec38556eb97de56c35218bfcc223d83 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 26 Mar 2026 11:35:10 +0000 Subject: [PATCH 2/6] [FFI] Fix repr override, host/device split, and Module.__getattr__ - Add Node.__repr__ to use TVMScript printer instead of dataclass repr - Fix host/device split: use target.kind.name instead of str(target.kind) - Override Module.__getattr__ to use query_imports=True --- python/tvm/ir/base.py | 8 ++++++++ python/tvm/runtime/executable.py | 2 +- python/tvm/runtime/module.py | 9 +++++++++ python/tvm/tirx/build.py | 2 +- 4 files changed, 19 insertions(+), 2 deletions(-) diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index e466ef850043..15629dcbe61f 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -27,6 +27,14 @@ class Node(Object): """Base class of all IR Nodes.""" + def __repr__(self) -> str: + from tvm.runtime.script_printer import _script # noqa: PLC0415 + + try: + return _script(self, None) + except Exception: # noqa: BLE001 + return super().__repr__() + @register_object("ir.SourceMap") class SourceMap(Object): diff --git a/python/tvm/runtime/executable.py b/python/tvm/runtime/executable.py index 2323ed02ca35..212896ccb2b7 100644 --- a/python/tvm/runtime/executable.py +++ b/python/tvm/runtime/executable.py @@ -41,7 +41,7 @@ def __getitem__(self, name: str) -> PackedFunc: def __call__(self, *args, **kwargs) -> Any: """Call the executable.""" - return self.jit().get_function("main", query_imports=True)(*args, **kwargs) + return self.jit().main(*args, **kwargs) def jit( self, diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index 7586e0df1576..e9d9098a53d0 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -108,6 +108,15 @@ def __str__(self): class Module(_Module): """Runtime Module.""" + def __getattr__(self, name: str): + """Get function from the module, searching imported modules.""" + try: + func = self.get_function(name, query_imports=True) + except AttributeError as exc: + raise AttributeError(f"Module has no function '{name}'") from exc + setattr(self, name, func) + return func + def _collect_from_import_tree(self, filter_func): """Helper function to collect modules from the tree matching a filter_func, then return it. diff --git a/python/tvm/tirx/build.py b/python/tvm/tirx/build.py index d310afee7938..020730d2f9de 100644 --- a/python/tvm/tirx/build.py +++ b/python/tvm/tirx/build.py @@ -100,7 +100,7 @@ def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.han def is_host_func(f): target = f.attrs.get("target", tvm.target.Target("llvm")) - return str(target.kind) in ["llvm", "c"] + return target.kind.name in ["llvm", "c"] host_mod = tvm.tirx.transform.Filter(is_host_func)(mod) device_mod = tvm.tirx.transform.Filter(lambda f: not is_host_func(f))(mod) From de81d8fac7f3e7ab287d75ddebf44749df8f9c0b Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 26 Mar 2026 12:22:12 +0000 Subject: [PATCH 3/6] [FFI] Fix slots enforcement and remove duplicate field declarations - Remove duplicate field declarations from C++ child types (RXPlaceholderOpNode, FunctionFrameNode) so auto-init works correctly - Add __slots__ = ("__dict__",) to Pass, BlockBuilder, and TVMDerivedObject to support instance attributes under slots enforcement - Revert unnecessary Module.__getattr__ override with query_imports=True --- include/tvm/script/ir_builder/relax/frame.h | 5 ++--- python/tvm/ir/transform.py | 2 ++ python/tvm/relax/block_builder.py | 2 ++ python/tvm/relax/expr.py | 3 --- python/tvm/runtime/module.py | 9 --------- python/tvm/runtime/support.py | 2 ++ python/tvm/s_tir/meta_schedule/utils.py | 2 ++ src/relax/ir/emit_te.h | 7 +------ 8 files changed, 11 insertions(+), 21 deletions(-) diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 898e318950cc..3d4ca3267461 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -125,9 +125,8 @@ class FunctionFrameNode : public SeqExprFrameNode { .def_ro("params", &FunctionFrameNode::params) .def_ro("ret_struct_info", &FunctionFrameNode::ret_struct_info) .def_ro("is_pure", &FunctionFrameNode::is_pure) - .def_ro("attrs", &FunctionFrameNode::attrs) - .def_ro("binding_blocks", &FunctionFrameNode::binding_blocks) - .def_ro("output", &FunctionFrameNode::output); + .def_ro("attrs", &FunctionFrameNode::attrs); + // `binding_blocks` and `output` are inherited from SeqExprFrameNode. // `block_builder` is not registered as it's not visited. } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.FunctionFrame", FunctionFrameNode, diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 28c187d7deb8..3e22a2b9084e 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -147,6 +147,8 @@ class Pass(tvm.runtime.Object): conveniently interact with the base class. """ + __slots__ = ("__dict__",) + @property def info(self): """Get the pass meta.""" diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 1d6057cec500..13bea3180c29 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -153,6 +153,8 @@ class BlockBuilder(Object): mod = bb.get() """ + __slots__ = ("__dict__",) + _stack = [] @staticmethod diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py index 8dd52de1b9fa..7c7bcc2aeaa7 100644 --- a/python/tvm/relax/expr.py +++ b/python/tvm/relax/expr.py @@ -1203,9 +1203,6 @@ def const( class TEPlaceholderOp(tvm.te.tensor.Operation): """The placeholder op that represents a relax expression.""" - def __init__(self): - pass - def te_tensor( value: Expr, tir_var_map: dict[tvm.tirx.Var, tvm.tirx.PrimExpr], name: str = "rxplaceholder" diff --git a/python/tvm/runtime/module.py b/python/tvm/runtime/module.py index e9d9098a53d0..7586e0df1576 100644 --- a/python/tvm/runtime/module.py +++ b/python/tvm/runtime/module.py @@ -108,15 +108,6 @@ def __str__(self): class Module(_Module): """Runtime Module.""" - def __getattr__(self, name: str): - """Get function from the module, searching imported modules.""" - try: - func = self.get_function(name, query_imports=True) - except AttributeError as exc: - raise AttributeError(f"Module has no function '{name}'") from exc - setattr(self, name, func) - return func - def _collect_from_import_tree(self, filter_func): """Helper function to collect modules from the tree matching a filter_func, then return it. diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py index d40f5c263618..b0ac67176328 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/runtime/support.py @@ -149,6 +149,8 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" + __slots__ = ("__dict__", "__weakref__",) + _cls = cls _type = "TVMDerivedObject" diff --git a/python/tvm/s_tir/meta_schedule/utils.py b/python/tvm/s_tir/meta_schedule/utils.py index b4cd2f4009df..834ebb7ef39d 100644 --- a/python/tvm/s_tir/meta_schedule/utils.py +++ b/python/tvm/s_tir/meta_schedule/utils.py @@ -109,6 +109,8 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" + __slots__ = ("__dict__",) + _cls = cls _type = "TVMDerivedObject" diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index d7c998729a10..31b4cc292762 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -44,12 +44,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("name", &RXPlaceholderOpNode::name) - .def_ro("tag", &RXPlaceholderOpNode::tag) - .def_ro("attrs", &RXPlaceholderOpNode::attrs) - .def_ro("value", &RXPlaceholderOpNode::value) - .def_ro("shape", &RXPlaceholderOpNode::shape) - .def_ro("dtype", &RXPlaceholderOpNode::dtype); + .def_ro("value", &RXPlaceholderOpNode::value); } // FFI system configuration for structural equality and hashing From 6896ffc817acc166bd4c21f165bcb50bef7b9977 Mon Sep 17 00:00:00 2001 From: tqchen Date: Thu, 26 Mar 2026 18:39:28 +0000 Subject: [PATCH 4/6] [FFI] Fix remaining slots issues: DataflowBlockRewrite and weakref --- python/tvm/relax/binding_rewrite.py | 2 ++ python/tvm/s_tir/meta_schedule/utils.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py index 01287893689e..0198befd33e1 100644 --- a/python/tvm/relax/binding_rewrite.py +++ b/python/tvm/relax/binding_rewrite.py @@ -38,6 +38,8 @@ class DataflowBlockRewrite(Object): use mutate_irmodule which rewrites the old function that registered in the constructor. """ + __slots__ = ("__dict__",) + def __init__(self, dfb: DataflowBlock, root_fn: Function): """ Construct a rewriter with the DataflowBlock to rewrite and its root function. diff --git a/python/tvm/s_tir/meta_schedule/utils.py b/python/tvm/s_tir/meta_schedule/utils.py index 834ebb7ef39d..2460a6cc265d 100644 --- a/python/tvm/s_tir/meta_schedule/utils.py +++ b/python/tvm/s_tir/meta_schedule/utils.py @@ -109,7 +109,7 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" - __slots__ = ("__dict__",) + __slots__ = ("__dict__", "__weakref__",) _cls = cls _type = "TVMDerivedObject" From 93e58d7ec403981b5593a0e2d229e06491929d33 Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 27 Mar 2026 02:36:27 +0000 Subject: [PATCH 5/6] [FIX] Register __ffi_repr__ for Target, AccessPath, AccessStep New tvm-ffi uses ffi.ReprPrint (dataclass repr) for CObject.__repr__ instead of TVM's ReprPrinter. This caused: - Target: str(target) returned dataclass repr instead of JSON - AccessPath: structural equality error messages showed verbose repr - PassInfo: pass name format changed in instrument output - ExprStmtDoc: __dict__ not available with __slots__ enforcement Fix by registering __ffi_repr__ TypeAttr for Target (returns JSON via TargetNode::str()), AccessPath and AccessStep (returns concise path format). Update tests for new PassInfo format and __slots__. --- src/node/repr_printer.cc | 65 +++++++++++++++++++ src/target/target.cc | 8 ++- tests/python/ir/test_pass_instrument.py | 4 +- .../tvmscript/test_tvmscript_printer_doc.py | 4 +- 4 files changed, 76 insertions(+), 5 deletions(-) diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index b60583c6ab85..dd9b31d3afa8 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -22,11 +22,15 @@ * \file node/repr_printer.cc */ #include +#include #include #include #include #include +#include +#include + #include "../support/str_escape.h" namespace tvm { @@ -127,6 +131,53 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << Downcast(node); }); +namespace { +/*! + * \brief Format an AccessStep as a concise string fragment. + */ +void FormatAccessStep(std::ostringstream& os, const ffi::reflection::AccessStep& step) { + using ffi::reflection::AccessKind; + switch (step->kind) { + case AccessKind::kAttr: + os << "." << step->key.cast(); + break; + case AccessKind::kArrayItem: + os << "[" << step->key.cast() << "]"; + break; + case AccessKind::kMapItem: + os << "{" << step->key.cast() << "}"; + break; + case AccessKind::kAttrMissing: + os << "." << step->key.cast() << "?"; + break; + case AccessKind::kArrayItemMissing: + os << "[" << step->key.cast() << "]?"; + break; + case AccessKind::kMapItemMissing: + os << "{" << step->key.cast() << "}?"; + break; + } +} + +/*! + * \brief Format an AccessPath as ".field[idx]". + */ +ffi::String FormatAccessPath(const ffi::reflection::AccessPath& path) { + std::vector steps; + const ffi::reflection::AccessPathObj* cur = path.get(); + while (cur->step.defined()) { + steps.push_back(cur->step.value()); + cur = static_cast(cur->parent.get()); + } + std::ostringstream os; + os << ""; + for (auto it = steps.rbegin(); it != steps.rend(); ++it) { + FormatAccessStep(os, *it); + } + return os.str(); +} +} // namespace + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("node.AsRepr", [](ffi::Any obj) { @@ -134,5 +185,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { os << obj; return os.str(); }); + // Register __ffi_repr__ for AccessPath/AccessStep so that ffi.ReprPrint + // uses the concise ".field[idx]" format instead of the dataclass repr. + refl::TypeAttrDef().def( + refl::type_attr::kRepr, + [](ffi::reflection::AccessPath path, ffi::Function) -> ffi::String { + return FormatAccessPath(path); + }); + refl::TypeAttrDef().def( + refl::type_attr::kRepr, + [](ffi::reflection::AccessStep step, ffi::Function) -> ffi::String { + std::ostringstream os; + FormatAccessStep(os, step); + return os.str(); + }); } } // namespace tvm diff --git a/src/target/target.cc b/src/target/target.cc index 4f1b4e3af2cb..2c093ee73fd1 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -479,7 +479,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { return (*it).second; } return Any(); - }); + }) + .def("target.TargetAsJSON", + [](const Target& target) -> ffi::String { return target->str(); }); + // Register __ffi_repr__ so that ffi.ReprPrint uses JSON format for Target + refl::TypeAttrDef().def( + refl::type_attr::kRepr, + [](Target target, ffi::Function) -> ffi::String { return target->str(); }); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/tests/python/ir/test_pass_instrument.py b/tests/python/ir/test_pass_instrument.py index 5814cd59b942..aca226e4e41a 100644 --- a/tests/python/ir/test_pass_instrument.py +++ b/tests/python/ir/test_pass_instrument.py @@ -42,7 +42,7 @@ def func(a: T.handle, b: T.handle) -> None: all_passes_output = capsys.readouterr().out assert "Before Running Pass:" in all_passes_output assert "After Running Pass:" in all_passes_output - assert "pass name: tirx." in all_passes_output + assert 'name="tirx.' in all_passes_output def test_relax_print_all_passes(capsys): @@ -60,4 +60,4 @@ def func(x: R.Tensor((16,), "float32"), y: R.Tensor((16,), "float32")): all_passes_output = capsys.readouterr().out assert "Before Running Pass:" in all_passes_output assert "After Running Pass:" in all_passes_output - assert "pass name: _pipeline" in all_passes_output + assert 'name="_pipeline"' in all_passes_output diff --git a/tests/python/tvmscript/test_tvmscript_printer_doc.py b/tests/python/tvmscript/test_tvmscript_printer_doc.py index e3fe8a87455e..8e9c7f0a74d8 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_doc.py +++ b/tests/python/tvmscript/test_tvmscript_printer_doc.py @@ -539,8 +539,8 @@ def test_stmt_doc_comment(): comment = "test comment" doc.comment = comment # Make sure the previous statement doesn't set attribute - # as if it's an ordinary Python object. - assert "comment" not in doc.__dict__ + # as if it's an ordinary Python object (__slots__ enforces this). + assert not hasattr(doc, "__dict__") or "comment" not in doc.__dict__ assert doc.comment == comment From 7a267e5cd16ff46b708a402b596a3d1e825c386a Mon Sep 17 00:00:00 2001 From: tqchen Date: Fri, 27 Mar 2026 12:12:52 +0000 Subject: [PATCH 6/6] [FIX] Register __ffi_repr__ for repr-sensitive types New tvm-ffi uses ffi.ReprPrint (dataclass repr) for CObject.__repr__ instead of TVM's ReprPrinter. Register __ffi_repr__ for types that need custom repr: - Target: returns JSON via TargetNode::str() - AccessPath/AccessStep: returns concise ".field[idx]" format - Trace: returns Python script format via AsPython() - GlobalVar: returns I.GlobalVar("name") - Var/SizeVar: returns the variable name Also fix structural_equal.cc to use << (which now goes through the fixed ReprPrinter dispatch) and update tests for new PassInfo format and __slots__ enforcement. --- src/ir/expr.cc | 4 ++ src/node/repr_printer.cc | 29 ++++++----- src/s_tir/schedule/instruction.cc | 83 +++++++++++++++++-------------- src/s_tir/schedule/trace.cc | 46 ++++++++++------- src/tirx/ir/expr.cc | 9 ++++ 5 files changed, 104 insertions(+), 67 deletions(-) diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 4acb0507343e..f9d6e0fc6080 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -225,6 +225,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { ss << ref; return ss.str(); }); + refl::TypeAttrDef().def( + refl::type_attr::kRepr, [](GlobalVar gvar, ffi::Function) -> ffi::String { + return "I.GlobalVar(\"" + std::string(gvar->name_hint) + "\")"; + }); } } // namespace tvm diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index dd9b31d3afa8..142ad74eb5e9 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -121,22 +121,15 @@ void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << Downcast(node); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << Downcast(node); - }); - namespace { /*! * \brief Format an AccessStep as a concise string fragment. + * + * For map keys, uses ffi.ReprPrint which dispatches to __ffi_repr__. */ void FormatAccessStep(std::ostringstream& os, const ffi::reflection::AccessStep& step) { using ffi::reflection::AccessKind; + static const ffi::Function repr_fn = ffi::Function::GetGlobal("ffi.ReprPrint").value(); switch (step->kind) { case AccessKind::kAttr: os << "." << step->key.cast(); @@ -145,7 +138,7 @@ void FormatAccessStep(std::ostringstream& os, const ffi::reflection::AccessStep& os << "[" << step->key.cast() << "]"; break; case AccessKind::kMapItem: - os << "{" << step->key.cast() << "}"; + os << "[" << repr_fn(step->key).cast() << "]"; break; case AccessKind::kAttrMissing: os << "." << step->key.cast() << "?"; @@ -154,7 +147,7 @@ void FormatAccessStep(std::ostringstream& os, const ffi::reflection::AccessStep& os << "[" << step->key.cast() << "]?"; break; case AccessKind::kMapItemMissing: - os << "{" << step->key.cast() << "}?"; + os << "[" << repr_fn(step->key).cast() << "]?"; break; } } @@ -178,6 +171,18 @@ ffi::String FormatAccessPath(const ffi::reflection::AccessPath& path) { } } // namespace +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + p->stream << FormatAccessPath(Downcast(node)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + std::ostringstream os; + FormatAccessStep(os, Downcast(node)); + p->stream << os.str(); + }); + TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("node.AsRepr", [](ffi::Any obj) { diff --git a/src/s_tir/schedule/instruction.cc b/src/s_tir/schedule/instruction.cc index 29fe3c2d88dd..ef635462ab25 100644 --- a/src/s_tir/schedule/instruction.cc +++ b/src/s_tir/schedule/instruction.cc @@ -63,46 +63,49 @@ InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const ffi::Strin /**************** Repr ****************/ +namespace { +ffi::String InstructionAsPythonRepr(const InstructionNode* self) { + ffi::Array inputs; + inputs.reserve(self->inputs.size()); + for (const Any& obj : self->inputs) { + if (obj == nullptr) { + inputs.push_back(ffi::String("None")); + } else if (auto opt_str = obj.as()) { + inputs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); + } else if (obj.as() || obj.as()) { + inputs.push_back(ffi::String("_")); + } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { + inputs.push_back(obj); + } else if (obj.as() || obj.as()) { + inputs.push_back(obj); + } else if (const auto* expr = obj.as()) { + PrimExpr new_expr = Substitute( + ffi::GetRef(expr), [](const Var& var) -> ffi::Optional { + ObjectPtr new_var = ffi::make_object(*var.get()); + new_var->name_hint = "_"; + return Var(new_var); + }); + std::ostringstream os; + os << new_expr; + inputs.push_back(ffi::String(os.str())); + } else if (obj.as()) { + inputs.push_back(obj); + } else { + TVM_FFI_THROW(TypeError) << "Stringifying is not supported for type: " << obj.GetTypeKey(); + throw; + } + } + return self->kind->f_as_python( + /*inputs=*/inputs, + /*attrs=*/self->attrs, + /*decision=*/Any(nullptr), + /*outputs=*/ffi::Array(self->outputs.size(), ffi::String("_"))); +} +} // namespace + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { - const auto* self = obj.as(); - TVM_FFI_ICHECK_NOTNULL(self); - ffi::Array inputs; - inputs.reserve(self->inputs.size()); - for (const Any& obj : self->inputs) { - if (obj == nullptr) { - inputs.push_back(ffi::String("None")); - } else if (auto opt_str = obj.as()) { - inputs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); - } else if (obj.as() || obj.as()) { - inputs.push_back(ffi::String("_")); - } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { - inputs.push_back(obj); - } else if (obj.as() || obj.as()) { - inputs.push_back(obj); - } else if (const auto* expr = obj.as()) { - PrimExpr new_expr = Substitute( - ffi::GetRef(expr), [](const Var& var) -> ffi::Optional { - ObjectPtr new_var = ffi::make_object(*var.get()); - new_var->name_hint = "_"; - return Var(new_var); - }); - std::ostringstream os; - os << new_expr; - inputs.push_back(ffi::String(os.str())); - } else if (obj.as()) { - inputs.push_back(obj); - } else { - TVM_FFI_THROW(TypeError) - << "Stringifying is not supported for type: " << obj.GetTypeKey(); - throw; - } - } - p->stream << self->kind->f_as_python( - /*inputs=*/inputs, - /*attrs=*/self->attrs, - /*decision=*/Any(nullptr), - /*outputs=*/ffi::Array(self->outputs.size(), ffi::String("_"))); + p->stream << InstructionAsPythonRepr(obj.as()); }); /**************** FFI ****************/ @@ -116,6 +119,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { ffi::Array outputs) -> Instruction { return Instruction(kind, inputs, attrs, outputs); }); + refl::TypeAttrDef().def( + refl::type_attr::kRepr, [](Instruction inst, ffi::Function) -> ffi::String { + return InstructionAsPythonRepr(inst.get()); + }); } } // namespace s_tir diff --git a/src/s_tir/schedule/trace.cc b/src/s_tir/schedule/trace.cc index a63fb15f64a8..17b169b85749 100644 --- a/src/s_tir/schedule/trace.cc +++ b/src/s_tir/schedule/trace.cc @@ -18,6 +18,8 @@ */ #include +#include + #include "./utils.h" namespace tvm { @@ -522,25 +524,31 @@ Trace TraceNode::Simplified(bool remove_postproc) const { /**************** Repr ****************/ +namespace { +ffi::String TraceAsPythonRepr(const TraceNode* self) { + std::ostringstream os; + os << "# from tvm import s_tir\n"; + os << "def apply_trace(sch: s_tir.Schedule) -> None:\n"; + ffi::Array repr = self->AsPython(/*remove_postproc=*/false); + bool is_first = true; + for (const ffi::String& line : repr) { + if (is_first) { + is_first = false; + } else { + os << '\n'; + } + os << " " << std::string(line); + } + if (is_first) { + os << " pass"; + } + return os.str(); +} +} // namespace + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { - const auto* self = obj.as(); - TVM_FFI_ICHECK_NOTNULL(self); - p->stream << "# from tvm import s_tir\n"; - p->stream << "def apply_trace(sch: s_tir.Schedule) -> None:\n"; - ffi::Array repr = self->AsPython(/*remove_postproc=*/false); - bool is_first = true; - for (const ffi::String& line : repr) { - if (is_first) { - is_first = false; - } else { - p->stream << '\n'; - } - p->stream << " " << line; - } - if (is_first) { - p->stream << " pass"; - } + p->stream << TraceAsPythonRepr(obj.as()); p->stream << std::flush; }); @@ -594,6 +602,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("s_tir.schedule.TraceWithDecision", &TraceNode::WithDecision) .def_method("s_tir.schedule.TraceSimplified", &TraceNode::Simplified) .def("s_tir.schedule.TraceApplyJSONToSchedule", Trace::ApplyJSONToSchedule); + // Register __ffi_repr__ so str(trace) returns the Python script format + refl::TypeAttrDef().def( + refl::type_attr::kRepr, + [](Trace trace, ffi::Function) -> ffi::String { return TraceAsPythonRepr(trace.get()); }); } } // namespace s_tir diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc index 8b841e18dc45..f4130e70d6c7 100644 --- a/src/tirx/ir/expr.cc +++ b/src/tirx/ir/expr.cc @@ -84,6 +84,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tirx.convert", [](ffi::Variant> expr) { return expr; }); + // Register __ffi_repr__ for Var/SizeVar so repr shows just the name + refl::TypeAttrDef().def(refl::type_attr::kRepr, + [](Var var, ffi::Function) -> ffi::String { + return std::string(var->name_hint); + }); + refl::TypeAttrDef().def(refl::type_attr::kRepr, + [](SizeVar var, ffi::Function) -> ffi::String { + return std::string(var->name_hint); + }); } #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \