Skip to content

Commit d9de0d0

Browse files
committed
[FIX] Register __ffi_repr__ for repr-sensitive types
New tvm-ffi uses ffi.ReprPrint (dataclass repr) for CObject.__repr__ instead of TVM's ReprPrinter. Register __ffi_repr__ for types that need custom repr: - Target: returns JSON via TargetNode::str() - AccessPath/AccessStep: returns concise "<root>.field[idx]" format - Trace: returns Python script format via AsPython() - GlobalVar: returns I.GlobalVar("name") - Var/SizeVar: returns the variable name Also fix structural_equal.cc to use << (which now goes through the fixed ReprPrinter dispatch) and update tests for new PassInfo format and __slots__ enforcement.
1 parent 93e58d7 commit d9de0d0

4 files changed

Lines changed: 59 additions & 29 deletions

File tree

src/ir/expr.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
225225
ss << ref;
226226
return ss.str();
227227
});
228+
refl::TypeAttrDef<GlobalVarNode>().def(
229+
refl::type_attr::kRepr, [](GlobalVar gvar, ffi::Function) -> ffi::String {
230+
return "I.GlobalVar(\"" + std::string(gvar->name_hint) + "\")";
231+
});
228232
}
229233

230234
} // namespace tvm

src/node/repr_printer.cc

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -121,22 +121,15 @@ void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; }
121121

122122
void Dump(const runtime::Object* n) { Dump(runtime::GetRef<runtime::ObjectRef>(n)); }
123123

