Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm-ffi
Submodule tvm-ffi updated 89 files
+2 −2 .github/actions/build-wheel-for-publish/action.yml
+2 −2 .github/workflows/ci_mainline_only.yml
+4 −4 .github/workflows/ci_test.yml
+3 −3 .github/workflows/torch_c_dlpack.yml
+1 −0 .gitignore
+3 −2 CMakeLists.txt
+59 −0 KEYS
+2 −0 cmake/Utils/Library.cmake
+1 −1 docs/.rstcheck.cfg
+1 −1 docs/concepts/object_and_class.rst
+715 −0 docs/concepts/structural_eq_hash.rst
+7 −1 docs/conf.py
+467 −0 docs/guides/dataclass_reflection.rst
+1 −1 docs/guides/export_func_cls.rst
+2 −0 docs/index.rst
+2 −0 examples/python_packaging/python/my_ffi_extension/_ffi_api.py
+1 −1 examples/python_packaging/run_example.py
+51 −2 include/tvm/ffi/c_api.h
+2 −1 include/tvm/ffi/container/tensor.h
+14 −0 include/tvm/ffi/extra/cuda/cubin_launcher.h
+74 −0 include/tvm/ffi/extra/cuda/internal/unified_api.h
+113 −0 include/tvm/ffi/extra/dataclass.h
+33 −4 include/tvm/ffi/reflection/accessor.h
+64 −13 include/tvm/ffi/reflection/creator.h
+246 −0 include/tvm/ffi/reflection/init.h
+9 −7 include/tvm/ffi/reflection/overload.h
+240 −30 include/tvm/ffi/reflection/registry.h
+1 −0 pyproject.toml
+93 −66 python/tvm_ffi/__init__.py
+16 −0 python/tvm_ffi/_ffi_api.py
+0 −1 python/tvm_ffi/access_path.py
+0 −26 python/tvm_ffi/container.py
+40 −3 python/tvm_ffi/core.pyi
+39 −3 python/tvm_ffi/cython/base.pxi
+4 −0 python/tvm_ffi/cython/core.pyx
+5 −1 python/tvm_ffi/cython/device.pxi
+1 −0 python/tvm_ffi/cython/dtype.pxi
+4 −3 python/tvm_ffi/cython/error.pxi
+24 −24 python/tvm_ffi/cython/function.pxi
+322 −99 python/tvm_ffi/cython/object.pxi
+806 −0 python/tvm_ffi/cython/pyclass_type_converter.pxi
+2 −2 python/tvm_ffi/cython/string.pxi
+7 −5 python/tvm_ffi/cython/tensor.pxi
+20 −7 python/tvm_ffi/cython/tvm_ffi_python_helpers.h
+831 −16 python/tvm_ffi/cython/type_info.pxi
+3 −4 python/tvm_ffi/dataclasses/__init__.py
+0 −210 python/tvm_ffi/dataclasses/_utils.py
+65 −160 python/tvm_ffi/dataclasses/c_class.py
+183 −108 python/tvm_ffi/dataclasses/field.py
+550 −0 python/tvm_ffi/dataclasses/py_class.py
+7 −5 python/tvm_ffi/module.py
+285 −18 python/tvm_ffi/registry.py
+3 −0 python/tvm_ffi/structural.py
+3 −1 python/tvm_ffi/stub/codegen.py
+116 −3 python/tvm_ffi/stub/utils.py
+12 −0 python/tvm_ffi/testing/__init__.py
+326 −19 python/tvm_ffi/testing/testing.py
+9 −3 rust/tvm-ffi-sys/src/c_api.rs
+13 −8 rust/tvm-ffi/src/collections/tensor.rs
+6 −2 rust/tvm-ffi/src/device.rs
+6 −1 rust/tvm-ffi/src/function.rs
+7 −4 rust/tvm-ffi/tests/test_device.rs
+7 −3 rust/tvm-ffi/tests/test_function.rs
+2 −7 src/ffi/container.cc
+2,004 −0 src/ffi/extra/dataclass.cc
+0 −176 src/ffi/extra/deep_copy.cc
+5 −12 src/ffi/extra/reflection_extra.cc
+0 −401 src/ffi/extra/repr_print.cc
+7 −14 src/ffi/extra/serialization.cc
+14 −7 src/ffi/function.cc
+66 −5 src/ffi/object.cc
+12 −20 src/ffi/object_internal.h
+304 −45 src/ffi/testing/testing.cc
+183 −2 tests/cpp/test_reflection.cc
+6 −6 tests/cpp/testing_object.h
+87 −0 tests/python/test_cubin_launcher.py
+319 −0 tests/python/test_dataclass_c_class.py
+1,350 −0 tests/python/test_dataclass_compare.py
+507 −25 tests/python/test_dataclass_copy.py
+1,008 −0 tests/python/test_dataclass_hash.py
+890 −0 tests/python/test_dataclass_init.py
+4,563 −0 tests/python/test_dataclass_py_class.py
+156 −4 tests/python/test_dataclass_repr.py
+0 −151 tests/python/test_dataclasses_c_class.py
+210 −2 tests/python/test_object.py
+5 −5 tests/python/test_serialization.py
+47 −0 tests/python/test_structural.py
+361 −0 tests/python/test_structural_py_class.py
+4,679 −0 tests/python/test_type_converter.py
5 changes: 2 additions & 3 deletions include/tvm/script/ir_builder/relax/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,8 @@ class FunctionFrameNode : public SeqExprFrameNode {
.def_ro("params", &FunctionFrameNode::params)
.def_ro("ret_struct_info", &FunctionFrameNode::ret_struct_info)
.def_ro("is_pure", &FunctionFrameNode::is_pure)
.def_ro("attrs", &FunctionFrameNode::attrs)
.def_ro("binding_blocks", &FunctionFrameNode::binding_blocks)
.def_ro("output", &FunctionFrameNode::output);
.def_ro("attrs", &FunctionFrameNode::attrs);
// `binding_blocks` and `output` are inherited from SeqExprFrameNode.
// `block_builder` is not registered as it's not visited.
}
TVM_FFI_DECLARE_OBJECT_INFO_FINAL("script.ir_builder.relax.FunctionFrame", FunctionFrameNode,
Expand Down
8 changes: 8 additions & 0 deletions python/tvm/ir/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,14 @@
class Node(Object):
"""Base class of all IR Nodes."""

def __repr__(self) -> str:
from tvm.runtime.script_printer import _script # noqa: PLC0415

try:
return _script(self, None)
except Exception: # noqa: BLE001
return super().__repr__()


@register_object("ir.SourceMap")
class SourceMap(Object):
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/ir/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class Pass(tvm.runtime.Object):
conveniently interact with the base class.
"""

__slots__ = ("__dict__",)

@property
def info(self):
"""Get the pass meta."""
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/binding_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class DataflowBlockRewrite(Object):
use mutate_irmodule which rewrites the old function that registered in the constructor.
"""

__slots__ = ("__dict__",)

def __init__(self, dfb: DataflowBlock, root_fn: Function):
"""
Construct a rewriter with the DataflowBlock to rewrite and its root function.
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relax/block_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ class BlockBuilder(Object):
mod = bb.get()
"""

__slots__ = ("__dict__",)

_stack = []

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/runtime/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def method(*args, **kwargs):
class TVMDerivedObject(metadata["cls"]): # type: ignore
"""The derived object to avoid cyclic dependency."""

__slots__ = ("__dict__", "__weakref__",)

_cls = cls
_type = "TVMDerivedObject"

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/s_tir/meta_schedule/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def method(*args, **kwargs):
class TVMDerivedObject(metadata["cls"]): # type: ignore
"""The derived object to avoid cyclic dependency."""

__slots__ = ("__dict__", "__weakref__",)

_cls = cls
_type = "TVMDerivedObject"

Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tirx/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main(self_handle: T.handle, args: T.handle, num_args: T.int32, result: T.han

def is_host_func(f):
target = f.attrs.get("target", tvm.target.Target("llvm"))
return str(target.kind) in ["llvm", "c"]
return target.kind.name in ["llvm", "c"]

host_mod = tvm.tirx.transform.Filter(is_host_func)(mod)
device_mod = tvm.tirx.transform.Filter(lambda f: not is_host_func(f))(mod)
Expand Down
4 changes: 4 additions & 0 deletions src/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
ss << ref;
return ss.str();
});
refl::TypeAttrDef<GlobalVarNode>().def(
refl::type_attr::kRepr, [](GlobalVar gvar, ffi::Function) -> ffi::String {
return "I.GlobalVar(\"" + std::string(gvar->name_hint) + "\")";
});
}

} // namespace tvm
74 changes: 72 additions & 2 deletions src/node/repr_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,15 @@
* \file node/repr_printer.cc
*/
#include <tvm/ffi/function.h>
#include <tvm/ffi/reflection/access_path.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/node/cast.h>
#include <tvm/node/repr_printer.h>
#include <tvm/runtime/device_api.h>

