Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 60 additions & 49 deletions onnxscript/_internal/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,22 +339,20 @@ def tensor_name_generator() -> str:
def _to_onnx_attr_ref(
self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo]
) -> ir.Attr:
pytype = val.typeinfo
attrtype = ta.pytype_to_attrtype(pytype)
attrtype = val.value.type
attrname = None
if attrtype is onnx.AttributeProto.FLOAT:
if attrtype is ir.AttributeType.FLOAT: # onnx.AttributeProto.FLOAT:
attrname = "value_float"
elif attrtype is onnx.AttributeProto.INT:
elif attrtype is ir.AttributeType.INT:
attrname = "value_int"
elif attrtype is onnx.AttributeProto.STRING:
elif attrtype is ir.AttributeType.STRING:
attrname = "value_string"
elif attrtype is onnx.AttributeProto.INTS:
elif attrtype is ir.AttributeType.INTS:
attrname = "value_ints"
else:
msg = f"Unsupported attribute type {pytype!r}."
msg = f"Unsupported attribute type {attrtype!r}."
fail(info.msg(msg) if info else msg)
attr_type = ir.AttributeType(ta.pytype_to_attrtype(pytype))
return ir.Attr(attrname, attr_type, None, val.value)
return ir.Attr(attrname, attrtype, value=None, ref_attr_name=val.value.name)

