diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 1defac3e53..0afbd36d60 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -189,30 +189,27 @@ def get_type_info(x): def static_cast_inputs( converter_: converter.Converter, op_schema: Optional[OpSchema], - args: Sequence[Optional[converter.Variable]], + args: Sequence[Optional[ir.Value]], ) -> tuple[str, ...]: """Used for autocast during script-translation. This is meant to transform expressions like "Add(X, 1)" to "Add(X, CastLike(1, X))" Polymorphic constants (like 0 and 1) are cast to the type of other operands as needed. """ - def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variable]: + def get_type_info(x: Optional[ir.Value]) -> Optional[ir.Value]: """Returns x back if x can serve as the target-type for a cast (as the second argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is castable, while X can serve as the target-type. """ - return None if x is None or x.is_castable else x + return None if x is None or converter_.is_castable(x.name) else x - def cast_like( - x: Optional[converter.Variable], y: Optional[converter.Variable] - ) -> Optional[str]: + def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]: if x is None: return None - if x.is_castable and y is not None: + if converter_.is_castable(x.name) and y is not None: # Polymorphic constant x is cast to the type of y: x_cast = converter_.generate_unique_name(f"{x.name}_cast") - converter_.emit([x_cast], "CastLike", [x.name, y.name]) - return x_cast - return x.name + return converter_.emit1([x_cast], "CastLike", [x, y]) + return x return cast_inputs(get_type_info, cast_like, op_schema, args) diff --git a/onnxscript/_internal/utils.py b/onnxscript/_internal/utils.py index ce2b657cfd..c5e05b3d92 100644 --- a/onnxscript/_internal/utils.py +++ b/onnxscript/_internal/utils.py @@ -7,6 +7,7 @@ import numpy as np import onnx +import onnx_ir as ir from onnxscript import tensor @@ -87,6 +88,41 @@ def value_to_type_proto(val): raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.") +def value_to_type(val): + """Return an ir.Value representation of a python-value.""" + if isinstance(val, (np.ndarray, tensor.Tensor)): + elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251 + shape = val.shape + return (ir.TensorType(elem_type), shape) + elif isinstance(val, int): + elem_type = onnx.TensorProto.INT32 + shape = [] + return (ir.TensorType(elem_type), shape) + elif isinstance(val, (float, np.float32)): + elem_type = onnx.TensorProto.FLOAT + shape = [] + return (ir.TensorType(elem_type), shape) + elif isinstance(val, list): + if len(val) > 0: + type, shape = value_to_type(val[0]) + return ir.SequenceType(type), shape + # Edge-case. Cannot determine a suitable ONNX type for an empty list. + # Should be using a typed-value instead. + # Treated as a sequence of tensors of float-type. + return ir.SequenceType(ir.TensorType(onnx.TensorProto.FLOAT)), None + if isinstance(val, numbers.Number): + nparray = np.array(val) + elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251 + return ir.TensorType(elem_type), [] + raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.") + + +def value_to_ir_value(name: str, val) -> ir.Value: + """Return an ir.Value representation of a python-value.""" + type, shape = value_to_type(val) + return ir.Value(name=name, type=type, shape=shape) + + def values_to_value_infos(name_values): """Create a list of ValueInfoProto from a list of (name, value) pairs, skipping any None values. @@ -96,3 +132,10 @@ def values_to_value_infos(name_values): for (name, val) in name_values if val is not None ] + + +def values_to_ir_values(name_values): + """Create a list of ir.Value from a list of (name, value) pairs, + skipping any None values. + """ + return [value_to_ir_value(name, val) for (name, val) in name_values if val is not None] diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 3e87c366ad..abef235092 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -17,6 +17,7 @@ ) import onnx +import onnx_ir as ir import onnxscript from onnxscript import irbuilder, onnx_types, sourceinfo, values @@ -80,30 +81,6 @@ def ignore(cond, msg): } -class Variable: - """Represents an ONNX variable. - - TODO(rama): Consider merging this with IRVar. However, "castable" is specific to this - converter. - """ - - def __init__(self, name: str, castable: bool = False): - """Initialize the instance. - - Args: - name: Name of the ONNX variable - castable: Whether this variable is castable to a desired target type. - Used for ONNX variables representing constants created from python values - like 0 or 1 or 0.5 which are treated as polymorphic values castable to other - types as needed. - """ - self.name = name - self.is_castable = castable - - def __str__(self) -> str: - return self.name - - if TYPE_CHECKING: # The type-alias LocalSymValue represents the types of values that local names in a # script-function may be bound to during translation, (ONNX IR values). @@ -136,6 +113,28 @@ def __str__(self) -> str: OnnxVarName = str +def set_type_info(value: ir.Value, typeinfo: ta.TypeAnnotationValue) -> None: + """Sets the type information on an IR value.""" + try: + type_and_shape = ir.from_proto(typeinfo.to_type_proto()) + value.type = type_and_shape.type + value.shape = type_and_shape.shape + except AttributeError: + # TODO: This needs to be fixed. + pass + value.meta["typeinfo"] = typeinfo + + +def make_value( + varname: str, typeinfo: ta.TypeAnnotationValue, source_info: sourceinfo.SourceInfo +) -> ir.Value: + value = ir.Value(name=varname) + value.meta.setdefault("sourceinfo", source_info) + if typeinfo is not None: + set_type_info(value, typeinfo) + return value + + class Converter: """Main class to translate python code into ONNX operators. @@ -163,13 +162,11 @@ class :class:`onnxscript.irbuilder.IRBuilder` is used def __init__( self, - ir_builder: Optional[irbuilder.IRBuilder] = None, opset: Optional[values.Opset] = None, global_names: Optional[dict[str, Any]] = None, source: Optional[str] = None, default_opset: Optional[values.Opset] = None, ): - self.ir_builder = ir_builder or irbuilder.IRBuilder() self.source = source if global_names is not None: # We make a copy in case function eval modifies it. @@ -184,6 +181,11 @@ def __init__( self._used_vars: set[str] = set() self._locals: List[Dict[str, LocalSymValue]] = [{}] self._analyzer: analysis.AstAnalyzer | None = None + self._castable: set[str] = set() + + def is_castable(self, var_name: str) -> bool: + """Returns True if the variable with the given name represents a polymorphic constant.""" + return var_name in self._castable @property def analyzer(self) -> analysis.AstAnalyzer: @@ -264,7 +266,7 @@ def _enter_scope(self, name: str, parent_node: ast.AST): The block is translated into a nested-scope in ONNX. """ self._outer.insert(0, self._current_fn) - self._current_fn = self.ir_builder.new_function(name) + self._current_fn = irbuilder.IRFunction(name) self._locals.insert(0, {}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) @@ -306,7 +308,10 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: def _make_onnx_attr( self, attrname: str, attrval: Any, attrtype: int | None = None - ) -> irbuilder.IRAttributeValue: + ) -> ir.Attr: + if isinstance(attrval, ir.Graph): + return ir.Attr(attrname, ir.AttributeType.GRAPH, attrval) + def tensor_name_generator() -> str: """Return name to be used for tensor, if we need to create one.""" return self.generate_unique_name(f"attr_{attrname}") @@ -314,11 +319,11 @@ def tensor_name_generator() -> str: proto = autocast.pyvalue_to_onnx_attribute( attrname, attrval, tensor_name_generator, attrtype ) - return self.ir_builder.make_attr(proto) + return ir.from_proto(proto) def _to_onnx_attr_ref( self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo] - ) -> irbuilder.IRAttributeValue: + ) -> ir.Attr: pytype = val.typeinfo attrtype = ta.pytype_to_attrtype(pytype) attrname = None @@ -333,72 +338,86 @@ def _to_onnx_attr_ref( else: msg = f"Unsupported attribute type {pytype!r}." fail(info.msg(msg) if info else msg) - return self.ir_builder.make_attr_ref(attrname, val.value, pytype) + attr_type = ir.AttributeType(ta.pytype_to_attrtype(pytype)) + return ir.Attr(attrname, attr_type, None, val.value) def _to_onnx_var( self, val: values.SymbolValue | PyValue, target: Optional[PreferredName] = None, info: Optional[sourceinfo.SourceInfo] = None, - ) -> Variable: + ) -> ir.Value: if isinstance(val, values.AttrRef): # promote attribute to value - result = self.generate_unique_name(target or "tmp") + result_name = self.generate_unique_name(target or "tmp") attr = self._to_onnx_attr_ref(val, info) - self.emit([result], values.Op(self.default_opset, "Constant"), [], [attr]) + result = self.emit( + [result_name], values.Op(self.default_opset, "Constant"), [], [attr] + ) if ta.base_type_is_bool(val.typeinfo): # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. - result_as_bool = self.generate_unique_name(result + "_as_bool") + result_as_bool = self.generate_unique_name(result_name + "_as_bool") cast_attr = self._make_onnx_attr("to", onnx_types.BOOL.dtype) - self.emit( + self._castable.add(result_as_bool) + return self.emit1( [result_as_bool], values.Op(self.default_opset, "Cast"), [result], [cast_attr], ) - return Variable(result_as_bool, True) - return Variable(result, True) + self._castable.add(result_name) + return result if isinstance(val, values.Dynamic): - return Variable(val.value) + return val.value # Assume value is a python-value convertible to a tensor # TODO: check if value is convertible to a TensorProto, so that we can # produce a better error _message otherwise return self._emit_const(val, target or "tmp", info) - def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable: + def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> ir.Value: return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) def emit( self, outputs: Sequence[str], callee: values.Op | str, - inputs: Sequence[Optional[str]], + inputs: Sequence[Optional[ir.Value]], attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, - sub_functions: Optional[dict[str, onnx.FunctionProto]] = None, - ): + ) -> Sequence[ir.Value] | ir.Value: if not isinstance(callee, values.Op): callee = values.Op(self.default_opset, callee) if attrs is None: attrs = [] - if sub_functions is None: - sub_functions = {} - self.ir_builder.add_stmt( - self._current_fn, - outputs, - callee, - inputs, - attrs, - sub_functions, + output_values = [ir.Value(name=o) for o in outputs] + node = ir.Node( + domain=callee.opset.domain, + version=callee.opset.version, + op_type=callee.name, + inputs=inputs, + outputs=output_values, + attributes=attrs, ) + if not isinstance(callee, values.Op): + raise TypeError(f"Unexpected type {type(callee)} for callee.") + node.meta.setdefault("callee", callee) + self._current_fn.append_node(node) + + return output_values if len(output_values) > 1 else output_values[0] + + def emit1(self, *args, **kwargs) -> ir.Value: + r = self.emit(*args, **kwargs) + if not isinstance(r, ir.Value): + raise TypeError(f"Expected single ONNX IR Value, got {type(r)!r}.") + return r def _emit_const( self, pyvalue: PyValue, suggested_name: Optional[PreferredName], info: sourceinfo.SourceInfo, - ) -> Variable: + ) -> ir.Value: if suggested_name is None: if isinstance(pyvalue, int): if pyvalue >= 0: @@ -420,14 +439,13 @@ def _emit_const( except ValueError as e: fail(info.msg(str(e))) attr = self._make_onnx_attr("value", tensor) - self.emit([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) - return Variable(ovar, True) + self._castable.add(ovar) + return self.emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) - def _emit_copy(self, original_var: str, suggested_name: str) -> str: + def _emit_copy(self, original_var: ir.Value, suggested_name: str) -> ir.Value: """Emits a copy statement, using the ONNX Identity operator.""" new_var = self.generate_unique_name(suggested_name) - self.emit([new_var], "Identity", [original_var]) - return new_var + return self.emit([new_var], "Identity", [original_var]) def _is_constant_expr(self, node: ast.AST) -> None: if isinstance(node, ast.UnaryOp): @@ -474,6 +492,16 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue: ) ) from e + def _get_type_annotation(self, annotation: ast.Expr) -> Optional[ta.TypeAnnotationValue]: + typeinfo = self._eval_constant_expr(annotation) + if not ta.is_valid_type(typeinfo): + self.warn( + annotation, + "Unsupported type annotation.", + ) + typeinfo = None + return typeinfo + def _translate_attr( self, attr_name: str, @@ -492,7 +520,8 @@ def _translate_attr( if isinstance(expr, ast.Name): val = self._lookup(expr.id, self._source_of(expr)) if isinstance(val, values.AttrRef): - attr_ref = self.ir_builder.make_attr_ref(attr_name, val.value, val.typeinfo) + attr_type = ir.AttributeType(ta.pytype_to_attrtype(val.typeinfo)) + attr_ref = ir.Attr(attr_name, attr_type, None, val.value) if attr_meta is not None and (attr_ref.type != attr_meta.type): self.fail( expr, @@ -537,14 +566,15 @@ def _translate_attr( def _translate_docstring(self, node: ast.Expr) -> None: if hasattr(node.value, "value"): # python 3.8+ - return self.ir_builder.add_docstring(self._current_fn, node.value.value) - raise TypeError( - f"Unexpected type {type(node)!r} for node. Unsupoorted version of python." - ) + self._current_fn.doc_string = node.value.value + else: + raise TypeError( + f"Unexpected type {type(node)!r} for node. Unsupoorted version of python." + ) def _translate_expr( self, node: ast.AST, target: Optional[PreferredName] = None - ) -> Variable: + ) -> ir.Value: """Expression-translation generates "IR statements/nodes" that compute the value of the expression into a target-variable, and returns the variable that is assigned this value. @@ -567,16 +597,15 @@ def _translate_expr( raise ValueError( self._message(node, f"Unsupported expression type {type(node)!r}.") ) - if isinstance(r, Variable): + if isinstance(r, ir.Value): return r callee, args, attrs = r target = "tmp" if target is None else target assert isinstance(target, str) result = self.generate_unique_name(target) - self.emit([result], callee, args, attrs) - return Variable(result) + return self.emit1([result], callee, args, attrs) - def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: + def _translate_opt_expr(self, node: ast.expr) -> Optional[ir.Value]: """Translation of an expression where "None" is permitted (eg., for an optional argument). None is represented as a Constant in Python 3.9+. """ @@ -586,7 +615,7 @@ def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: def _translate_subscript_expr( self, node: ast.Subscript, target: Optional[PreferredName] - ) -> Variable: + ) -> ir.Value: """List of supported syntaxes is below. `A` is a tensor or an expression equivalent to a tensor. @@ -633,15 +662,15 @@ def _translate_subscript_expr( # Create cached int constants: # TODO: Do this at a graph-scope level. - cached_int_consts = {} + cached_int_consts: dict[int, ir.Value] = {} - def const_1d(value, name: Optional[str] = None): + def const_1d(value, name: Optional[str] = None) -> ir.Value: nonlocal cached_int_consts if value not in cached_int_consts: cached_int_consts[value] = self._emit_const([value], name, info) return cached_int_consts[value] - def one_1d(): + def one_1d() -> ir.Value: return const_1d(1) # Max/min 64-bit int values are used to represent default values for start/stop in Slice. @@ -650,7 +679,7 @@ def one_1d(): def translate_slice_component( node_arg, default_value: Optional[int] = None - ) -> tuple[str, Optional[int]]: + ) -> tuple[ir.Value, Optional[int]]: """Translate optional start/stop/step component of a Slice expression.""" if node_arg is None: if default_value is None: @@ -667,17 +696,17 @@ def translate_slice_component( else: raise RuntimeError(f"Slice component type must be int, not {type(cst)}") else: - name = self._translate_expr(node_arg).name - reshaped = self.generate_unique_name(f"{name}_reshaped") - self.emit( + value = self._translate_expr(node_arg) + reshaped = self.generate_unique_name(f"{value.name}_reshaped") + reshaped_value = self.emit1( [reshaped], values.Op(self.default_opset, "Reshape"), - [name, one_1d().name], + [value, one_1d()], [], ) - return reshaped, None + return reshaped_value, None - def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: + def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value]: """Translate slice-expression of the form from:to:step.""" step_name, step = translate_slice_component(slice_expr.step, 1) if step is None: @@ -720,8 +749,8 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: non_scalar_indices.append((axis, elt)) if not (sliced_indices or scalar_indices or non_scalar_indices): # Edge case: no index specified. Eg. A[:, :] - self.emit([target], "Identity", [var_name]) - return Variable(target) + return self.emit1([target], "Identity", [var_name]) + if sliced_indices or len(scalar_indices) > 1: # We emit a Slice operation if we have any indices like 1:5:2 or if the number of # scalar indices (like 2) is more than 1. @@ -751,52 +780,52 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: inputs = translate_slice(element) starts.append(inputs[0]) ends.append(inputs[1]) - axes.append(axis_var.name) + axes.append(axis_var) steps.append(inputs[2]) if len(starts) > 1: axis_0_attr = self._make_onnx_attr("axis", 0) start_name = self.generate_unique_name(f"{var_name}_start") - self.emit([start_name], "Concat", starts, [axis_0_attr]) + start_value = self.emit([start_name], "Concat", starts, [axis_0_attr]) end_name = self.generate_unique_name(f"{var_name}_end") - self.emit([end_name], "Concat", ends, [axis_0_attr]) + end_value = self.emit([end_name], "Concat", ends, [axis_0_attr]) axes_name = self.generate_unique_name(f"{var_name}_axis") - self.emit([axes_name], "Concat", axes, [axis_0_attr]) + axes_value = self.emit([axes_name], "Concat", axes, [axis_0_attr]) steps_name = self.generate_unique_name(f"{var_name}_step") - self.emit([steps_name], "Concat", steps, [axis_0_attr]) + steps_value = self.emit([steps_name], "Concat", steps, [axis_0_attr]) else: - start_name = starts[0] - end_name = ends[0] - axes_name = axes[0] - steps_name = steps[0] + start_value = starts[0] + end_value = ends[0] + axes_value = axes[0] + steps_value = steps[0] if squeezed_axes: sliced_name = self.generate_unique_name(f"{var_name}_sliced") - self.emit( + sliced_value = self.emit( [sliced_name], "Slice", - [var_name, start_name, end_name, axes_name, steps_name], + [var, start_value, end_value, axes_value, steps_value], ) squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info) if non_scalar_indices: # use temporary to store result of squeeze - result = self.generate_unique_name(f"{var_name}_squeezed") + result_name = self.generate_unique_name(f"{var_name}_squeezed") else: # store squeezed result in final target - result = target + result_name = target - self.emit([result], "Squeeze", [sliced_name, squeezed_axes]) + result = self.emit([result_name], "Squeeze", [sliced_value, squeezed_axes]) else: if non_scalar_indices: # use temporary to store result of Slice - result = self.generate_unique_name(f"{var_name}_sliced") + result_name = self.generate_unique_name(f"{var_name}_sliced") else: # store result of Slice in final target - result = target - slice_inputs = [var_name, start_name, end_name, axes_name, steps_name] - self.emit([result], "Slice", slice_inputs) + result_name = target + slice_inputs = [var, start_value, end_value, axes_value, steps_value] + result = self.emit1([result_name], "Slice", slice_inputs) else: - result = var_name + result = var non_scalar_indices.extend(scalar_indices) if non_scalar_indices: last_axis, _ = non_scalar_indices[-1] @@ -812,12 +841,13 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: gathered = self.generate_unique_name(f"{var_name}_axis_{axis}") else: # store result of Gather in final target gathered = target - self.emit([gathered], "Gather", [str(result), index_value], [axis_attr]) - result = gathered + result = self.emit1([gathered], "Gather", [result, index_value], [axis_attr]) - return Variable(result) + return result - def _translate_call_expr(self, node: ast.Call): + def _translate_call_expr( + self, node: ast.Call + ) -> tuple[values.Op, list[Optional[ir.Value]], list[irbuilder.IRAttributeValue]]: """Translates a call-expression.""" callee = self._translate_callee_expr(node.func) param_schemas = callee.param_schemas() @@ -844,7 +874,7 @@ def _translate_call_expr(self, node: ast.Call): attrs = [attr for attr in attrs if attr is not None] return callee, args, attrs - def _cast_like_binary_expression(self, op, left, right): + def _cast_like_binary_expression(self, op, left, right) -> tuple[ir.Value, ir.Value]: schema = op.op_schema return autocast.static_cast_inputs(self, schema, (left, right)) @@ -911,13 +941,13 @@ def _translate_compare_expr(self, node): left, right = self._cast_like_binary_expression(op, left, right) if opname == "NotEqual": tmp = self.generate_unique_name() - self.emit([tmp], op, [left, right]) + tmp_value = self.emit1([tmp], op, [left, right]) not_op = values.Op(self.default_opset, "Not") - return not_op, [tmp], [] + return not_op, [tmp_value], [] return op, [left, right], [] - def _translate_name_expr(self, node: ast.Name) -> Variable: + def _translate_name_expr(self, node: ast.Name) -> ir.Value: return self._py_var_to_onnx_var(node.id, self._source_of(node)) # pylint: disable=inconsistent-return-statements @@ -946,10 +976,7 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable if isinstance(node, ast.Name): function_name = node.id found = self._lookup(function_name, self._source_of(node), raise_exception=False) - if isinstance(found, onnxscript.OnnxFunction): - self._current_fn.add_called_function(found) - return found - if isinstance(found, values.Op): + if isinstance(found, (values.Op, onnxscript.OnnxFunction)): return found if not found: if function_name not in self.default_opset: @@ -996,11 +1023,13 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: # Assignments of the form "x = SomeExpression" info = self._source_of(lhs) lhs = lhs.id - t = self._translate_expr(rhs, lhs).name + t = self._translate_expr(rhs, lhs) if isinstance(stmt, ast.AnnAssign): - typeinfo = self._eval_constant_expr(stmt.annotation) + typeinfo = self._get_type_annotation(stmt.annotation) else: typeinfo = None + if typeinfo is not None: + set_type_info(t, typeinfo) var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo) self._bind(lhs, var) elif isinstance(lhs, ast.Tuple): @@ -1015,17 +1044,19 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: def generate_onnx_name(x: ast.AST): if not isinstance(x, ast.Name): self.fail(x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'") - onnx_name = self.generate_unique_name(x.id) + return self.generate_unique_name(x.id) + + output_names = [generate_onnx_name(x) for x in lhs.elts] + outputs = self.emit(output_names, callee, inputs, attrs) + if isinstance(outputs, ir.Value): + outputs = [outputs] + for x, output in zip(lhs.elts, outputs): self._bind( x.id, values.Dynamic( - onnx_name, values.DynamicKind.Intermediate, self._source_of(x) + output, values.DynamicKind.Intermediate, self._source_of(x) ), ) - return onnx_name - - outputs = [generate_onnx_name(x) for x in lhs.elts] - self.emit(outputs, callee, inputs, attrs) else: self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'") @@ -1069,14 +1100,14 @@ def check_num_outputs(n): def ret(exp, i, suffix): preferred_name = f"return_val{suffix}" - return_var = self._translate_expr(exp, preferred_name).name - val = self._lookup(return_var, self._source_of(exp), False) + return_var = self._translate_expr(exp, preferred_name) # TODO(rama) + val = self._lookup(return_var.name, self._source_of(exp), False) if val and val.kind == values.DynamicKind.Input: # In ONNX, a graph-input cannot be an output of the graph. # We need to insert a copy. return_var = self._emit_copy(return_var, preferred_name) for prev_output in self._current_fn.outputs: - if prev_output.name == return_var: + if prev_output.name == return_var.name: # ONNX does not allow duplicate output names. return_var = self._emit_copy(return_var, f"{return_var}_copy") break @@ -1084,7 +1115,9 @@ def ret(exp, i, suffix): t = None else: t = self.returntype[i] - self.ir_builder.add_output(self._current_fn, return_var, t, self._source_of(stmt)) + self._current_fn.outputs.append( + make_value(return_var.name, t, self._source_of(stmt)) + ) return return_var val = stmt.value @@ -1114,42 +1147,40 @@ def _translate_if_stmt(self, stmt: ast.If) -> None: # due to some existing usage. live_def_set = live_out.intersection(live_def_set) live_defs = list(live_def_set) - test = self._translate_expr(stmt.test, "cond").name + test = self._translate_expr(stmt.test, "cond") lineno = self._source_of(stmt).lineno - thenGraph, sub_fct_then = self._translate_block( + thenGraph = self._translate_block( stmt.body, f"thenGraph_{lineno}", live_defs, parent_stmt=stmt ) thenAttr = self._make_onnx_attr("then_branch", thenGraph) - elseGraph, sub_fct_else = self._translate_block( + elseGraph = self._translate_block( stmt.orelse, f"elseGraph_{lineno}", live_defs, parent_stmt=stmt ) elseAttr = self._make_onnx_attr("else_branch", elseGraph) def rename(x): - r = self.generate_unique_name(x) - self._bind( - x, - values.Dynamic(r, values.DynamicKind.Intermediate, self._source_of(stmt)), - ) - return r + return self.generate_unique_name(x) # no break condition renamed = [rename(x) for x in live_defs] if not renamed: self.fail(stmt, "A subgraph for a test do not have any output variable.") - sub_functions = {} - sub_functions.update(sub_fct_then) - sub_functions.update(sub_fct_else) - if renamed == [test]: + if renamed == [test.name]: self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") - self.emit( + if_outputs = self.emit( renamed, values.Op(self.default_opset, "If"), [test], [thenAttr, elseAttr], - sub_functions=sub_functions, ) + if isinstance(if_outputs, ir.Value): + if_outputs = [if_outputs] + for x, y in zip(live_defs, if_outputs): + self._bind( + x, + values.Dynamic(y, values.DynamicKind.Intermediate, self._source_of(stmt)), + ) def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # loop-variable @@ -1169,11 +1200,11 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: if not iter.args or len(iter.args) != 1: self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.") assert not iter.keywords, "Unsupported loop bound." - o_loop_bound = self._translate_expr(iter.args[0], "loop_bound").name - o_cond_var = self.generate_unique_name("cond_in") + o_loop_bound = self._translate_expr(iter.args[0], "loop_bound") + o_cond_var = ir.Value(name=self.generate_unique_name("cond_in")) # TODO(Rama) i_cond_var = o_cond_var cond_while = None - o_loop_condition = "" # No condition for a for loop. + o_loop_condition = None # No condition for a for loop. elif isinstance(loop_stmt, ast.While): test = loop_stmt.test if not isinstance(test, ast.Name): @@ -1183,9 +1214,9 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: "it should be 'while :'.", ) p_loop_var = "infinite_loop" - o_loop_bound = "" - i_cond_var = test.id - cond_while = test.id + o_loop_bound = None + i_cond_var = ir.Value(name=test.id) # TODO(Rama) + cond_while = ir.Value(name=test.id) # TODO(Rama) o_cond_var = None o_loop_condition = self._translate_name_expr(test) # we need to go through all the instructions to see @@ -1207,22 +1238,26 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # build loop_body self._enter_scope("loop_body", loop_stmt) o_loop_var = self.generate_unique_name(p_loop_var) - self.ir_builder.add_input( - self._current_fn, - o_loop_var, - onnx_types.INT64, - self._source_of(loop_stmt), + self._current_fn.append_parameter( + make_value( + o_loop_var, + onnx_types.INT64, + self._source_of(loop_stmt), + ) ) self._bind( p_loop_var, - values.Dynamic(o_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)), + values.Dynamic( + ir.Value(name=o_loop_var), values.DynamicKind.Loop, self._source_of(loop_stmt) + ), ) - self.ir_builder.add_input( - self._current_fn, - i_cond_var, - onnx_types.BOOL, - self._source_of(loop_stmt), + self._current_fn.append_parameter( + make_value( + i_cond_var.name, + onnx_types.BOOL, + self._source_of(loop_stmt), + ) ) for pv in loop_state_vars: @@ -1230,15 +1265,17 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # TODO: retrieve the annotation for variable pv is any is specified. # typeinfo = self._eval_constant_expr(pv.annotation) typeinfo = None - self.ir_builder.add_input( - self._current_fn, ov, typeinfo, self._source_of(loop_stmt) + self._current_fn.append_parameter( + make_value(ov, typeinfo, self._source_of(loop_stmt)) ) self._bind( pv, - values.Dynamic(ov, values.DynamicKind.Loop, self._source_of(loop_stmt)), + values.Dynamic( + ir.Value(name=ov), values.DynamicKind.Loop, self._source_of(loop_stmt) + ), ) - condition_name = None + condition_name: ir.Value | None = None operator_name = "Identity" for i, s in enumerate(loop_stmt.body): # We first need to intercept a break instruction in test block. @@ -1272,13 +1309,13 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: if cond_while is not None: # Loop while current_scope = self._current_scope() - if cond_while not in current_scope: + if cond_while.name not in current_scope: self.fail( loop_stmt, - f"Unable to find condition variable {cond_while!r} in known " + f"Unable to find condition variable {cond_while.name} in known " f"variables {list(current_scope)!r}.", ) - o_cond_var = current_scope[cond_while].value + o_cond_var = current_scope[cond_while.name].value self.emit( [o_cond_out], @@ -1287,15 +1324,16 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: [], ) - self.ir_builder.add_output( - self._current_fn, - o_cond_out, - onnx_types.BOOL, - self._source_of(loop_stmt), + self._current_fn.outputs.append( + make_value( + o_cond_out, + onnx_types.BOOL, + self._source_of(loop_stmt), + ) ) for pv in loop_state_vars: - ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name - if ov not in self._current_fn.assigned_names: + ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) + if ov.name not in self._current_fn.assigned_names: # When converting the loop-body into a graph, we need to handle # identity assignments of the form "x = y" inside the loop body # specially if y represents a value computed outside the loop body. @@ -1304,31 +1342,31 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ov = self._emit_copy(ov, pv) # TODO: retrieve variable type for the annotation if any. typeinfo = None - self.ir_builder.add_output( - self._current_fn, ov, typeinfo, self._source_of(loop_stmt) + self._current_fn.outputs.append( + make_value(ov.name, typeinfo, self._source_of(loop_stmt)) ) body = self._exit_scope() inputs = [o_loop_bound, o_loop_condition] + [ - self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name - for pv in loop_state_vars + self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars ] - graph, sub_functions = body.to_graph_and_functions() - attrs = [self._make_onnx_attr("body", graph)] + attrs = [self._make_onnx_attr("body", body.graph)] info = self._source_of(loop_stmt) def rename(x): r = self.generate_unique_name(x) - self._bind(x, values.Dynamic(r, values.DynamicKind.Output, info)) return r - onnx_outputs = [rename(x) for x in outputs] - self.emit( - onnx_outputs, + onnx_output_names = [rename(x) for x in outputs] + loop_outputs = self.emit( + onnx_output_names, "Loop", inputs, attrs, - sub_functions=sub_functions, ) + if isinstance(loop_outputs, ir.Value): + loop_outputs = [loop_outputs] + for x, loop_output in zip(outputs, loop_outputs): + self._bind(x, values.Dynamic(loop_output, values.DynamicKind.Output, info)) def _translate_block( self, @@ -1346,16 +1384,17 @@ def _translate_block( for pvar in live_defs: if pvar in self._current_scope(): pv_val = self._current_scope()[pvar] - output = self._to_onnx_var(pv_val, pvar).name - if output not in self._current_fn.assigned_names: + output = self._to_onnx_var(pv_val, pvar) + if output.name not in self._current_fn.assigned_names: # To return an outer-scope variable, an ONNX Graph has to # use an explicit copy via Identity. output = self._emit_copy(output, pvar) - self.ir_builder.add_output( - self._current_fn, - output, - pv_val.typeinfo, - source, + self._current_fn.outputs.append( + make_value( + output.name, + pv_val.typeinfo, + source, + ) ) else: pv_val = None @@ -1366,17 +1405,17 @@ def _translate_block( if pv_val is None: self.fail( stmts[0], - f"Variable {pvar} is not assigned a value along a conditional " + f"ir.Value {pvar} is not assigned a value along a conditional " f"branch, known variables: {list(self._locals)}.", ) # introduce a copy - ovar = self._emit_copy(self._to_onnx_var(pv_val, pvar).name, pvar) + ovar = self._emit_copy(self._to_onnx_var(pv_val, pvar), pvar) # TODO: retrieve the annotation if any. typeinfo = None - self.ir_builder.add_output(self._current_fn, ovar, typeinfo, source) + self._current_fn.outputs.append(make_value(ovar.name, typeinfo, source)) graph = self._exit_scope() - return graph.to_graph_and_functions() + return graph.graph def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: """Translate a nested function definition.""" @@ -1407,32 +1446,25 @@ def _translate_function_signature_common( else: default_value = None if x.annotation: - typeinfo = self._eval_constant_expr(x.annotation) - if not ta.is_valid_type(typeinfo): - self.warn( - x.annotation, - f"Unsupported type annotation for argument {x.arg}.", - ) - typeinfo = None + typeinfo = self._get_type_annotation(x.annotation) else: # The code can only be exported as a function. typeinfo = None if typeinfo and ta.is_attr_type(typeinfo): - self.ir_builder.add_attr_parameter( - self._current_fn, - x.arg, - ta.pytype_to_attrtype(typeinfo), - default_value, - ) + attribute_type = ta.pytype_to_attrtype(typeinfo) + attr = ir.Attr(x.arg, ir.AttributeType(attribute_type), default_value, None) + self._current_fn.append_parameter(attr) self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x))) else: - self.ir_builder.add_input( - self._current_fn, x.arg, typeinfo, self._source_of(x) + self._current_fn.append_parameter( + make_value(x.arg, typeinfo, self._source_of(x)) ) self._used_vars.add(x.arg) self._bind( x.arg, - values.Dynamic(x.arg, values.DynamicKind.Input, self._source_of(x)), + values.Dynamic( + ir.Value(name=x.arg), values.DynamicKind.Input, self._source_of(x) + ), ) if fn.returns: type_annotation = self._eval_constant_expr(fn.returns) @@ -1468,10 +1500,9 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: if opset: self._set_default_opset(opset, stmt) domain = self.this_module.domain - self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) + self._current_fn = irbuilder.IRFunction(stmt.name, domain) self._analyzer = analysis.AstAnalyzer(stmt, self._message, self.globals) fn_ir = self._translate_function_def_common(stmt) - fn_ir.debug_print() self.this_module.add_function_def(fn_ir) self._analyzer = None return fn_ir @@ -1480,5 +1511,5 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: """Translate a (top-level) function signature.""" domain = self.this_module.domain - self._current_fn = self.ir_builder.new_function(fn.name, domain, True) + self._current_fn = irbuilder.IRFunction(fn.name, domain) return self._translate_function_signature_common(fn) diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index a35711aea9..b6cada54b0 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -235,7 +235,7 @@ def test_unary_op(self): def test_subfunction_check_model(self): from tests.models import subfunction - model = subfunction.MyElu.function_ir.to_model_proto(producer_name="p2o") + model = subfunction.MyElu.to_model_proto(producer_name="p2o") model = onnx.shape_inference.infer_shapes(model) onnx.checker.check_model(model) @@ -740,6 +740,37 @@ def model(x: FLOAT[10]) -> FLOAT[10]: model_false = make_model(False) onnxscript.testing.assert_isomorphic(model_false, sub_model.to_model_proto()) + def test_type_annotation(self): + """Test that type annotations are processed correctly.""" + + @script() + def model(x: FLOAT[10]) -> FLOAT[10]: + temp: FLOAT[10] = op.Add(x, x) + y = op.Mul(temp, temp) + return y + + model_proto = model.to_model_proto() + input_type = model_proto.graph.input[0].type.tensor_type + output_type = model_proto.graph.output[0].type.tensor_type + temp_value_info = None + for value_info in model_proto.graph.value_info: + if value_info.name == "temp": + temp_value_info = value_info + break + self.assertIsNotNone(temp_value_info, "ValueInfo for 'temp' not found in graph.") + temp_type = temp_value_info.type.tensor_type + self.assertEqual(temp_type.elem_type, onnx.TensorProto.FLOAT) + self.assertEqual(len(temp_type.shape.dim), 1) + self.assertEqual(temp_type.shape.dim[0].dim_value, 10) + + self.assertEqual(input_type.elem_type, onnx.TensorProto.FLOAT) + self.assertEqual(len(input_type.shape.dim), 1) + self.assertEqual(input_type.shape.dim[0].dim_value, 10) + + self.assertEqual(output_type.elem_type, onnx.TensorProto.FLOAT) + self.assertEqual(len(output_type.shape.dim), 1) + self.assertEqual(output_type.shape.dim[0].dim_value, 10) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index 1d87ee135e..327440243e 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -23,7 +23,7 @@ import onnx.reference from typing_extensions import TypeAlias -from onnxscript import irbuilder, onnx_opset, tensor, values +from onnxscript import onnx_opset, tensor, values from onnxscript._internal import autocast, param_manipulation, utils UserModeValue: TypeAlias = Union[Optional[np.ndarray], Sequence["UserModeValue"]] @@ -217,7 +217,7 @@ def adapt_attributes( if use_graph_attribute: adapted_attributes[k] = v.function_ir.to_graph_proto() for pyvar, onnxvar in v.function_ir.outer_scope_variables: - closure[onnxvar.value] = v.frame.f_locals[pyvar] + closure[onnxvar.value.name] = v.frame.f_locals[pyvar] else: adapted_attributes[k] = v.function elif callable(v): @@ -447,7 +447,7 @@ def make_tensor_name() -> str: model = onnx.helper.make_model( # noqa: TID251 graph, opset_imports=[opset_id], - ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain), + ir_version=values.select_ir_version(schema.since_version, domain=schema.domain), ) model = onnx.shape_inference.infer_shapes(model) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 4274bf2062..c67ecc4be5 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -1,564 +1,96 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# ruff: noqa: TID251 from __future__ import annotations -import dataclasses -import io import logging import warnings -from typing import Any, Optional, Protocol, Sequence, Union +from typing import Any, Sequence, Union import onnx -from onnx import ValueInfoProto, helper -from onnx.defs import onnx_opset_version +import onnx_ir as ir -import onnxscript -from onnxscript import type_annotation as ta +import onnxscript.type_annotation from onnxscript import values -from onnxscript._internal import version_utils -from onnxscript.onnx_types import ONNXType -from onnxscript.sourceinfo import SourceInfo - -# A simple IR (Function, Stmt, Attr, Var): logger = logging.getLogger("onnxscript") +TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue -def _format(seq: Sequence[Any], prefix: str, sep: str, suffix: str, formatter=str): - """Formats a sequence of objects into a string.""" - return prefix + sep.join([formatter(x) for x in seq]) + suffix - - -def select_ir_version(version: int, domain: str = "") -> int: - """Selects a suitable ONNX ir_version for a given opset version.""" - if domain == "": - domain = "ai.onnx" - if (domain, version) not in helper.OP_SET_ID_VERSION_MAP: - return max(v for k, v in helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx") - return helper.OP_SET_ID_VERSION_MAP[domain, version] - - -class IRType: - def __init__(self): - self.onnx_type = onnx.TypeProto() - - def to_type_proto(self): - return self.onnx_type - - def __repr__(self) -> str: - return "IRType()" - - -class IRTensorType(IRType): - def __init__(self, elem_type: onnx.TensorProto.DataType) -> None: - super().__init__() - self.onnx_type.tensor_type.elem_type = elem_type - - def __repr__(self) -> str: - return f"IRTensorType({self.onnx_type.tensor_type.elem_type})" - - -class IRTypeLike(Protocol): - def to_type_proto(self) -> onnx.TypeProto: - """Converts IR type representation to onnx.TypeProto""" - - -class IRVar: - """A variable (representing a formal parameter).""" - - def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -> None: - if not isinstance(varname, str): - raise TypeError(f"varname must be a string not {type(varname)!r}.") - self.name = varname - self.info = sourceinfo - self.typeinfo = typeinfo - - def __str__(self): - return self.name - - def __repr__(self): - return f"{self.__class__.__name__}({self.name!r}, {self.typeinfo!r})" - - def typed_str(self): - return f"{self.name} : {self.typeinfo}" - - def to_value_info(self, use_default_type: bool = True): - """Converts the content of this class into :class:`onnx.ValueInfoProto`. - - Args: - use_default_type: if True, use a default type if an explicit type - is not known. Otherwise, returns a ValueInfoProto without type. - - Returns: - an instance of :class:`onnx.ValueInfoProto` - """ - if self.name is None: - raise ValueError(self.info.msg("name cannot be None.")) - value_info_proto = ValueInfoProto() - value_info_proto.name = self.name - if self.typeinfo is not None: - value_info_proto.type.CopyFrom(self.typeinfo.to_type_proto()) - elif use_default_type: - value_info_proto.type.CopyFrom(IRType().to_type_proto()) - return value_info_proto - - -def _opt_var_to_str(x): - return "" if x is None else str(x) - - -class IRAttributeValue: - """An attribute value (representing an actual parameter). - - Attributes: - name: The name of the attribute. - type: The type of the attribute. - attr_proto: The attribute proto. - """ - - def __init__(self, attrproto: onnx.AttributeProto) -> None: - self.attr_proto = attrproto - - def __str__(self): - if self.attr_proto.HasField("ref_attr_name"): - return f"{self.attr_proto.name} = @{self.attr_proto.ref_attr_name}" - # self.name + " = " + self.value - return helper.printable_attribute(self.attr_proto) - - @property - def name(self) -> str: - return self.attr_proto.name - - @property - def type(self) -> onnx.AttributeProto.AttributeType: - return self.attr_proto.type - - -@dataclasses.dataclass(frozen=True) -class IRAttributeParameter: - """An attribute parameter (representing a formal parameter). - - It may or may not carry a default value. - - Attributes: - name: The name of the attribute. - type: The type of the attribute. - default_value: The default value of the attribute. - has_default: Whether the attribute has a default value. - attr_proto: The attribute proto. - """ - - name: str - type: onnx.AttributeProto.AttributeType - default_value: str | int | float | None = None - - # TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType. - - def __str__(self): - if self.has_default: - return helper.printable_attribute(self.attr_proto) - # TODO(justinchuby): Include a readable type name. - return self.name - - @property - def has_default(self): - return self.default_value is not None - - @property - def attr_proto(self) -> onnx.AttributeProto: - if not self.has_default: - raise ValueError( - "Attribute has no default value. Only attributes with default " - "values can be converted to AttributeProto." - ) - if version_utils.onnx_older_than("1.15"): - # TODO(after 1.14 is deprecated): Remove this branch. - # Argument 'attr_type' was added after version 1.14. - return helper.make_attribute(self.name, self.default_value) - # pylint: disable=unexpected-keyword-arg - return helper.make_attribute(self.name, self.default_value, attr_type=self.type) # type: ignore[call-arg] - # pylint: enable=unexpected-keyword-arg - - -class IRStmt: - def __init__( - self, - result: Sequence[str], - callee: values.Op, - args: Sequence[Optional[str]], - attrs: Sequence[IRAttributeValue], - sub_functions=None, - ) -> None: - if not isinstance(callee, values.Op): - raise TypeError(f"Unexpected type {type(callee)} for callee.") - self.result = result - self.callee = callee - self.args = args - self.attrs = attrs - self.functions = sub_functions or {} - - def __str__(self): - if isinstance(self.result, str): - logger.debug("unexpected str type for self.result where type(self)=%r", type(self)) - lhs = ", ".join(self.result) - attrs = "" - if self.attrs: - attrs = _format(self.attrs, "<", ", ", ">") - - args = _format(self.args, "(", ", ", ")", _opt_var_to_str) - domain = self.callee.opset.domain - opname = self.callee.name - callee = f"{domain}.{opname}" if (domain != "") else opname - return f"{lhs} = {callee} {attrs}{args}" - - def debug_print(self): - if logger.isEnabledFor(logging.DEBUG): - logger.debug("%s: %s", type(self), self) - def to_node_proto(self, node_name: str) -> onnx.NodeProto: - n = helper.make_node( - self.callee.name, - [_opt_var_to_str(x) for x in self.args], - [str(x) for x in self.result], - domain=self.callee.opset.domain, - name=node_name, - ) - for a in self.attrs: - n.attribute.append(a.attr_proto) - return n - - @property - def output_names(self) -> Sequence[str]: - """Returns the list of variables assigned to by this statement.""" - return [str(x) for x in self.result] - - -class IRFunction: +class IRFunction(ir.Function): """Represents a function in the IR.""" def __init__(self, name: str, domain: str = "") -> None: - self.domain = domain - self.name = name - self.outputs: list[IRVar] = [] - self.stmts: list[IRStmt] = [] - self.called_functions: dict[str, onnx.FunctionProto] = {} - self.docstring: str = "" + graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name) + super().__init__(domain, name, graph=graph, attributes=[]) + self.ordered_inputs_and_attrs: list[Union[ir.Value, ir.Attr]] = [] + # a dictionary of nested function-definitions self.nested_functions: dict[str, IRFunction] = {} self.outer_scope_variables: dict[Any, Any] = {} - self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = [] @property def assigned_names(self) -> Sequence[str]: """Returns the list of variables assigned to by this function.""" - return [v for stmt in self.stmts for v in stmt.output_names] - - @property - def inputs(self) -> Sequence[IRVar]: - return [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)] + return [v.name for n in self for v in n.outputs] @property - def attrs(self) -> Sequence[IRAttributeParameter]: - return [ - attr - for attr in self.ordered_inputs_and_attrs - if isinstance(attr, IRAttributeParameter) - ] - - def __str__(self): - attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else "" - inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")") - outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")") - stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n") - return f"{self.name} {attrs}{inputs} => {outputs}{stmts}" - - def append_docstring(self, docstring): - self.docstring += docstring - - def append_stmt(self, stmt: IRStmt) -> None: - self.stmts.append(stmt) - - def append_input(self, name: IRVar) -> None: - self.ordered_inputs_and_attrs.append(name) - - def append_output(self, name: IRVar) -> None: - self.outputs.append(name) - - def add_attr_parameter(self, attr: IRAttributeParameter) -> None: - self.ordered_inputs_and_attrs.append(attr) - - def debug_print(self): - if logger.isEnabledFor(logging.DEBUG): - st = io.StringIO() - for s in self.stmts: - for attr in s.attrs: - if attr.attr_proto.HasField("g"): - st.write(helper.printable_graph(attr.attr_proto.g)) - st.write("\n") - - def add_called_function(self, fun: values.OnnxFunction) -> None: - for name, fct in fun.function_ir.called_functions.items(): - if name in self.called_functions: - continue - self.called_functions[name] = fct - if fun.name in self.called_functions: - # Already added. - return - try: - proto = fun.to_function_proto() - except (TypeError, AttributeError) as e: - raise TypeError(f"Issue with type f{type(fun)}.") from e - self.called_functions[fun.name] = proto - - def add_nested_function(self, fun: IRFunction) -> None: - self.nested_functions[fun.name] = fun - - def to_model_proto( - self, - functions=None, - io_types: Optional[ONNXType] = None, - input_types: Optional[Sequence[ONNXType]] = None, - output_types: Optional[Sequence[ONNXType]] = None, - value_infos: dict[str, ONNXType] | None = None, - opset_version: int | None = None, - **kwargs, - ) -> onnx.ModelProto: - """Converts this instance into a `onnx.ModelProto`. - - Args: - functions: A list of functions to include in the model. - By default, all functions called at least once are included. - io_types: When specified, all the inputs/outputs of the model - are set to be of this type. - input_types: When specified, all the inputs of the model - are set to be of the corresponding type in this list. - output_types: When specified, all the outputs of the model - are set to be of the corresponding type in this list. - value_infos: A dictionary mapping intermediate variable names to ONNX types. - Used to set value_info for intermediate variables. - opset_version: The standard opset version to use for the model if it - cannot be inferred. Otherwise defaults to the current opset version. - kwargs: Additional parameters given to function :func:`onnx.helper.make_model`. - - Returns: - An instance of :class:`onnx.ModelProto`. - """ - value_infos = ( - [ - onnx.helper.make_value_info(name, type.to_type_proto()) - for name, type in value_infos.items() - ] - if value_infos - else None - ) - graph, sub_functions = self.to_graph_and_functions( - use_default_type=False, value_infos=value_infos - ) - if io_types is not None: - for input in graph.input: - if not input.HasField("type"): - input.type.CopyFrom(io_types.to_type_proto()) - for output in graph.output: - if not output.HasField("type"): - output.type.CopyFrom(io_types.to_type_proto()) - if input_types is not None: - for input, type in zip(graph.input, input_types): - input.type.CopyFrom(type.to_type_proto()) - if output_types is not None: - for output, type in zip(graph.output, output_types): - output.type.CopyFrom(type.to_type_proto()) - if functions is None: - functions = sub_functions.values() + def attrs(self) -> Sequence[ir.Attr]: + return [attr for attr in self.ordered_inputs_and_attrs if isinstance(attr, ir.Attr)] + + def append_node(self, node: ir.Node) -> None: + count = len(self) + node.name = f"n{count}" + self.append(node) + domain = node.domain + version = node.version + if domain not in self.opset_imports: + self.opset_imports[domain] = version else: - - def to_proto(f): - if isinstance(f, onnx.FunctionProto): - return f - if isinstance(f, onnxscript.OnnxFunction): - return f.to_function_proto() - raise TypeError("Expected a value of type FunctionProto of OnnxFunction") - - functions = [to_proto(f) for f in functions] - - opsets = {} - for n in self.stmts: - if n.callee.opset.domain not in opsets: - opsets[n.callee.opset.domain] = n.callee.opset.version - - for proto in functions: - if proto.domain not in opsets: - opsets[proto.domain] = 1 - # TODO(rama): Handle conflicts with appropriate error/warning message. - for opset in proto.opset_import: - if opset.domain not in opsets: - opsets[opset.domain] = opset.version - - if "" not in opsets: - # No operator is using the standard opset. - # Use the specified version if provided or the default value. - opsets[""] = opset_version if opset_version is not None else onnx_opset_version() - - if "ir_version" not in kwargs: - kwargs["ir_version"] = select_ir_version(opsets[""]) - opset_imports = [ - onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() - ] - - return helper.make_model( - graph, opset_imports=opset_imports, functions=functions, **kwargs - ) - - def to_graph_and_functions( - self, - use_default_type: bool = True, - value_infos: Sequence[ValueInfoProto] | None = None, - ) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]: - """Converts this instance into a `onnx.GraphProto` and a map from - function-name to `onnx.FunctionProto`. - - Args: - use_default_type: if True, the function uses a default type - for inputs and outputs that do not have a type - value_infos: a sequence of :class:`onnx.ValueInfoProto` to be added - to the graph. - - Returns: - a pair of a :class:`onnx.GraphProto` and list of :class:`onnx.FunctionProto` - """ - called_functions: dict[str, onnx.FunctionProto] = {} - for s in self.stmts: - called_functions.update(s.functions) - called_functions.update(self.called_functions) - graph = helper.make_graph( - [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)], - self.name, - [x.to_value_info(use_default_type) for x in self.inputs], - [y.to_value_info(use_default_type) for y in self.outputs], - value_info=value_infos, - ) - return graph, called_functions - - def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: - """Converts this instance into a `onnx.GraphProto`. - - Args: - use_default_type: if True, the function uses a default type - for inputs and outputs that do not have a type - - Returns: - an instance of :class:`onnx.GraphProto` - """ - graph, _ = self.to_graph_and_functions(use_default_type=use_default_type) - return graph - - def get_opset_import(self) -> dict[str, int]: - func_opset_imports = {} - for s in self.stmts: - if s.callee.opset.domain not in func_opset_imports: - func_opset_imports[s.callee.opset.domain] = s.callee.opset.version - elif func_opset_imports[s.callee.opset.domain] != s.callee.opset.version: + existing_version = self.opset_imports[domain] + if existing_version != version: warnings.warn( - f"There is a version conflict in domain: {s.callee.opset.domain!r}, " - f"with {self.name!r}.", + f"Version conflict: domain: {domain!r}, " + f"versions {existing_version} and {version} used.", category=UserWarning, - stacklevel=1, + stacklevel=2, ) - return func_opset_imports - def to_function_proto(self) -> onnx.FunctionProto: - """Converts this instance into a `onnx.FunctionProto`. - - Note: Default values for attributes are an experimental feature in ONNX. - Conversion ignores default values for attributes if the ONNX version installed - doesn't support it. - """ - opsets = self.get_opset_import() - nodes = [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)] - for n in nodes: - if n.domain not in opsets: - opsets[n.domain] = 1 # TODO: how to get n.version? - opset_imports = [ - onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() - ] - - attribute_names = [attr.name for attr in self.attrs if not attr.has_default] - - f = helper.make_function( - self.domain, - self.name, - inputs=[x.name for x in self.inputs], - outputs=[y.name for y in self.outputs], - nodes=nodes, - opset_imports=opset_imports, # TODO - attributes=attribute_names, - doc_string=self.docstring, - ) - # In protobuf 4.x fields aren't defined as class attribute so it should check instance attribute instead - if hasattr(f, "attribute_proto"): - f.attribute_proto.extend( - [attr.attr_proto for attr in self.attrs if attr.has_default] - ) - return f - - -# IRBuilder: abstracts out details of the IR in the python-to-IR converter - - -class IRBuilder: - def __init__(self): - self.functions = {} + def append_parameter(self, parameter: ir.Value | ir.Attr) -> None: + self.ordered_inputs_and_attrs.append(parameter) + if isinstance(parameter, ir.Value): + self.inputs.append(parameter) + else: + if not isinstance(parameter, ir.Attr): + raise TypeError(f"Expected ir.Value or ir.Attr, got {type(parameter)}") + self.attributes.add(parameter) - def new_function(self, name: str, domain: str = "", register: bool = False) -> IRFunction: - if register and (domain, name) in self.functions: - raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.") - function = IRFunction(name, domain) - if register: - self.functions[domain, name] = function - return function + def add_nested_function(self, fun: IRFunction) -> None: + self.nested_functions[fun.name] = fun - def add_docstring(self, fn: IRFunction, docstring: str): - fn.append_docstring(docstring) + def get_called_functions(self) -> dict[str, onnx.FunctionProto]: + called_functions: dict[str, values.OnnxFunction] = {} - def add_stmt( - self, - fn: IRFunction, - results: Sequence[str], - callee: values.Op, - args: Sequence[Optional[str]], - attrs: Sequence[IRAttributeValue], - sub_functions=None, - ) -> None: - stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions) - fn.append_stmt(stmt) + def visit(function_ir: IRFunction): + for node in ir.traversal.RecursiveGraphIterator(function_ir.graph): + callee = node.meta.get("callee", None) + if isinstance(callee, values.OnnxFunction): + add(callee) - def add_input( - self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo - ) -> None: - var = IRVar(varname, type, info) - fn.append_input(var) + def add(f: values.OnnxFunction): + if f.name in called_functions: + return + called_functions[f.name] = f + visit(f.function_ir) - def add_attr_parameter( - self, - fn: IRFunction, - varname: str, - attribute_type: onnx.AttributeProto.AttributeType, - default_value: int | float | str | None, - ) -> None: - fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value)) + visit(self) - def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: - var = IRVar(varname, typeinfo, sourceinfo) - fn.append_output(var) + return {name: f.to_function_proto() for name, f in called_functions.items()} - def make_attr(self, attrproto: onnx.AttributeProto) -> IRAttributeValue: - return IRAttributeValue(attrproto) + def to_graph_proto(self) -> onnx.GraphProto: + """Converts this instance into a `onnx.GraphProto`.""" + return ir.to_proto(self.graph) - def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue: - proto = onnx.AttributeProto() - proto.name = attrname - proto.ref_attr_name = refname - attr_type = ta.pytype_to_attrtype(pytype) - assert attr_type is not None - proto.type = attr_type - return IRAttributeValue(proto) + def to_function_proto(self) -> onnx.FunctionProto: + """Converts this instance into a `onnx.FunctionProto`.""" + return ir.to_proto(self) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py index f82702d557..43033b9b4f 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -66,7 +66,7 @@ def _run( @script() def _fused_matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: - C = 0.6 + C = op.Constant(value_float=0.6) ab = ms_op.FusedMatMul(A, B, alpha=0.4, transA=1) out = op.Div(ab, C) return out @@ -74,7 +74,7 @@ def _fused_matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: @script() def _matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: - C = 0.8 + C = op.Constant(value_float=0.8) ab = op.MatMul(A, B) out = op.Div(ab, C) return out @@ -82,7 +82,7 @@ def _matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: @script() def _matmul_div_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: - C = 0.6 + C = op.Constant(value_float=0.6) ab = op.MatMul(A, B) abd = op.Div(ab, C) out = op.Div(abd, C) diff --git a/onnxscript/values.py b/onnxscript/values.py index 1897ae14d5..4c33664226 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -1,5 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + +# ruff: noqa: TID251 + from __future__ import annotations import dataclasses @@ -23,17 +26,31 @@ import onnx import onnx.defs +import onnx_ir as ir from typing_extensions import ParamSpec from onnxscript import converter as converter_module from onnxscript import irbuilder, sourceinfo, type_annotation from onnxscript._internal import ast_utils, deprecation from onnxscript.ir import _schemas +from onnxscript.onnx_types import ONNXType _R = TypeVar("_R") _P = ParamSpec("_P") +def select_ir_version(version: int, domain: str = "") -> int: + """Selects a suitable ONNX ir_version for a given opset version.""" + if domain == "": + domain = "ai.onnx" + if (domain, version) not in onnx.helper.OP_SET_ID_VERSION_MAP: + return max( + v for k, v in onnx.helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx" + ) + required_min_version = onnx.helper.OP_SET_ID_VERSION_MAP[domain, version] + return max(required_min_version, 10) + + _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, onnx.defs.OpSchema.AttrType.INT: int, @@ -176,7 +193,7 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any: """Get the default value of an ONNX attribute.""" if attr_proto.type == onnx.AttributeProto.UNDEFINED: return _EmptyDefault - return onnx.helper.get_attribute_value(attr_proto) # noqa: TID251 + return onnx.helper.get_attribute_value(attr_proto) def _param_schemas_from_op_schema( @@ -208,23 +225,28 @@ def _param_schemas_from_op_schema( return tuple(schemas) -def _param_schema_from_function_ir_input(input: irbuilder.IRVar): - if type_annotation.is_optional(input.typeinfo): +def _typeinfo(var: ir.Value) -> Any: + return var.meta.get("typeinfo") + + +def _param_schema_from_function_ir_input(input: ir.Value): + typeinfo = _typeinfo(input) + if type_annotation.is_optional(typeinfo): required = False else: required = True - return ParamSchema(name=input.name, type=input.typeinfo, is_input=True, required=required) + return ParamSchema(name=input.name, type=typeinfo, is_input=True, required=required) -def _param_schema_from_function_ir_attr(attr: irbuilder.IRAttributeParameter): +def _param_schema_from_function_ir_attr(attr: ir.Attr): return ParamSchema( name=attr.name, type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get( onnx.defs.OpSchema.AttrType(attr.type) # type: ignore[call-arg] ), - default=_EmptyDefault if attr.default_value is None else attr.default_value, + default=_EmptyDefault if attr.value is None else attr.value, is_input=False, - required=not attr.has_default, + required=attr.value is None, ) @@ -240,10 +262,10 @@ def _param_schemas_from_function_ir( # ONNX OpSchema and FunctionProto does not support interleaving inputs and attributes. # This is by design. See more at https://github.com/microsoft/onnxscript/issues/771. for arg in function_ir.ordered_inputs_and_attrs: - if isinstance(arg, irbuilder.IRVar): + if isinstance(arg, ir.Value): # input schemas.append(_param_schema_from_function_ir_input(arg)) - elif isinstance(arg, irbuilder.IRAttributeParameter): + elif isinstance(arg, ir.Attr): # attr schemas.append(_param_schema_from_function_ir_attr(arg)) else: @@ -392,8 +414,8 @@ def _op_schema_from_function_ir( """Construct an ONNX OpSchema from an IRFunction.""" # Find all distinct types in the inputs and outputs - distinct_types = {arg.typeinfo for arg in function_ir.inputs}.union( - {arg.typeinfo for arg in function_ir.outputs} + distinct_types = {_typeinfo(arg) for arg in function_ir.inputs}.union( + {_typeinfo(arg) for arg in function_ir.outputs} ) # Create a mapping from type to a unique name type_to_constraint = {} @@ -407,10 +429,10 @@ def _op_schema_from_function_ir( formal_inputs = [ onnx.defs.OpSchema.FormalParameter( arg.name, - type_to_constraint[arg.typeinfo].name, + type_to_constraint[_typeinfo(arg)].name, param_option=( onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(arg.typeinfo) + if type_annotation.is_optional(_typeinfo(arg)) else onnx.defs.OpSchema.FormalParameterOption.Single ), # TODO(justinchu): Check this is_homogeneous thing @@ -421,10 +443,10 @@ def _op_schema_from_function_ir( formal_outputs = [ onnx.defs.OpSchema.FormalParameter( arg.name, - type_to_constraint[arg.typeinfo].name, + type_to_constraint[_typeinfo(arg)].name, param_option=( onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(arg.typeinfo) + if type_annotation.is_optional(_typeinfo(arg)) else onnx.defs.OpSchema.FormalParameterOption.Single ), # TODO(justinchu): Check this is_homogeneous thing @@ -436,7 +458,7 @@ def _op_schema_from_function_ir( function_ir.name, opset.domain, since_version=opset.version, - doc=function_ir.docstring, + doc=function_ir.doc_string or "", inputs=formal_inputs, outputs=formal_outputs, type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()], @@ -447,15 +469,15 @@ def _op_schema_from_function_ir( type=onnx.defs.OpSchema.AttrType(attr.type), # type: ignore[call-arg] ) for attr in function_ir.attrs - if not attr.has_default + if attr.value is None ], *[ onnx.defs.OpSchema.Attribute( attr.name, - default_value=attr.attr_proto, + default_value=ir.to_proto(attr), ) for attr in function_ir.attrs - if attr.has_default + if attr.value is not None ], ], ) @@ -591,7 +613,7 @@ def to_function_proto(self) -> onnx.FunctionProto: def to_model_proto(self, **kwargs): """Converts the function into :class:`onnx.ModelProto`.""" if self.function_ir.attrs and any( - not attr.has_default for attr in self.function_ir.attrs + attr.value is None for attr in self.function_ir.attrs ): raise ValueError( "A function with required attributes cannot be exported as a model." @@ -603,7 +625,111 @@ def to_model_proto(self, **kwargs): # Merge kwargs specified in script-decorator with those specified in this call. merged_kw_args = {**self.kwargs, **kwargs} - return self.function_ir.to_model_proto(**merged_kw_args) + return self._to_model_proto(**merged_kw_args) + + def _to_model_proto( + self, + functions=None, + io_types: Optional[ONNXType] = None, + input_types: Optional[Sequence[ONNXType]] = None, + output_types: Optional[Sequence[ONNXType]] = None, + value_infos: dict[str, ONNXType] | None = None, + opset_version: int | None = None, + **kwargs, + ) -> onnx.ModelProto: + """Converts this instance into a `onnx.ModelProto`. + + Args: + functions: A list of functions to include in the model. + By default, all functions called at least once are included. + io_types: When specified, all the inputs/outputs of the model + are set to be of this type. + input_types: When specified, all the inputs of the model + are set to be of the corresponding type in this list. + output_types: When specified, all the outputs of the model + are set to be of the corresponding type in this list. + value_infos: A dictionary mapping intermediate variable names to ONNX types. + Used to set value_info for intermediate variables. + opset_version: The standard opset version to use for the model if it + cannot be inferred. Otherwise defaults to the current opset version. + kwargs: Additional parameters given to function :func:`onnx.helper.make_model`. + + Returns: + An instance of :class:`onnx.ModelProto`. + """ + # Identify functions to include in the model + if functions is None: + sub_functions = self.function_ir.get_called_functions() + functions = sub_functions.values() + else: + + def to_proto(f): + if isinstance(f, onnx.FunctionProto): + return f + if isinstance(f, OnnxFunction): + return f.to_function_proto() + raise TypeError("Expected a value of type FunctionProto of OnnxFunction") + + functions = [to_proto(f) for f in functions] + + # Determine opset imports + opsets = self.function_ir.graph.opset_imports + + for proto in functions: + if proto.domain not in opsets: + opsets[proto.domain] = 1 + # TODO(rama): Handle conflicts with appropriate error/warning message. + for opset in proto.opset_import: + if opset.domain not in opsets: + opsets[opset.domain] = opset.version + + if "" not in opsets: + # No operator is using the standard opset. + # Use the specified version if provided or the default value. + opsets[""] = ( + opset_version if opset_version is not None else onnx.defs.onnx_opset_version() + ) + + # Determine ir_version + if "ir_version" in kwargs: + ir_version = kwargs.pop("ir_version") + else: + ir_version = select_ir_version(opsets[""]) + + # Create the model + model = ir.Model(self.function_ir.graph, ir_version=ir_version) + model_proto = ir.to_proto(model) + model_proto.functions.extend(functions) + + # Set additional type information if provided + graph = model_proto.graph + + if value_infos: + graph.value_info.extend( + [ + onnx.helper.make_value_info(name, type.to_type_proto()) + for name, type in value_infos.items() + ] + ) + + if io_types is not None: + for input in graph.input: + if not input.HasField("type"): + input.type.CopyFrom(io_types.to_type_proto()) + for output in graph.output: + if not output.HasField("type"): + output.type.CopyFrom(io_types.to_type_proto()) + if input_types is not None: + for input, type in zip(graph.input, input_types): + input.type.CopyFrom(type.to_type_proto()) + if output_types is not None: + for output, type in zip(graph.output, output_types): + output.type.CopyFrom(type.to_type_proto()) + + for k, v in kwargs.items(): + setattr(model_proto, k, v) + + return model_proto class TracedOnnxFunction(Op): @@ -758,7 +884,7 @@ class DynamicKind(IntFlag): class Dynamic(SymbolValue): def __init__( - self, onnx_var: str, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None + self, onnx_var: ir.Value, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None ) -> None: """Initializes Dynamic. @@ -770,6 +896,8 @@ def __init__( """ super().__init__(info) assert isinstance(kind, DynamicKind) + if not isinstance(onnx_var, ir.Value): + raise TypeError(f"onnx_var must be of type ir.Value not {type(onnx_var)!r}.") self.value = onnx_var self.kind = kind self.typeinfo = typeinfo diff --git a/tests/common/onnx_script_test_case.py b/tests/common/onnx_script_test_case.py index 3a46a870a0..e209579749 100644 --- a/tests/common/onnx_script_test_case.py +++ b/tests/common/onnx_script_test_case.py @@ -144,9 +144,7 @@ def _create_model_from_param( # there is not way from the onnx test case's model and feed to get TypeProto # in order to build a model. # we have to resolve the TypeProto from script function. - local_function_model_proto = param.function.function_ir.to_model_proto( - ir_version=ir_version - ) + local_function_model_proto = param.function.to_model_proto(ir_version=ir_version) input_value_infos = [] for i, input in enumerate(local_function_model_proto.graph.input): vi = copy.deepcopy(input) @@ -202,7 +200,7 @@ def run_converter_test( param, onnx_case_model, ir_version=ir_version ) else: - model = param.function.function_ir.to_model_proto( + model = param.function.to_model_proto( producer_name="call_clip", ir_version=ir_version ) try: