From 26010dabf49f6fa96852009637b2330ff69b5054 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 7 Nov 2025 11:55:08 -0800 Subject: [PATCH 01/43] Removable castable field Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/autocast.py | 4 ++-- onnxscript/converter.py | 16 +++++++++++----- 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 1defac3e53..898c2be5f1 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -201,14 +201,14 @@ def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variabl argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is castable, while X can serve as the target-type. """ - return None if x is None or x.is_castable else x + return None if x is None or converter_.is_castable(x.name) else x def cast_like( x: Optional[converter.Variable], y: Optional[converter.Variable] ) -> Optional[str]: if x is None: return None - if x.is_castable and y is not None: + if converter_.is_castable(x.name) and y is not None: # Polymorphic constant x is cast to the type of y: x_cast = converter_.generate_unique_name(f"{x.name}_cast") converter_.emit([x_cast], "CastLike", [x.name, y.name]) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 3e87c366ad..9eb9a5d13d 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -87,7 +87,7 @@ class Variable: converter. """ - def __init__(self, name: str, castable: bool = False): + def __init__(self, name: str): """Initialize the instance. Args: @@ -98,7 +98,6 @@ def __init__(self, name: str, castable: bool = False): types as needed. """ self.name = name - self.is_castable = castable def __str__(self) -> str: return self.name @@ -184,6 +183,10 @@ def __init__( self._used_vars: set[str] = set() self._locals: List[Dict[str, LocalSymValue]] = [{}] self._analyzer: analysis.AstAnalyzer | None = None + self._castable: set[str] = set() + + def is_castable(self, var_name: str) -> bool: + return var_name in self._castable @property def analyzer(self) -> analysis.AstAnalyzer: @@ -358,8 +361,10 @@ def _to_onnx_var( [result], [cast_attr], ) - return Variable(result_as_bool, True) - return Variable(result, True) + self._castable.add(result_as_bool) + return Variable(result_as_bool) + self._castable.add(result) + return Variable(result) if isinstance(val, values.Dynamic): return Variable(val.value) # Assume value is a python-value convertible to a tensor @@ -421,7 +426,8 @@ def _emit_const( fail(info.msg(str(e))) attr = self._make_onnx_attr("value", tensor) self.emit([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) - return Variable(ovar, True) + self._castable.add(ovar) + return Variable(ovar) def _emit_copy(self, original_var: str, suggested_name: str) -> str: """Emits a copy statement, using the ONNX Identity operator.""" From e4d8b8f8d37b1a74748603a4fc589fdf177f1ade Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 7 Nov 2025 13:18:58 -0800 Subject: [PATCH 02/43] Prepare to replace Variable by ir.Value Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 60 ++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 27 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 9eb9a5d13d..ee69de25e5 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -83,8 +83,7 @@ def ignore(cond, msg): class Variable: """Represents an ONNX variable. - TODO(rama): Consider merging this with IRVar. However, "castable" is specific to this - converter. + TODO(rama): Consider merging this with IRVar. """ def __init__(self, name: str): @@ -186,6 +185,7 @@ def __init__( self._castable: set[str] = set() def is_castable(self, var_name: str) -> bool: + """Returns True if the variable with the given name represents a polymorphic constant.""" return var_name in self._castable @property @@ -346,25 +346,26 @@ def _to_onnx_var( ) -> Variable: if isinstance(val, values.AttrRef): # promote attribute to value - result = self.generate_unique_name(target or "tmp") + result_name = self.generate_unique_name(target or "tmp") attr = self._to_onnx_attr_ref(val, info) - self.emit([result], values.Op(self.default_opset, "Constant"), [], [attr]) + result = self.emit( + [result_name], values.Op(self.default_opset, "Constant"), [], [attr] + ) if ta.base_type_is_bool(val.typeinfo): # ONNX attributes use an int-encoding for bools, but ONNX tensor types # distinguish between int and bool. So we cast the int tensor to a bool tensor, # to promote a (python) bool attribute to a ONNX bool tensor. - result_as_bool = self.generate_unique_name(result + "_as_bool") + result_as_bool = self.generate_unique_name(result_name + "_as_bool") cast_attr = self._make_onnx_attr("to", onnx_types.BOOL.dtype) - self.emit( + self._castable.add(result_as_bool) + return self.emit1( [result_as_bool], values.Op(self.default_opset, "Cast"), [result], [cast_attr], ) - self._castable.add(result_as_bool) - return Variable(result_as_bool) - self._castable.add(result) - return Variable(result) + self._castable.add(result_name) + return result if isinstance(val, values.Dynamic): return Variable(val.value) # Assume value is a python-value convertible to a tensor @@ -382,7 +383,7 @@ def emit( inputs: Sequence[Optional[str]], attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, sub_functions: Optional[dict[str, onnx.FunctionProto]] = None, - ): + ) -> Sequence[Variable] | Variable: if not isinstance(callee, values.Op): callee = values.Op(self.default_opset, callee) if attrs is None: @@ -397,6 +398,14 @@ def emit( attrs, sub_functions, ) + if len(outputs) == 1: + return Variable(outputs[0]) + return [Variable(o) for o in outputs] + + def emit1(self, *args, **kwargs) -> Variable: + r = self.emit(*args, **kwargs) + assert isinstance(r, Variable) + return r def _emit_const( self, @@ -425,9 +434,8 @@ def _emit_const( except ValueError as e: fail(info.msg(str(e))) attr = self._make_onnx_attr("value", tensor) - self.emit([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) self._castable.add(ovar) - return Variable(ovar) + return self.emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) def _emit_copy(self, original_var: str, suggested_name: str) -> str: """Emits a copy statement, using the ONNX Identity operator.""" @@ -579,8 +587,7 @@ def _translate_expr( target = "tmp" if target is None else target assert isinstance(target, str) result = self.generate_unique_name(target) - self.emit([result], callee, args, attrs) - return Variable(result) + return self.emit1([result], callee, args, attrs) def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: """Translation of an expression where "None" is permitted (eg., for an optional argument). @@ -726,8 +733,8 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: non_scalar_indices.append((axis, elt)) if not (sliced_indices or scalar_indices or non_scalar_indices): # Edge case: no index specified. Eg. A[:, :] - self.emit([target], "Identity", [var_name]) - return Variable(target) + return self.emit1([target], "Identity", [var_name]) + if sliced_indices or len(scalar_indices) > 1: # We emit a Slice operation if we have any indices like 1:5:2 or if the number of # scalar indices (like 2) is more than 1. @@ -789,20 +796,20 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info) if non_scalar_indices: # use temporary to store result of squeeze - result = self.generate_unique_name(f"{var_name}_squeezed") + result_name = self.generate_unique_name(f"{var_name}_squeezed") else: # store squeezed result in final target - result = target + result_name = target - self.emit([result], "Squeeze", [sliced_name, squeezed_axes]) + result = self.emit([result_name], "Squeeze", [sliced_name, squeezed_axes]) else: if non_scalar_indices: # use temporary to store result of Slice - result = self.generate_unique_name(f"{var_name}_sliced") + result_name = self.generate_unique_name(f"{var_name}_sliced") else: # store result of Slice in final target - result = target + result_name = target slice_inputs = [var_name, start_name, end_name, axes_name, steps_name] - self.emit([result], "Slice", slice_inputs) + result = self.emit1([result_name], "Slice", slice_inputs) else: - result = var_name + result = var non_scalar_indices.extend(scalar_indices) if non_scalar_indices: last_axis, _ = non_scalar_indices[-1] @@ -818,10 +825,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: gathered = self.generate_unique_name(f"{var_name}_axis_{axis}") else: # store result of Gather in final target gathered = target - self.emit([gathered], "Gather", [str(result), index_value], [axis_attr]) - result = gathered + result = self.emit1([gathered], "Gather", [result.name, index_value], [axis_attr]) - return Variable(result) + return result def _translate_call_expr(self, node: ast.Call): """Translates a call-expression.""" From c8708efb09c555e39cd7a46533cc23eb8f4baa34 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 7 Nov 2025 17:36:54 -0800 Subject: [PATCH 03/43] Create ir.Values, first pass Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/autocast.py | 5 +- onnxscript/converter.py | 102 +++++++++++++------------------ 2 files changed, 45 insertions(+), 62 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 898c2be5f1..911ecbb024 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -211,8 +211,7 @@ def cast_like( if converter_.is_castable(x.name) and y is not None: # Polymorphic constant x is cast to the type of y: x_cast = converter_.generate_unique_name(f"{x.name}_cast") - converter_.emit([x_cast], "CastLike", [x.name, y.name]) - return x_cast - return x.name + return converter_.emit1([x_cast], "CastLike", [x, y]) + return x return cast_inputs(get_type_info, cast_like, op_schema, args) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index ee69de25e5..42c626167d 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -17,6 +17,7 @@ ) import onnx +import onnx_ir as ir import onnxscript from onnxscript import irbuilder, onnx_types, sourceinfo, values @@ -80,26 +81,7 @@ def ignore(cond, msg): } -class Variable: - """Represents an ONNX variable. - - TODO(rama): Consider merging this with IRVar. - """ - - def __init__(self, name: str): - """Initialize the instance. - - Args: - name: Name of the ONNX variable - castable: Whether this variable is castable to a desired target type. - Used for ONNX variables representing constants created from python values - like 0 or 1 or 0.5 which are treated as polymorphic values castable to other - types as needed. - """ - self.name = name - - def __str__(self) -> str: - return self.name +Variable = ir.Value if TYPE_CHECKING: @@ -380,10 +362,11 @@ def emit( self, outputs: Sequence[str], callee: values.Op | str, - inputs: Sequence[Optional[str]], + inputs: Sequence[Optional[Variable]], attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, sub_functions: Optional[dict[str, onnx.FunctionProto]] = None, ) -> Sequence[Variable] | Variable: + assert all(isinstance(i, ir.Value) for i in inputs if i is not None) if not isinstance(callee, values.Op): callee = values.Op(self.default_opset, callee) if attrs is None: @@ -394,13 +377,13 @@ def emit( self._current_fn, outputs, callee, - inputs, + [x.name for x in inputs], attrs, sub_functions, ) if len(outputs) == 1: - return Variable(outputs[0]) - return [Variable(o) for o in outputs] + return ir.Value(name=outputs[0]) + return [ir.Value(name=o) for o in outputs] def emit1(self, *args, **kwargs) -> Variable: r = self.emit(*args, **kwargs) @@ -437,7 +420,7 @@ def _emit_const( self._castable.add(ovar) return self.emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) - def _emit_copy(self, original_var: str, suggested_name: str) -> str: + def _emit_copy(self, original_var: Variable, suggested_name: str) -> str: """Emits a copy statement, using the ONNX Identity operator.""" new_var = self.generate_unique_name(suggested_name) self.emit([new_var], "Identity", [original_var]) @@ -646,15 +629,15 @@ def _translate_subscript_expr( # Create cached int constants: # TODO: Do this at a graph-scope level. - cached_int_consts = {} + cached_int_consts: dict[int, Variable] = {} - def const_1d(value, name: Optional[str] = None): + def const_1d(value, name: Optional[str] = None) -> Variable: nonlocal cached_int_consts if value not in cached_int_consts: cached_int_consts[value] = self._emit_const([value], name, info) return cached_int_consts[value] - def one_1d(): + def one_1d() -> Variable: return const_1d(1) # Max/min 64-bit int values are used to represent default values for start/stop in Slice. @@ -663,7 +646,7 @@ def one_1d(): def translate_slice_component( node_arg, default_value: Optional[int] = None - ) -> tuple[str, Optional[int]]: + ) -> tuple[Variable, Optional[int]]: """Translate optional start/stop/step component of a Slice expression.""" if node_arg is None: if default_value is None: @@ -682,15 +665,15 @@ def translate_slice_component( else: name = self._translate_expr(node_arg).name reshaped = self.generate_unique_name(f"{name}_reshaped") - self.emit( + reshaped_value = self.emit1( [reshaped], values.Op(self.default_opset, "Reshape"), [name, one_1d().name], [], ) - return reshaped, None + return reshaped_value, None - def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: + def translate_slice(slice_expr: ast.Slice) -> tuple[Variable, Variable, Variable]: """Translate slice-expression of the form from:to:step.""" step_name, step = translate_slice_component(slice_expr.step, 1) if step is None: @@ -764,7 +747,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: inputs = translate_slice(element) starts.append(inputs[0]) ends.append(inputs[1]) - axes.append(axis_var.name) + axes.append(axis_var) steps.append(inputs[2]) if len(starts) > 1: @@ -788,10 +771,10 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: if squeezed_axes: sliced_name = self.generate_unique_name(f"{var_name}_sliced") - self.emit( + sliced_value = self.emit( [sliced_name], "Slice", - [var_name, start_name, end_name, axes_name, steps_name], + [var, start_name, end_name, axes_name, steps_name], ) squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info) @@ -800,7 +783,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: else: # store squeezed result in final target result_name = target - result = self.emit([result_name], "Squeeze", [sliced_name, squeezed_axes]) + result = self.emit([result_name], "Squeeze", [sliced_value, squeezed_axes]) else: if non_scalar_indices: # use temporary to store result of Slice result_name = self.generate_unique_name(f"{var_name}_sliced") @@ -829,7 +812,9 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[str, str, str]: return result - def _translate_call_expr(self, node: ast.Call): + def _translate_call_expr( + self, node: ast.Call + ) -> tuple[values.Op, list[Optional[Variable]], list[irbuilder.IRAttributeValue]]: """Translates a call-expression.""" callee = self._translate_callee_expr(node.func) param_schemas = callee.param_schemas() @@ -856,7 +841,7 @@ def _translate_call_expr(self, node: ast.Call): attrs = [attr for attr in attrs if attr is not None] return callee, args, attrs - def _cast_like_binary_expression(self, op, left, right): + def _cast_like_binary_expression(self, op, left, right) -> tuple[Variable, Variable]: schema = op.op_schema return autocast.static_cast_inputs(self, schema, (left, right)) @@ -1081,14 +1066,14 @@ def check_num_outputs(n): def ret(exp, i, suffix): preferred_name = f"return_val{suffix}" - return_var = self._translate_expr(exp, preferred_name).name - val = self._lookup(return_var, self._source_of(exp), False) + return_var = self._translate_expr(exp, preferred_name) # TODO(rama) + val = self._lookup(return_var.name, self._source_of(exp), False) if val and val.kind == values.DynamicKind.Input: # In ONNX, a graph-input cannot be an output of the graph. # We need to insert a copy. return_var = self._emit_copy(return_var, preferred_name) for prev_output in self._current_fn.outputs: - if prev_output.name == return_var: + if prev_output.name == return_var.name: # ONNX does not allow duplicate output names. return_var = self._emit_copy(return_var, f"{return_var}_copy") break @@ -1126,7 +1111,7 @@ def _translate_if_stmt(self, stmt: ast.If) -> None: # due to some existing usage. live_def_set = live_out.intersection(live_def_set) live_defs = list(live_def_set) - test = self._translate_expr(stmt.test, "cond").name + test = self._translate_expr(stmt.test, "cond") lineno = self._source_of(stmt).lineno thenGraph, sub_fct_then = self._translate_block( stmt.body, f"thenGraph_{lineno}", live_defs, parent_stmt=stmt @@ -1153,7 +1138,7 @@ def rename(x): sub_functions = {} sub_functions.update(sub_fct_then) sub_functions.update(sub_fct_else) - if renamed == [test]: + if renamed == [test.name]: self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") self.emit( renamed, @@ -1181,11 +1166,11 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: if not iter.args or len(iter.args) != 1: self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.") assert not iter.keywords, "Unsupported loop bound." - o_loop_bound = self._translate_expr(iter.args[0], "loop_bound").name - o_cond_var = self.generate_unique_name("cond_in") + o_loop_bound = self._translate_expr(iter.args[0], "loop_bound") + o_cond_var = ir.Value(self.generate_unique_name("cond_in")) # TODO(Rama) i_cond_var = o_cond_var cond_while = None - o_loop_condition = "" # No condition for a for loop. + o_loop_condition = None # No condition for a for loop. elif isinstance(loop_stmt, ast.While): test = loop_stmt.test if not isinstance(test, ast.Name): @@ -1195,9 +1180,9 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: "it should be 'while :'.", ) p_loop_var = "infinite_loop" - o_loop_bound = "" - i_cond_var = test.id - cond_while = test.id + o_loop_bound = None + i_cond_var = ir.Value(test.id) # TODO(Rama) + cond_while = ir.Value(test.id) # TODO(Rama) o_cond_var = None o_loop_condition = self._translate_name_expr(test) # we need to go through all the instructions to see @@ -1250,7 +1235,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: values.Dynamic(ov, values.DynamicKind.Loop, self._source_of(loop_stmt)), ) - condition_name = None + condition_name: Variable | None = None operator_name = "Identity" for i, s in enumerate(loop_stmt.body): # We first need to intercept a break instruction in test block. @@ -1306,8 +1291,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self._source_of(loop_stmt), ) for pv in loop_state_vars: - ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name - if ov not in self._current_fn.assigned_names: + ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) + if ov.name not in self._current_fn.assigned_names: # When converting the loop-body into a graph, we need to handle # identity assignments of the form "x = y" inside the loop body # specially if y represents a value computed outside the loop body. @@ -1317,12 +1302,11 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # TODO: retrieve variable type for the annotation if any. typeinfo = None self.ir_builder.add_output( - self._current_fn, ov, typeinfo, self._source_of(loop_stmt) + self._current_fn, ov.name, typeinfo, self._source_of(loop_stmt) ) body = self._exit_scope() inputs = [o_loop_bound, o_loop_condition] + [ - self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)).name - for pv in loop_state_vars + self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars ] graph, sub_functions = body.to_graph_and_functions() attrs = [self._make_onnx_attr("body", graph)] @@ -1358,14 +1342,14 @@ def _translate_block( for pvar in live_defs: if pvar in self._current_scope(): pv_val = self._current_scope()[pvar] - output = self._to_onnx_var(pv_val, pvar).name - if output not in self._current_fn.assigned_names: + output = self._to_onnx_var(pv_val, pvar) + if output.name not in self._current_fn.assigned_names: # To return an outer-scope variable, an ONNX Graph has to # use an explicit copy via Identity. output = self._emit_copy(output, pvar) self.ir_builder.add_output( self._current_fn, - output, + output.name, pv_val.typeinfo, source, ) @@ -1382,7 +1366,7 @@ def _translate_block( f"branch, known variables: {list(self._locals)}.", ) # introduce a copy - ovar = self._emit_copy(self._to_onnx_var(pv_val, pvar).name, pvar) + ovar = self._emit_copy(self._to_onnx_var(pv_val, pvar), pvar) # TODO: retrieve the annotation if any. typeinfo = None From 2738841a93c5f6979924e6badc6f3c415bf0bfed Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 7 Nov 2025 17:47:53 -0800 Subject: [PATCH 04/43] A minor fix --- onnxscript/converter.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 42c626167d..97b442b0db 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -377,7 +377,7 @@ def emit( self._current_fn, outputs, callee, - [x.name for x in inputs], + [(x.name if x is not None else None) for x in inputs], attrs, sub_functions, ) @@ -808,7 +808,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[Variable, Variable, Variable gathered = self.generate_unique_name(f"{var_name}_axis_{axis}") else: # store result of Gather in final target gathered = target - result = self.emit1([gathered], "Gather", [result.name, index_value], [axis_attr]) + result = self.emit1([gathered], "Gather", [result, index_value], [axis_attr]) return result @@ -1081,7 +1081,7 @@ def ret(exp, i, suffix): t = None else: t = self.returntype[i] - self.ir_builder.add_output(self._current_fn, return_var, t, self._source_of(stmt)) + self.ir_builder.add_output(self._current_fn, return_var.name, t, self._source_of(stmt)) return return_var val = stmt.value From 0e8efef99fadb691bfbc565629daf62be2ef23f1 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 7 Nov 2025 22:45:27 -0800 Subject: [PATCH 05/43] Various bug fixes Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 95 ++++++++++++++++++++++------------------- onnxscript/evaluator.py | 2 +- onnxscript/values.py | 5 ++- 3 files changed, 57 insertions(+), 45 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 97b442b0db..a9731bb422 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -349,7 +349,7 @@ def _to_onnx_var( self._castable.add(result_name) return result if isinstance(val, values.Dynamic): - return Variable(val.value) + return val.value # Assume value is a python-value convertible to a tensor # TODO: check if value is convertible to a TensorProto, so that we can # produce a better error _message otherwise @@ -366,7 +366,9 @@ def emit( attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, sub_functions: Optional[dict[str, onnx.FunctionProto]] = None, ) -> Sequence[Variable] | Variable: - assert all(isinstance(i, ir.Value) for i in inputs if i is not None) + for i, x in enumerate(inputs): + if (x is not None) and not isinstance(x, ir.Value): + raise TypeError(f"Expected ONNX IR Value for input {i}, got {type(x)!r}.") if not isinstance(callee, values.Op): callee = values.Op(self.default_opset, callee) if attrs is None: @@ -420,11 +422,10 @@ def _emit_const( self._castable.add(ovar) return self.emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) - def _emit_copy(self, original_var: Variable, suggested_name: str) -> str: + def _emit_copy(self, original_var: Variable, suggested_name: str) -> Variable: """Emits a copy statement, using the ONNX Identity operator.""" new_var = self.generate_unique_name(suggested_name) - self.emit([new_var], "Identity", [original_var]) - return new_var + return self.emit([new_var], "Identity", [original_var]) def _is_constant_expr(self, node: ast.AST) -> None: if isinstance(node, ast.UnaryOp): @@ -663,12 +664,12 @@ def translate_slice_component( else: raise RuntimeError(f"Slice component type must be int, not {type(cst)}") else: - name = self._translate_expr(node_arg).name - reshaped = self.generate_unique_name(f"{name}_reshaped") + value = self._translate_expr(node_arg) + reshaped = self.generate_unique_name(f"{value.name}_reshaped") reshaped_value = self.emit1( [reshaped], values.Op(self.default_opset, "Reshape"), - [name, one_1d().name], + [value, one_1d()], [], ) return reshaped_value, None @@ -753,28 +754,28 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[Variable, Variable, Variable if len(starts) > 1: axis_0_attr = self._make_onnx_attr("axis", 0) start_name = self.generate_unique_name(f"{var_name}_start") - self.emit([start_name], "Concat", starts, [axis_0_attr]) + start_value = self.emit([start_name], "Concat", starts, [axis_0_attr]) end_name = self.generate_unique_name(f"{var_name}_end") - self.emit([end_name], "Concat", ends, [axis_0_attr]) + end_value = self.emit([end_name], "Concat", ends, [axis_0_attr]) axes_name = self.generate_unique_name(f"{var_name}_axis") - self.emit([axes_name], "Concat", axes, [axis_0_attr]) + axes_value = self.emit([axes_name], "Concat", axes, [axis_0_attr]) steps_name = self.generate_unique_name(f"{var_name}_step") - self.emit([steps_name], "Concat", steps, [axis_0_attr]) + steps_value = self.emit([steps_name], "Concat", steps, [axis_0_attr]) else: - start_name = starts[0] - end_name = ends[0] - axes_name = axes[0] - steps_name = steps[0] + start_value = starts[0] + end_value = ends[0] + axes_value = axes[0] + steps_value = steps[0] if squeezed_axes: sliced_name = self.generate_unique_name(f"{var_name}_sliced") sliced_value = self.emit( [sliced_name], "Slice", - [var, start_name, end_name, axes_name, steps_name], + [var, start_value, end_value, axes_value, steps_value], ) squeezed_axes = self._emit_const(squeezed_axes, "squeezed_axes", info) @@ -789,7 +790,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[Variable, Variable, Variable result_name = self.generate_unique_name(f"{var_name}_sliced") else: # store result of Slice in final target result_name = target - slice_inputs = [var_name, start_name, end_name, axes_name, steps_name] + slice_inputs = [var, start_value, end_value, axes_value, steps_value] result = self.emit1([result_name], "Slice", slice_inputs) else: result = var @@ -993,7 +994,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: # Assignments of the form "x = SomeExpression" info = self._source_of(lhs) lhs = lhs.id - t = self._translate_expr(rhs, lhs).name + t = self._translate_expr(rhs, lhs) if isinstance(stmt, ast.AnnAssign): typeinfo = self._eval_constant_expr(stmt.annotation) else: @@ -1012,17 +1013,19 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: def generate_onnx_name(x: ast.AST): if not isinstance(x, ast.Name): self.fail(x, f"LHS must be a Name for unpacking, found: '{type(x)!r}'") - onnx_name = self.generate_unique_name(x.id) + return self.generate_unique_name(x.id) + + output_names = [generate_onnx_name(x) for x in lhs.elts] + outputs = self.emit(output_names, callee, inputs, attrs) + if isinstance(outputs, ir.Value): + outputs = [outputs] + for x, output in zip(lhs.elts, outputs): self._bind( x.id, values.Dynamic( - onnx_name, values.DynamicKind.Intermediate, self._source_of(x) + output, values.DynamicKind.Intermediate, self._source_of(x) ), ) - return onnx_name - - outputs = [generate_onnx_name(x) for x in lhs.elts] - self.emit(outputs, callee, inputs, attrs) else: self.fail(lhs, f"Unsupported construct in LHS of assignment: '{type(lhs)!r}'") @@ -1123,12 +1126,7 @@ def _translate_if_stmt(self, stmt: ast.If) -> None: elseAttr = self._make_onnx_attr("else_branch", elseGraph) def rename(x): - r = self.generate_unique_name(x) - self._bind( - x, - values.Dynamic(r, values.DynamicKind.Intermediate, self._source_of(stmt)), - ) - return r + return self.generate_unique_name(x) # no break condition renamed = [rename(x) for x in live_defs] @@ -1140,13 +1138,21 @@ def rename(x): sub_functions.update(sub_fct_else) if renamed == [test.name]: self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") - self.emit( + if_outputs = self.emit( renamed, values.Op(self.default_opset, "If"), [test], [thenAttr, elseAttr], sub_functions=sub_functions, ) + if isinstance(if_outputs, ir.Value): + if_outputs = [if_outputs] + for x, y in zip(live_defs, if_outputs): + self._bind( + x, + values.Dynamic(y, values.DynamicKind.Intermediate, self._source_of(stmt)), + ) + def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # loop-variable @@ -1167,7 +1173,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: self.fail(loop_stmt, "Unsupported loop bound, it should be 'range(?)'.") assert not iter.keywords, "Unsupported loop bound." o_loop_bound = self._translate_expr(iter.args[0], "loop_bound") - o_cond_var = ir.Value(self.generate_unique_name("cond_in")) # TODO(Rama) + o_cond_var = ir.Value(name=self.generate_unique_name("cond_in")) # TODO(Rama) i_cond_var = o_cond_var cond_while = None o_loop_condition = None # No condition for a for loop. @@ -1181,8 +1187,8 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ) p_loop_var = "infinite_loop" o_loop_bound = None - i_cond_var = ir.Value(test.id) # TODO(Rama) - cond_while = ir.Value(test.id) # TODO(Rama) + i_cond_var = ir.Value(name=test.id) # TODO(Rama) + cond_while = ir.Value(name=test.id) # TODO(Rama) o_cond_var = None o_loop_condition = self._translate_name_expr(test) # we need to go through all the instructions to see @@ -1212,12 +1218,12 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ) self._bind( p_loop_var, - values.Dynamic(o_loop_var, values.DynamicKind.Loop, self._source_of(loop_stmt)), + values.Dynamic(ir.Value(name=o_loop_var), values.DynamicKind.Loop, self._source_of(loop_stmt)), ) self.ir_builder.add_input( self._current_fn, - i_cond_var, + i_cond_var.name, onnx_types.BOOL, self._source_of(loop_stmt), ) @@ -1232,7 +1238,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ) self._bind( pv, - values.Dynamic(ov, values.DynamicKind.Loop, self._source_of(loop_stmt)), + values.Dynamic(ir.Value(name=ov), values.DynamicKind.Loop, self._source_of(loop_stmt)), ) condition_name: Variable | None = None @@ -1314,17 +1320,20 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: def rename(x): r = self.generate_unique_name(x) - self._bind(x, values.Dynamic(r, values.DynamicKind.Output, info)) return r - onnx_outputs = [rename(x) for x in outputs] - self.emit( - onnx_outputs, + onnx_output_names = [rename(x) for x in outputs] + loop_outputs = self.emit( + onnx_output_names, "Loop", inputs, attrs, sub_functions=sub_functions, ) + if isinstance(loop_outputs, ir.Value): + loop_outputs = [loop_outputs] + for x, loop_output in zip(outputs, loop_outputs): + self._bind(x, values.Dynamic(loop_output, values.DynamicKind.Output, info)) def _translate_block( self, @@ -1428,7 +1437,7 @@ def _translate_function_signature_common( self._used_vars.add(x.arg) self._bind( x.arg, - values.Dynamic(x.arg, values.DynamicKind.Input, self._source_of(x)), + values.Dynamic(ir.Value(name=x.arg), values.DynamicKind.Input, self._source_of(x)), ) if fn.returns: type_annotation = self._eval_constant_expr(fn.returns) diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index 1d87ee135e..a644108a78 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -217,7 +217,7 @@ def adapt_attributes( if use_graph_attribute: adapted_attributes[k] = v.function_ir.to_graph_proto() for pyvar, onnxvar in v.function_ir.outer_scope_variables: - closure[onnxvar.value] = v.frame.f_locals[pyvar] + closure[onnxvar.value.name] = v.frame.f_locals[pyvar] else: adapted_attributes[k] = v.function elif callable(v): diff --git a/onnxscript/values.py b/onnxscript/values.py index 1897ae14d5..deeca21e58 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -23,6 +23,7 @@ import onnx import onnx.defs +import onnx_ir as ir from typing_extensions import ParamSpec from onnxscript import converter as converter_module @@ -758,7 +759,7 @@ class DynamicKind(IntFlag): class Dynamic(SymbolValue): def __init__( - self, onnx_var: str, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None + self, onnx_var: ir.Value, kind: DynamicKind, info: sourceinfo.SourceInfo, typeinfo=None ) -> None: """Initializes Dynamic. @@ -770,6 +771,8 @@ def __init__( """ super().__init__(info) assert isinstance(kind, DynamicKind) + if not isinstance(onnx_var, ir.Value): + raise TypeError(f"onnx_var must be of type ir.Value not {type(onnx_var)!r}.") self.value = onnx_var self.kind = kind self.typeinfo = typeinfo From a577f1a94c9cfa81e720ce9d08922c81f3e8c92f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 8 Nov 2025 06:27:05 -0800 Subject: [PATCH 06/43] Final cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 31 +++++++++++++++++++------------ 1 file changed, 19 insertions(+), 12 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index a9731bb422..a17ed3317c 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -368,7 +368,7 @@ def emit( ) -> Sequence[Variable] | Variable: for i, x in enumerate(inputs): if (x is not None) and not isinstance(x, ir.Value): - raise TypeError(f"Expected ONNX IR Value for input {i}, got {type(x)!r}.") + raise TypeError(f"Expected ONNX IR Value for input {i}, got {type(x)!r}.") if not isinstance(callee, values.Op): callee = values.Op(self.default_opset, callee) if attrs is None: @@ -909,9 +909,9 @@ def _translate_compare_expr(self, node): left, right = self._cast_like_binary_expression(op, left, right) if opname == "NotEqual": tmp = self.generate_unique_name() - self.emit([tmp], op, [left, right]) + tmp_value = self.emit1([tmp], op, [left, right]) not_op = values.Op(self.default_opset, "Not") - return not_op, [tmp], [] + return not_op, [tmp_value], [] return op, [left, right], [] @@ -1084,7 +1084,9 @@ def ret(exp, i, suffix): t = None else: t = self.returntype[i] - self.ir_builder.add_output(self._current_fn, return_var.name, t, self._source_of(stmt)) + self.ir_builder.add_output( + self._current_fn, return_var.name, t, self._source_of(stmt) + ) return return_var val = stmt.value @@ -1152,7 +1154,6 @@ def rename(x): x, values.Dynamic(y, values.DynamicKind.Intermediate, self._source_of(stmt)), ) - def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # loop-variable @@ -1218,7 +1219,9 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ) self._bind( p_loop_var, - values.Dynamic(ir.Value(name=o_loop_var), values.DynamicKind.Loop, self._source_of(loop_stmt)), + values.Dynamic( + ir.Value(name=o_loop_var), values.DynamicKind.Loop, self._source_of(loop_stmt) + ), ) self.ir_builder.add_input( @@ -1238,7 +1241,9 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ) self._bind( pv, - values.Dynamic(ir.Value(name=ov), values.DynamicKind.Loop, self._source_of(loop_stmt)), + values.Dynamic( + ir.Value(name=ov), values.DynamicKind.Loop, self._source_of(loop_stmt) + ), ) condition_name: Variable | None = None @@ -1275,13 +1280,13 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: if cond_while is not None: # Loop while current_scope = self._current_scope() - if cond_while not in current_scope: + if cond_while.name not in current_scope: self.fail( loop_stmt, - f"Unable to find condition variable {cond_while!r} in known " + f"Unable to find condition variable {cond_while.name} in known " f"variables {list(current_scope)!r}.", ) - o_cond_var = current_scope[cond_while].value + o_cond_var = current_scope[cond_while.name].value self.emit( [o_cond_out], @@ -1379,7 +1384,7 @@ def _translate_block( # TODO: retrieve the annotation if any. typeinfo = None - self.ir_builder.add_output(self._current_fn, ovar, typeinfo, source) + self.ir_builder.add_output(self._current_fn, ovar.name, typeinfo, source) graph = self._exit_scope() return graph.to_graph_and_functions() @@ -1437,7 +1442,9 @@ def _translate_function_signature_common( self._used_vars.add(x.arg) self._bind( x.arg, - values.Dynamic(ir.Value(name=x.arg), values.DynamicKind.Input, self._source_of(x)), + values.Dynamic( + ir.Value(name=x.arg), values.DynamicKind.Input, self._source_of(x) + ), ) if fn.returns: type_annotation = self._eval_constant_expr(fn.returns) From dbda8e4270dc1a2a7cddf499c62b446e2f5170c2 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 8 Nov 2025 08:09:53 -0800 Subject: [PATCH 07/43] Fix matmul fusion testcase Signed-off-by: Ganesan Ramalingam --- .../rewriter/ort_fusions/fused_matmul_rule_sets_test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py index f82702d557..43033b9b4f 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -66,7 +66,7 @@ def _run( @script() def _fused_matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: - C = 0.6 + C = op.Constant(value_float=0.6) ab = ms_op.FusedMatMul(A, B, alpha=0.4, transA=1) out = op.Div(ab, C) return out @@ -74,7 +74,7 @@ def _fused_matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: @script() def _matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: - C = 0.8 + C = op.Constant(value_float=0.8) ab = op.MatMul(A, B) out = op.Div(ab, C) return out @@ -82,7 +82,7 @@ def _matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: @script() def _matmul_div_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: - C = 0.6 + C = op.Constant(value_float=0.6) ab = op.MatMul(A, B) abd = op.Div(ab, C) out = op.Div(abd, C) From 85937ed6f04f3fd278589a5357d877879aa43366 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 8 Nov 2025 14:25:32 -0800 Subject: [PATCH 08/43] First updates to builder Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 11 +++++------ onnxscript/irbuilder.py | 10 +++++++--- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index a17ed3317c..491931f759 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -375,21 +375,20 @@ def emit( attrs = [] if sub_functions is None: sub_functions = {} - self.ir_builder.add_stmt( + output_values = self.ir_builder.add_stmt( self._current_fn, outputs, callee, - [(x.name if x is not None else None) for x in inputs], + inputs, attrs, sub_functions, ) - if len(outputs) == 1: - return ir.Value(name=outputs[0]) - return [ir.Value(name=o) for o in outputs] + return output_values if len(output_values) > 1 else output_values[0] def emit1(self, *args, **kwargs) -> Variable: r = self.emit(*args, **kwargs) - assert isinstance(r, Variable) + if not isinstance(r, ir.Value): + raise TypeError(f"Expected single ONNX IR Value, got {type(r)!r}.") return r def _emit_const( diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 4274bf2062..f109f8ee45 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -10,6 +10,7 @@ from typing import Any, Optional, Protocol, Sequence, Union import onnx +import onnx_ir as ir from onnx import ValueInfoProto, helper from onnx.defs import onnx_opset_version @@ -525,12 +526,15 @@ def add_stmt( fn: IRFunction, results: Sequence[str], callee: values.Op, - args: Sequence[Optional[str]], + inputs: Sequence[Optional[ir.Value]], attrs: Sequence[IRAttributeValue], sub_functions=None, - ) -> None: - stmt = IRStmt(results, callee, args, attrs, sub_functions=sub_functions) + ) -> Sequence[ir.Value]: + input_names = [(x.name if x is not None else None) for x in inputs] + stmt = IRStmt(results, callee, input_names, attrs, sub_functions=sub_functions) fn.append_stmt(stmt) + output_values = [ir.Value(name=o) for o in results] + return output_values def add_input( self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo From 9534b951c7b721437bccac3f185e73b0f084acec Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 10 Nov 2025 20:30:33 -0800 Subject: [PATCH 09/43] Create nodes --- onnxscript/irbuilder.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index f109f8ee45..65ccf56cc5 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -185,6 +185,7 @@ def attr_proto(self) -> onnx.AttributeProto: class IRStmt: def __init__( self, + node: ir.Node, result: Sequence[str], callee: values.Op, args: Sequence[Optional[str]], @@ -193,6 +194,7 @@ def __init__( ) -> None: if not isinstance(callee, values.Op): raise TypeError(f"Unexpected type {type(callee)} for callee.") + self.node = node self.result = result self.callee = callee self.args = args @@ -530,8 +532,18 @@ def add_stmt( attrs: Sequence[IRAttributeValue], sub_functions=None, ) -> Sequence[ir.Value]: + output_values = [ir.Value(name=o) for o in results] input_names = [(x.name if x is not None else None) for x in inputs] - stmt = IRStmt(results, callee, input_names, attrs, sub_functions=sub_functions) + attributes = [ir.from_proto(a.attr_proto) for a in attrs] + node = ir.Node( + domain=callee.opset.domain, + version=callee.opset.version, + op_type=callee.name, + inputs = inputs, + outputs=output_values, + attributes=attributes, + ) + stmt = IRStmt(node, results, callee, input_names, attrs, sub_functions=sub_functions) fn.append_stmt(stmt) output_values = [ir.Value(name=o) for o in results] return output_values From 957be40dd35dbdb50dbb340660da345edde3a4d7 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 14 Nov 2025 16:12:32 -0800 Subject: [PATCH 10/43] Adding ir Graph step 1 Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 4 ++-- onnxscript/irbuilder.py | 9 +++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 491931f759..4e5c45ee95 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1318,7 +1318,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars ] - graph, sub_functions = body.to_graph_and_functions() + graph, sub_functions = body._to_graph_and_functions() attrs = [self._make_onnx_attr("body", graph)] info = self._source_of(loop_stmt) @@ -1385,7 +1385,7 @@ def _translate_block( typeinfo = None self.ir_builder.add_output(self._current_fn, ovar.name, typeinfo, source) graph = self._exit_scope() - return graph.to_graph_and_functions() + return graph._to_graph_and_functions() def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: """Translate a nested function definition.""" diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 65ccf56cc5..60ba8f09ac 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -241,6 +241,7 @@ class IRFunction: """Represents a function in the IR.""" def __init__(self, name: str, domain: str = "") -> None: + self.ir_graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name) self.domain = domain self.name = name self.outputs: list[IRVar] = [] @@ -355,7 +356,7 @@ def to_model_proto( if value_infos else None ) - graph, sub_functions = self.to_graph_and_functions( + graph, sub_functions = self._to_graph_and_functions( use_default_type=False, value_infos=value_infos ) if io_types is not None: @@ -412,7 +413,7 @@ def to_proto(f): graph, opset_imports=opset_imports, functions=functions, **kwargs ) - def to_graph_and_functions( + def _to_graph_and_functions( self, use_default_type: bool = True, value_infos: Sequence[ValueInfoProto] | None = None, @@ -452,7 +453,7 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: Returns: an instance of :class:`onnx.GraphProto` """ - graph, _ = self.to_graph_and_functions(use_default_type=use_default_type) + graph, _ = self._to_graph_and_functions(use_default_type=use_default_type) return graph def get_opset_import(self) -> dict[str, int]: @@ -539,7 +540,7 @@ def add_stmt( domain=callee.opset.domain, version=callee.opset.version, op_type=callee.name, - inputs = inputs, + inputs=inputs, outputs=output_values, attributes=attributes, ) From 09a0c165c3c6a5574cbc94b3d54811f5de16a72b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 14 Nov 2025 16:28:41 -0800 Subject: [PATCH 11/43] IRFunction cleanup docstring Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 60ba8f09ac..9253172c81 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -247,12 +247,16 @@ def __init__(self, name: str, domain: str = "") -> None: self.outputs: list[IRVar] = [] self.stmts: list[IRStmt] = [] self.called_functions: dict[str, onnx.FunctionProto] = {} - self.docstring: str = "" # a dictionary of nested function-definitions self.nested_functions: dict[str, IRFunction] = {} self.outer_scope_variables: dict[Any, Any] = {} self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = [] + @property + def docstring(self) -> str: + """Returns the docstring of this function.""" + return self.ir_graph.doc_string or "" + @property def assigned_names(self) -> Sequence[str]: """Returns the list of variables assigned to by this function.""" @@ -277,9 +281,6 @@ def __str__(self): stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n") return f"{self.name} {attrs}{inputs} => {outputs}{stmts}" - def append_docstring(self, docstring): - self.docstring += docstring - def append_stmt(self, stmt: IRStmt) -> None: self.stmts.append(stmt) @@ -522,7 +523,7 @@ def new_function(self, name: str, domain: str = "", register: bool = False) -> I return function def add_docstring(self, fn: IRFunction, docstring: str): - fn.append_docstring(docstring) + fn.ir_graph.doc_string = docstring def add_stmt( self, From 1d73d52b269f5b8515dedbe2c22d5f8d9537bdf7 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 14 Nov 2025 17:18:30 -0800 Subject: [PATCH 12/43] Further partial cleanup of IR Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 9253172c81..11aad3beb0 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -186,25 +186,23 @@ class IRStmt: def __init__( self, node: ir.Node, - result: Sequence[str], callee: values.Op, - args: Sequence[Optional[str]], attrs: Sequence[IRAttributeValue], sub_functions=None, ) -> None: if not isinstance(callee, values.Op): raise TypeError(f"Unexpected type {type(callee)} for callee.") self.node = node - self.result = result self.callee = callee - self.args = args self.attrs = attrs self.functions = sub_functions or {} + @property + def args(self) -> Sequence[Optional[str]]: + return [x.name if x is not None else None for x in self.node.inputs] + def __str__(self): - if isinstance(self.result, str): - logger.debug("unexpected str type for self.result where type(self)=%r", type(self)) - lhs = ", ".join(self.result) + lhs = ", ".join(self.output_names) attrs = "" if self.attrs: attrs = _format(self.attrs, "<", ", ", ">") @@ -223,7 +221,7 @@ def to_node_proto(self, node_name: str) -> onnx.NodeProto: n = helper.make_node( self.callee.name, [_opt_var_to_str(x) for x in self.args], - [str(x) for x in self.result], + self.output_names, domain=self.callee.opset.domain, name=node_name, ) @@ -234,7 +232,8 @@ def to_node_proto(self, node_name: str) -> onnx.NodeProto: @property def output_names(self) -> Sequence[str]: """Returns the list of variables assigned to by this statement.""" - return [str(x) for x in self.result] + return [x.name for x in self.node.outputs] + class IRFunction: @@ -243,7 +242,6 @@ class IRFunction: def __init__(self, name: str, domain: str = "") -> None: self.ir_graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name) self.domain = domain - self.name = name self.outputs: list[IRVar] = [] self.stmts: list[IRStmt] = [] self.called_functions: dict[str, onnx.FunctionProto] = {} @@ -257,6 +255,11 @@ def docstring(self) -> str: """Returns the docstring of this function.""" return self.ir_graph.doc_string or "" + @property + def name(self) -> str: + """Returns the name of this function.""" + return self.ir_graph.name + @property def assigned_names(self) -> Sequence[str]: """Returns the list of variables assigned to by this function.""" @@ -535,7 +538,6 @@ def add_stmt( sub_functions=None, ) -> Sequence[ir.Value]: output_values = [ir.Value(name=o) for o in results] - input_names = [(x.name if x is not None else None) for x in inputs] attributes = [ir.from_proto(a.attr_proto) for a in attrs] node = ir.Node( domain=callee.opset.domain, @@ -545,7 +547,7 @@ def add_stmt( outputs=output_values, attributes=attributes, ) - stmt = IRStmt(node, results, callee, input_names, attrs, sub_functions=sub_functions) + stmt = IRStmt(node, callee, attrs, sub_functions=sub_functions) fn.append_stmt(stmt) output_values = [ir.Value(name=o) for o in results] return output_values From ddce3b3302561b7b00168e520c5c05f9a813a3de Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 14 Nov 2025 17:46:41 -0800 Subject: [PATCH 13/43] Ir builder cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 11aad3beb0..81ec482060 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -119,6 +119,8 @@ class IRAttributeValue: """ def __init__(self, attrproto: onnx.AttributeProto) -> None: + if not isinstance(attrproto, onnx.AttributeProto): + raise TypeError(f"Expected onnx.AttributeProto not {type(attrproto)!r}.") self.attr_proto = attrproto def __str__(self): @@ -187,20 +189,22 @@ def __init__( self, node: ir.Node, callee: values.Op, - attrs: Sequence[IRAttributeValue], sub_functions=None, ) -> None: if not isinstance(callee, values.Op): raise TypeError(f"Unexpected type {type(callee)} for callee.") self.node = node self.callee = callee - self.attrs = attrs self.functions = sub_functions or {} @property def args(self) -> Sequence[Optional[str]]: return [x.name if x is not None else None for x in self.node.inputs] - + + @property + def attrs(self) -> Sequence[IRAttributeValue]: + return [IRAttributeValue(ir.to_proto(a)) for a in self.node.attributes.values()] + def __str__(self): lhs = ", ".join(self.output_names) attrs = "" @@ -233,7 +237,6 @@ def to_node_proto(self, node_name: str) -> onnx.NodeProto: def output_names(self) -> Sequence[str]: """Returns the list of variables assigned to by this statement.""" return [x.name for x in self.node.outputs] - class IRFunction: @@ -259,7 +262,7 @@ def docstring(self) -> str: def name(self) -> str: """Returns the name of this function.""" return self.ir_graph.name - + @property def assigned_names(self) -> Sequence[str]: """Returns the list of variables assigned to by this function.""" @@ -547,7 +550,7 @@ def add_stmt( outputs=output_values, attributes=attributes, ) - stmt = IRStmt(node, callee, attrs, sub_functions=sub_functions) + stmt = IRStmt(node, callee, sub_functions=sub_functions) fn.append_stmt(stmt) output_values = [ir.Value(name=o) for o in results] return output_values From 7b8cd9b28d954cb2a5216d6d9096bc8a926395ea Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 14 Nov 2025 19:44:22 -0800 Subject: [PATCH 14/43] IR cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 82 +++++++++++------------------------------ 1 file changed, 21 insertions(+), 61 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 81ec482060..df967fb8ca 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -74,6 +74,13 @@ def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) - self.name = varname self.info = sourceinfo self.typeinfo = typeinfo + if typeinfo is None: + self.value = ir.Value(name=varname) + else: + type_and_shape = ir.from_proto(typeinfo.to_type_proto()) + self.value = ir.Value( + name=varname, type=type_and_shape.type, shape=type_and_shape.shape + ) def __str__(self): return self.name @@ -109,33 +116,7 @@ def _opt_var_to_str(x): return "" if x is None else str(x) -class IRAttributeValue: - """An attribute value (representing an actual parameter). - - Attributes: - name: The name of the attribute. - type: The type of the attribute. - attr_proto: The attribute proto. - """ - - def __init__(self, attrproto: onnx.AttributeProto) -> None: - if not isinstance(attrproto, onnx.AttributeProto): - raise TypeError(f"Expected onnx.AttributeProto not {type(attrproto)!r}.") - self.attr_proto = attrproto - - def __str__(self): - if self.attr_proto.HasField("ref_attr_name"): - return f"{self.attr_proto.name} = @{self.attr_proto.ref_attr_name}" - # self.name + " = " + self.value - return helper.printable_attribute(self.attr_proto) - - @property - def name(self) -> str: - return self.attr_proto.name - - @property - def type(self) -> onnx.AttributeProto.AttributeType: - return self.attr_proto.type +IRAttributeValue = ir.Attr @dataclasses.dataclass(frozen=True) @@ -202,35 +183,19 @@ def args(self) -> Sequence[Optional[str]]: return [x.name if x is not None else None for x in self.node.inputs] @property - def attrs(self) -> Sequence[IRAttributeValue]: - return [IRAttributeValue(ir.to_proto(a)) for a in self.node.attributes.values()] + def attrs(self) -> Sequence[ir.Attr]: + return list(self.node.attributes.values()) def __str__(self): - lhs = ", ".join(self.output_names) - attrs = "" - if self.attrs: - attrs = _format(self.attrs, "<", ", ", ">") - - args = _format(self.args, "(", ", ", ")", _opt_var_to_str) - domain = self.callee.opset.domain - opname = self.callee.name - callee = f"{domain}.{opname}" if (domain != "") else opname - return f"{lhs} = {callee} {attrs}{args}" + return str(self.node) def debug_print(self): if logger.isEnabledFor(logging.DEBUG): logger.debug("%s: %s", type(self), self) def to_node_proto(self, node_name: str) -> onnx.NodeProto: - n = helper.make_node( - self.callee.name, - [_opt_var_to_str(x) for x in self.args], - self.output_names, - domain=self.callee.opset.domain, - name=node_name, - ) - for a in self.attrs: - n.attribute.append(a.attr_proto) + n = ir.to_proto(self.node) + n.name = node_name return n @property @@ -537,11 +502,11 @@ def add_stmt( results: Sequence[str], callee: values.Op, inputs: Sequence[Optional[ir.Value]], - attrs: Sequence[IRAttributeValue], + attrs: Sequence[ir.Attr], sub_functions=None, ) -> Sequence[ir.Value]: output_values = [ir.Value(name=o) for o in results] - attributes = [ir.from_proto(a.attr_proto) for a in attrs] + attributes = attrs # [ir.from_proto(a.attr_proto) for a in attrs] node = ir.Node( domain=callee.opset.domain, version=callee.opset.version, @@ -574,14 +539,9 @@ def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None var = IRVar(varname, typeinfo, sourceinfo) fn.append_output(var) - def make_attr(self, attrproto: onnx.AttributeProto) -> IRAttributeValue: - return IRAttributeValue(attrproto) - - def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> IRAttributeValue: - proto = onnx.AttributeProto() - proto.name = attrname - proto.ref_attr_name = refname - attr_type = ta.pytype_to_attrtype(pytype) - assert attr_type is not None - proto.type = attr_type - return IRAttributeValue(proto) + def make_attr(self, attrproto: onnx.AttributeProto) -> ir.Attr: + return ir.from_proto(attrproto) + + def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> ir.Attr: + attr_type = ir.AttributeType(ta.pytype_to_attrtype(pytype)) + return ir.Attr(attrname, attr_type, None, refname) From fae6609059f3f16955908774a47b31fcb9e3d0d4 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 14 Nov 2025 22:07:24 -0800 Subject: [PATCH 15/43] IR builder cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 45 +++++++++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 15 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index df967fb8ca..d989547731 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -74,7 +74,7 @@ def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) - self.name = varname self.info = sourceinfo self.typeinfo = typeinfo - if typeinfo is None: + if typeinfo is None or not hasattr(typeinfo, "to_type_proto"): self.value = ir.Value(name=varname) else: type_and_shape = ir.from_proto(typeinfo.to_type_proto()) @@ -135,6 +135,7 @@ class IRAttributeParameter: name: str type: onnx.AttributeProto.AttributeType + attr: ir.Attr default_value: str | int | float | None = None # TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType. @@ -193,9 +194,8 @@ def debug_print(self): if logger.isEnabledFor(logging.DEBUG): logger.debug("%s: %s", type(self), self) - def to_node_proto(self, node_name: str) -> onnx.NodeProto: + def to_node_proto(self) -> onnx.NodeProto: n = ir.to_proto(self.node) - n.name = node_name return n @property @@ -208,8 +208,8 @@ class IRFunction: """Represents a function in the IR.""" def __init__(self, name: str, domain: str = "") -> None: - self.ir_graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name) - self.domain = domain + graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name) + self.ir_function = ir.Function(domain, name, graph=graph, attributes=[]) self.outputs: list[IRVar] = [] self.stmts: list[IRStmt] = [] self.called_functions: dict[str, onnx.FunctionProto] = {} @@ -218,15 +218,20 @@ def __init__(self, name: str, domain: str = "") -> None: self.outer_scope_variables: dict[Any, Any] = {} self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = [] + @property + def domain(self) -> str: + """Returns the domain of this function.""" + return self.ir_function.domain + @property def docstring(self) -> str: """Returns the docstring of this function.""" - return self.ir_graph.doc_string or "" + return self.ir_function.doc_string or "" @property def name(self) -> str: """Returns the name of this function.""" - return self.ir_graph.name + return self.ir_function.name @property def assigned_names(self) -> Sequence[str]: @@ -253,16 +258,23 @@ def __str__(self): return f"{self.name} {attrs}{inputs} => {outputs}{stmts}" def append_stmt(self, stmt: IRStmt) -> None: + count = len(self.stmts) + node_name = f"n{count}" + stmt.node.name = node_name self.stmts.append(stmt) + self.ir_function.append(stmt.node) - def append_input(self, name: IRVar) -> None: - self.ordered_inputs_and_attrs.append(name) + def append_input(self, var: IRVar) -> None: + self.ordered_inputs_and_attrs.append(var) + self.ir_function.inputs.append(var.value) - def append_output(self, name: IRVar) -> None: - self.outputs.append(name) + def append_output(self, var: IRVar) -> None: + self.outputs.append(var) + self.ir_function.outputs.append(var.value) def add_attr_parameter(self, attr: IRAttributeParameter) -> None: self.ordered_inputs_and_attrs.append(attr) + self.ir_function.attributes.add(attr.attr) def debug_print(self): if logger.isEnabledFor(logging.DEBUG): @@ -407,7 +419,7 @@ def _to_graph_and_functions( called_functions.update(s.functions) called_functions.update(self.called_functions) graph = helper.make_graph( - [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)], + [s.to_node_proto() for s in self.stmts], self.name, [x.to_value_info(use_default_type) for x in self.inputs], [y.to_value_info(use_default_type) for y in self.outputs], @@ -450,7 +462,7 @@ def to_function_proto(self) -> onnx.FunctionProto: doesn't support it. """ opsets = self.get_opset_import() - nodes = [s.to_node_proto(f"n{i}") for i, s in enumerate(self.stmts)] + nodes = [s.to_node_proto() for s in self.stmts] for n in nodes: if n.domain not in opsets: opsets[n.domain] = 1 # TODO: how to get n.version? @@ -494,7 +506,7 @@ def new_function(self, name: str, domain: str = "", register: bool = False) -> I return function def add_docstring(self, fn: IRFunction, docstring: str): - fn.ir_graph.doc_string = docstring + fn.ir_function.doc_string = docstring def add_stmt( self, @@ -533,7 +545,10 @@ def add_attr_parameter( attribute_type: onnx.AttributeProto.AttributeType, default_value: int | float | str | None, ) -> None: - fn.add_attr_parameter(IRAttributeParameter(varname, attribute_type, default_value)) + attr = ir.Attr(varname, ir.AttributeType(attribute_type), None, None) + fn.add_attr_parameter( + IRAttributeParameter(varname, attribute_type, attr, default_value) + ) def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: var = IRVar(varname, typeinfo, sourceinfo) From 6b874078397141a41c0af836fd20743a546c6e94 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 14 Nov 2025 22:29:12 -0800 Subject: [PATCH 16/43] IR builder cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 4 ++-- onnxscript/irbuilder.py | 51 +++++++++++++++++++++-------------------- 2 files changed, 28 insertions(+), 27 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 4e5c45ee95..491931f759 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1318,7 +1318,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars ] - graph, sub_functions = body._to_graph_and_functions() + graph, sub_functions = body.to_graph_and_functions() attrs = [self._make_onnx_attr("body", graph)] info = self._source_of(loop_stmt) @@ -1385,7 +1385,7 @@ def _translate_block( typeinfo = None self.ir_builder.add_output(self._current_fn, ovar.name, typeinfo, source) graph = self._exit_scope() - return graph._to_graph_and_functions() + return graph.to_graph_and_functions() def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: """Translate a nested function definition.""" diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index d989547731..bb2f63f692 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -340,7 +340,7 @@ def to_model_proto( if value_infos else None ) - graph, sub_functions = self._to_graph_and_functions( + graph, sub_functions = self.to_graph_and_functions( use_default_type=False, value_infos=value_infos ) if io_types is not None: @@ -397,7 +397,7 @@ def to_proto(f): graph, opset_imports=opset_imports, functions=functions, **kwargs ) - def _to_graph_and_functions( + def to_graph_and_functions( self, use_default_type: bool = True, value_infos: Sequence[ValueInfoProto] | None = None, @@ -437,11 +437,11 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: Returns: an instance of :class:`onnx.GraphProto` """ - graph, _ = self._to_graph_and_functions(use_default_type=use_default_type) + graph, _ = self.to_graph_and_functions(use_default_type=use_default_type) return graph def get_opset_import(self) -> dict[str, int]: - func_opset_imports = {} + func_opset_imports = self.ir_function.opset_imports for s in self.stmts: if s.callee.opset.domain not in func_opset_imports: func_opset_imports[s.callee.opset.domain] = s.callee.opset.version @@ -466,27 +466,28 @@ def to_function_proto(self) -> onnx.FunctionProto: for n in nodes: if n.domain not in opsets: opsets[n.domain] = 1 # TODO: how to get n.version? - opset_imports = [ - onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() - ] - - attribute_names = [attr.name for attr in self.attrs if not attr.has_default] - - f = helper.make_function( - self.domain, - self.name, - inputs=[x.name for x in self.inputs], - outputs=[y.name for y in self.outputs], - nodes=nodes, - opset_imports=opset_imports, # TODO - attributes=attribute_names, - doc_string=self.docstring, - ) - # In protobuf 4.x fields aren't defined as class attribute so it should check instance attribute instead - if hasattr(f, "attribute_proto"): - f.attribute_proto.extend( - [attr.attr_proto for attr in self.attrs if attr.has_default] - ) + f = ir.to_proto(self.ir_function) + # opset_imports = [ + # onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() + # ] + + # attribute_names = [attr.name for attr in self.attrs if not attr.has_default] + + # f = helper.make_function( + # self.domain, + # self.name, + # inputs=[x.name for x in self.inputs], + # outputs=[y.name for y in self.outputs], + # nodes=nodes, + # opset_imports=opset_imports, # TODO + # attributes=attribute_names, + # doc_string=self.docstring, + # ) + # # In protobuf 4.x fields aren't defined as class attribute so it should check instance attribute instead + # if hasattr(f, "attribute_proto"): + # f.attribute_proto.extend( + # [attr.attr_proto for attr in self.attrs if attr.has_default] + # ) return f From 0047814363fdb9c5e1c7c870a9160149e0a8ec0c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 15 Nov 2025 08:34:27 -0800 Subject: [PATCH 17/43] Fix attribute error Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index bb2f63f692..dc76988fcb 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -74,13 +74,16 @@ def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) - self.name = varname self.info = sourceinfo self.typeinfo = typeinfo - if typeinfo is None or not hasattr(typeinfo, "to_type_proto"): + if typeinfo is None: self.value = ir.Value(name=varname) else: - type_and_shape = ir.from_proto(typeinfo.to_type_proto()) - self.value = ir.Value( - name=varname, type=type_and_shape.type, shape=type_and_shape.shape - ) + try: + type_and_shape = ir.from_proto(typeinfo.to_type_proto()) + self.value = ir.Value( + name=varname, type=type_and_shape.type, shape=type_and_shape.shape + ) + except AttributeError: + self.value = ir.Value(name=varname) def __str__(self): return self.name From 1917af16ffdda4c0a9d2ac86e5cd4db8ff9f75b5 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sat, 15 Nov 2025 16:52:28 -0800 Subject: [PATCH 18/43] Default value for attr parameter Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index dc76988fcb..06955ae918 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -549,7 +549,7 @@ def add_attr_parameter( attribute_type: onnx.AttributeProto.AttributeType, default_value: int | float | str | None, ) -> None: - attr = ir.Attr(varname, ir.AttributeType(attribute_type), None, None) + attr = ir.Attr(varname, ir.AttributeType(attribute_type), default_value, None) fn.add_attr_parameter( IRAttributeParameter(varname, attribute_type, attr, default_value) ) From 6b01c988d10ddd9953f5cb802aba103f71560d3d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 16 Nov 2025 08:22:07 -0800 Subject: [PATCH 19/43] Fix to-graph-proto Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 35 ++++------------------------------- 1 file changed, 4 insertions(+), 31 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 06955ae918..2ade5f5eaf 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -421,13 +421,9 @@ def to_graph_and_functions( for s in self.stmts: called_functions.update(s.functions) called_functions.update(self.called_functions) - graph = helper.make_graph( - [s.to_node_proto() for s in self.stmts], - self.name, - [x.to_value_info(use_default_type) for x in self.inputs], - [y.to_value_info(use_default_type) for y in self.outputs], - value_info=value_infos, - ) + graph = self.to_graph_proto(use_default_type=use_default_type) + if value_infos: + graph.value_info.extend(value_infos) return graph, called_functions def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: @@ -440,8 +436,7 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: Returns: an instance of :class:`onnx.GraphProto` """ - graph, _ = self.to_graph_and_functions(use_default_type=use_default_type) - return graph + return ir.to_proto(self.ir_function.graph) def get_opset_import(self) -> dict[str, int]: func_opset_imports = self.ir_function.opset_imports @@ -470,30 +465,8 @@ def to_function_proto(self) -> onnx.FunctionProto: if n.domain not in opsets: opsets[n.domain] = 1 # TODO: how to get n.version? f = ir.to_proto(self.ir_function) - # opset_imports = [ - # onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() - # ] - - # attribute_names = [attr.name for attr in self.attrs if not attr.has_default] - - # f = helper.make_function( - # self.domain, - # self.name, - # inputs=[x.name for x in self.inputs], - # outputs=[y.name for y in self.outputs], - # nodes=nodes, - # opset_imports=opset_imports, # TODO - # attributes=attribute_names, - # doc_string=self.docstring, - # ) - # # In protobuf 4.x fields aren't defined as class attribute so it should check instance attribute instead - # if hasattr(f, "attribute_proto"): - # f.attribute_proto.extend( - # [attr.attr_proto for attr in self.attrs if attr.has_default] - # ) return f - # IRBuilder: abstracts out details of the IR in the python-to-IR converter From 74ca99647f3c550b6e65b9e3d530e470ac8ecc37 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 16 Nov 2025 09:05:31 -0800 Subject: [PATCH 20/43] Opset imports Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 50 ++++++++++++++++++++++++++--------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 2ade5f5eaf..2b75440ad6 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -266,6 +266,18 @@ def append_stmt(self, stmt: IRStmt) -> None: stmt.node.name = node_name self.stmts.append(stmt) self.ir_function.append(stmt.node) + domain = stmt.node.domain + version = stmt.node.version + if domain not in self.ir_function.opset_imports: + self.ir_function.opset_imports[domain] = version + else: + existing_version = self.ir_function.opset_imports[domain] + if existing_version != version: + warnings.warn( + f"Version conflict: domain: {domain!r}, " + f"versions {existing_version} and {version} used.", + category=UserWarning, + ) def append_input(self, var: IRVar) -> None: self.ordered_inputs_and_attrs.append(var) @@ -436,21 +448,22 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: Returns: an instance of :class:`onnx.GraphProto` """ + del use_default_type # currently not used return ir.to_proto(self.ir_function.graph) - def get_opset_import(self) -> dict[str, int]: - func_opset_imports = self.ir_function.opset_imports - for s in self.stmts: - if s.callee.opset.domain not in func_opset_imports: - func_opset_imports[s.callee.opset.domain] = s.callee.opset.version - elif func_opset_imports[s.callee.opset.domain] != s.callee.opset.version: - warnings.warn( - f"There is a version conflict in domain: {s.callee.opset.domain!r}, " - f"with {self.name!r}.", - category=UserWarning, - stacklevel=1, - ) - return func_opset_imports + # def get_opset_import(self) -> dict[str, int]: + # func_opset_imports = self.ir_function.opset_imports + # for s in self.stmts: + # if s.callee.opset.domain not in func_opset_imports: + # func_opset_imports[s.callee.opset.domain] = s.callee.opset.version + # elif func_opset_imports[s.callee.opset.domain] != s.callee.opset.version: + # warnings.warn( + # f"There is a version conflict in domain: {s.callee.opset.domain!r}, " + # f"with {self.name!r}.", + # category=UserWarning, + # stacklevel=1, + # ) + # return func_opset_imports def to_function_proto(self) -> onnx.FunctionProto: """Converts this instance into a `onnx.FunctionProto`. @@ -459,14 +472,15 @@ def to_function_proto(self) -> onnx.FunctionProto: Conversion ignores default values for attributes if the ONNX version installed doesn't support it. """ - opsets = self.get_opset_import() - nodes = [s.to_node_proto() for s in self.stmts] - for n in nodes: - if n.domain not in opsets: - opsets[n.domain] = 1 # TODO: how to get n.version? + # opsets = self.get_opset_import() + # nodes = [s.to_node_proto() for s in self.stmts] + # for n in nodes: + # if n.domain not in opsets: + # opsets[n.domain] = 1 # TODO: how to get n.version? f = ir.to_proto(self.ir_function) return f + # IRBuilder: abstracts out details of the IR in the python-to-IR converter From d855c2c67fab8d854264dae3d83273a0f9a3d76b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 16 Nov 2025 10:07:13 -0800 Subject: [PATCH 21/43] Minor fix Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 44 +++++------------------------------------ 1 file changed, 5 insertions(+), 39 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 2b75440ad6..62f59eacda 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -293,12 +293,7 @@ def add_attr_parameter(self, attr: IRAttributeParameter) -> None: def debug_print(self): if logger.isEnabledFor(logging.DEBUG): - st = io.StringIO() - for s in self.stmts: - for attr in s.attrs: - if attr.attr_proto.HasField("g"): - st.write(helper.printable_graph(attr.attr_proto.g)) - st.write("\n") + logger.debug(str(self.ir_function)) def add_called_function(self, fun: values.OnnxFunction) -> None: for name, fct in fun.function_ir.called_functions.items(): @@ -384,10 +379,7 @@ def to_proto(f): functions = [to_proto(f) for f in functions] - opsets = {} - for n in self.stmts: - if n.callee.opset.domain not in opsets: - opsets[n.callee.opset.domain] = n.callee.opset.version + opsets = self.ir_function.opset_imports.copy() for proto in functions: if proto.domain not in opsets: @@ -442,8 +434,7 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: """Converts this instance into a `onnx.GraphProto`. Args: - use_default_type: if True, the function uses a default type - for inputs and outputs that do not have a type + use_default_type: Unused. Returns: an instance of :class:`onnx.GraphProto` @@ -451,34 +442,9 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: del use_default_type # currently not used return ir.to_proto(self.ir_function.graph) - # def get_opset_import(self) -> dict[str, int]: - # func_opset_imports = self.ir_function.opset_imports - # for s in self.stmts: - # if s.callee.opset.domain not in func_opset_imports: - # func_opset_imports[s.callee.opset.domain] = s.callee.opset.version - # elif func_opset_imports[s.callee.opset.domain] != s.callee.opset.version: - # warnings.warn( - # f"There is a version conflict in domain: {s.callee.opset.domain!r}, " - # f"with {self.name!r}.", - # category=UserWarning, - # stacklevel=1, - # ) - # return func_opset_imports - def to_function_proto(self) -> onnx.FunctionProto: - """Converts this instance into a `onnx.FunctionProto`. - - Note: Default values for attributes are an experimental feature in ONNX. - Conversion ignores default values for attributes if the ONNX version installed - doesn't support it. - """ - # opsets = self.get_opset_import() - # nodes = [s.to_node_proto() for s in self.stmts] - # for n in nodes: - # if n.domain not in opsets: - # opsets[n.domain] = 1 # TODO: how to get n.version? - f = ir.to_proto(self.ir_function) - return f + """Converts this instance into a `onnx.FunctionProto`.""" + return ir.to_proto(self.ir_function) # IRBuilder: abstracts out details of the IR in the python-to-IR converter From 4372082bad99221b824a5c56d888974f1dc0a60b Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 16 Nov 2025 10:08:01 -0800 Subject: [PATCH 22/43] Run lint Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 62f59eacda..66ace617cc 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -4,7 +4,6 @@ from __future__ import annotations import dataclasses -import io import logging import warnings from typing import Any, Optional, Protocol, Sequence, Union From cf1c12adf59b2d1b1c80a789772454b0fa33a2e0 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 16 Nov 2025 15:42:45 -0800 Subject: [PATCH 23/43] More cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 33 ++++++-------- onnxscript/irbuilder.py | 97 +++++++++++++++-------------------------- 2 files changed, 48 insertions(+), 82 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 491931f759..e6c65cd051 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -291,7 +291,10 @@ def generate_unique_name(self, candidate: str = "tmp") -> str: def _make_onnx_attr( self, attrname: str, attrval: Any, attrtype: int | None = None - ) -> irbuilder.IRAttributeValue: + ) -> ir.Attr: + if isinstance(attrval, ir.Graph): + return ir.Attr(attrname, ir.AttributeType.GRAPH, attrval) + def tensor_name_generator() -> str: """Return name to be used for tensor, if we need to create one.""" return self.generate_unique_name(f"attr_{attrname}") @@ -303,7 +306,7 @@ def tensor_name_generator() -> str: def _to_onnx_attr_ref( self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo] - ) -> irbuilder.IRAttributeValue: + ) -> ir.Attr: pytype = val.typeinfo attrtype = ta.pytype_to_attrtype(pytype) attrname = None @@ -364,7 +367,6 @@ def emit( callee: values.Op | str, inputs: Sequence[Optional[Variable]], attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, - sub_functions: Optional[dict[str, onnx.FunctionProto]] = None, ) -> Sequence[Variable] | Variable: for i, x in enumerate(inputs): if (x is not None) and not isinstance(x, ir.Value): @@ -373,15 +375,12 @@ def emit( callee = values.Op(self.default_opset, callee) if attrs is None: attrs = [] - if sub_functions is None: - sub_functions = {} output_values = self.ir_builder.add_stmt( self._current_fn, outputs, callee, inputs, attrs, - sub_functions, ) return output_values if len(output_values) > 1 else output_values[0] @@ -943,10 +942,10 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable if isinstance(node, ast.Name): function_name = node.id found = self._lookup(function_name, self._source_of(node), raise_exception=False) - if isinstance(found, onnxscript.OnnxFunction): - self._current_fn.add_called_function(found) - return found - if isinstance(found, values.Op): + # if isinstance(found, onnxscript.OnnxFunction): + # self._current_fn.add_called_function(found) + # return found + if isinstance(found, (values.Op, onnxscript.OnnxFunction)): return found if not found: if function_name not in self.default_opset: @@ -1117,11 +1116,11 @@ def _translate_if_stmt(self, stmt: ast.If) -> None: live_defs = list(live_def_set) test = self._translate_expr(stmt.test, "cond") lineno = self._source_of(stmt).lineno - thenGraph, sub_fct_then = self._translate_block( + thenGraph = self._translate_block( stmt.body, f"thenGraph_{lineno}", live_defs, parent_stmt=stmt ) thenAttr = self._make_onnx_attr("then_branch", thenGraph) - elseGraph, sub_fct_else = self._translate_block( + elseGraph = self._translate_block( stmt.orelse, f"elseGraph_{lineno}", live_defs, parent_stmt=stmt ) elseAttr = self._make_onnx_attr("else_branch", elseGraph) @@ -1134,9 +1133,6 @@ def rename(x): if not renamed: self.fail(stmt, "A subgraph for a test do not have any output variable.") - sub_functions = {} - sub_functions.update(sub_fct_then) - sub_functions.update(sub_fct_else) if renamed == [test.name]: self.fail(stmt, f"Input and output cannot be the same {renamed!r}.") if_outputs = self.emit( @@ -1144,7 +1140,6 @@ def rename(x): values.Op(self.default_opset, "If"), [test], [thenAttr, elseAttr], - sub_functions=sub_functions, ) if isinstance(if_outputs, ir.Value): if_outputs = [if_outputs] @@ -1318,8 +1313,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars ] - graph, sub_functions = body.to_graph_and_functions() - attrs = [self._make_onnx_attr("body", graph)] + attrs = [self._make_onnx_attr("body", body.ir_function.graph)] info = self._source_of(loop_stmt) def rename(x): @@ -1332,7 +1326,6 @@ def rename(x): "Loop", inputs, attrs, - sub_functions=sub_functions, ) if isinstance(loop_outputs, ir.Value): loop_outputs = [loop_outputs] @@ -1385,7 +1378,7 @@ def _translate_block( typeinfo = None self.ir_builder.add_output(self._current_fn, ovar.name, typeinfo, source) graph = self._exit_scope() - return graph.to_graph_and_functions() + return graph.ir_function.graph def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: """Translate a nested function definition.""" diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 66ace617cc..b37d0aa43b 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -20,8 +20,6 @@ from onnxscript.onnx_types import ONNXType from onnxscript.sourceinfo import SourceInfo -# A simple IR (Function, Stmt, Attr, Var): - logger = logging.getLogger("onnxscript") @@ -178,8 +176,7 @@ def __init__( if not isinstance(callee, values.Op): raise TypeError(f"Unexpected type {type(callee)} for callee.") self.node = node - self.callee = callee - self.functions = sub_functions or {} + node.meta.setdefault("callee", callee) @property def args(self) -> Sequence[Optional[str]]: @@ -213,8 +210,7 @@ def __init__(self, name: str, domain: str = "") -> None: graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name) self.ir_function = ir.Function(domain, name, graph=graph, attributes=[]) self.outputs: list[IRVar] = [] - self.stmts: list[IRStmt] = [] - self.called_functions: dict[str, onnx.FunctionProto] = {} + # a dictionary of nested function-definitions self.nested_functions: dict[str, IRFunction] = {} self.outer_scope_variables: dict[Any, Any] = {} @@ -238,7 +234,7 @@ def name(self) -> str: @property def assigned_names(self) -> Sequence[str]: """Returns the list of variables assigned to by this function.""" - return [v for stmt in self.stmts for v in stmt.output_names] + return [v.name for n in self.ir_function for v in n.outputs] @property def inputs(self) -> Sequence[IRVar]: @@ -253,20 +249,14 @@ def attrs(self) -> Sequence[IRAttributeParameter]: ] def __str__(self): - attrs = _format(self.attrs, "<", ", ", ">") if self.attrs else "" - inputs = _format([x.typed_str() for x in self.inputs], "(", ", ", ")") - outputs = _format([x.typed_str() for x in self.outputs], "(", ", ", ")") - stmts = _format(self.stmts, "\n{\n ", "\n ", "\n}\n") - return f"{self.name} {attrs}{inputs} => {outputs}{stmts}" - - def append_stmt(self, stmt: IRStmt) -> None: - count = len(self.stmts) + return str(self.ir_function) + + def append_stmt(self, node: ir.Node) -> None: + count = len(self.ir_function) node_name = f"n{count}" - stmt.node.name = node_name - self.stmts.append(stmt) - self.ir_function.append(stmt.node) - domain = stmt.node.domain - version = stmt.node.version + self.ir_function.append(node) + domain = node.domain + version = node.version if domain not in self.ir_function.opset_imports: self.ir_function.opset_imports[domain] = version else: @@ -276,6 +266,7 @@ def append_stmt(self, stmt: IRStmt) -> None: f"Version conflict: domain: {domain!r}, " f"versions {existing_version} and {version} used.", category=UserWarning, + stacklevel=2, ) def append_input(self, var: IRVar) -> None: @@ -294,20 +285,6 @@ def debug_print(self): if logger.isEnabledFor(logging.DEBUG): logger.debug(str(self.ir_function)) - def add_called_function(self, fun: values.OnnxFunction) -> None: - for name, fct in fun.function_ir.called_functions.items(): - if name in self.called_functions: - continue - self.called_functions[name] = fct - if fun.name in self.called_functions: - # Already added. - return - try: - proto = fun.to_function_proto() - except (TypeError, AttributeError) as e: - raise TypeError(f"Issue with type f{type(fun)}.") from e - self.called_functions[fun.name] = proto - def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun @@ -349,9 +326,10 @@ def to_model_proto( if value_infos else None ) - graph, sub_functions = self.to_graph_and_functions( - use_default_type=False, value_infos=value_infos - ) + sub_functions = self.get_called_functions() + graph = self.to_graph_proto(use_default_type=False) + if value_infos: + graph.value_info.extend(value_infos) if io_types is not None: for input in graph.input: if not input.HasField("type"): @@ -403,31 +381,24 @@ def to_proto(f): graph, opset_imports=opset_imports, functions=functions, **kwargs ) - def to_graph_and_functions( - self, - use_default_type: bool = True, - value_infos: Sequence[ValueInfoProto] | None = None, - ) -> tuple[onnx.GraphProto, dict[str, onnx.FunctionProto]]: - """Converts this instance into a `onnx.GraphProto` and a map from - function-name to `onnx.FunctionProto`. + def get_called_functions(self) -> dict[str, onnx.FunctionProto]: + called_functions: dict[str, values.OnnxFunction] = {} - Args: - use_default_type: if True, the function uses a default type - for inputs and outputs that do not have a type - value_infos: a sequence of :class:`onnx.ValueInfoProto` to be added - to the graph. + def visit(function_ir: IRFunction): + for node in ir.traversal.RecursiveGraphIterator(function_ir.ir_function.graph): + callee = node.meta.get("callee", None) + if isinstance(callee, values.OnnxFunction): + add(callee) - Returns: - a pair of a :class:`onnx.GraphProto` and list of :class:`onnx.FunctionProto` - """ - called_functions: dict[str, onnx.FunctionProto] = {} - for s in self.stmts: - called_functions.update(s.functions) - called_functions.update(self.called_functions) - graph = self.to_graph_proto(use_default_type=use_default_type) - if value_infos: - graph.value_info.extend(value_infos) - return graph, called_functions + def add(f: values.OnnxFunction): + if f.name in called_functions: + return + called_functions[f.name] = f + visit(f.function_ir) + + visit(self) + + return {name: f.to_function_proto() for name, f in called_functions.items()} def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: """Converts this instance into a `onnx.GraphProto`. @@ -483,8 +454,10 @@ def add_stmt( outputs=output_values, attributes=attributes, ) - stmt = IRStmt(node, callee, sub_functions=sub_functions) - fn.append_stmt(stmt) + if not isinstance(callee, values.Op): + raise TypeError(f"Unexpected type {type(callee)} for callee.") + node.meta.setdefault("callee", callee) + fn.append_stmt(node) output_values = [ir.Value(name=o) for o in results] return output_values From 29ec1391cc423ddce1dbef0297e10299b76e9d74 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 16 Nov 2025 18:02:51 -0800 Subject: [PATCH 24/43] More cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 101 +++------------------------------------- onnxscript/values.py | 12 ++--- 2 files changed, 12 insertions(+), 101 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index b37d0aa43b..6a0096b69a 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -3,7 +3,6 @@ # ruff: noqa: TID251 from __future__ import annotations -import dataclasses import logging import warnings from typing import Any, Optional, Protocol, Sequence, Union @@ -16,7 +15,6 @@ import onnxscript from onnxscript import type_annotation as ta from onnxscript import values -from onnxscript._internal import version_utils from onnxscript.onnx_types import ONNXType from onnxscript.sourceinfo import SourceInfo @@ -116,91 +114,7 @@ def _opt_var_to_str(x): return "" if x is None else str(x) -IRAttributeValue = ir.Attr - - -@dataclasses.dataclass(frozen=True) -class IRAttributeParameter: - """An attribute parameter (representing a formal parameter). - - It may or may not carry a default value. - - Attributes: - name: The name of the attribute. - type: The type of the attribute. - default_value: The default value of the attribute. - has_default: Whether the attribute has a default value. - attr_proto: The attribute proto. - """ - - name: str - type: onnx.AttributeProto.AttributeType - attr: ir.Attr - default_value: str | int | float | None = None - - # TODO(justinchuby): Validate the default_value is the same type as specified in AttributeType. - - def __str__(self): - if self.has_default: - return helper.printable_attribute(self.attr_proto) - # TODO(justinchuby): Include a readable type name. - return self.name - - @property - def has_default(self): - return self.default_value is not None - - @property - def attr_proto(self) -> onnx.AttributeProto: - if not self.has_default: - raise ValueError( - "Attribute has no default value. Only attributes with default " - "values can be converted to AttributeProto." - ) - if version_utils.onnx_older_than("1.15"): - # TODO(after 1.14 is deprecated): Remove this branch. - # Argument 'attr_type' was added after version 1.14. - return helper.make_attribute(self.name, self.default_value) - # pylint: disable=unexpected-keyword-arg - return helper.make_attribute(self.name, self.default_value, attr_type=self.type) # type: ignore[call-arg] - # pylint: enable=unexpected-keyword-arg - - -class IRStmt: - def __init__( - self, - node: ir.Node, - callee: values.Op, - sub_functions=None, - ) -> None: - if not isinstance(callee, values.Op): - raise TypeError(f"Unexpected type {type(callee)} for callee.") - self.node = node - node.meta.setdefault("callee", callee) - - @property - def args(self) -> Sequence[Optional[str]]: - return [x.name if x is not None else None for x in self.node.inputs] - - @property - def attrs(self) -> Sequence[ir.Attr]: - return list(self.node.attributes.values()) - - def __str__(self): - return str(self.node) - - def debug_print(self): - if logger.isEnabledFor(logging.DEBUG): - logger.debug("%s: %s", type(self), self) - - def to_node_proto(self) -> onnx.NodeProto: - n = ir.to_proto(self.node) - return n - - @property - def output_names(self) -> Sequence[str]: - """Returns the list of variables assigned to by this statement.""" - return [x.name for x in self.node.outputs] +IRAttributeParameter = ir.Attr class IRFunction: @@ -251,9 +165,9 @@ def attrs(self) -> Sequence[IRAttributeParameter]: def __str__(self): return str(self.ir_function) - def append_stmt(self, node: ir.Node) -> None: + def append_node(self, node: ir.Node) -> None: count = len(self.ir_function) - node_name = f"n{count}" + node.name = f"n{count}" self.ir_function.append(node) domain = node.domain version = node.version @@ -279,7 +193,7 @@ def append_output(self, var: IRVar) -> None: def add_attr_parameter(self, attr: IRAttributeParameter) -> None: self.ordered_inputs_and_attrs.append(attr) - self.ir_function.attributes.add(attr.attr) + self.ir_function.attributes.add(attr) def debug_print(self): if logger.isEnabledFor(logging.DEBUG): @@ -457,8 +371,7 @@ def add_stmt( if not isinstance(callee, values.Op): raise TypeError(f"Unexpected type {type(callee)} for callee.") node.meta.setdefault("callee", callee) - fn.append_stmt(node) - output_values = [ir.Value(name=o) for o in results] + fn.append_node(node) return output_values def add_input( @@ -475,9 +388,7 @@ def add_attr_parameter( default_value: int | float | str | None, ) -> None: attr = ir.Attr(varname, ir.AttributeType(attribute_type), default_value, None) - fn.add_attr_parameter( - IRAttributeParameter(varname, attribute_type, attr, default_value) - ) + fn.add_attr_parameter(attr) def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: var = IRVar(varname, typeinfo, sourceinfo) diff --git a/onnxscript/values.py b/onnxscript/values.py index deeca21e58..ec66bb5cda 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -223,9 +223,9 @@ def _param_schema_from_function_ir_attr(attr: irbuilder.IRAttributeParameter): type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get( onnx.defs.OpSchema.AttrType(attr.type) # type: ignore[call-arg] ), - default=_EmptyDefault if attr.default_value is None else attr.default_value, + default=_EmptyDefault if attr.value is None else attr.value, is_input=False, - required=not attr.has_default, + required=attr.value is None, ) @@ -448,15 +448,15 @@ def _op_schema_from_function_ir( type=onnx.defs.OpSchema.AttrType(attr.type), # type: ignore[call-arg] ) for attr in function_ir.attrs - if not attr.has_default + if attr.value is None ], *[ onnx.defs.OpSchema.Attribute( attr.name, - default_value=attr.attr_proto, + default_value=ir.to_proto(attr), ) for attr in function_ir.attrs - if attr.has_default + if attr.value is not None ], ], ) @@ -592,7 +592,7 @@ def to_function_proto(self) -> onnx.FunctionProto: def to_model_proto(self, **kwargs): """Converts the function into :class:`onnx.ModelProto`.""" if self.function_ir.attrs and any( - not attr.has_default for attr in self.function_ir.attrs + attr.value is None for attr in self.function_ir.attrs ): raise ValueError( "A function with required attributes cannot be exported as a model." From 95a76a21ae0da975e1933e0c0aae1f15ce980463 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 16 Nov 2025 20:46:52 -0800 Subject: [PATCH 25/43] More cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 125 +++++++++++++--------------------------- onnxscript/values.py | 25 ++++---- 2 files changed, 56 insertions(+), 94 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 6a0096b69a..77efbbb81b 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -5,14 +5,15 @@ import logging import warnings -from typing import Any, Optional, Protocol, Sequence, Union +from typing import Any, Optional, Sequence, Union import onnx import onnx_ir as ir -from onnx import ValueInfoProto, helper +from onnx import helper from onnx.defs import onnx_opset_version import onnxscript +import onnxscript.type_annotation from onnxscript import type_annotation as ta from onnxscript import values from onnxscript.onnx_types import ONNXType @@ -55,66 +56,7 @@ def __repr__(self) -> str: return f"IRTensorType({self.onnx_type.tensor_type.elem_type})" -class IRTypeLike(Protocol): - def to_type_proto(self) -> onnx.TypeProto: - """Converts IR type representation to onnx.TypeProto""" - - -class IRVar: - """A variable (representing a formal parameter).""" - - def __init__(self, varname: str, typeinfo: IRTypeLike, sourceinfo: SourceInfo) -> None: - if not isinstance(varname, str): - raise TypeError(f"varname must be a string not {type(varname)!r}.") - self.name = varname - self.info = sourceinfo - self.typeinfo = typeinfo - if typeinfo is None: - self.value = ir.Value(name=varname) - else: - try: - type_and_shape = ir.from_proto(typeinfo.to_type_proto()) - self.value = ir.Value( - name=varname, type=type_and_shape.type, shape=type_and_shape.shape - ) - except AttributeError: - self.value = ir.Value(name=varname) - - def __str__(self): - return self.name - - def __repr__(self): - return f"{self.__class__.__name__}({self.name!r}, {self.typeinfo!r})" - - def typed_str(self): - return f"{self.name} : {self.typeinfo}" - - def to_value_info(self, use_default_type: bool = True): - """Converts the content of this class into :class:`onnx.ValueInfoProto`. - - Args: - use_default_type: if True, use a default type if an explicit type - is not known. Otherwise, returns a ValueInfoProto without type. - - Returns: - an instance of :class:`onnx.ValueInfoProto` - """ - if self.name is None: - raise ValueError(self.info.msg("name cannot be None.")) - value_info_proto = ValueInfoProto() - value_info_proto.name = self.name - if self.typeinfo is not None: - value_info_proto.type.CopyFrom(self.typeinfo.to_type_proto()) - elif use_default_type: - value_info_proto.type.CopyFrom(IRType().to_type_proto()) - return value_info_proto - - -def _opt_var_to_str(x): - return "" if x is None else str(x) - - -IRAttributeParameter = ir.Attr +TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue class IRFunction: @@ -123,12 +65,15 @@ class IRFunction: def __init__(self, name: str, domain: str = "") -> None: graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name) self.ir_function = ir.Function(domain, name, graph=graph, attributes=[]) - self.outputs: list[IRVar] = [] + self.ordered_inputs_and_attrs: list[Union[ir.Value, ir.Attr]] = [] # a dictionary of nested function-definitions self.nested_functions: dict[str, IRFunction] = {} self.outer_scope_variables: dict[Any, Any] = {} - self.ordered_inputs_and_attrs: list[Union[IRVar, IRAttributeParameter]] = [] + + @property + def outputs(self) -> Sequence[ir.Value]: + return self.ir_function.outputs @property def domain(self) -> str: @@ -151,16 +96,14 @@ def assigned_names(self) -> Sequence[str]: return [v.name for n in self.ir_function for v in n.outputs] @property - def inputs(self) -> Sequence[IRVar]: - return [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)] + def inputs(self) -> Sequence[ir.Value]: + return ( + self.ir_function.inputs + ) # [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)] @property - def attrs(self) -> Sequence[IRAttributeParameter]: - return [ - attr - for attr in self.ordered_inputs_and_attrs - if isinstance(attr, IRAttributeParameter) - ] + def attrs(self) -> Sequence[ir.Attr]: + return [attr for attr in self.ordered_inputs_and_attrs if isinstance(attr, ir.Attr)] def __str__(self): return str(self.ir_function) @@ -183,15 +126,14 @@ def append_node(self, node: ir.Node) -> None: stacklevel=2, ) - def append_input(self, var: IRVar) -> None: + def append_input(self, var: ir.Value) -> None: self.ordered_inputs_and_attrs.append(var) - self.ir_function.inputs.append(var.value) + self.ir_function.inputs.append(var) - def append_output(self, var: IRVar) -> None: - self.outputs.append(var) - self.ir_function.outputs.append(var.value) + def append_output(self, var: ir.Value) -> None: + self.ir_function.outputs.append(var) - def add_attr_parameter(self, attr: IRAttributeParameter) -> None: + def add_attr_parameter(self, attr: ir.Attr) -> None: self.ordered_inputs_and_attrs.append(attr) self.ir_function.attributes.add(attr) @@ -334,6 +276,24 @@ def to_function_proto(self) -> onnx.FunctionProto: # IRBuilder: abstracts out details of the IR in the python-to-IR converter +def _make_value( + varname: str, typeinfo: TypeAnnotationValue, sourceinfo: SourceInfo +) -> ir.Value: + if typeinfo is None: + value = ir.Value(name=varname) + else: + try: + type_and_shape = ir.from_proto(typeinfo.to_type_proto()) + value = ir.Value( + name=varname, type=type_and_shape.type, shape=type_and_shape.shape + ) + except AttributeError: + value = ir.Value(name=varname) + value.meta.setdefault("sourceinfo", sourceinfo) + value.meta.setdefault("typeinfo", typeinfo) + return value + + class IRBuilder: def __init__(self): self.functions = {} @@ -356,7 +316,6 @@ def add_stmt( callee: values.Op, inputs: Sequence[Optional[ir.Value]], attrs: Sequence[ir.Attr], - sub_functions=None, ) -> Sequence[ir.Value]: output_values = [ir.Value(name=o) for o in results] attributes = attrs # [ir.from_proto(a.attr_proto) for a in attrs] @@ -375,10 +334,9 @@ def add_stmt( return output_values def add_input( - self, fn: IRFunction, varname: str, type: IRTypeLike, info: SourceInfo + self, fn: IRFunction, varname: str, type: TypeAnnotationValue, info: SourceInfo ) -> None: - var = IRVar(varname, type, info) - fn.append_input(var) + fn.append_input(_make_value(varname, type, info)) def add_attr_parameter( self, @@ -391,8 +349,7 @@ def add_attr_parameter( fn.add_attr_parameter(attr) def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: - var = IRVar(varname, typeinfo, sourceinfo) - fn.append_output(var) + fn.append_output(_make_value(varname, typeinfo, sourceinfo)) def make_attr(self, attrproto: onnx.AttributeProto) -> ir.Attr: return ir.from_proto(attrproto) diff --git a/onnxscript/values.py b/onnxscript/values.py index ec66bb5cda..f875807d5d 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -209,12 +209,17 @@ def _param_schemas_from_op_schema( return tuple(schemas) +def _typeinfo(var: ir.Value) -> Any: + return var.meta.get("typeinfo") + + def _param_schema_from_function_ir_input(input: irbuilder.IRVar): - if type_annotation.is_optional(input.typeinfo): + typeinfo = _typeinfo(input) + if type_annotation.is_optional(typeinfo): required = False else: required = True - return ParamSchema(name=input.name, type=input.typeinfo, is_input=True, required=required) + return ParamSchema(name=input.name, type=typeinfo, is_input=True, required=required) def _param_schema_from_function_ir_attr(attr: irbuilder.IRAttributeParameter): @@ -241,10 +246,10 @@ def _param_schemas_from_function_ir( # ONNX OpSchema and FunctionProto does not support interleaving inputs and attributes. # This is by design. See more at https://github.com/microsoft/onnxscript/issues/771. for arg in function_ir.ordered_inputs_and_attrs: - if isinstance(arg, irbuilder.IRVar): + if isinstance(arg, ir.Value): # input schemas.append(_param_schema_from_function_ir_input(arg)) - elif isinstance(arg, irbuilder.IRAttributeParameter): + elif isinstance(arg, ir.Attr): # attr schemas.append(_param_schema_from_function_ir_attr(arg)) else: @@ -393,8 +398,8 @@ def _op_schema_from_function_ir( """Construct an ONNX OpSchema from an IRFunction.""" # Find all distinct types in the inputs and outputs - distinct_types = {arg.typeinfo for arg in function_ir.inputs}.union( - {arg.typeinfo for arg in function_ir.outputs} + distinct_types = {_typeinfo(arg) for arg in function_ir.inputs}.union( + {_typeinfo(arg) for arg in function_ir.outputs} ) # Create a mapping from type to a unique name type_to_constraint = {} @@ -408,10 +413,10 @@ def _op_schema_from_function_ir( formal_inputs = [ onnx.defs.OpSchema.FormalParameter( arg.name, - type_to_constraint[arg.typeinfo].name, + type_to_constraint[_typeinfo(arg)].name, param_option=( onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(arg.typeinfo) + if type_annotation.is_optional(_typeinfo(arg)) else onnx.defs.OpSchema.FormalParameterOption.Single ), # TODO(justinchu): Check this is_homogeneous thing @@ -422,10 +427,10 @@ def _op_schema_from_function_ir( formal_outputs = [ onnx.defs.OpSchema.FormalParameter( arg.name, - type_to_constraint[arg.typeinfo].name, + type_to_constraint[_typeinfo(arg)].name, param_option=( onnx.defs.OpSchema.FormalParameterOption.Optional - if type_annotation.is_optional(arg.typeinfo) + if type_annotation.is_optional(_typeinfo(arg)) else onnx.defs.OpSchema.FormalParameterOption.Single ), # TODO(justinchu): Check this is_homogeneous thing From 053e7c56033b9dda531b6e8e85b5b950045d9658 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 16 Nov 2025 20:49:20 -0800 Subject: [PATCH 26/43] Remove unused Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 20 -------------------- 1 file changed, 20 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 77efbbb81b..55968e9759 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -36,26 +36,6 @@ def select_ir_version(version: int, domain: str = "") -> int: return helper.OP_SET_ID_VERSION_MAP[domain, version] -class IRType: - def __init__(self): - self.onnx_type = onnx.TypeProto() - - def to_type_proto(self): - return self.onnx_type - - def __repr__(self) -> str: - return "IRType()" - - -class IRTensorType(IRType): - def __init__(self, elem_type: onnx.TensorProto.DataType) -> None: - super().__init__() - self.onnx_type.tensor_type.elem_type = elem_type - - def __repr__(self) -> str: - return f"IRTensorType({self.onnx_type.tensor_type.elem_type})" - - TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue From 46239f1a7688acc807cd1005efb5a804ea1d898d Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Sun, 16 Nov 2025 22:03:09 -0800 Subject: [PATCH 27/43] Minor cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 3 +-- onnxscript/values.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 55968e9759..823891e52b 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -14,7 +14,6 @@ import onnxscript import onnxscript.type_annotation -from onnxscript import type_annotation as ta from onnxscript import values from onnxscript.onnx_types import ONNXType from onnxscript.sourceinfo import SourceInfo @@ -335,5 +334,5 @@ def make_attr(self, attrproto: onnx.AttributeProto) -> ir.Attr: return ir.from_proto(attrproto) def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> ir.Attr: - attr_type = ir.AttributeType(ta.pytype_to_attrtype(pytype)) + attr_type = ir.AttributeType(onnxscript.type_annotation.pytype_to_attrtype(pytype)) return ir.Attr(attrname, attr_type, None, refname) diff --git a/onnxscript/values.py b/onnxscript/values.py index f875807d5d..f4423d0804 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -213,7 +213,7 @@ def _typeinfo(var: ir.Value) -> Any: return var.meta.get("typeinfo") -def _param_schema_from_function_ir_input(input: irbuilder.IRVar): +def _param_schema_from_function_ir_input(input: ir.Value): typeinfo = _typeinfo(input) if type_annotation.is_optional(typeinfo): required = False @@ -222,7 +222,7 @@ def _param_schema_from_function_ir_input(input: irbuilder.IRVar): return ParamSchema(name=input.name, type=typeinfo, is_input=True, required=required) -def _param_schema_from_function_ir_attr(attr: irbuilder.IRAttributeParameter): +def _param_schema_from_function_ir_attr(attr: ir.Attr): return ParamSchema( name=attr.name, type=_ATTRIBUTE_TYPE_TO_PYTHON_TYPE.get( From dda79770fb6dd239c4dbd744dd74515fe01c7eaa Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 05:22:15 -0800 Subject: [PATCH 28/43] Move to_model_proto Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 1 - onnxscript/converter_test.py | 2 +- onnxscript/irbuilder.py | 120 ++------------------------ onnxscript/values.py | 110 ++++++++++++++++++++++- tests/common/onnx_script_test_case.py | 4 +- 5 files changed, 119 insertions(+), 118 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index e6c65cd051..71e19c7929 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1475,7 +1475,6 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) self._analyzer = analysis.AstAnalyzer(stmt, self._message, self.globals) fn_ir = self._translate_function_def_common(stmt) - fn_ir.debug_print() self.this_module.add_function_def(fn_ir) self._analyzer = None return fn_ir diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index a35711aea9..63cdfd2939 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -235,7 +235,7 @@ def test_unary_op(self): def test_subfunction_check_model(self): from tests.models import subfunction - model = subfunction.MyElu.function_ir.to_model_proto(producer_name="p2o") + model = subfunction.MyElu.to_model_proto(producer_name="p2o") model = onnx.shape_inference.infer_shapes(model) onnx.checker.check_model(model) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 823891e52b..0f0b01d47b 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -9,30 +9,23 @@ import onnx import onnx_ir as ir -from onnx import helper -from onnx.defs import onnx_opset_version -import onnxscript import onnxscript.type_annotation from onnxscript import values -from onnxscript.onnx_types import ONNXType from onnxscript.sourceinfo import SourceInfo logger = logging.getLogger("onnxscript") -def _format(seq: Sequence[Any], prefix: str, sep: str, suffix: str, formatter=str): - """Formats a sequence of objects into a string.""" - return prefix + sep.join([formatter(x) for x in seq]) + suffix - - def select_ir_version(version: int, domain: str = "") -> int: """Selects a suitable ONNX ir_version for a given opset version.""" if domain == "": domain = "ai.onnx" - if (domain, version) not in helper.OP_SET_ID_VERSION_MAP: - return max(v for k, v in helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx") - return helper.OP_SET_ID_VERSION_MAP[domain, version] + if (domain, version) not in onnx.helper.OP_SET_ID_VERSION_MAP: + return max( + v for k, v in onnx.helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx" + ) + return onnx.helper.OP_SET_ID_VERSION_MAP[domain, version] TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue @@ -76,9 +69,7 @@ def assigned_names(self) -> Sequence[str]: @property def inputs(self) -> Sequence[ir.Value]: - return ( - self.ir_function.inputs - ) # [var for var in self.ordered_inputs_and_attrs if isinstance(var, IRVar)] + return self.ir_function.inputs @property def attrs(self) -> Sequence[ir.Attr]: @@ -116,106 +107,9 @@ def add_attr_parameter(self, attr: ir.Attr) -> None: self.ordered_inputs_and_attrs.append(attr) self.ir_function.attributes.add(attr) - def debug_print(self): - if logger.isEnabledFor(logging.DEBUG): - logger.debug(str(self.ir_function)) - def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun - def to_model_proto( - self, - functions=None, - io_types: Optional[ONNXType] = None, - input_types: Optional[Sequence[ONNXType]] = None, - output_types: Optional[Sequence[ONNXType]] = None, - value_infos: dict[str, ONNXType] | None = None, - opset_version: int | None = None, - **kwargs, - ) -> onnx.ModelProto: - """Converts this instance into a `onnx.ModelProto`. - - Args: - functions: A list of functions to include in the model. - By default, all functions called at least once are included. - io_types: When specified, all the inputs/outputs of the model - are set to be of this type. - input_types: When specified, all the inputs of the model - are set to be of the corresponding type in this list. - output_types: When specified, all the outputs of the model - are set to be of the corresponding type in this list. - value_infos: A dictionary mapping intermediate variable names to ONNX types. - Used to set value_info for intermediate variables. - opset_version: The standard opset version to use for the model if it - cannot be inferred. Otherwise defaults to the current opset version. - kwargs: Additional parameters given to function :func:`onnx.helper.make_model`. - - Returns: - An instance of :class:`onnx.ModelProto`. - """ - value_infos = ( - [ - onnx.helper.make_value_info(name, type.to_type_proto()) - for name, type in value_infos.items() - ] - if value_infos - else None - ) - sub_functions = self.get_called_functions() - graph = self.to_graph_proto(use_default_type=False) - if value_infos: - graph.value_info.extend(value_infos) - if io_types is not None: - for input in graph.input: - if not input.HasField("type"): - input.type.CopyFrom(io_types.to_type_proto()) - for output in graph.output: - if not output.HasField("type"): - output.type.CopyFrom(io_types.to_type_proto()) - if input_types is not None: - for input, type in zip(graph.input, input_types): - input.type.CopyFrom(type.to_type_proto()) - if output_types is not None: - for output, type in zip(graph.output, output_types): - output.type.CopyFrom(type.to_type_proto()) - if functions is None: - functions = sub_functions.values() - else: - - def to_proto(f): - if isinstance(f, onnx.FunctionProto): - return f - if isinstance(f, onnxscript.OnnxFunction): - return f.to_function_proto() - raise TypeError("Expected a value of type FunctionProto of OnnxFunction") - - functions = [to_proto(f) for f in functions] - - opsets = self.ir_function.opset_imports.copy() - - for proto in functions: - if proto.domain not in opsets: - opsets[proto.domain] = 1 - # TODO(rama): Handle conflicts with appropriate error/warning message. - for opset in proto.opset_import: - if opset.domain not in opsets: - opsets[opset.domain] = opset.version - - if "" not in opsets: - # No operator is using the standard opset. - # Use the specified version if provided or the default value. - opsets[""] = opset_version if opset_version is not None else onnx_opset_version() - - if "ir_version" not in kwargs: - kwargs["ir_version"] = select_ir_version(opsets[""]) - opset_imports = [ - onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() - ] - - return helper.make_model( - graph, opset_imports=opset_imports, functions=functions, **kwargs - ) - def get_called_functions(self) -> dict[str, onnx.FunctionProto]: called_functions: dict[str, values.OnnxFunction] = {} @@ -297,7 +191,7 @@ def add_stmt( attrs: Sequence[ir.Attr], ) -> Sequence[ir.Value]: output_values = [ir.Value(name=o) for o in results] - attributes = attrs # [ir.from_proto(a.attr_proto) for a in attrs] + attributes = attrs node = ir.Node( domain=callee.opset.domain, version=callee.opset.version, diff --git a/onnxscript/values.py b/onnxscript/values.py index f4423d0804..fa924186e7 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -30,11 +30,23 @@ from onnxscript import irbuilder, sourceinfo, type_annotation from onnxscript._internal import ast_utils, deprecation from onnxscript.ir import _schemas +from onnxscript.onnx_types import ONNXType _R = TypeVar("_R") _P = ParamSpec("_P") +def select_ir_version(version: int, domain: str = "") -> int: + """Selects a suitable ONNX ir_version for a given opset version.""" + if domain == "": + domain = "ai.onnx" + if (domain, version) not in onnx.helper.OP_SET_ID_VERSION_MAP: + return max( + v for k, v in onnx.helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx" + ) + return onnx.helper.OP_SET_ID_VERSION_MAP[domain, version] + + _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { onnx.defs.OpSchema.AttrType.FLOAT: float, onnx.defs.OpSchema.AttrType.INT: int, @@ -609,7 +621,103 @@ def to_model_proto(self, **kwargs): # Merge kwargs specified in script-decorator with those specified in this call. merged_kw_args = {**self.kwargs, **kwargs} - return self.function_ir.to_model_proto(**merged_kw_args) + return self._to_model_proto(**merged_kw_args) + + def _to_model_proto( + self, + functions=None, + io_types: Optional[ONNXType] = None, + input_types: Optional[Sequence[ONNXType]] = None, + output_types: Optional[Sequence[ONNXType]] = None, + value_infos: dict[str, ONNXType] | None = None, + opset_version: int | None = None, + **kwargs, + ) -> onnx.ModelProto: + """Converts this instance into a `onnx.ModelProto`. + + Args: + functions: A list of functions to include in the model. + By default, all functions called at least once are included. + io_types: When specified, all the inputs/outputs of the model + are set to be of this type. + input_types: When specified, all the inputs of the model + are set to be of the corresponding type in this list. + output_types: When specified, all the outputs of the model + are set to be of the corresponding type in this list. + value_infos: A dictionary mapping intermediate variable names to ONNX types. + Used to set value_info for intermediate variables. + opset_version: The standard opset version to use for the model if it + cannot be inferred. Otherwise defaults to the current opset version. + kwargs: Additional parameters given to function :func:`onnx.helper.make_model`. + + Returns: + An instance of :class:`onnx.ModelProto`. + """ + value_infos = ( + [ + onnx.helper.make_value_info(name, type.to_type_proto()) + for name, type in value_infos.items() + ] + if value_infos + else None + ) + + graph = self.function_ir.to_graph_proto(use_default_type=False) + if value_infos: + graph.value_info.extend(value_infos) + if io_types is not None: + for input in graph.input: + if not input.HasField("type"): + input.type.CopyFrom(io_types.to_type_proto()) + for output in graph.output: + if not output.HasField("type"): + output.type.CopyFrom(io_types.to_type_proto()) + if input_types is not None: + for input, type in zip(graph.input, input_types): + input.type.CopyFrom(type.to_type_proto()) + if output_types is not None: + for output, type in zip(graph.output, output_types): + output.type.CopyFrom(type.to_type_proto()) + if functions is None: + sub_functions = self.function_ir.get_called_functions() + functions = sub_functions.values() + else: + + def to_proto(f): + if isinstance(f, onnx.FunctionProto): + return f + if isinstance(f, OnnxFunction): + return f.to_function_proto() + raise TypeError("Expected a value of type FunctionProto of OnnxFunction") + + functions = [to_proto(f) for f in functions] + + opsets = self.function_ir.ir_function.opset_imports.copy() + + for proto in functions: + if proto.domain not in opsets: + opsets[proto.domain] = 1 + # TODO(rama): Handle conflicts with appropriate error/warning message. + for opset in proto.opset_import: + if opset.domain not in opsets: + opsets[opset.domain] = opset.version + + if "" not in opsets: + # No operator is using the standard opset. + # Use the specified version if provided or the default value. + opsets[""] = ( + opset_version if opset_version is not None else onnx.defs.onnx_opset_version() + ) + + if "ir_version" not in kwargs: + kwargs["ir_version"] = select_ir_version(opsets[""]) + opset_imports = [ + onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() + ] + + return onnx.helper.make_model( + graph, opset_imports=opset_imports, functions=functions, **kwargs + ) class TracedOnnxFunction(Op): diff --git a/tests/common/onnx_script_test_case.py b/tests/common/onnx_script_test_case.py index 3a46a870a0..ecb8cd7fdc 100644 --- a/tests/common/onnx_script_test_case.py +++ b/tests/common/onnx_script_test_case.py @@ -144,7 +144,7 @@ def _create_model_from_param( # there is not way from the onnx test case's model and feed to get TypeProto # in order to build a model. # we have to resolve the TypeProto from script function. - local_function_model_proto = param.function.function_ir.to_model_proto( + local_function_model_proto = param.function.to_model_proto( ir_version=ir_version ) input_value_infos = [] @@ -202,7 +202,7 @@ def run_converter_test( param, onnx_case_model, ir_version=ir_version ) else: - model = param.function.function_ir.to_model_proto( + model = param.function.to_model_proto( producer_name="call_clip", ir_version=ir_version ) try: From 4527f92f4d8e0b6b4b551bf0b824501b52eddf02 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 05:31:11 -0800 Subject: [PATCH 29/43] More cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 4 ++-- onnxscript/irbuilder.py | 53 +++++++++++++---------------------------- onnxscript/values.py | 2 +- 3 files changed, 19 insertions(+), 40 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 71e19c7929..e5a9ea590d 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -1313,7 +1313,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars ] - attrs = [self._make_onnx_attr("body", body.ir_function.graph)] + attrs = [self._make_onnx_attr("body", body.graph)] info = self._source_of(loop_stmt) def rename(x): @@ -1378,7 +1378,7 @@ def _translate_block( typeinfo = None self.ir_builder.add_output(self._current_fn, ovar.name, typeinfo, source) graph = self._exit_scope() - return graph.ir_function.graph + return graph.graph def _translate_nested_function_def(self, fn: ast.FunctionDef) -> None: """Translate a nested function definition.""" diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 0f0b01d47b..675c8f3c78 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -31,63 +31,42 @@ def select_ir_version(version: int, domain: str = "") -> int: TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue -class IRFunction: +class IRFunction(ir.Function): """Represents a function in the IR.""" def __init__(self, name: str, domain: str = "") -> None: graph = ir.Graph(inputs=[], outputs=[], nodes=[], name=name) - self.ir_function = ir.Function(domain, name, graph=graph, attributes=[]) + super().__init__(domain, name, graph=graph, attributes=[]) self.ordered_inputs_and_attrs: list[Union[ir.Value, ir.Attr]] = [] # a dictionary of nested function-definitions self.nested_functions: dict[str, IRFunction] = {} self.outer_scope_variables: dict[Any, Any] = {} - @property - def outputs(self) -> Sequence[ir.Value]: - return self.ir_function.outputs - - @property - def domain(self) -> str: - """Returns the domain of this function.""" - return self.ir_function.domain - @property def docstring(self) -> str: """Returns the docstring of this function.""" - return self.ir_function.doc_string or "" - - @property - def name(self) -> str: - """Returns the name of this function.""" - return self.ir_function.name + return self.doc_string or "" @property def assigned_names(self) -> Sequence[str]: """Returns the list of variables assigned to by this function.""" - return [v.name for n in self.ir_function for v in n.outputs] - - @property - def inputs(self) -> Sequence[ir.Value]: - return self.ir_function.inputs + return [v.name for n in self for v in n.outputs] @property def attrs(self) -> Sequence[ir.Attr]: return [attr for attr in self.ordered_inputs_and_attrs if isinstance(attr, ir.Attr)] - def __str__(self): - return str(self.ir_function) - def append_node(self, node: ir.Node) -> None: - count = len(self.ir_function) + count = len(self) node.name = f"n{count}" - self.ir_function.append(node) + self.append(node) domain = node.domain version = node.version - if domain not in self.ir_function.opset_imports: - self.ir_function.opset_imports[domain] = version + if domain not in self.opset_imports: + self.opset_imports[domain] = version else: - existing_version = self.ir_function.opset_imports[domain] + existing_version = self.opset_imports[domain] if existing_version != version: warnings.warn( f"Version conflict: domain: {domain!r}, " @@ -98,14 +77,14 @@ def append_node(self, node: ir.Node) -> None: def append_input(self, var: ir.Value) -> None: self.ordered_inputs_and_attrs.append(var) - self.ir_function.inputs.append(var) + self.inputs.append(var) def append_output(self, var: ir.Value) -> None: - self.ir_function.outputs.append(var) + self.outputs.append(var) def add_attr_parameter(self, attr: ir.Attr) -> None: self.ordered_inputs_and_attrs.append(attr) - self.ir_function.attributes.add(attr) + self.attributes.add(attr) def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun @@ -114,7 +93,7 @@ def get_called_functions(self) -> dict[str, onnx.FunctionProto]: called_functions: dict[str, values.OnnxFunction] = {} def visit(function_ir: IRFunction): - for node in ir.traversal.RecursiveGraphIterator(function_ir.ir_function.graph): + for node in ir.traversal.RecursiveGraphIterator(function_ir.graph): callee = node.meta.get("callee", None) if isinstance(callee, values.OnnxFunction): add(callee) @@ -139,11 +118,11 @@ def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: an instance of :class:`onnx.GraphProto` """ del use_default_type # currently not used - return ir.to_proto(self.ir_function.graph) + return ir.to_proto(self.graph) def to_function_proto(self) -> onnx.FunctionProto: """Converts this instance into a `onnx.FunctionProto`.""" - return ir.to_proto(self.ir_function) + return ir.to_proto(self) # IRBuilder: abstracts out details of the IR in the python-to-IR converter @@ -180,7 +159,7 @@ def new_function(self, name: str, domain: str = "", register: bool = False) -> I return function def add_docstring(self, fn: IRFunction, docstring: str): - fn.ir_function.doc_string = docstring + fn.doc_string = docstring def add_stmt( self, diff --git a/onnxscript/values.py b/onnxscript/values.py index fa924186e7..6ab565da25 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -692,7 +692,7 @@ def to_proto(f): functions = [to_proto(f) for f in functions] - opsets = self.function_ir.ir_function.opset_imports.copy() + opsets = self.function_ir.opset_imports.copy() for proto in functions: if proto.domain not in opsets: From fd3c14a44e89564f3ab68723a77502abc2736d45 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 05:36:01 -0800 Subject: [PATCH 30/43] More cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 675c8f3c78..fa224a97ed 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -108,16 +108,8 @@ def add(f: values.OnnxFunction): return {name: f.to_function_proto() for name, f in called_functions.items()} - def to_graph_proto(self, use_default_type: bool = True) -> onnx.GraphProto: - """Converts this instance into a `onnx.GraphProto`. - - Args: - use_default_type: Unused. - - Returns: - an instance of :class:`onnx.GraphProto` - """ - del use_default_type # currently not used + def to_graph_proto(self) -> onnx.GraphProto: + """Converts this instance into a `onnx.GraphProto`.""" return ir.to_proto(self.graph) def to_function_proto(self) -> onnx.FunctionProto: From 09bc3d3a82effaed8b4b529518fe8b13125dc0c8 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 05:44:02 -0800 Subject: [PATCH 31/43] More cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/irbuilder.py | 19 ++++++++++--------- onnxscript/values.py | 2 +- tests/common/onnx_script_test_case.py | 4 +--- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index fa224a97ed..d717a954fb 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -75,17 +75,18 @@ def append_node(self, node: ir.Node) -> None: stacklevel=2, ) - def append_input(self, var: ir.Value) -> None: - self.ordered_inputs_and_attrs.append(var) - self.inputs.append(var) + def append_parameter(self, parameter: ir.Value | ir.Attr) -> None: + self.ordered_inputs_and_attrs.append(parameter) + if isinstance(parameter, ir.Value): + self.inputs.append(parameter) + else: + if not isinstance(parameter, ir.Attr): + raise TypeError(f"Expected ir.Value or ir.Attr, got {type(parameter)}") + self.attributes.add(parameter) def append_output(self, var: ir.Value) -> None: self.outputs.append(var) - def add_attr_parameter(self, attr: ir.Attr) -> None: - self.ordered_inputs_and_attrs.append(attr) - self.attributes.add(attr) - def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun @@ -180,7 +181,7 @@ def add_stmt( def add_input( self, fn: IRFunction, varname: str, type: TypeAnnotationValue, info: SourceInfo ) -> None: - fn.append_input(_make_value(varname, type, info)) + fn.append_parameter(_make_value(varname, type, info)) def add_attr_parameter( self, @@ -190,7 +191,7 @@ def add_attr_parameter( default_value: int | float | str | None, ) -> None: attr = ir.Attr(varname, ir.AttributeType(attribute_type), default_value, None) - fn.add_attr_parameter(attr) + fn.append_parameter(attr) def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: fn.append_output(_make_value(varname, typeinfo, sourceinfo)) diff --git a/onnxscript/values.py b/onnxscript/values.py index 6ab565da25..a6ff57b0f6 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -662,7 +662,7 @@ def _to_model_proto( else None ) - graph = self.function_ir.to_graph_proto(use_default_type=False) + graph = self.function_ir.to_graph_proto() if value_infos: graph.value_info.extend(value_infos) if io_types is not None: diff --git a/tests/common/onnx_script_test_case.py b/tests/common/onnx_script_test_case.py index ecb8cd7fdc..e209579749 100644 --- a/tests/common/onnx_script_test_case.py +++ b/tests/common/onnx_script_test_case.py @@ -144,9 +144,7 @@ def _create_model_from_param( # there is not way from the onnx test case's model and feed to get TypeProto # in order to build a model. # we have to resolve the TypeProto from script function. - local_function_model_proto = param.function.to_model_proto( - ir_version=ir_version - ) + local_function_model_proto = param.function.to_model_proto(ir_version=ir_version) input_value_infos = [] for i, input in enumerate(local_function_model_proto.graph.input): vi = copy.deepcopy(input) From 5e17ae1e87f4b38b4bdfae4a88dae8e145392c1f Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 09:01:02 -0800 Subject: [PATCH 32/43] Address lint warning Signed-off-by: Ganesan Ramalingam --- onnxscript/values.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/onnxscript/values.py b/onnxscript/values.py index a6ff57b0f6..c50cdcadf8 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -1,5 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. + +# ruff: noqa: TID251 + from __future__ import annotations import dataclasses @@ -189,7 +192,7 @@ def _get_attribute_value(attr_proto: onnx.AttributeProto) -> Any: """Get the default value of an ONNX attribute.""" if attr_proto.type == onnx.AttributeProto.UNDEFINED: return _EmptyDefault - return onnx.helper.get_attribute_value(attr_proto) # noqa: TID251 + return onnx.helper.get_attribute_value(attr_proto) def _param_schemas_from_op_schema( From 3e4b59c795aa18f76caa1a880df8f20baf2aa79c Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 09:05:16 -0800 Subject: [PATCH 33/43] Remove unused code Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index e5a9ea590d..6717f31d2e 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -942,9 +942,6 @@ def _translate_callee_expr(self, node: ast.AST) -> values.Op: # pylint: disable if isinstance(node, ast.Name): function_name = node.id found = self._lookup(function_name, self._source_of(node), raise_exception=False) - # if isinstance(found, onnxscript.OnnxFunction): - # self._current_fn.add_called_function(found) - # return found if isinstance(found, (values.Op, onnxscript.OnnxFunction)): return found if not found: From cdfc749cb9b534bc55179a3fadcddc579defd762 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 09:16:55 -0800 Subject: [PATCH 34/43] More cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 48 ++++++++++++++++++----------------------- onnxscript/irbuilder.py | 3 +-- 2 files changed, 22 insertions(+), 29 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 6717f31d2e..a482287ae7 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -81,9 +81,6 @@ def ignore(cond, msg): } -Variable = ir.Value - - if TYPE_CHECKING: # The type-alias LocalSymValue represents the types of values that local names in a # script-function may be bound to during translation, (ONNX IR values). @@ -328,7 +325,7 @@ def _to_onnx_var( val: values.SymbolValue | PyValue, target: Optional[PreferredName] = None, info: Optional[sourceinfo.SourceInfo] = None, - ) -> Variable: + ) -> ir.Value: if isinstance(val, values.AttrRef): # promote attribute to value result_name = self.generate_unique_name(target or "tmp") @@ -358,19 +355,16 @@ def _to_onnx_var( # produce a better error _message otherwise return self._emit_const(val, target or "tmp", info) - def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> Variable: + def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> ir.Value: return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) def emit( self, outputs: Sequence[str], callee: values.Op | str, - inputs: Sequence[Optional[Variable]], + inputs: Sequence[Optional[ir.Value]], attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, - ) -> Sequence[Variable] | Variable: - for i, x in enumerate(inputs): - if (x is not None) and not isinstance(x, ir.Value): - raise TypeError(f"Expected ONNX IR Value for input {i}, got {type(x)!r}.") + ) -> Sequence[ir.Value] | ir.Value: if not isinstance(callee, values.Op): callee = values.Op(self.default_opset, callee) if attrs is None: @@ -384,7 +378,7 @@ def emit( ) return output_values if len(output_values) > 1 else output_values[0] - def emit1(self, *args, **kwargs) -> Variable: + def emit1(self, *args, **kwargs) -> ir.Value: r = self.emit(*args, **kwargs) if not isinstance(r, ir.Value): raise TypeError(f"Expected single ONNX IR Value, got {type(r)!r}.") @@ -395,7 +389,7 @@ def _emit_const( pyvalue: PyValue, suggested_name: Optional[PreferredName], info: sourceinfo.SourceInfo, - ) -> Variable: + ) -> ir.Value: if suggested_name is None: if isinstance(pyvalue, int): if pyvalue >= 0: @@ -420,7 +414,7 @@ def _emit_const( self._castable.add(ovar) return self.emit1([ovar], values.Op(self.default_opset, "Constant"), [], [attr]) - def _emit_copy(self, original_var: Variable, suggested_name: str) -> Variable: + def _emit_copy(self, original_var: ir.Value, suggested_name: str) -> ir.Value: """Emits a copy statement, using the ONNX Identity operator.""" new_var = self.generate_unique_name(suggested_name) return self.emit([new_var], "Identity", [original_var]) @@ -540,7 +534,7 @@ def _translate_docstring(self, node: ast.Expr) -> None: def _translate_expr( self, node: ast.AST, target: Optional[PreferredName] = None - ) -> Variable: + ) -> ir.Value: """Expression-translation generates "IR statements/nodes" that compute the value of the expression into a target-variable, and returns the variable that is assigned this value. @@ -563,7 +557,7 @@ def _translate_expr( raise ValueError( self._message(node, f"Unsupported expression type {type(node)!r}.") ) - if isinstance(r, Variable): + if isinstance(r, ir.Value): return r callee, args, attrs = r target = "tmp" if target is None else target @@ -571,7 +565,7 @@ def _translate_expr( result = self.generate_unique_name(target) return self.emit1([result], callee, args, attrs) - def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: + def _translate_opt_expr(self, node: ast.expr) -> Optional[ir.Value]: """Translation of an expression where "None" is permitted (eg., for an optional argument). None is represented as a Constant in Python 3.9+. """ @@ -581,7 +575,7 @@ def _translate_opt_expr(self, node: ast.expr) -> Optional[Variable]: def _translate_subscript_expr( self, node: ast.Subscript, target: Optional[PreferredName] - ) -> Variable: + ) -> ir.Value: """List of supported syntaxes is below. `A` is a tensor or an expression equivalent to a tensor. @@ -628,15 +622,15 @@ def _translate_subscript_expr( # Create cached int constants: # TODO: Do this at a graph-scope level. - cached_int_consts: dict[int, Variable] = {} + cached_int_consts: dict[int, ir.Value] = {} - def const_1d(value, name: Optional[str] = None) -> Variable: + def const_1d(value, name: Optional[str] = None) -> ir.Value: nonlocal cached_int_consts if value not in cached_int_consts: cached_int_consts[value] = self._emit_const([value], name, info) return cached_int_consts[value] - def one_1d() -> Variable: + def one_1d() -> ir.Value: return const_1d(1) # Max/min 64-bit int values are used to represent default values for start/stop in Slice. @@ -645,7 +639,7 @@ def one_1d() -> Variable: def translate_slice_component( node_arg, default_value: Optional[int] = None - ) -> tuple[Variable, Optional[int]]: + ) -> tuple[ir.Value, Optional[int]]: """Translate optional start/stop/step component of a Slice expression.""" if node_arg is None: if default_value is None: @@ -672,7 +666,7 @@ def translate_slice_component( ) return reshaped_value, None - def translate_slice(slice_expr: ast.Slice) -> tuple[Variable, Variable, Variable]: + def translate_slice(slice_expr: ast.Slice) -> tuple[ir.Value, ir.Value, ir.Value]: """Translate slice-expression of the form from:to:step.""" step_name, step = translate_slice_component(slice_expr.step, 1) if step is None: @@ -813,7 +807,7 @@ def translate_slice(slice_expr: ast.Slice) -> tuple[Variable, Variable, Variable def _translate_call_expr( self, node: ast.Call - ) -> tuple[values.Op, list[Optional[Variable]], list[irbuilder.IRAttributeValue]]: + ) -> tuple[values.Op, list[Optional[ir.Value]], list[irbuilder.IRAttributeValue]]: """Translates a call-expression.""" callee = self._translate_callee_expr(node.func) param_schemas = callee.param_schemas() @@ -840,7 +834,7 @@ def _translate_call_expr( attrs = [attr for attr in attrs if attr is not None] return callee, args, attrs - def _cast_like_binary_expression(self, op, left, right) -> tuple[Variable, Variable]: + def _cast_like_binary_expression(self, op, left, right) -> tuple[ir.Value, ir.Value]: schema = op.op_schema return autocast.static_cast_inputs(self, schema, (left, right)) @@ -913,7 +907,7 @@ def _translate_compare_expr(self, node): return op, [left, right], [] - def _translate_name_expr(self, node: ast.Name) -> Variable: + def _translate_name_expr(self, node: ast.Name) -> ir.Value: return self._py_var_to_onnx_var(node.id, self._source_of(node)) # pylint: disable=inconsistent-return-statements @@ -1237,7 +1231,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ), ) - condition_name: Variable | None = None + condition_name: ir.Value | None = None operator_name = "Identity" for i, s in enumerate(loop_stmt.body): # We first need to intercept a break instruction in test block. @@ -1365,7 +1359,7 @@ def _translate_block( if pv_val is None: self.fail( stmts[0], - f"Variable {pvar} is not assigned a value along a conditional " + f"ir.Value {pvar} is not assigned a value along a conditional " f"branch, known variables: {list(self._locals)}.", ) # introduce a copy diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index d717a954fb..7e5241d03a 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -163,14 +163,13 @@ def add_stmt( attrs: Sequence[ir.Attr], ) -> Sequence[ir.Value]: output_values = [ir.Value(name=o) for o in results] - attributes = attrs node = ir.Node( domain=callee.opset.domain, version=callee.opset.version, op_type=callee.name, inputs=inputs, outputs=output_values, - attributes=attributes, + attributes=attrs, ) if not isinstance(callee, values.Op): raise TypeError(f"Unexpected type {type(callee)} for callee.") From a46dcc27417f439fb3df54bbacce42e93ae17a20 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Mon, 17 Nov 2025 10:39:14 -0800 Subject: [PATCH 35/43] Fix lint issue Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/autocast.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/onnxscript/_internal/autocast.py b/onnxscript/_internal/autocast.py index 911ecbb024..0afbd36d60 100644 --- a/onnxscript/_internal/autocast.py +++ b/onnxscript/_internal/autocast.py @@ -189,23 +189,21 @@ def get_type_info(x): def static_cast_inputs( converter_: converter.Converter, op_schema: Optional[OpSchema], - args: Sequence[Optional[converter.Variable]], + args: Sequence[Optional[ir.Value]], ) -> tuple[str, ...]: """Used for autocast during script-translation. This is meant to transform expressions like "Add(X, 1)" to "Add(X, CastLike(1, X))" Polymorphic constants (like 0 and 1) are cast to the type of other operands as needed. """ - def get_type_info(x: Optional[converter.Variable]) -> Optional[converter.Variable]: + def get_type_info(x: Optional[ir.Value]) -> Optional[ir.Value]: """Returns x back if x can serve as the target-type for a cast (as the second argument of CastLike) and None otherwise. In the expression "Add(X, 1), 1 is castable, while X can serve as the target-type. """ return None if x is None or converter_.is_castable(x.name) else x - def cast_like( - x: Optional[converter.Variable], y: Optional[converter.Variable] - ) -> Optional[str]: + def cast_like(x: Optional[ir.Value], y: Optional[ir.Value]) -> Optional[str]: if x is None: return None if converter_.is_castable(x.name) and y is not None: From e0281a3ed7111c8abf25d948b4d9746331ba97ab Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 18 Nov 2025 16:56:33 -0800 Subject: [PATCH 36/43] Add support for type annotation Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 22 ++++++++++++++-------- onnxscript/converter_test.py | 31 +++++++++++++++++++++++++++++++ onnxscript/irbuilder.py | 25 ++++++++++++++----------- 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index a482287ae7..8b35b2c327 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -464,6 +464,16 @@ def _eval_constant_expr(self, expr: ast.AST) -> PyValue: ) ) from e + def _get_type_annotation(self, annotation: ast.Expr) -> Optional[ta.TypeAnnotationValue]: + typeinfo = self._eval_constant_expr(annotation) + if not ta.is_valid_type(typeinfo): + self.warn( + annotation, + "Unsupported type annotation.", + ) + typeinfo = None + return typeinfo + def _translate_attr( self, attr_name: str, @@ -985,9 +995,11 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: lhs = lhs.id t = self._translate_expr(rhs, lhs) if isinstance(stmt, ast.AnnAssign): - typeinfo = self._eval_constant_expr(stmt.annotation) + typeinfo = self._get_type_annotation(stmt.annotation) else: typeinfo = None + if typeinfo is not None: + irbuilder.set_type_info(t, typeinfo) var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo) self._bind(lhs, var) elif isinstance(lhs, ast.Tuple): @@ -1400,13 +1412,7 @@ def _translate_function_signature_common( else: default_value = None if x.annotation: - typeinfo = self._eval_constant_expr(x.annotation) - if not ta.is_valid_type(typeinfo): - self.warn( - x.annotation, - f"Unsupported type annotation for argument {x.arg}.", - ) - typeinfo = None + typeinfo = self._get_type_annotation(x.annotation) else: # The code can only be exported as a function. typeinfo = None diff --git a/onnxscript/converter_test.py b/onnxscript/converter_test.py index 63cdfd2939..b6cada54b0 100644 --- a/onnxscript/converter_test.py +++ b/onnxscript/converter_test.py @@ -740,6 +740,37 @@ def model(x: FLOAT[10]) -> FLOAT[10]: model_false = make_model(False) onnxscript.testing.assert_isomorphic(model_false, sub_model.to_model_proto()) + def test_type_annotation(self): + """Test that type annotations are processed correctly.""" + + @script() + def model(x: FLOAT[10]) -> FLOAT[10]: + temp: FLOAT[10] = op.Add(x, x) + y = op.Mul(temp, temp) + return y + + model_proto = model.to_model_proto() + input_type = model_proto.graph.input[0].type.tensor_type + output_type = model_proto.graph.output[0].type.tensor_type + temp_value_info = None + for value_info in model_proto.graph.value_info: + if value_info.name == "temp": + temp_value_info = value_info + break + self.assertIsNotNone(temp_value_info, "ValueInfo for 'temp' not found in graph.") + temp_type = temp_value_info.type.tensor_type + self.assertEqual(temp_type.elem_type, onnx.TensorProto.FLOAT) + self.assertEqual(len(temp_type.shape.dim), 1) + self.assertEqual(temp_type.shape.dim[0].dim_value, 10) + + self.assertEqual(input_type.elem_type, onnx.TensorProto.FLOAT) + self.assertEqual(len(input_type.shape.dim), 1) + self.assertEqual(input_type.shape.dim[0].dim_value, 10) + + self.assertEqual(output_type.elem_type, onnx.TensorProto.FLOAT) + self.assertEqual(len(output_type.shape.dim), 1) + self.assertEqual(output_type.shape.dim[0].dim_value, 10) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 7e5241d03a..438bdb0d0a 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -121,21 +121,24 @@ def to_function_proto(self) -> onnx.FunctionProto: # IRBuilder: abstracts out details of the IR in the python-to-IR converter +def set_type_info(value: ir.Value, typeinfo: TypeAnnotationValue) -> None: + """Sets the type information on an IR value.""" + try: + type_and_shape = ir.from_proto(typeinfo.to_type_proto()) + value.type = type_and_shape.type + value.shape = type_and_shape.shape + except AttributeError: + pass + value.meta["typeinfo"] = typeinfo + + def _make_value( varname: str, typeinfo: TypeAnnotationValue, sourceinfo: SourceInfo ) -> ir.Value: - if typeinfo is None: - value = ir.Value(name=varname) - else: - try: - type_and_shape = ir.from_proto(typeinfo.to_type_proto()) - value = ir.Value( - name=varname, type=type_and_shape.type, shape=type_and_shape.shape - ) - except AttributeError: - value = ir.Value(name=varname) + value = ir.Value(name=varname) value.meta.setdefault("sourceinfo", sourceinfo) - value.meta.setdefault("typeinfo", typeinfo) + if typeinfo is not None: + set_type_info(value, typeinfo) return value From c7d1eb55d28b0ea8764d44e7a0fa4e332b449a8a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 18 Nov 2025 21:28:25 -0800 Subject: [PATCH 37/43] ir builder cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 8 +++++--- onnxscript/irbuilder.py | 7 ------- 2 files changed, 5 insertions(+), 10 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 8b35b2c327..de6fd98108 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -299,7 +299,7 @@ def tensor_name_generator() -> str: proto = autocast.pyvalue_to_onnx_attribute( attrname, attrval, tensor_name_generator, attrtype ) - return self.ir_builder.make_attr(proto) + return ir.from_proto(proto) def _to_onnx_attr_ref( self, val: values.AttrRef, info: Optional[sourceinfo.SourceInfo] @@ -318,7 +318,8 @@ def _to_onnx_attr_ref( else: msg = f"Unsupported attribute type {pytype!r}." fail(info.msg(msg) if info else msg) - return self.ir_builder.make_attr_ref(attrname, val.value, pytype) + attr_type = ir.AttributeType(ta.pytype_to_attrtype(pytype)) + return ir.Attr(attrname, attr_type, None, val.value) def _to_onnx_var( self, @@ -492,7 +493,8 @@ def _translate_attr( if isinstance(expr, ast.Name): val = self._lookup(expr.id, self._source_of(expr)) if isinstance(val, values.AttrRef): - attr_ref = self.ir_builder.make_attr_ref(attr_name, val.value, val.typeinfo) + attr_type = ir.AttributeType(ta.pytype_to_attrtype(val.typeinfo)) + attr_ref = ir.Attr(attr_name, attr_type, None, val.value) if attr_meta is not None and (attr_ref.type != attr_meta.type): self.fail( expr, diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 438bdb0d0a..e4568ebe7c 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -197,10 +197,3 @@ def add_attr_parameter( def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: fn.append_output(_make_value(varname, typeinfo, sourceinfo)) - - def make_attr(self, attrproto: onnx.AttributeProto) -> ir.Attr: - return ir.from_proto(attrproto) - - def make_attr_ref(self, attrname: str, refname: str, pytype: type) -> ir.Attr: - attr_type = ir.AttributeType(onnxscript.type_annotation.pytype_to_attrtype(pytype)) - return ir.Attr(attrname, attr_type, None, refname) From fa95a935ceb7e874e913853a546189d699bd27b3 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 19 Nov 2025 17:12:10 -0800 Subject: [PATCH 38/43] Cleanup IRBuilder Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 132 +++++++++++++++++++++++++++++----------- onnxscript/irbuilder.py | 56 +---------------- 2 files changed, 99 insertions(+), 89 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index de6fd98108..97cd65b17d 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -113,6 +113,27 @@ def ignore(cond, msg): OnnxVarName = str +def set_type_info(value: ir.Value, typeinfo: ta.TypeAnnotationValue) -> None: + """Sets the type information on an IR value.""" + try: + type_and_shape = ir.from_proto(typeinfo.to_type_proto()) + value.type = type_and_shape.type + value.shape = type_and_shape.shape + except AttributeError: + pass + value.meta["typeinfo"] = typeinfo + + +def make_value( + varname: str, typeinfo: ta.TypeAnnotationValue, source_info: sourceinfo.SourceInfo +) -> ir.Value: + value = ir.Value(name=varname) + value.meta.setdefault("sourceinfo", source_info) + if typeinfo is not None: + set_type_info(value, typeinfo) + return value + + class Converter: """Main class to translate python code into ONNX operators. @@ -140,13 +161,12 @@ class :class:`onnxscript.irbuilder.IRBuilder` is used def __init__( self, - ir_builder: Optional[irbuilder.IRBuilder] = None, opset: Optional[values.Opset] = None, global_names: Optional[dict[str, Any]] = None, source: Optional[str] = None, default_opset: Optional[values.Opset] = None, ): - self.ir_builder = ir_builder or irbuilder.IRBuilder() + self.ir_builder = irbuilder.IRBuilder() self.source = source if global_names is not None: # We make a copy in case function eval modifies it. @@ -246,7 +266,7 @@ def _enter_scope(self, name: str, parent_node: ast.AST): The block is translated into a nested-scope in ONNX. """ self._outer.insert(0, self._current_fn) - self._current_fn = self.ir_builder.new_function(name) + self._current_fn = self.new_function(name) self._locals.insert(0, {}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) @@ -359,6 +379,61 @@ def _to_onnx_var( def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> ir.Value: return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) + def new_function(self, name: str, domain: str = "", register: bool = False) -> irbuilder.IRFunction: + if register and (domain, name) in self.ir_builder.functions: + raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.") + function = irbuilder.IRFunction(name, domain) + if register: + self.ir_builder.functions[domain, name] = function + return function + + def add_stmt( + self, + results: Sequence[str], + callee: values.Op, + inputs: Sequence[Optional[ir.Value]], + attrs: Sequence[ir.Attr], + ) -> Sequence[ir.Value]: + output_values = [ir.Value(name=o) for o in results] + node = ir.Node( + domain=callee.opset.domain, + version=callee.opset.version, + op_type=callee.name, + inputs=inputs, + outputs=output_values, + attributes=attrs, + ) + if not isinstance(callee, values.Op): + raise TypeError(f"Unexpected type {type(callee)} for callee.") + node.meta.setdefault("callee", callee) + self._current_fn.append_node(node) + return output_values + + def add_attr_parameter( + self, + varname: str, + attribute_type: onnx.AttributeProto.AttributeType, + default_value: int | float | str | None, + ) -> None: + attr = ir.Attr(varname, ir.AttributeType(attribute_type), default_value, None) + self._current_fn.append_parameter(attr) + + def add_input( + self, + varname: str, + typeinfo: ta.TypeAnnotationValue, + source_info: sourceinfo.SourceInfo, + ) -> None: + self._current_fn.append_parameter(make_value(varname, typeinfo, source_info)) + + def add_output( + self, + varname: str, + typeinfo: ta.TypeAnnotationValue, + source_info: sourceinfo.SourceInfo, + ) -> None: + self._current_fn.append_output(make_value(varname, typeinfo, source_info)) + def emit( self, outputs: Sequence[str], @@ -370,8 +445,7 @@ def emit( callee = values.Op(self.default_opset, callee) if attrs is None: attrs = [] - output_values = self.ir_builder.add_stmt( - self._current_fn, + output_values = self.add_stmt( outputs, callee, inputs, @@ -539,10 +613,11 @@ def _translate_attr( def _translate_docstring(self, node: ast.Expr) -> None: if hasattr(node.value, "value"): # python 3.8+ - return self.ir_builder.add_docstring(self._current_fn, node.value.value) - raise TypeError( - f"Unexpected type {type(node)!r} for node. Unsupoorted version of python." - ) + self._current_fn.doc_string = node.value.value + else: + raise TypeError( + f"Unexpected type {type(node)!r} for node. Unsupoorted version of python." + ) def _translate_expr( self, node: ast.AST, target: Optional[PreferredName] = None @@ -1001,7 +1076,7 @@ def assign(lhs: ast.AST, rhs: ast.AST) -> None: else: typeinfo = None if typeinfo is not None: - irbuilder.set_type_info(t, typeinfo) + set_type_info(t, typeinfo) var = values.Dynamic(t, values.DynamicKind.Intermediate, info, typeinfo) self._bind(lhs, var) elif isinstance(lhs, ast.Tuple): @@ -1087,9 +1162,7 @@ def ret(exp, i, suffix): t = None else: t = self.returntype[i] - self.ir_builder.add_output( - self._current_fn, return_var.name, t, self._source_of(stmt) - ) + self.add_output(return_var.name, t, self._source_of(stmt)) return return_var val = stmt.value @@ -1210,8 +1283,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # build loop_body self._enter_scope("loop_body", loop_stmt) o_loop_var = self.generate_unique_name(p_loop_var) - self.ir_builder.add_input( - self._current_fn, + self.add_input( o_loop_var, onnx_types.INT64, self._source_of(loop_stmt), @@ -1223,8 +1295,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ), ) - self.ir_builder.add_input( - self._current_fn, + self.add_input( i_cond_var.name, onnx_types.BOOL, self._source_of(loop_stmt), @@ -1235,9 +1306,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # TODO: retrieve the annotation for variable pv is any is specified. # typeinfo = self._eval_constant_expr(pv.annotation) typeinfo = None - self.ir_builder.add_input( - self._current_fn, ov, typeinfo, self._source_of(loop_stmt) - ) + self.add_input(ov, typeinfo, self._source_of(loop_stmt)) self._bind( pv, values.Dynamic( @@ -1294,8 +1363,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: [], ) - self.ir_builder.add_output( - self._current_fn, + self.add_output( o_cond_out, onnx_types.BOOL, self._source_of(loop_stmt), @@ -1311,9 +1379,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ov = self._emit_copy(ov, pv) # TODO: retrieve variable type for the annotation if any. typeinfo = None - self.ir_builder.add_output( - self._current_fn, ov.name, typeinfo, self._source_of(loop_stmt) - ) + self.add_output(ov.name, typeinfo, self._source_of(loop_stmt)) body = self._exit_scope() inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars @@ -1358,8 +1424,7 @@ def _translate_block( # To return an outer-scope variable, an ONNX Graph has to # use an explicit copy via Identity. output = self._emit_copy(output, pvar) - self.ir_builder.add_output( - self._current_fn, + self.add_output( output.name, pv_val.typeinfo, source, @@ -1381,7 +1446,7 @@ def _translate_block( # TODO: retrieve the annotation if any. typeinfo = None - self.ir_builder.add_output(self._current_fn, ovar.name, typeinfo, source) + self.add_output(ovar.name, typeinfo, source) graph = self._exit_scope() return graph.graph @@ -1419,17 +1484,14 @@ def _translate_function_signature_common( # The code can only be exported as a function. typeinfo = None if typeinfo and ta.is_attr_type(typeinfo): - self.ir_builder.add_attr_parameter( - self._current_fn, + self.add_attr_parameter( x.arg, ta.pytype_to_attrtype(typeinfo), default_value, ) self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x))) else: - self.ir_builder.add_input( - self._current_fn, x.arg, typeinfo, self._source_of(x) - ) + self.add_input(x.arg, typeinfo, self._source_of(x)) self._used_vars.add(x.arg) self._bind( x.arg, @@ -1471,7 +1533,7 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: if opset: self._set_default_opset(opset, stmt) domain = self.this_module.domain - self._current_fn = self.ir_builder.new_function(stmt.name, domain, True) + self._current_fn = self.new_function(stmt.name, domain, True) self._analyzer = analysis.AstAnalyzer(stmt, self._message, self.globals) fn_ir = self._translate_function_def_common(stmt) self.this_module.add_function_def(fn_ir) @@ -1482,5 +1544,5 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: """Translate a (top-level) function signature.""" domain = self.this_module.domain - self._current_fn = self.ir_builder.new_function(fn.name, domain, True) + self._current_fn = self.new_function(fn.name, domain, True) return self._translate_function_signature_common(fn) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index e4568ebe7c..c131528feb 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -5,7 +5,7 @@ import logging import warnings -from typing import Any, Optional, Sequence, Union +from typing import Any, Sequence, Union import onnx import onnx_ir as ir @@ -132,7 +132,7 @@ def set_type_info(value: ir.Value, typeinfo: TypeAnnotationValue) -> None: value.meta["typeinfo"] = typeinfo -def _make_value( +def make_value( varname: str, typeinfo: TypeAnnotationValue, sourceinfo: SourceInfo ) -> ir.Value: value = ir.Value(name=varname) @@ -145,55 +145,3 @@ def _make_value( class IRBuilder: def __init__(self): self.functions = {} - - def new_function(self, name: str, domain: str = "", register: bool = False) -> IRFunction: - if register and (domain, name) in self.functions: - raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.") - function = IRFunction(name, domain) - if register: - self.functions[domain, name] = function - return function - - def add_docstring(self, fn: IRFunction, docstring: str): - fn.doc_string = docstring - - def add_stmt( - self, - fn: IRFunction, - results: Sequence[str], - callee: values.Op, - inputs: Sequence[Optional[ir.Value]], - attrs: Sequence[ir.Attr], - ) -> Sequence[ir.Value]: - output_values = [ir.Value(name=o) for o in results] - node = ir.Node( - domain=callee.opset.domain, - version=callee.opset.version, - op_type=callee.name, - inputs=inputs, - outputs=output_values, - attributes=attrs, - ) - if not isinstance(callee, values.Op): - raise TypeError(f"Unexpected type {type(callee)} for callee.") - node.meta.setdefault("callee", callee) - fn.append_node(node) - return output_values - - def add_input( - self, fn: IRFunction, varname: str, type: TypeAnnotationValue, info: SourceInfo - ) -> None: - fn.append_parameter(_make_value(varname, type, info)) - - def add_attr_parameter( - self, - fn: IRFunction, - varname: str, - attribute_type: onnx.AttributeProto.AttributeType, - default_value: int | float | str | None, - ) -> None: - attr = ir.Attr(varname, ir.AttributeType(attribute_type), default_value, None) - fn.append_parameter(attr) - - def add_output(self, fn: IRFunction, varname: str, typeinfo, sourceinfo) -> None: - fn.append_output(_make_value(varname, typeinfo, sourceinfo)) From 86e10dd56fd1dc752f299ff82847ce46e2714ee0 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Wed, 19 Nov 2025 22:07:45 -0800 Subject: [PATCH 39/43] More cleanup Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 139 +++++++++++++++------------------------- 1 file changed, 53 insertions(+), 86 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index 97cd65b17d..b198229a75 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -266,7 +266,7 @@ def _enter_scope(self, name: str, parent_node: ast.AST): The block is translated into a nested-scope in ONNX. """ self._outer.insert(0, self._current_fn) - self._current_fn = self.new_function(name) + self._current_fn = irbuilder.IRFunction(name) self._locals.insert(0, {}) logger.debug("Converter:_enter_scope:%d:node:%s", len(self._locals), type(parent_node)) @@ -379,22 +379,18 @@ def _to_onnx_var( def _py_var_to_onnx_var(self, py_var: str, info: sourceinfo.SourceInfo) -> ir.Value: return self._to_onnx_var(self._lookup(py_var, info), target=py_var, info=info) - def new_function(self, name: str, domain: str = "", register: bool = False) -> irbuilder.IRFunction: - if register and (domain, name) in self.ir_builder.functions: - raise RuntimeError(f"Function '{name}' already exists in domain '{domain}'.") - function = irbuilder.IRFunction(name, domain) - if register: - self.ir_builder.functions[domain, name] = function - return function - - def add_stmt( + def emit( self, - results: Sequence[str], - callee: values.Op, + outputs: Sequence[str], + callee: values.Op | str, inputs: Sequence[Optional[ir.Value]], - attrs: Sequence[ir.Attr], - ) -> Sequence[ir.Value]: - output_values = [ir.Value(name=o) for o in results] + attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, + ) -> Sequence[ir.Value] | ir.Value: + if not isinstance(callee, values.Op): + callee = values.Op(self.default_opset, callee) + if attrs is None: + attrs = [] + output_values = [ir.Value(name=o) for o in outputs] node = ir.Node( domain=callee.opset.domain, version=callee.opset.version, @@ -407,50 +403,7 @@ def add_stmt( raise TypeError(f"Unexpected type {type(callee)} for callee.") node.meta.setdefault("callee", callee) self._current_fn.append_node(node) - return output_values - def add_attr_parameter( - self, - varname: str, - attribute_type: onnx.AttributeProto.AttributeType, - default_value: int | float | str | None, - ) -> None: - attr = ir.Attr(varname, ir.AttributeType(attribute_type), default_value, None) - self._current_fn.append_parameter(attr) - - def add_input( - self, - varname: str, - typeinfo: ta.TypeAnnotationValue, - source_info: sourceinfo.SourceInfo, - ) -> None: - self._current_fn.append_parameter(make_value(varname, typeinfo, source_info)) - - def add_output( - self, - varname: str, - typeinfo: ta.TypeAnnotationValue, - source_info: sourceinfo.SourceInfo, - ) -> None: - self._current_fn.append_output(make_value(varname, typeinfo, source_info)) - - def emit( - self, - outputs: Sequence[str], - callee: values.Op | str, - inputs: Sequence[Optional[ir.Value]], - attrs: Optional[Sequence[irbuilder.IRAttributeValue]] = None, - ) -> Sequence[ir.Value] | ir.Value: - if not isinstance(callee, values.Op): - callee = values.Op(self.default_opset, callee) - if attrs is None: - attrs = [] - output_values = self.add_stmt( - outputs, - callee, - inputs, - attrs, - ) return output_values if len(output_values) > 1 else output_values[0] def emit1(self, *args, **kwargs) -> ir.Value: @@ -1162,7 +1115,9 @@ def ret(exp, i, suffix): t = None else: t = self.returntype[i] - self.add_output(return_var.name, t, self._source_of(stmt)) + self._current_fn.append_output( + make_value(return_var.name, t, self._source_of(stmt)) + ) return return_var val = stmt.value @@ -1283,10 +1238,12 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # build loop_body self._enter_scope("loop_body", loop_stmt) o_loop_var = self.generate_unique_name(p_loop_var) - self.add_input( - o_loop_var, - onnx_types.INT64, - self._source_of(loop_stmt), + self._current_fn.append_parameter( + make_value( + o_loop_var, + onnx_types.INT64, + self._source_of(loop_stmt), + ) ) self._bind( p_loop_var, @@ -1295,10 +1252,12 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ), ) - self.add_input( - i_cond_var.name, - onnx_types.BOOL, - self._source_of(loop_stmt), + self._current_fn.append_parameter( + make_value( + i_cond_var.name, + onnx_types.BOOL, + self._source_of(loop_stmt), + ) ) for pv in loop_state_vars: @@ -1306,7 +1265,9 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: # TODO: retrieve the annotation for variable pv is any is specified. # typeinfo = self._eval_constant_expr(pv.annotation) typeinfo = None - self.add_input(ov, typeinfo, self._source_of(loop_stmt)) + self._current_fn.append_parameter( + make_value(ov, typeinfo, self._source_of(loop_stmt)) + ) self._bind( pv, values.Dynamic( @@ -1363,10 +1324,12 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: [], ) - self.add_output( - o_cond_out, - onnx_types.BOOL, - self._source_of(loop_stmt), + self._current_fn.append_output( + make_value( + o_cond_out, + onnx_types.BOOL, + self._source_of(loop_stmt), + ) ) for pv in loop_state_vars: ov = self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) @@ -1379,7 +1342,9 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ov = self._emit_copy(ov, pv) # TODO: retrieve variable type for the annotation if any. typeinfo = None - self.add_output(ov.name, typeinfo, self._source_of(loop_stmt)) + self._current_fn.append_output( + make_value(ov.name, typeinfo, self._source_of(loop_stmt)) + ) body = self._exit_scope() inputs = [o_loop_bound, o_loop_condition] + [ self._py_var_to_onnx_var(pv, self._source_of(loop_stmt)) for pv in loop_state_vars @@ -1424,10 +1389,12 @@ def _translate_block( # To return an outer-scope variable, an ONNX Graph has to # use an explicit copy via Identity. output = self._emit_copy(output, pvar) - self.add_output( - output.name, - pv_val.typeinfo, - source, + self._current_fn.append_output( + make_value( + output.name, + pv_val.typeinfo, + source, + ) ) else: pv_val = None @@ -1446,7 +1413,7 @@ def _translate_block( # TODO: retrieve the annotation if any. typeinfo = None - self.add_output(ovar.name, typeinfo, source) + self._current_fn.append_output(make_value(ovar.name, typeinfo, source)) graph = self._exit_scope() return graph.graph @@ -1484,14 +1451,14 @@ def _translate_function_signature_common( # The code can only be exported as a function. typeinfo = None if typeinfo and ta.is_attr_type(typeinfo): - self.add_attr_parameter( - x.arg, - ta.pytype_to_attrtype(typeinfo), - default_value, - ) + attribute_type = ta.pytype_to_attrtype(typeinfo) + attr = ir.Attr(x.arg, ir.AttributeType(attribute_type), default_value, None) + self._current_fn.append_parameter(attr) self._bind(x.arg, values.AttrRef(x.arg, typeinfo, self._source_of(x))) else: - self.add_input(x.arg, typeinfo, self._source_of(x)) + self._current_fn.append_parameter( + make_value(x.arg, typeinfo, self._source_of(x)) + ) self._used_vars.add(x.arg) self._bind( x.arg, @@ -1533,7 +1500,7 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: if opset: self._set_default_opset(opset, stmt) domain = self.this_module.domain - self._current_fn = self.new_function(stmt.name, domain, True) + self._current_fn = irbuilder.IRFunction(stmt.name, domain) self._analyzer = analysis.AstAnalyzer(stmt, self._message, self.globals) fn_ir = self._translate_function_def_common(stmt) self.this_module.add_function_def(fn_ir) @@ -1544,5 +1511,5 @@ def translate_function_def(self, stmt: ast.FunctionDef) -> irbuilder.IRFunction: def translate_function_signature(self, fn: ast.FunctionDef) -> irbuilder.IRFunction: """Translate a (top-level) function signature.""" domain = self.this_module.domain - self._current_fn = self.new_function(fn.name, domain, True) + self._current_fn = irbuilder.IRFunction(fn.name, domain) return self._translate_function_signature_common(fn) From 62627cbf8b7ef76d85eb47799c251045a37bbd8a Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Tue, 25 Nov 2025 22:24:24 -0500 Subject: [PATCH 40/43] Cleanup irbuilder Signed-off-by: Ganesan Ramalingam --- onnxscript/converter.py | 12 +++++----- onnxscript/evaluator.py | 2 +- onnxscript/irbuilder.py | 49 ----------------------------------------- onnxscript/values.py | 2 +- 4 files changed, 8 insertions(+), 57 deletions(-) diff --git a/onnxscript/converter.py b/onnxscript/converter.py index b198229a75..abef235092 100644 --- a/onnxscript/converter.py +++ b/onnxscript/converter.py @@ -120,6 +120,7 @@ def set_type_info(value: ir.Value, typeinfo: ta.TypeAnnotationValue) -> None: value.type = type_and_shape.type value.shape = type_and_shape.shape except AttributeError: + # TODO: This needs to be fixed. pass value.meta["typeinfo"] = typeinfo @@ -166,7 +167,6 @@ def __init__( source: Optional[str] = None, default_opset: Optional[values.Opset] = None, ): - self.ir_builder = irbuilder.IRBuilder() self.source = source if global_names is not None: # We make a copy in case function eval modifies it. @@ -1115,7 +1115,7 @@ def ret(exp, i, suffix): t = None else: t = self.returntype[i] - self._current_fn.append_output( + self._current_fn.outputs.append( make_value(return_var.name, t, self._source_of(stmt)) ) return return_var @@ -1324,7 +1324,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: [], ) - self._current_fn.append_output( + self._current_fn.outputs.append( make_value( o_cond_out, onnx_types.BOOL, @@ -1342,7 +1342,7 @@ def _translate_loop_stmt(self, loop_stmt: Union[ast.For, ast.While]) -> None: ov = self._emit_copy(ov, pv) # TODO: retrieve variable type for the annotation if any. typeinfo = None - self._current_fn.append_output( + self._current_fn.outputs.append( make_value(ov.name, typeinfo, self._source_of(loop_stmt)) ) body = self._exit_scope() @@ -1389,7 +1389,7 @@ def _translate_block( # To return an outer-scope variable, an ONNX Graph has to # use an explicit copy via Identity. output = self._emit_copy(output, pvar) - self._current_fn.append_output( + self._current_fn.outputs.append( make_value( output.name, pv_val.typeinfo, @@ -1413,7 +1413,7 @@ def _translate_block( # TODO: retrieve the annotation if any. typeinfo = None - self._current_fn.append_output(make_value(ovar.name, typeinfo, source)) + self._current_fn.outputs.append(make_value(ovar.name, typeinfo, source)) graph = self._exit_scope() return graph.graph diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index a644108a78..a606700708 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -447,7 +447,7 @@ def make_tensor_name() -> str: model = onnx.helper.make_model( # noqa: TID251 graph, opset_imports=[opset_id], - ir_version=irbuilder.select_ir_version(schema.since_version, domain=schema.domain), + ir_version=values.select_ir_version(schema.since_version, domain=schema.domain), ) model = onnx.shape_inference.infer_shapes(model) diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index c131528feb..0a449bfa8c 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -16,21 +16,8 @@ logger = logging.getLogger("onnxscript") - -def select_ir_version(version: int, domain: str = "") -> int: - """Selects a suitable ONNX ir_version for a given opset version.""" - if domain == "": - domain = "ai.onnx" - if (domain, version) not in onnx.helper.OP_SET_ID_VERSION_MAP: - return max( - v for k, v in onnx.helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx" - ) - return onnx.helper.OP_SET_ID_VERSION_MAP[domain, version] - - TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue - class IRFunction(ir.Function): """Represents a function in the IR.""" @@ -43,11 +30,6 @@ def __init__(self, name: str, domain: str = "") -> None: self.nested_functions: dict[str, IRFunction] = {} self.outer_scope_variables: dict[Any, Any] = {} - @property - def docstring(self) -> str: - """Returns the docstring of this function.""" - return self.doc_string or "" - @property def assigned_names(self) -> Sequence[str]: """Returns the list of variables assigned to by this function.""" @@ -84,9 +66,6 @@ def append_parameter(self, parameter: ir.Value | ir.Attr) -> None: raise TypeError(f"Expected ir.Value or ir.Attr, got {type(parameter)}") self.attributes.add(parameter) - def append_output(self, var: ir.Value) -> None: - self.outputs.append(var) - def add_nested_function(self, fun: IRFunction) -> None: self.nested_functions[fun.name] = fun @@ -117,31 +96,3 @@ def to_function_proto(self) -> onnx.FunctionProto: """Converts this instance into a `onnx.FunctionProto`.""" return ir.to_proto(self) - -# IRBuilder: abstracts out details of the IR in the python-to-IR converter - - -def set_type_info(value: ir.Value, typeinfo: TypeAnnotationValue) -> None: - """Sets the type information on an IR value.""" - try: - type_and_shape = ir.from_proto(typeinfo.to_type_proto()) - value.type = type_and_shape.type - value.shape = type_and_shape.shape - except AttributeError: - pass - value.meta["typeinfo"] = typeinfo - - -def make_value( - varname: str, typeinfo: TypeAnnotationValue, sourceinfo: SourceInfo -) -> ir.Value: - value = ir.Value(name=varname) - value.meta.setdefault("sourceinfo", sourceinfo) - if typeinfo is not None: - set_type_info(value, typeinfo) - return value - - -class IRBuilder: - def __init__(self): - self.functions = {} diff --git a/onnxscript/values.py b/onnxscript/values.py index c50cdcadf8..028b0728df 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -457,7 +457,7 @@ def _op_schema_from_function_ir( function_ir.name, opset.domain, since_version=opset.version, - doc=function_ir.docstring, + doc=function_ir.doc_string or "", inputs=formal_inputs, outputs=formal_outputs, type_constraints=[constraint.as_tuple() for constraint in type_to_constraint.values()], From 811bb0ab9e8ea45a4df75b7d3adc857afcf82152 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 18 Dec 2025 16:43:39 -0800 Subject: [PATCH 41/43] minor fixes Signed-off-by: Ganesan Ramalingam --- onnxscript/_internal/utils.py | 43 +++++++++++++++++++++++++++++++++++ onnxscript/irbuilder.py | 4 +--- 2 files changed, 44 insertions(+), 3 deletions(-) diff --git a/onnxscript/_internal/utils.py b/onnxscript/_internal/utils.py index ce2b657cfd..17028d583e 100644 --- a/onnxscript/_internal/utils.py +++ b/onnxscript/_internal/utils.py @@ -7,6 +7,7 @@ import numpy as np import onnx +import onnx_ir as ir from onnxscript import tensor @@ -86,6 +87,38 @@ def value_to_type_proto(val): return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251 raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.") +def value_to_type(val): + """Return an ir.Value representation of a python-value.""" + if isinstance(val, (np.ndarray, tensor.Tensor)): + elem_type = onnx.helper.np_dtype_to_tensor_dtype(val.dtype) # noqa: TID251 + shape = val.shape + return (ir.TensorType(elem_type), shape) + elif isinstance(val, int): + elem_type = onnx.TensorProto.INT32 + shape = [] + return (ir.TensorType(elem_type), shape) + elif isinstance(val, (float, np.float32)): + elem_type = onnx.TensorProto.FLOAT + shape = [] + return (ir.TensorType(elem_type), shape) + elif isinstance(val, list): + if len(val) > 0: + type, shape = value_to_type(val[0]) + return ir.SequenceType(type), shape + # Edge-case. Cannot determine a suitable ONNX type for an empty list. + # Should be using a typed-value instead. + # Treated as a sequence of tensors of float-type. + return ir.SequenceType(ir.TensorType(onnx.TensorProto.FLOAT)), None + if isinstance(val, numbers.Number): + nparray = np.array(val) + elem_type = onnx.helper.np_dtype_to_tensor_dtype(nparray.dtype) # noqa: TID251 + return ir.TensorType(elem_type), [] + raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.") + +def value_to_ir_value(name: str, val) -> ir.Value: + """Return an ir.Value representation of a python-value.""" + type, shape = value_to_type(val) + return ir.Value(name=name, type=type, shape=shape) def values_to_value_infos(name_values): """Create a list of ValueInfoProto from a list of (name, value) pairs, @@ -96,3 +129,13 @@ def values_to_value_infos(name_values): for (name, val) in name_values if val is not None ] + +def values_to_ir_values(name_values): + """Create a list of ir.Value from a list of (name, value) pairs, + skipping any None values. + """ + return [ + value_to_ir_value(name, val) + for (name, val) in name_values + if val is not None + ] \ No newline at end of file diff --git a/onnxscript/irbuilder.py b/onnxscript/irbuilder.py index 0a449bfa8c..c67ecc4be5 100644 --- a/onnxscript/irbuilder.py +++ b/onnxscript/irbuilder.py @@ -1,6 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -# ruff: noqa: TID251 from __future__ import annotations import logging @@ -12,12 +11,12 @@ import onnxscript.type_annotation from onnxscript import values -from onnxscript.sourceinfo import SourceInfo logger = logging.getLogger("onnxscript") TypeAnnotationValue = onnxscript.type_annotation.TypeAnnotationValue + class IRFunction(ir.Function): """Represents a function in the IR.""" @@ -95,4 +94,3 @@ def to_graph_proto(self) -> onnx.GraphProto: def to_function_proto(self) -> onnx.FunctionProto: """Converts this instance into a `onnx.FunctionProto`.""" return ir.to_proto(self) - From 4f97263b32c5a9ecef52553d4127353b79905835 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Thu, 18 Dec 2025 17:07:24 -0800 Subject: [PATCH 42/43] Address PR feedback --- onnxscript/_internal/utils.py | 10 +++++----- onnxscript/evaluator.py | 2 +- onnxscript/values.py | 3 ++- 3 files changed, 8 insertions(+), 7 deletions(-) diff --git a/onnxscript/_internal/utils.py b/onnxscript/_internal/utils.py index 17028d583e..c5e05b3d92 100644 --- a/onnxscript/_internal/utils.py +++ b/onnxscript/_internal/utils.py @@ -87,6 +87,7 @@ def value_to_type_proto(val): return onnx.helper.make_tensor_type_proto(elem_type, []) # noqa: TID251 raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.") + def value_to_type(val): """Return an ir.Value representation of a python-value.""" if isinstance(val, (np.ndarray, tensor.Tensor)): @@ -115,11 +116,13 @@ def value_to_type(val): return ir.TensorType(elem_type), [] raise ValueError(f"Value of type {type(val)} is invalid as an ONNX input/output.") + def value_to_ir_value(name: str, val) -> ir.Value: """Return an ir.Value representation of a python-value.""" type, shape = value_to_type(val) return ir.Value(name=name, type=type, shape=shape) + def values_to_value_infos(name_values): """Create a list of ValueInfoProto from a list of (name, value) pairs, skipping any None values. @@ -130,12 +133,9 @@ def values_to_value_infos(name_values): if val is not None ] + def values_to_ir_values(name_values): """Create a list of ir.Value from a list of (name, value) pairs, skipping any None values. """ - return [ - value_to_ir_value(name, val) - for (name, val) in name_values - if val is not None - ] \ No newline at end of file + return [value_to_ir_value(name, val) for (name, val) in name_values if val is not None] diff --git a/onnxscript/evaluator.py b/onnxscript/evaluator.py index a606700708..327440243e 100644 --- a/onnxscript/evaluator.py +++ b/onnxscript/evaluator.py @@ -23,7 +23,7 @@ import onnx.reference from typing_extensions import TypeAlias -from onnxscript import irbuilder, onnx_opset, tensor, values +from onnxscript import onnx_opset, tensor, values from onnxscript._internal import autocast, param_manipulation, utils UserModeValue: TypeAlias = Union[Optional[np.ndarray], Sequence["UserModeValue"]] diff --git a/onnxscript/values.py b/onnxscript/values.py index 028b0728df..6257e9d762 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -47,7 +47,8 @@ def select_ir_version(version: int, domain: str = "") -> int: return max( v for k, v in onnx.helper.OP_SET_ID_VERSION_MAP.items() if k[0] == "ai.onnx" ) - return onnx.helper.OP_SET_ID_VERSION_MAP[domain, version] + required_min_version = onnx.helper.OP_SET_ID_VERSION_MAP[domain, version] + return max(required_min_version, 10) _ATTRIBUTE_TYPE_TO_PYTHON_TYPE = { From ac77c9225854292f38df0bcccf09bfddde1c44b1 Mon Sep 17 00:00:00 2001 From: Ganesan Ramalingam Date: Fri, 19 Dec 2025 09:19:23 -0800 Subject: [PATCH 43/43] Address PR feedback Signed-off-by: Ganesan Ramalingam --- onnxscript/values.py | 76 ++++++++++++++++++++++++-------------------- 1 file changed, 42 insertions(+), 34 deletions(-) diff --git a/onnxscript/values.py b/onnxscript/values.py index 6257e9d762..4c33664226 100644 --- a/onnxscript/values.py +++ b/onnxscript/values.py @@ -657,31 +657,7 @@ def _to_model_proto( Returns: An instance of :class:`onnx.ModelProto`. """ - value_infos = ( - [ - onnx.helper.make_value_info(name, type.to_type_proto()) - for name, type in value_infos.items() - ] - if value_infos - else None - ) - - graph = self.function_ir.to_graph_proto() - if value_infos: - graph.value_info.extend(value_infos) - if io_types is not None: - for input in graph.input: - if not input.HasField("type"): - input.type.CopyFrom(io_types.to_type_proto()) - for output in graph.output: - if not output.HasField("type"): - output.type.CopyFrom(io_types.to_type_proto()) - if input_types is not None: - for input, type in zip(graph.input, input_types): - input.type.CopyFrom(type.to_type_proto()) - if output_types is not None: - for output, type in zip(graph.output, output_types): - output.type.CopyFrom(type.to_type_proto()) + # Identify functions to include in the model if functions is None: sub_functions = self.function_ir.get_called_functions() functions = sub_functions.values() @@ -696,7 +672,8 @@ def to_proto(f): functions = [to_proto(f) for f in functions] - opsets = self.function_ir.opset_imports.copy() + # Determine opset imports + opsets = self.function_ir.graph.opset_imports for proto in functions: if proto.domain not in opsets: @@ -713,15 +690,46 @@ def to_proto(f): opset_version if opset_version is not None else onnx.defs.onnx_opset_version() ) - if "ir_version" not in kwargs: - kwargs["ir_version"] = select_ir_version(opsets[""]) - opset_imports = [ - onnx.helper.make_opsetid(domain, version) for domain, version in opsets.items() - ] + # Determine ir_version + if "ir_version" in kwargs: + ir_version = kwargs.pop("ir_version") + else: + ir_version = select_ir_version(opsets[""]) + + # Create the model + model = ir.Model(self.function_ir.graph, ir_version=ir_version) + model_proto = ir.to_proto(model) + model_proto.functions.extend(functions) - return onnx.helper.make_model( - graph, opset_imports=opset_imports, functions=functions, **kwargs - ) + # Set additional type information if provided + graph = model_proto.graph + + if value_infos: + graph.value_info.extend( + [ + onnx.helper.make_value_info(name, type.to_type_proto()) + for name, type in value_infos.items() + ] + ) + + if io_types is not None: + for input in graph.input: + if not input.HasField("type"): + input.type.CopyFrom(io_types.to_type_proto()) + for output in graph.output: + if not output.HasField("type"): + output.type.CopyFrom(io_types.to_type_proto()) + if input_types is not None: + for input, type in zip(graph.input, input_types): + input.type.CopyFrom(type.to_type_proto()) + if output_types is not None: + for output, type in zip(graph.output, output_types): + output.type.CopyFrom(type.to_type_proto()) + + for k, v in kwargs.items(): + setattr(model_proto, k, v) + + return model_proto class TracedOnnxFunction(Op):