#include <sstream>
#include <vector>

#include "../support/str_escape.h"

namespace tvm {
Expand Down Expand Up @@ -117,14 +121,66 @@ void Dump(const runtime::ObjectRef& n) { std::cerr << n << "\n"; }

void Dump(const runtime::Object* n) { Dump(runtime::GetRef<runtime::ObjectRef>(n)); }

namespace {
/*!
* \brief Format an AccessStep as a concise string fragment.
*
* For map keys, uses ffi.ReprPrint which dispatches to __ffi_repr__.
*/
void FormatAccessStep(std::ostringstream& os, const ffi::reflection::AccessStep& step) {
using ffi::reflection::AccessKind;
static const ffi::Function repr_fn = ffi::Function::GetGlobal("ffi.ReprPrint").value();
switch (step->kind) {
case AccessKind::kAttr:
os << "." << step->key.cast<ffi::String>();
break;
case AccessKind::kArrayItem:
os << "[" << step->key.cast<int64_t>() << "]";
break;
case AccessKind::kMapItem:
os << "[" << repr_fn(step->key).cast<ffi::String>() << "]";
break;
case AccessKind::kAttrMissing:
os << "." << step->key.cast<ffi::String>() << "?";
break;
case AccessKind::kArrayItemMissing:
os << "[" << step->key.cast<int64_t>() << "]?";
break;
case AccessKind::kMapItemMissing:
os << "[" << repr_fn(step->key).cast<ffi::String>() << "]?";
break;
}
}

/*!
* \brief Format an AccessPath as "<root>.field[idx]".
*/
ffi::String FormatAccessPath(const ffi::reflection::AccessPath& path) {
std::vector<ffi::reflection::AccessStep> steps;
const ffi::reflection::AccessPathObj* cur = path.get();
while (cur->step.defined()) {
steps.push_back(cur->step.value());
cur = static_cast<const ffi::reflection::AccessPathObj*>(cur->parent.get());
}
std::ostringstream os;
os << "<root>";
for (auto it = steps.rbegin(); it != steps.rend(); ++it) {
FormatAccessStep(os, *it);
}
return os.str();
}
} // namespace

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ffi::reflection::AccessPathObj>([](const ObjectRef& node, ReprPrinter* p) {
p->stream << Downcast<ffi::reflection::AccessPath>(node);
p->stream << FormatAccessPath(Downcast<ffi::reflection::AccessPath>(node));
});

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ffi::reflection::AccessStepObj>([](const ObjectRef& node, ReprPrinter* p) {
p->stream << Downcast<ffi::reflection::AccessStep>(node);
std::ostringstream os;
FormatAccessStep(os, Downcast<ffi::reflection::AccessStep>(node));
p->stream << os.str();
});

TVM_FFI_STATIC_INIT_BLOCK() {
Expand All @@ -134,5 +190,19 @@ TVM_FFI_STATIC_INIT_BLOCK() {
os << obj;
return os.str();
});
// Register __ffi_repr__ for AccessPath/AccessStep so that ffi.ReprPrint
// uses the concise "<root>.field[idx]" format instead of the dataclass repr.
refl::TypeAttrDef<ffi::reflection::AccessPathObj>().def(
refl::type_attr::kRepr,
[](ffi::reflection::AccessPath path, ffi::Function) -> ffi::String {
return FormatAccessPath(path);
});
refl::TypeAttrDef<ffi::reflection::AccessStepObj>().def(
refl::type_attr::kRepr,
[](ffi::reflection::AccessStep step, ffi::Function) -> ffi::String {
std::ostringstream os;
FormatAccessStep(os, step);
return os.str();
});
}
} // namespace tvm
7 changes: 1 addition & 6 deletions src/relax/ir/emit_te.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,7 @@ class RXPlaceholderOpNode : public te::PlaceholderOpNode {
static void RegisterReflection() {
namespace refl = tvm::ffi::reflection;
refl::ObjectDef<RXPlaceholderOpNode>()
.def_ro("name", &RXPlaceholderOpNode::name)
.def_ro("tag", &RXPlaceholderOpNode::tag)
.def_ro("attrs", &RXPlaceholderOpNode::attrs)
.def_ro("value", &RXPlaceholderOpNode::value)
.def_ro("shape", &RXPlaceholderOpNode::shape)
.def_ro("dtype", &RXPlaceholderOpNode::dtype);
.def_ro("value", &RXPlaceholderOpNode::value);
}

// FFI system configuration for structural equality and hashing
Expand Down
83 changes: 45 additions & 38 deletions src/s_tir/schedule/instruction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,46 +63,49 @@ InstructionKindRegEntry& InstructionKindRegEntry::RegisterOrGet(const ffi::Strin

/**************** Repr ****************/

namespace {
ffi::String InstructionAsPythonRepr(const InstructionNode* self) {
ffi::Array<Any> inputs;
inputs.reserve(self->inputs.size());
for (const Any& obj : self->inputs) {
if (obj == nullptr) {
inputs.push_back(ffi::String("None"));
} else if (auto opt_str = obj.as<ffi::String>()) {
inputs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"'));
} else if (obj.as<SBlockRVNode>() || obj.as<LoopRVNode>()) {
inputs.push_back(ffi::String("_"));
} else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
inputs.push_back(obj);
} else if (obj.as<IntImmNode>() || obj.as<FloatImmNode>()) {
inputs.push_back(obj);
} else if (const auto* expr = obj.as<PrimExprNode>()) {
PrimExpr new_expr = Substitute(
ffi::GetRef<PrimExpr>(expr), [](const Var& var) -> ffi::Optional<PrimExpr> {
ObjectPtr<VarNode> new_var = ffi::make_object<VarNode>(*var.get());
new_var->name_hint = "_";
return Var(new_var);
});
std::ostringstream os;
os << new_expr;
inputs.push_back(ffi::String(os.str()));
} else if (obj.as<IndexMapNode>()) {
inputs.push_back(obj);
} else {
TVM_FFI_THROW(TypeError) << "Stringifying is not supported for type: " << obj.GetTypeKey();
throw;
}
}
return self->kind->f_as_python(
/*inputs=*/inputs,
/*attrs=*/self->attrs,
/*decision=*/Any(nullptr),
/*outputs=*/ffi::Array<ffi::String>(self->outputs.size(), ffi::String("_")));
}
} // namespace

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<InstructionNode>([](const ObjectRef& obj, ReprPrinter* p) {
const auto* self = obj.as<InstructionNode>();
TVM_FFI_ICHECK_NOTNULL(self);
ffi::Array<Any> inputs;
inputs.reserve(self->inputs.size());
for (const Any& obj : self->inputs) {
if (obj == nullptr) {
inputs.push_back(ffi::String("None"));
} else if (auto opt_str = obj.as<ffi::String>()) {
inputs.push_back(ffi::String('"' + (*opt_str).operator std::string() + '"'));
} else if (obj.as<SBlockRVNode>() || obj.as<LoopRVNode>()) {
inputs.push_back(ffi::String("_"));
} else if (obj.type_index() < ffi::TypeIndex::kTVMFFISmallStr) {
inputs.push_back(obj);
} else if (obj.as<IntImmNode>() || obj.as<FloatImmNode>()) {
inputs.push_back(obj);
} else if (const auto* expr = obj.as<PrimExprNode>()) {
PrimExpr new_expr = Substitute(
ffi::GetRef<PrimExpr>(expr), [](const Var& var) -> ffi::Optional<PrimExpr> {
ObjectPtr<VarNode> new_var = ffi::make_object<VarNode>(*var.get());
new_var->name_hint = "_";
return Var(new_var);
});
std::ostringstream os;
os << new_expr;
inputs.push_back(ffi::String(os.str()));
} else if (obj.as<IndexMapNode>()) {
inputs.push_back(obj);
} else {
TVM_FFI_THROW(TypeError)
<< "Stringifying is not supported for type: " << obj.GetTypeKey();
throw;
}
}
p->stream << self->kind->f_as_python(
/*inputs=*/inputs,
/*attrs=*/self->attrs,
/*decision=*/Any(nullptr),
/*outputs=*/ffi::Array<ffi::String>(self->outputs.size(), ffi::String("_")));
p->stream << InstructionAsPythonRepr(obj.as<InstructionNode>());
});

/**************** FFI ****************/
Expand All @@ -116,6 +119,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
ffi::Array<Any> outputs) -> Instruction {
return Instruction(kind, inputs, attrs, outputs);
});
refl::TypeAttrDef<InstructionNode>().def(
refl::type_attr::kRepr, [](Instruction inst, ffi::Function) -> ffi::String {
return InstructionAsPythonRepr(inst.get());
});
}

} // namespace s_tir
Expand Down
46 changes: 29 additions & 17 deletions src/s_tir/schedule/trace.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
*/
#include <tvm/ffi/reflection/registry.h>