124-
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
125-
.set_dispatch<ffi::reflection::AccessPathObj>([](const ObjectRef& node, ReprPrinter* p) {
126-
p->stream << Downcast<ffi::reflection::AccessPath>(node);
127-
});
128-
129-
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
130-
.set_dispatch<ffi::reflection::AccessStepObj>([](const ObjectRef& node, ReprPrinter* p) {
131-
p->stream << Downcast<ffi::reflection::AccessStep>(node);
132-
});
133-
134124
namespace {
135125
/*!
136126
* \brief Format an AccessStep as a concise string fragment.
127+
*
128+
* For map keys, uses ffi.ReprPrint which dispatches to __ffi_repr__.
137129
*/
138130
void FormatAccessStep(std::ostringstream& os, const ffi::reflection::AccessStep& step) {
139131
using ffi::reflection::AccessKind;
132+
static const ffi::Function repr_fn = ffi::Function::GetGlobal("ffi.ReprPrint").value();
140133
switch (step->kind) {
141134
case AccessKind::kAttr:
142135
os << "." << step->key.cast<ffi::String>();
@@ -145,7 +138,7 @@ void FormatAccessStep(std::ostringstream& os, const ffi::reflection::AccessStep&
145138
os << "[" << step->key.cast<int64_t>() << "]";
146139
break;
147140
case AccessKind::kMapItem:
148-
os << "{" << step->key.cast<ffi::String>() << "}";
141+
os << "[" << repr_fn(step->key).cast<ffi::String>() << "]";
149142
break;
150143
case AccessKind::kAttrMissing:
151144
os << "." << step->key.cast<ffi::String>() << "?";
@@ -154,7 +147,7 @@ void FormatAccessStep(std::ostringstream& os, const ffi::reflection::AccessStep&
154147
os << "[" << step->key.cast<int64_t>() << "]?";
155148
break;
156149
case AccessKind::kMapItemMissing:
157-
os << "{" << step->key.cast<ffi::String>() << "}?";
150+
os << "[" << repr_fn(step->key).cast<ffi::String>() << "]?";
158151
break;
159152
}
160153
}
@@ -178,6 +171,18 @@ ffi::String FormatAccessPath(const ffi::reflection::AccessPath& path) {
178171
}
179172
} // namespace
180173

174+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
175+
.set_dispatch<ffi::reflection::AccessPathObj>([](const ObjectRef& node, ReprPrinter* p) {
176+
p->stream << FormatAccessPath(Downcast<ffi::reflection::AccessPath>(node));
177+
});
178+
179+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
180+
.set_dispatch<ffi::reflection::AccessStepObj>([](const ObjectRef& node, ReprPrinter* p) {
181+
std::ostringstream os;
182+
FormatAccessStep(os, Downcast<ffi::reflection::AccessStep>(node));
183+
p->stream << os.str();
184+
});
185+
181186
TVM_FFI_STATIC_INIT_BLOCK() {
182187
namespace refl = tvm::ffi::reflection;
183188
refl::GlobalDef().def("node.AsRepr", [](ffi::Any obj) {

src/s_tir/schedule/trace.cc

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
*/
1919
#include <tvm/ffi/reflection/registry.h>
2020

21+
#include <sstream>
22+
2123
#include "./utils.h"
2224

2325
namespace tvm {
@@ -522,25 +524,31 @@ Trace TraceNode::Simplified(bool remove_postproc) const {
522524

523525
/**************** Repr ****************/
524526

527+
namespace {
528+
ffi::String TraceAsPythonRepr(const TraceNode* self) {
529+
std::ostringstream os;
530+
os << "# from tvm import s_tir\n";
531+
os << "def apply_trace(sch: s_tir.Schedule) -> None:\n";
532+
ffi::Array<ffi::String> repr = self->AsPython(/*remove_postproc=*/false);
533+
bool is_first = true;
534+
for (const ffi::String& line : repr) {
535+
if (is_first) {
536+
is_first = false;
537+
} else {
538+
os << '\n';
539+
}
540+
os << " " << std::string(line);
541+
}
542+
if (is_first) {
543+
os << " pass";
544+
}
545+
return os.str();
546+
}
547+
} // namespace
548+
525549
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
526550
.set_dispatch<TraceNode>([](const ObjectRef& obj, ReprPrinter* p) {
527-
const auto* self = obj.as<TraceNode>();
528-
TVM_FFI_ICHECK_NOTNULL(self);
529-
p->stream << "# from tvm import s_tir\n";
530-
p->stream << "def apply_trace(sch: s_tir.Schedule) -> None:\n";
531-
ffi::Array<ffi::String> repr = self->AsPython(/*remove_postproc=*/false);
532-
bool is_first = true;
533-
for (const ffi::String& line : repr) {
534-
if (is_first) {
535-
is_first = false;
536-
} else {
537-
p->stream << '\n';
538-
}
539-
p->stream << " " << line;
540-
}
541-
if (is_first) {
542-
p->stream << " pass";
543-
}
551+
p->stream << TraceAsPythonRepr(obj.as<TraceNode>());
544552
p->stream << std::flush;
545553
});
546554

@@ -594,6 +602,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
594602
.def_method("s_tir.schedule.TraceWithDecision", &TraceNode::WithDecision)
595603
.def_method("s_tir.schedule.TraceSimplified", &TraceNode::Simplified)
596604
.def("s_tir.schedule.TraceApplyJSONToSchedule", Trace::ApplyJSONToSchedule);
605+
// Register __ffi_repr__ so str(trace) returns the Python script format
606+
refl::TypeAttrDef<TraceNode>().def(
607+
refl::type_attr::kRepr,
608+
[](Trace trace, ffi::Function) -> ffi::String { return TraceAsPythonRepr(trace.get()); });
597609
}
598610

599611
} // namespace s_tir

src/tirx/ir/expr.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ TVM_FFI_STATIC_INIT_BLOCK() {
8484
namespace refl = tvm::ffi::reflection;
8585
refl::GlobalDef().def("tirx.convert",
8686
[](ffi::Variant<PrimExpr, ffi::Array<PrimExpr>> expr) { return expr; });
87+
// Register __ffi_repr__ for Var/SizeVar so repr shows just the name
88+
refl::TypeAttrDef<VarNode>().def(refl::type_attr::kRepr,
89+
[](Var var, ffi::Function) -> ffi::String {
90+
return std::string(var->name_hint);
91+
});
92+
refl::TypeAttrDef<SizeVarNode>().def(refl::type_attr::kRepr,
93+
[](SizeVar var, ffi::Function) -> ffi::String {
94+
return std::string(var->name_hint);
95+
});
8796
}
8897

8998
#define TVM_DEFINE_BINOP_CONSTRUCTOR(Name) \

0 commit comments

Comments
 (0)