diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index c85fd42df6ea..63224e3f1e46 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit c85fd42df6eae4ae0ec1aaa4ebb67ac859758cf5 +Subproject commit 63224e3f1e464cc62307223787926a48fc8df8c0 diff --git a/include/tvm/script/ir_builder/relax/frame.h b/include/tvm/script/ir_builder/relax/frame.h index 898e318950cc..3d4ca3267461 100644 --- a/include/tvm/script/ir_builder/relax/frame.h +++ b/include/tvm/script/ir_builder/relax/frame.h @@ -125,9 +125,8 @@ class FunctionFrameNode : public SeqExprFrameNode { .def_ro("params", &FunctionFrameNode::params) .def_ro("ret_struct_info", &FunctionFrameNode::ret_struct_info) .def_ro("is_pure", &FunctionFrameNode::is_pure) - .def_ro("attrs", &FunctionFrameNode::attrs) - .def_ro("binding_blocks", &FunctionFrameNode::binding_blocks) - .def_ro("output", &FunctionFrameNode::output); + .def_ro("attrs", &FunctionFrameNode::attrs); + // `binding_blocks` and `output` are inherited from SeqExprFrameNode. // `block_builder` is not registered as it's not visited. } TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.FunctionFrame", FunctionFrameNode, diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py index e466ef850043..15629dcbe61f 100644 --- a/python/tvm/ir/base.py +++ b/python/tvm/ir/base.py @@ -27,6 +27,14 @@ class Node(Object): """Base class of all IR Nodes.""" + def __repr__(self) -> str: + from tvm.runtime.script_printer import _script # noqa: PLC0415 + + try: + return _script(self, None) + except Exception: # noqa: BLE001 + return super().__repr__() + @register_object("ir.SourceMap") class SourceMap(Object): diff --git a/python/tvm/ir/transform.py b/python/tvm/ir/transform.py index 28c187d7deb8..3e22a2b9084e 100644 --- a/python/tvm/ir/transform.py +++ b/python/tvm/ir/transform.py @@ -147,6 +147,8 @@ class Pass(tvm.runtime.Object): conveniently interact with the base class. """ + __slots__ = ("__dict__",) + @property def info(self): """Get the pass meta.""" diff --git a/python/tvm/relax/binding_rewrite.py b/python/tvm/relax/binding_rewrite.py index 01287893689e..0198befd33e1 100644 --- a/python/tvm/relax/binding_rewrite.py +++ b/python/tvm/relax/binding_rewrite.py @@ -38,6 +38,8 @@ class DataflowBlockRewrite(Object): use mutate_irmodule which rewrites the old function that registered in the constructor. """ + __slots__ = ("__dict__",) + def __init__(self, dfb: DataflowBlock, root_fn: Function): """ Construct a rewriter with the DataflowBlock to rewrite and its root function. diff --git a/python/tvm/relax/block_builder.py b/python/tvm/relax/block_builder.py index 1d6057cec500..13bea3180c29 100644 --- a/python/tvm/relax/block_builder.py +++ b/python/tvm/relax/block_builder.py @@ -153,6 +153,8 @@ class BlockBuilder(Object): mod = bb.get() """ + __slots__ = ("__dict__",) + _stack = [] @staticmethod diff --git a/python/tvm/runtime/support.py b/python/tvm/runtime/support.py index d40f5c263618..b0ac67176328 100644 --- a/python/tvm/runtime/support.py +++ b/python/tvm/runtime/support.py @@ -149,6 +149,8 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" + __slots__ = ("__dict__", "__weakref__",) + _cls = cls _type = "TVMDerivedObject" diff --git a/python/tvm/s_tir/meta_schedule/utils.py b/python/tvm/s_tir/meta_schedule/utils.py index b4cd2f4009df..2460a6cc265d 100644 --- a/python/tvm/s_tir/meta_schedule/utils.py +++ b/python/tvm/s_tir/meta_schedule/utils.py @@ -109,6 +109,8 @@ def method(*args, **kwargs): class TVMDerivedObject(metadata["cls"]): # type: ignore """The derived object to avoid cyclic dependency.""" + __slots__ = ("__dict__", "__weakref__",) + _cls = cls _type = "TVMDerivedObject" diff --git a/python/tvm/tirx/build.py b/python/tvm/tirx/build.py index d310afee7938..020730d2f9de 100644 --- a/python/tvm/tirx/build.py +++ b/python/tvm/tirx/build.py @@ -100,7 +100,7 @@ def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.han def is_host_func(f): target = f.attrs.get("target", tvm.target.Target("llvm")) - return str(target.kind) in ["llvm", "c"] + return target.kind.name in ["llvm", "c"] host_mod = tvm.tirx.transform.Filter(is_host_func)(mod) device_mod = tvm.tirx.transform.Filter(lambda f: not is_host_func(f))(mod) diff --git a/src/ir/expr.cc b/src/ir/expr.cc index 4acb0507343e..f9d6e0fc6080 100644 --- a/src/ir/expr.cc +++ b/src/ir/expr.cc @@ -225,6 +225,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { ss << ref; return ss.str(); }); + refl::TypeAttrDef().def( + refl::type_attr::kRepr, [](GlobalVar gvar, ffi::Function) -> ffi::String { + return "I.GlobalVar(\"" + std::string(gvar->name_hint) + "\")"; + }); } } // namespace tvm diff --git a/src/node/repr_printer.cc b/src/node/repr_printer.cc index b60583c6ab85..142ad74eb5e9 100644 --- a/src/node/repr_printer.cc +++ b/src/node/repr_printer.cc @@ -22,11 +22,15 @@ * \file node/repr_printer.cc */ #include +#include #include #include #include #include +#include +#include + #include "../support/str_escape.h" namespace tvm { @@ -117,14 +121,66 @@ void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; } void Dump(const runtime::Object* n) { Dump(runtime::GetRef(n)); } +namespace { +/*! + * \brief Format an AccessStep as a concise string fragment. + * + * For map keys, uses ffi.ReprPrint which dispatches to __ffi_repr__. + */ +void FormatAccessStep(std::ostringstream& os, const ffi::reflection::AccessStep& step) { + using ffi::reflection::AccessKind; + static const ffi::Function repr_fn = ffi::Function::GetGlobal("ffi.ReprPrint").value(); + switch (step->kind) { + case AccessKind::kAttr: + os << "." << step->key.cast(); + break; + case AccessKind::kArrayItem: + os << "[" << step->key.cast() << "]"; + break; + case AccessKind::kMapItem: + os << "[" << repr_fn(step->key).cast() << "]"; + break; + case AccessKind::kAttrMissing: + os << "." << step->key.cast() << "?"; + break; + case AccessKind::kArrayItemMissing: + os << "[" << step->key.cast() << "]?"; + break; + case AccessKind::kMapItemMissing: + os << "[" << repr_fn(step->key).cast() << "]?"; + break; + } +} + +/*! + * \brief Format an AccessPath as ".field[idx]". + */ +ffi::String FormatAccessPath(const ffi::reflection::AccessPath& path) { + std::vector steps; + const ffi::reflection::AccessPathObj* cur = path.get(); + while (cur->step.defined()) { + steps.push_back(cur->step.value()); + cur = static_cast(cur->parent.get()); + } + std::ostringstream os; + os << ""; + for (auto it = steps.rbegin(); it != steps.rend(); ++it) { + FormatAccessStep(os, *it); + } + return os.str(); +} +} // namespace + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << Downcast(node); + p->stream << FormatAccessPath(Downcast(node)); }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - p->stream << Downcast(node); + std::ostringstream os; + FormatAccessStep(os, Downcast(node)); + p->stream << os.str(); }); TVM_FFI_STATIC_INIT_BLOCK() { @@ -134,5 +190,19 @@ TVM_FFI_STATIC_INIT_BLOCK() { os << obj; return os.str(); }); + // Register __ffi_repr__ for AccessPath/AccessStep so that ffi.ReprPrint + // uses the concise ".field[idx]" format instead of the dataclass repr. + refl::TypeAttrDef().def( + refl::type_attr::kRepr, + [](ffi::reflection::AccessPath path, ffi::Function) -> ffi::String { + return FormatAccessPath(path); + }); + refl::TypeAttrDef().def( + refl::type_attr::kRepr, + [](ffi::reflection::AccessStep step, ffi::Function) -> ffi::String { + std::ostringstream os; + FormatAccessStep(os, step); + return os.str(); + }); } } // namespace tvm diff --git a/src/relax/ir/emit_te.h b/src/relax/ir/emit_te.h index d7c998729a10..31b4cc292762 100644 --- a/src/relax/ir/emit_te.h +++ b/src/relax/ir/emit_te.h @@ -44,12 +44,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode { static void RegisterReflection() { namespace refl = tvm::ffi::reflection; refl::ObjectDef() - .def_ro("name", &RXPlaceholderOpNode::name) - .def_ro("tag", &RXPlaceholderOpNode::tag) - .def_ro("attrs", &RXPlaceholderOpNode::attrs) - .def_ro("value", &RXPlaceholderOpNode::value) - .def_ro("shape", &RXPlaceholderOpNode::shape) - .def_ro("dtype", &RXPlaceholderOpNode::dtype); + .def_ro("value", &RXPlaceholderOpNode::value); } // FFI system configuration for structural equality and hashing diff --git a/src/s_tir/schedule/instruction.cc b/src/s_tir/schedule/instruction.cc index 29fe3c2d88dd..ef635462ab25 100644 --- a/src/s_tir/schedule/instruction.cc +++ b/src/s_tir/schedule/instruction.cc @@ -63,46 +63,49 @@ InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const ffi::Strin /**************** Repr ****************/ +namespace { +ffi::String InstructionAsPythonRepr(const InstructionNode* self) { + ffi::Array inputs; + inputs.reserve(self->inputs.size()); + for (const Any& obj : self->inputs) { + if (obj == nullptr) { + inputs.push_back(ffi::String("None")); + } else if (auto opt_str = obj.as()) { + inputs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); + } else if (obj.as() || obj.as()) { + inputs.push_back(ffi::String("_")); + } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { + inputs.push_back(obj); + } else if (obj.as() || obj.as()) { + inputs.push_back(obj); + } else if (const auto* expr = obj.as()) { + PrimExpr new_expr = Substitute( + ffi::GetRef(expr), [](const Var& var) -> ffi::Optional { + ObjectPtr new_var = ffi::make_object(*var.get()); + new_var->name_hint = "_"; + return Var(new_var); + }); + std::ostringstream os; + os << new_expr; + inputs.push_back(ffi::String(os.str())); + } else if (obj.as()) { + inputs.push_back(obj); + } else { + TVM_FFI_THROW(TypeError) << "Stringifying is not supported for type: " << obj.GetTypeKey(); + throw; + } + } + return self->kind->f_as_python( + /*inputs=*/inputs, + /*attrs=*/self->attrs, + /*decision=*/Any(nullptr), + /*outputs=*/ffi::Array(self->outputs.size(), ffi::String("_"))); +} +} // namespace + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { - const auto* self = obj.as(); - TVM_FFI_ICHECK_NOTNULL(self); - ffi::Array inputs; - inputs.reserve(self->inputs.size()); - for (const Any& obj : self->inputs) { - if (obj == nullptr) { - inputs.push_back(ffi::String("None")); - } else if (auto opt_str = obj.as()) { - inputs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"')); - } else if (obj.as() || obj.as()) { - inputs.push_back(ffi::String("_")); - } else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) { - inputs.push_back(obj); - } else if (obj.as() || obj.as()) { - inputs.push_back(obj); - } else if (const auto* expr = obj.as()) { - PrimExpr new_expr = Substitute( - ffi::GetRef(expr), [](const Var& var) -> ffi::Optional { - ObjectPtr new_var = ffi::make_object(*var.get()); - new_var->name_hint = "_"; - return Var(new_var); - }); - std::ostringstream os; - os << new_expr; - inputs.push_back(ffi::String(os.str())); - } else if (obj.as()) { - inputs.push_back(obj); - } else { - TVM_FFI_THROW(TypeError) - << "Stringifying is not supported for type: " << obj.GetTypeKey(); - throw; - } - } - p->stream << self->kind->f_as_python( - /*inputs=*/inputs, - /*attrs=*/self->attrs, - /*decision=*/Any(nullptr), - /*outputs=*/ffi::Array(self->outputs.size(), ffi::String("_"))); + p->stream << InstructionAsPythonRepr(obj.as()); }); /**************** FFI ****************/ @@ -116,6 +119,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { ffi::Array outputs) -> Instruction { return Instruction(kind, inputs, attrs, outputs); }); + refl::TypeAttrDef().def( + refl::type_attr::kRepr, [](Instruction inst, ffi::Function) -> ffi::String { + return InstructionAsPythonRepr(inst.get()); + }); } } // namespace s_tir diff --git a/src/s_tir/schedule/trace.cc b/src/s_tir/schedule/trace.cc index a63fb15f64a8..17b169b85749 100644 --- a/src/s_tir/schedule/trace.cc +++ b/src/s_tir/schedule/trace.cc @@ -18,6 +18,8 @@ */ #include +#include + #include "./utils.h" namespace tvm { @@ -522,25 +524,31 @@ Trace TraceNode::Simplified(bool remove_postproc) const { /**************** Repr ****************/ +namespace { +ffi::String TraceAsPythonRepr(const TraceNode* self) { + std::ostringstream os; + os << "# from tvm import s_tir\n"; + os << "def apply_trace(sch: s_tir.Schedule) -> None:\n"; + ffi::Array repr = self->AsPython(/*remove_postproc=*/false); + bool is_first = true; + for (const ffi::String& line : repr) { + if (is_first) { + is_first = false; + } else { + os << '\n'; + } + os << " " << std::string(line); + } + if (is_first) { + os << " pass"; + } + return os.str(); +} +} // namespace + TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch([](const ObjectRef& obj, ReprPrinter* p) { - const auto* self = obj.as(); - TVM_FFI_ICHECK_NOTNULL(self); - p->stream << "# from tvm import s_tir\n"; - p->stream << "def apply_trace(sch: s_tir.Schedule) -> None:\n"; - ffi::Array repr = self->AsPython(/*remove_postproc=*/false); - bool is_first = true; - for (const ffi::String& line : repr) { - if (is_first) { - is_first = false; - } else { - p->stream << '\n'; - } - p->stream << " " << line; - } - if (is_first) { - p->stream << " pass"; - } + p->stream << TraceAsPythonRepr(obj.as()); p->stream << std::flush; }); @@ -594,6 +602,10 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def_method("s_tir.schedule.TraceWithDecision", &TraceNode::WithDecision) .def_method("s_tir.schedule.TraceSimplified", &TraceNode::Simplified) .def("s_tir.schedule.TraceApplyJSONToSchedule", Trace::ApplyJSONToSchedule); + // Register __ffi_repr__ so str(trace) returns the Python script format + refl::TypeAttrDef().def( + refl::type_attr::kRepr, + [](Trace trace, ffi::Function) -> ffi::String { return TraceAsPythonRepr(trace.get()); }); } } // namespace s_tir diff --git a/src/target/target.cc b/src/target/target.cc index 4f1b4e3af2cb..2c093ee73fd1 100644 --- a/src/target/target.cc +++ b/src/target/target.cc @@ -479,7 +479,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { return (*it).second; } return Any(); - }); + }) + .def("target.TargetAsJSON", + [](const Target& target) -> ffi::String { return target->str(); }); + // Register __ffi_repr__ so that ffi.ReprPrint uses JSON format for Target + refl::TypeAttrDef().def( + refl::type_attr::kRepr, + [](Target target, ffi::Function) -> ffi::String { return target->str(); }); } TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) diff --git a/src/tirx/ir/expr.cc b/src/tirx/ir/expr.cc index 8b841e18dc45..f4130e70d6c7 100644 --- a/src/tirx/ir/expr.cc +++ b/src/tirx/ir/expr.cc @@ -84,6 +84,15 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef().def("tirx.convert", [](ffi::Variant> expr) { return expr; }); + // Register __ffi_repr__ for Var/SizeVar so repr shows just the name + refl::TypeAttrDef().def(refl::type_attr::kRepr, + [](Var var, ffi::Function) -> ffi::String { + return std::string(var->name_hint); + }); + refl::TypeAttrDef().def(refl::type_attr::kRepr, + [](SizeVar var, ffi::Function) -> ffi::String { + return std::string(var->name_hint); + }); } #define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \ diff --git a/tests/python/ir/test_pass_instrument.py b/tests/python/ir/test_pass_instrument.py index 5814cd59b942..aca226e4e41a 100644 --- a/tests/python/ir/test_pass_instrument.py +++ b/tests/python/ir/test_pass_instrument.py @@ -42,7 +42,7 @@ def func(a: T.handle, b: T.handle) -> None: all_passes_output = capsys.readouterr().out assert "Before Running Pass:" in all_passes_output assert "After Running Pass:" in all_passes_output - assert "pass name: tirx." in all_passes_output + assert 'name="tirx.' in all_passes_output def test_relax_print_all_passes(capsys): @@ -60,4 +60,4 @@ def func(x: R.Tensor((16,), "float32"), y: R.Tensor((16,), "float32")): all_passes_output = capsys.readouterr().out assert "Before Running Pass:" in all_passes_output assert "After Running Pass:" in all_passes_output - assert "pass name: _pipeline" in all_passes_output + assert 'name="_pipeline"' in all_passes_output diff --git a/tests/python/tvmscript/test_tvmscript_printer_doc.py b/tests/python/tvmscript/test_tvmscript_printer_doc.py index e3fe8a87455e..8e9c7f0a74d8 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_doc.py +++ b/tests/python/tvmscript/test_tvmscript_printer_doc.py @@ -539,8 +539,8 @@ def test_stmt_doc_comment(): comment = "test comment" doc.comment = comment # Make sure the previous statement doesn't set attribute - # as if it's an ordinary Python object. - assert "comment" not in doc.__dict__ + # as if it's an ordinary Python object (__slots__ enforces this). + assert not hasattr(doc, "__dict__") or "comment" not in doc.__dict__ assert doc.comment == comment