#include <sstream>

#include "./utils.h"

namespace tvm {
Expand Down Expand Up @@ -522,25 +524,31 @@ Trace TraceNode::Simplified(bool remove_postproc) const {

/**************** Repr ****************/

namespace {
ffi::String TraceAsPythonRepr(const TraceNode* self) {
std::ostringstream os;
os << "# from tvm import s_tir\n";
os << "def apply_trace(sch: s_tir.Schedule) -> None:\n";
ffi::Array<ffi::String> repr = self->AsPython(/*remove_postproc=*/false);
bool is_first = true;
for (const ffi::String& line : repr) {
if (is_first) {
is_first = false;
} else {
os << '\n';
}
os << " " << std::string(line);
}
if (is_first) {
os << " pass";
}
return os.str();
}
} // namespace

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<TraceNode>([](const ObjectRef& obj, ReprPrinter* p) {
const auto* self = obj.as<TraceNode>();
TVM_FFI_ICHECK_NOTNULL(self);
p->stream << "# from tvm import s_tir\n";
p->stream << "def apply_trace(sch: s_tir.Schedule) -> None:\n";
ffi::Array<ffi::String> repr = self->AsPython(/*remove_postproc=*/false);
bool is_first = true;
for (const ffi::String& line : repr) {
if (is_first) {
is_first = false;
} else {
p->stream << '\n';
}
p->stream << " " << line;
}
if (is_first) {
p->stream << " pass";
}
p->stream << TraceAsPythonRepr(obj.as<TraceNode>());
p->stream << std::flush;
});

Expand Down Expand Up @@ -594,6 +602,10 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_method("s_tir.schedule.TraceWithDecision", &TraceNode::WithDecision)
.def_method("s_tir.schedule.TraceSimplified", &TraceNode::Simplified)
.def("s_tir.schedule.TraceApplyJSONToSchedule", Trace::ApplyJSONToSchedule);
// Register __ffi_repr__ so str(trace) returns the Python script format
refl::TypeAttrDef<TraceNode>().def(
refl::type_attr::kRepr,
[](Trace trace, ffi::Function) -> ffi::String { return TraceAsPythonRepr(trace.get()); });
}

} // namespace s_tir
Expand Down
8 changes: 7 additions & 1 deletion src/target/target.cc
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
return (*it).second;
}
return Any();
});
})
.def("target.TargetAsJSON",
[](const Target& target) -> ffi::String { return target->str(); });
// Register __ffi_repr__ so that ffi.ReprPrint uses JSON format for Target
refl::TypeAttrDef<TargetNode>().def(
refl::type_attr::kRepr,
[](Target target, ffi::Function) -> ffi::String { return target->str(); });
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
Expand Down
Loading
Loading