def _to_onnx_var(
self,
Expand All @@ -369,7 +367,7 @@ def _to_onnx_var(
result = self.emit(
[result_name], values.Op(self.default_opset, "Constant"), [], [attr]
)
if ta.base_type_is_bool(val.typeinfo):
if val.as_bool:
# ONNX attributes use an int-encoding for bools, but ONNX tensor types
# distinguish between int and bool. So we cast the int tensor to a bool tensor,
# to promote a (python) bool attribute to a ONNX bool tensor.
Expand All @@ -384,8 +382,9 @@ def _to_onnx_var(
)
self._castable.add(result_name)
return result
if isinstance(val, values.Dynamic):
return val.value
if isinstance(val, values.SymbolValue):
if isinstance(val.value, ir.Value):
return val.value
# Assume value is a python-value convertible to a tensor
# TODO: check if value is convertible to a TensorProto, so that we can
# produce a better error _message otherwise
Expand Down Expand Up @@ -534,29 +533,44 @@ def _translate_attr(

if isinstance(expr, ast.Name):
val = self._lookup(expr.id, self._source_of(expr))
if isinstance(val, values.AttrRef):
attr_type = ir.AttributeType(ta.pytype_to_attrtype(val.typeinfo))
attr_ref = ir.Attr(attr_name, attr_type, None, val.value)
if attr_meta is not None and (attr_ref.type != attr_meta.type):
self.fail(
expr,
f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'",
if isinstance(val, values.SymbolValue):
val = val.value
if isinstance(val, ir.Attr):
# A reference to an attribute parameter:
attr = val
attr_ref = ir.Attr(
attr_name, attr.type, value=None, ref_attr_name=attr.name
)
return attr_ref
if isinstance(val, irbuilder.IRFunction):
# Check that outer-scope variables referenced by function have same value
# at function-definition site and use-as-attribute site, to avoid errors.
for pyvar, previous in val.outer_scope_variables:
current = self._lookup(pyvar, self._source_of(expr))
if current.value != previous.value:
if attr_meta is not None and (attr.type != attr_meta.type):
self.fail(
expr,
f"Outer scope variable '{pyvar}' referenced by function "
f"'{expr.id!r}' modified.",
f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'",
)

# Create GraphProto attribute
val = val.to_graph_proto()
return attr_ref
if isinstance(val, irbuilder.IRFunction):
# A reference to a nested-function: convert to GraphProto and use it.
irfunction = val
# Check that outer-scope variables referenced by function have same value
# at function-definition site and use-as-attribute site, to avoid errors.
for pyvar, previous in irfunction.outer_scope_variables:
current = self._lookup(pyvar, self._source_of(expr))
if current.value != previous.value:
self.fail(
expr,
f"Outer scope variable '{pyvar}' referenced by function "
f"'{expr.id!r}' modified.",
)
# Create GraphProto attribute
val = irfunction.to_graph_proto()
if isinstance(val, ir.Value):
self.fail(expr, f"Cannot use ir.Value '{expr.id}' as an attribute.")
else:
# Treat as a constant python-value, to be converted below.
pass
else:
# This must be a reference to an outer-scope python-value, typically a constant.
# The value will be converted to an ONNX attribute value below.
pass
else:
val = self._eval_constant_expr(expr)

Expand Down Expand Up @@ -1045,7 +1059,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None:
typeinfo = None
if typeinfo is not None:
set_type_info(t, typeinfo)
var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo)
var = values.SymbolValue(t, info)
self._bind(lhs, var)
elif isinstance(lhs, ast.Tuple):
# Assignments of the form "x, y, z = op.SomeOp(...)"
Expand All @@ -1068,9 +1082,7 @@ def generate_onnx_name(x: ast.AST):
for x, output in zip(lhs.elts, outputs):
self._bind(
x.id,
values.Dynamic(
output, values.DynamicKind.Intermediate, self._source_of(x)
),
values.SymbolValue(output, self._source_of(x)),
)
else:
self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'")
Expand Down Expand Up @@ -1117,10 +1129,11 @@ def ret(exp, i, suffix):
preferred_name = f"return_val{suffix}"
return_var = self._translate_expr(exp, preferred_name) # TODO(rama)
val = self._lookup(return_var.name, self._source_of(exp), False)
if val and val.kind == values.DynamicKind.Input:
# In ONNX, a graph-input cannot be an output of the graph.
# We need to insert a copy.
return_var = self._emit_copy(return_var, preferred_name)
if isinstance(val, values.SymbolValue) and isinstance(val.value, ir.Value):
if val.value.is_graph_input():
# In ONNX, a graph-input cannot be an output of the graph.
# We need to insert a copy.
return_var = self._emit_copy(return_var, preferred_name)
for prev_output in self._current_fn.outputs:
if prev_output.name == return_var.name:
# ONNX does not allow duplicate output names.
Expand Down Expand Up @@ -1190,7 +1203,7 @@ def rename(x):
for x, y in zip(live_defs, if_outputs):
self._bind(
x,
values.Dynamic(y, values.DynamicKind.Intermediate, self._source_of(stmt)),
values.SymbolValue(y, self._source_of(stmt)),
)

def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
Expand Down Expand Up @@ -1257,7 +1270,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
self._current_fn.append_parameter(onnx_loop_var)
self._bind(
python_loop_var_name,
values.Dynamic(onnx_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)),
values.SymbolValue(onnx_loop_var, self._source_of(loop_stmt)),
)

self._current_fn.append_parameter(
Expand All @@ -1278,9 +1291,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None:
)
self._bind(
pv,
values.Dynamic(
values.SymbolValue(
ir.Value(name=onnx_var_name),
values.DynamicKind.Loop,
self._source_of(loop_stmt),
),
)
Expand Down Expand Up @@ -1376,7 +1388,7 @@ def rename(x):
if isinstance(loop_outputs, ir.Value):
loop_outputs = [loop_outputs]
for x, loop_output in zip(outputs, loop_outputs):
self._bind(x, values.Dynamic(loop_output, values.DynamicKind.Output, info))
self._bind(x, values.SymbolValue(loop_output, info))

def _translate_block(
self,
Expand Down Expand Up @@ -1431,7 +1443,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None:
function_ir.outer_scope_variables = [
(var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars
]
self._bind(fn.name, function_ir)
self._bind(fn.name, values.SymbolValue(function_ir, self._source_of(fn)))
# TODO: Does not yet handle nested functions within nested functions.
self._current_fn.add_nested_function(function_ir)

Expand Down Expand Up @@ -1459,16 +1471,15 @@ def _translate_function_signature_common(
attribute_type = ta.pytype_to_attrtype(typeinfo)
attr = ir.Attr(x.arg, ir.AttributeType(attribute_type), default_value, None)
self._current_fn.append_parameter(attr)
self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x)))
as_bool = ta.base_type_is_bool(typeinfo)
self._bind(x.arg, values.AttrRef(attr, as_bool, self._source_of(x)))
else:
onnx_parameter = make_value(x.arg, typeinfo, self._source_of(x))
self._current_fn.append_parameter(onnx_parameter)
self._used_vars.add(x.arg)
self._bind(
x.arg,
values.Dynamic(
onnx_parameter, values.DynamicKind.Input, self._source_of(x)
),
values.SymbolValue(onnx_parameter, self._source_of(x)),
)
if fn.returns:
type_annotation = self._eval_constant_expr(fn.returns)
Expand Down
67 changes: 14 additions & 53 deletions onnxscript/_internal/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
import logging
import types
import typing
from enum import IntFlag
from typing import ( # type: ignore[attr-defined]
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Expand All @@ -22,7 +20,6 @@
Protocol,
Sequence,
TypeVar,
_GenericAlias,
)

import onnx
Expand All @@ -36,9 +33,6 @@
from onnxscript.ir import _schemas
from onnxscript.onnx_types import ONNXType

if TYPE_CHECKING:
from onnxscript._internal.type_annotation import TypeAnnotationValue

_R = TypeVar("_R")
_P = ParamSpec("_P")

Expand Down Expand Up @@ -853,61 +847,28 @@ def ThresholdedRelu(X, alpha: float):
* To represent constant-values, translated into ONNX constants.
"""

def __init__(self, info: sourceinfo.SourceInfo) -> None:
def __init__(self, value: Any, info: sourceinfo.SourceInfo) -> None:
"""
Initializes SymbolValue.

Arguments:
value: The value bound to a python variable in a script.
info: source-location information for error-messages/debugging
"""
if not isinstance(info, sourceinfo.SourceInfo):
raise TypeError(f"info must be of type sourceinfo.SourceInfo not {type(info)!r}.")
self.value = value
self.info = info


class AttrRef(SymbolValue):
def __init__(
self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo
) -> None:
def __init__(self, attr: ir.Attr, as_bool: bool, info: sourceinfo.SourceInfo) -> None:
"""Initializes AttrRef.

Arguments:
attr_name: name of the attribute-parameter
typeinfo: type annotation of the attribute.
op's attributes in ONNX are usually single type or list of single type.
attr: An ir.Attr representing the attribute-parameter
as_bool: Whether the attribute is to be interpreted as a bool type (represented as int in ONNX)
info: for debugging use.
"""
super().__init__(info)
self.value = attr_name
self.typeinfo = typeinfo
if not isinstance(typeinfo, (type, _GenericAlias)):
# typing._GenericAlias for List[int] and List[str], etc.
raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.")
self.typeinfo = typeinfo


class DynamicKind(IntFlag):
Unknown = 0
Input = 1
Output = 2
Intermediate = 4
Loop = 8


class Dynamic(SymbolValue):
def __init__(
self,
onnx_var: ir.Value,
kind: DynamicKind,
info: sourceinfo.SourceInfo,
typeinfo: TypeAnnotationValue | None = None,
) -> None:
"""Represents an ir.Value with some extra information.

Arguments:
onnx_var: the name of the ONNX variable used to represent this value
kind: the DynamicKind of this variable
info: source-location information for error-messages/debugging
typeinfo: type-information for the value
"""
super().__init__(info)
assert isinstance(kind, DynamicKind)
if not isinstance(onnx_var, ir.Value):
raise TypeError(f"onnx_var must be of type ir.Value not {type(onnx_var)!r}.")
self.value = onnx_var
self.kind = kind
self.typeinfo = typeinfo
super().__init__(attr, info)
self.as_bool = as_bool
4 changes: 0 additions & 4 deletions onnxscript/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from onnxscript._internal.values import (
AttrRef,
Dynamic,
DynamicKind,
OnnxClosure,
OnnxFunction,
Op,
Expand All @@ -21,8 +19,6 @@

__all__ = [
"AttrRef",
"Dynamic",
"DynamicKind",
"OnnxClosure",
"OnnxFunction",
"Op",
Expand Down
Loading