diff --git a/onnxscript/_internal/converter.py b/onnxscript/_internal/converter.py index 82388a19cf..d9c367115a 100644 --- a/onnxscript/_internal/converter.py +++ b/onnxscript/_internal/converter.py @@ -339,22 +339,20 @@ def tensor_name_generator() -> str: def _to_onnx_attr_ref( self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo] ) -> ir.Attr: - pytype = val.typeinfo - attrtype = ta.pytype_to_attrtype(pytype) + attrtype = val.value.type attrname = None - if attrtype is onnx.AttributeProto.FLOAT: + if attrtype is ir.AttributeType.FLOAT: # onnx.AttributeProto.FLOAT: attrname = "value_float" - elif attrtype is onnx.AttributeProto.INT: + elif attrtype is ir.AttributeType.INT: attrname = "value_int" - elif attrtype is onnx.AttributeProto.STRING: + elif attrtype is ir.AttributeType.STRING: attrname = "value_string" - elif attrtype is onnx.AttributeProto.INTS: + elif attrtype is ir.AttributeType.INTS: attrname = "value_ints" else: - msg = f"Unsupported attribute type {pytype!r}." + msg = f"Unsupported attribute type {attrtype!r}." fail(info.msg(msg) if info else msg) - attr_type = ir.AttributeType(ta.pytype_to_attrtype(pytype)) - return ir.Attr(attrname, attr_type, None, val.value) + return ir.Attr(attrname, attrtype, value=None, ref_attr_name=val.value.name) def _to_onnx_var( self, @@ -369,7 +367,7 @@ def _to_onnx_var( result = self.emit( [result_name], values.Op(self.default_opset, "Constant"), [], [attr] ) - if ta.base_type_is_bool(val.typeinfo): + if val.as_bool: # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. @@ -384,8 +382,9 @@ def _to_onnx_var( ) self._castable.add(result_name) return result - if isinstance(val, values.Dynamic): - return val.value + if isinstance(val, values.SymbolValue): + if isinstance(val.value, ir.Value): + return val.value # Assume value is a python-value convertible to a tensor # TODO: check if value is convertible to a TensorProto, so that we can # produce a better error _message otherwise @@ -534,29 +533,44 @@ def _translate_attr( if isinstance(expr, ast.Name): val = self._lookup(expr.id, self._source_of(expr)) - if isinstance(val, values.AttrRef): - attr_type = ir.AttributeType(ta.pytype_to_attrtype(val.typeinfo)) - attr_ref = ir.Attr(attr_name, attr_type, None, val.value) - if attr_meta is not None and (attr_ref.type != attr_meta.type): - self.fail( - expr, - f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'", + if isinstance(val, values.SymbolValue): + val = val.value + if isinstance(val, ir.Attr): + # A reference to an attribute parameter: + attr = val + attr_ref = ir.Attr( + attr_name, attr.type, value=None, ref_attr_name=attr.name ) - return attr_ref - if isinstance(val, irbuilder.IRFunction): - # Check that outer-scope variables referenced by function have same value - # at function-definition site and use-as-attribute site, to avoid errors. - for pyvar, previous in val.outer_scope_variables: - current = self._lookup(pyvar, self._source_of(expr)) - if current.value != previous.value: + if attr_meta is not None and (attr.type != attr_meta.type): self.fail( expr, - f"Outer scope variable '{pyvar}' referenced by function " - f"'{expr.id!r}' modified.", + f"Attribute type '{attr_ref.type}' does not match expected type '{attr_meta.type}'", ) - - # Create GraphProto attribute - val = val.to_graph_proto() + return attr_ref + if isinstance(val, irbuilder.IRFunction): + # A reference to a nested-function: convert to GraphProto and use it. + irfunction = val + # Check that outer-scope variables referenced by function have same value + # at function-definition site and use-as-attribute site, to avoid errors. + for pyvar, previous in irfunction.outer_scope_variables: + current = self._lookup(pyvar, self._source_of(expr)) + if current.value != previous.value: + self.fail( + expr, + f"Outer scope variable '{pyvar}' referenced by function " + f"'{expr.id!r}' modified.", + ) + # Create GraphProto attribute + val = irfunction.to_graph_proto() + if isinstance(val, ir.Value): + self.fail(expr, f"Cannot use ir.Value '{expr.id}' as an attribute.") + else: + # Treat as a constant python-value, to be converted below. + pass + else: + # This must be a reference to an outer-scope python-value, typically a constant. + # The value will be converted to an ONNX attribute value below. + pass else: val = self._eval_constant_expr(expr) @@ -1045,7 +1059,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: typeinfo = None if typeinfo is not None: set_type_info(t, typeinfo) - var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo) + var = values.SymbolValue(t, info) self._bind(lhs, var) elif isinstance(lhs, ast.Tuple): # Assignments of the form "x, y, z = op.SomeOp(...)" @@ -1068,9 +1082,7 @@ def generate_onnx_name(x: ast.AST): for x, output in zip(lhs.elts, outputs): self._bind( x.id, - values.Dynamic( - output, values.DynamicKind.Intermediate, self._source_of(x) - ), + values.SymbolValue(output, self._source_of(x)), ) else: self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'") @@ -1117,10 +1129,11 @@ def ret(exp, i, suffix): preferred_name = f"return_val{suffix}" return_var = self._translate_expr(exp, preferred_name) # TODO(rama) val = self._lookup(return_var.name, self._source_of(exp), False) - if val and val.kind == values.DynamicKind.Input: - # In ONNX, a graph-input cannot be an output of the graph. - # We need to insert a copy. - return_var = self._emit_copy(return_var, preferred_name) + if isinstance(val, values.SymbolValue) and isinstance(val.value, ir.Value): + if val.value.is_graph_input(): + # In ONNX, a graph-input cannot be an output of the graph. + # We need to insert a copy. + return_var = self._emit_copy(return_var, preferred_name) for prev_output in self._current_fn.outputs: if prev_output.name == return_var.name: # ONNX does not allow duplicate output names. @@ -1190,7 +1203,7 @@ def rename(x): for x, y in zip(live_defs, if_outputs): self._bind( x, - values.Dynamic(y, values.DynamicKind.Intermediate, self._source_of(stmt)), + values.SymbolValue(y, self._source_of(stmt)), ) def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: @@ -1257,7 +1270,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self._current_fn.append_parameter(onnx_loop_var) self._bind( python_loop_var_name, - values.Dynamic(onnx_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)), + values.SymbolValue(onnx_loop_var, self._source_of(loop_stmt)), ) self._current_fn.append_parameter( @@ -1278,9 +1291,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ) self._bind( pv, - values.Dynamic( + values.SymbolValue( ir.Value(name=onnx_var_name), - values.DynamicKind.Loop, self._source_of(loop_stmt), ), ) @@ -1376,7 +1388,7 @@ def rename(x): if isinstance(loop_outputs, ir.Value): loop_outputs = [loop_outputs] for x, loop_output in zip(outputs, loop_outputs): - self._bind(x, values.Dynamic(loop_output, values.DynamicKind.Output, info)) + self._bind(x, values.SymbolValue(loop_output, info)) def _translate_block( self, @@ -1431,7 +1443,7 @@ def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: function_ir.outer_scope_variables = [ (var, self._lookup(var, self._source_of(fn))) for var in outer_scope_vars ] - self._bind(fn.name, function_ir) + self._bind(fn.name, values.SymbolValue(function_ir, self._source_of(fn))) # TODO: Does not yet handle nested functions within nested functions. self._current_fn.add_nested_function(function_ir) @@ -1459,16 +1471,15 @@ def _translate_function_signature_common( attribute_type = ta.pytype_to_attrtype(typeinfo) attr = ir.Attr(x.arg, ir.AttributeType(attribute_type), default_value, None) self._current_fn.append_parameter(attr) - self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x))) + as_bool = ta.base_type_is_bool(typeinfo) + self._bind(x.arg, values.AttrRef(attr, as_bool, self._source_of(x))) else: onnx_parameter = make_value(x.arg, typeinfo, self._source_of(x)) self._current_fn.append_parameter(onnx_parameter) self._used_vars.add(x.arg) self._bind( x.arg, - values.Dynamic( - onnx_parameter, values.DynamicKind.Input, self._source_of(x) - ), + values.SymbolValue(onnx_parameter, self._source_of(x)), ) if fn.returns: type_annotation = self._eval_constant_expr(fn.returns) diff --git a/onnxscript/_internal/values.py b/onnxscript/_internal/values.py index f9f0958de3..b6e0b33eb6 100644 --- a/onnxscript/_internal/values.py +++ b/onnxscript/_internal/values.py @@ -11,9 +11,7 @@ import logging import types import typing -from enum import IntFlag from typing import ( # type: ignore[attr-defined] - TYPE_CHECKING, Any, Callable, ClassVar, @@ -22,7 +20,6 @@ Protocol, Sequence, TypeVar, - _GenericAlias, ) import onnx @@ -36,9 +33,6 @@ from onnxscript.ir import _schemas from onnxscript.onnx_types import ONNXType -if TYPE_CHECKING: - from onnxscript._internal.type_annotation import TypeAnnotationValue - _R = TypeVar("_R") _P = ParamSpec("_P") @@ -853,61 +847,28 @@ def ThresholdedRelu(X, alpha: float): * To represent constant-values, translated into ONNX constants. """ - def __init__(self, info: sourceinfo.SourceInfo) -> None: + def __init__(self, value: Any, info: sourceinfo.SourceInfo) -> None: + """ + Initializes SymbolValue. + + Arguments: + value: The value bound to a python variable in a script. + info: source-location information for error-messages/debugging + """ if not isinstance(info, sourceinfo.SourceInfo): raise TypeError(f"info must be of type sourceinfo.SourceInfo not {type(info)!r}.") + self.value = value self.info = info class AttrRef(SymbolValue): - def __init__( - self, attr_name: str, typeinfo: _GenericAlias, info: sourceinfo.SourceInfo - ) -> None: + def __init__(self, attr: ir.Attr, as_bool: bool, info: sourceinfo.SourceInfo) -> None: """Initializes AttrRef. Arguments: - attr_name: name of the attribute-parameter - typeinfo: type annotation of the attribute. - op's attributes in ONNX are usually single type or list of single type. + attr: An ir.Attr representing the attribute-parameter + as_bool: Whether the attribute is to be interpreted as a bool type (represented as int in ONNX) info: for debugging use. """ - super().__init__(info) - self.value = attr_name - self.typeinfo = typeinfo - if not isinstance(typeinfo, (type, _GenericAlias)): - # typing._GenericAlias for List[int] and List[str], etc. - raise TypeError(f"Expecting a type not f{type(typeinfo)} for typeinfo.") - self.typeinfo = typeinfo - - -class DynamicKind(IntFlag): - Unknown = 0 - Input = 1 - Output = 2 - Intermediate = 4 - Loop = 8 - - -class Dynamic(SymbolValue): - def __init__( - self, - onnx_var: ir.Value, - kind: DynamicKind, - info: sourceinfo.SourceInfo, - typeinfo: TypeAnnotationValue | None = None, - ) -> None: - """Represents an ir.Value with some extra information. - - Arguments: - onnx_var: the name of the ONNX variable used to represent this value - kind: the DynamicKind of this variable - info: source-location information for error-messages/debugging - typeinfo: type-information for the value - """ - super().__init__(info) - assert isinstance(kind, DynamicKind) - if not isinstance(onnx_var, ir.Value): - raise TypeError(f"onnx_var must be of type ir.Value not {type(onnx_var)!r}.") - self.value = onnx_var - self.kind = kind - self.typeinfo = typeinfo + super().__init__(attr, info) + self.as_bool = as_bool diff --git a/onnxscript/values.py b/onnxscript/values.py index 0b8bd2519a..2759dd5fdf 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -7,8 +7,6 @@ from onnxscript._internal.values import ( AttrRef, - Dynamic, - DynamicKind, OnnxClosure, OnnxFunction, Op, @@ -21,8 +19,6 @@ __all__ = [ "AttrRef", - "Dynamic", - "DynamicKind", "OnnxClosure", "OnnxFunction", "Op",