From ce7672264e36abf48b37d9a520720ebe1c58cf43 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:31:15 -0700 Subject: [PATCH 01/16] Consolidate all overloads and prevent new ones from being created Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 20 +++++++++++++------ .../function_libs/torch_lib/ops_test_data.py | 18 ++--------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index e837bfadae..162672696e 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4060,7 +4060,9 @@ def _aten_index_onnx( @torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) -def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType: +def aten_index( + self: TensorType, indices: Sequence[Optional[Union[INT64, BOOL]]] +) -> TensorType: """index.Tensor(Tensor self, Tensor?[] indices) -> Tensor NOTE: Understanding `aten::index` @@ -4080,14 +4082,17 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy None in `indices` are like fillers for dimensions that cannot be removed in the process. """ + # Handle Boolean indexing first + for index in indices: + if index is not None and index.dtype == BOOL.dtype: + return _aten_index_bool(self, indices) index_ranks = [len(index.shape) for index in indices if index is not None] return _aten_index_onnx(self, indices, index_ranks) -@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) -def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements +def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements index_ranks = [len(index.shape) for index in indices if index is not None] if index_ranks[0] == 1: @@ -4146,7 +4151,7 @@ def aten_index_copy( @torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, - indices: Sequence[INT64], + indices: Sequence[[Union[INT64, BOOL]]], values: TReal, accumulate: bool = False, ) -> TReal: @@ -4155,6 +4160,10 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ + # Handle Boolean indexing first + for index in indices: + if index is not None and index.dtype == BOOL.dtype: + return _aten_index_put_bool(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): # Remove ones until the rank of reshape_list matches values_shape. @@ -4232,8 +4241,7 @@ def _make_reshape_list_broadcastable(reshape_list, values_shape): return result -@torch_op("aten::index_put", trace_only=True) -def aten_index_put_bool( +def _aten_index_put_bool( self: TReal, indices: Sequence[BOOL], values: TReal, diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index b60fd8cf31..4a36da2d67 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -721,23 +721,10 @@ def _where_input_wrangler( # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), - TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool), - TorchLibOpInfo( - "index_put_bool", - core_ops.aten_index_put_bool, - input_wrangler=_index_put_input_wrangler, - ).skip( - matcher=lambda sample: sample.args[0][0].dtype != torch.bool, - reason="this Aten overload only supports tensor(bool) as indices", - ), + TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index), TorchLibOpInfo( "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler - ) - .skip( - matcher=lambda sample: sample.args[0][0].dtype != torch.int64, - reason="this Aten overload only supports tensor(int) as indices", - ) - .xfail( + ).xfail( dtypes=(torch.float16,), matcher=lambda sample: sample.kwargs.get("accumulate") is True, reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", @@ -1806,7 +1793,6 @@ def _where_input_wrangler( ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) -ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) From 477d9fbb18335620502fefb48a5a641632e59211 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:36:30 -0700 Subject: [PATCH 02/16] Remove registration of private functions Signed-off-by: Justin Chu --- .../torch_lib/deduce_type_constraints_test.py | 2 +- .../function_libs/torch_lib/ops/core.py | 2 -- onnxscript/function_libs/torch_lib/ops/nn.py | 1 - .../function_libs/torch_lib/registration.py | 23 ++++++++++--------- 4 files changed, 13 insertions(+), 15 deletions(-) diff --git a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py index a8d15c242a..a2db474acc 100644 --- a/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py +++ b/onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py @@ -21,7 +21,7 @@ def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunction, None, None]: for op in registration.default_registry.values(): - for func in (*op.overloads, *op.privates, *op.complex): + for func in (*op.overloads, *op.complex): if isinstance(func, onnxscript.OnnxFunction): yield func diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 162672696e..4612c01c27 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -3704,7 +3704,6 @@ def aten_grid_sampler( padding_mode_options = ("zeros", "border", "reflection") padding_mode_str = padding_mode_options[padding_mode] - # Only one onnx Op so don't put into private function return op.GridSample( input, grid, @@ -3730,7 +3729,6 @@ def aten_grid_sampler_2d( padding_mode_options = ("zeros", "border", "reflection") padding_mode_str = padding_mode_options[padding_mode] - # Only one onnx Op so don't put into private function return op.GridSample( input, grid, diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index 2a7a46ec28..ab733a7b46 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -330,7 +330,6 @@ def aten_col2im( else: # assert len(padding) == 4, already [w, x, y, z] pads = padding - # Only one ONNX op here so didn't write a private function return op.Col2Im( self, output_size, diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 162d69d747..c7c0a39634 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -22,14 +22,12 @@ class OverloadedFunction: Attributes: name: Name of the op. E.g. "aten::add". overloads: Overloads function. - privates: Private functions not exposed to users. complex: Support complex functions. """ def __init__(self, name: str): self.name = name self.overloads: list[Any] = [] - self.privates: list[Any] = [] self.complex: list[Any] = [] @@ -39,17 +37,18 @@ class Registry: def __init__(self): self._registry: dict[str, OverloadedFunction] = {} - def register( - self, func: Any, name: str, *, private: bool = False, complex: bool = False - ) -> None: + def register(self, func: Any, name: str, *, complex: bool = False) -> None: """Register a function.""" + overloaded_function = self._registry.setdefault(name, OverloadedFunction(name)) - if private: - self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func) - elif complex: - self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func) + if complex: + if overloaded_function.complex: + raise ValueError(f"Complex overload for '{name}' already registered.") + overloaded_function.complex.append(func) else: - self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func) + if overloaded_function.overloads: + raise ValueError(f"Real overload for '{name}' already registered.") + overloaded_function.overloads.append(func) def __getitem__(self, name): return self._registry[name] @@ -131,7 +130,9 @@ def wrapper( assert registry is not None for name_ in _check_and_normalize_names(name): - registry.register(processed_func, name_, private=private, complex=complex) + if private: + continue + registry.register(processed_func, name_, complex=complex) return processed_func return wrapper From f704a3a0e5b24d041bb464a205c18de8847127cd Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:41:19 -0700 Subject: [PATCH 03/16] fix typing Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 4612c01c27..d37e149524 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4082,7 +4082,9 @@ def aten_index( """ # Handle Boolean indexing first for index in indices: - if index is not None and index.dtype == BOOL.dtype: + if index is None: + continue + if index.dtype == BOOL.dtype: return _aten_index_bool(self, indices) index_ranks = [len(index.shape) for index in indices if index is not None] @@ -4149,7 +4151,7 @@ def aten_index_copy( @torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True) def aten_index_put( self: TReal, - indices: Sequence[[Union[INT64, BOOL]]], + indices: Sequence[Optional[Union[INT64, BOOL]]], values: TReal, accumulate: bool = False, ) -> TReal: @@ -4160,7 +4162,9 @@ def aten_index_put( """ # Handle Boolean indexing first for index in indices: - if index is not None and index.dtype == BOOL.dtype: + if index is None: + continue + if index.dtype == BOOL.dtype: return _aten_index_put_bool(self, indices, values, accumulate=accumulate) def _make_reshape_list_broadcastable(reshape_list, values_shape): From 8f69c8daead1d0e7bdbc84879d529eb84ebf2f75 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:43:58 -0700 Subject: [PATCH 04/16] test Signed-off-by: Justin Chu --- tests/function_libs/torch_lib/ops_test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4a36da2d67..f68a9f7d34 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -724,7 +724,7 @@ def _where_input_wrangler( TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index), TorchLibOpInfo( "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler - ).xfail( + ).skip( dtypes=(torch.float16,), matcher=lambda sample: sample.kwargs.get("accumulate") is True, reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", From 15d3c45db6bfcf697ac64ff4cb6ff9647a5b9c64 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 12:45:43 -0700 Subject: [PATCH 05/16] msg Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/registration.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index c7c0a39634..e4000bc55e 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -43,11 +43,15 @@ def register(self, func: Any, name: str, *, complex: bool = False) -> None: if complex: if overloaded_function.complex: - raise ValueError(f"Complex overload for '{name}' already registered.") + raise ValueError( + f"Complex overload for '{name}' already registered: {overloaded_function.complex}." + ) overloaded_function.complex.append(func) else: if overloaded_function.overloads: - raise ValueError(f"Real overload for '{name}' already registered.") + raise ValueError( + f"Real overload for '{name}' already registered: {overloaded_function.overloads}." + ) overloaded_function.overloads.append(func) def __getitem__(self, name): From 8515502e94e554b0221ba55daf9dad638624fa39 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 22:15:23 -0700 Subject: [PATCH 06/16] skip pytest Signed-off-by: Justin Chu --- pyproject.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 4f7edc9bf8..315cc8a6ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,11 @@ onnx = ["py.typed"] [tool.pytest.ini_options] addopts = "-rsfEX --tb=short --color=yes" +norecursedirs = [ + # Skip test collection because pytest will try to import the modules twice, + # causing the torchlib registry to complain that functions are redefined. + "onnxscript/function_libs/torch_lib/ops", +] [tool.mypy] # TODO disallow_incomplete_defs = true From 90c3d0227f45cf5d7b9e5d1ea282187f8fa29945 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Thu, 9 Oct 2025 22:19:37 -0700 Subject: [PATCH 07/16] index_bool is wrong Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 9ae8eb1c43..539947bbbd 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4201,9 +4201,9 @@ def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Ten finla_rank = input_rank - (len(index.shape) - 1) trans_perm = list(range(finla_rank)) trans_perm = trans_perm[-1:] + trans_perm[:-1] - for _ in range(count_of_none): - result = op.Transpose(result, perm=trans_perm) - return result + for _ in range(count_of_none): + result = op.Transpose(result, perm=trans_perm) + return result def aten_index_add( From 257e58332416b30e664d85269abafb9a8a9a3b21 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Dec 2025 09:47:31 -0800 Subject: [PATCH 08/16] Consolidate index Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 015182f7ca..6b4646020c 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4765,11 +4765,8 @@ def aten_index( None in `indices` are like fillers for dimensions that cannot be removed in the process. """ # Handle Boolean indexing first - for index in indices: - if index is None: - continue - if index.dtype == BOOL.dtype: - return _aten_index_bool(self, indices) + if any(index is not None and index.dtype == ir.DataType.BOOL for index in indices): + return _aten_index_bool(self, indices) index_ranks = [len(index.shape) for index in indices if index is not None] @@ -4844,6 +4841,9 @@ def aten_index_put( See implementation of `torch.onnx.symbolic_opset11.index_put `_. """ + if any(index is not None and index.dtype == BOOL.dtype for index in indices): + return _aten_index_put_bool(self, indices, values, accumulate) + # Ensure the number of indices matches the tensor rank by appending trailing Nones. self_rank = len(self.shape) if len(indices) < self_rank: From 10065a4d325dcdc906ae7b26a837611c79382bed Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Dec 2025 09:55:12 -0800 Subject: [PATCH 09/16] warn Signed-off-by: Justin Chu --- .../function_libs/torch_lib/registration.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index e4000bc55e..ebc50cd190 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -5,6 +5,7 @@ from __future__ import annotations import re +import warnings from typing import Any, Callable, Generator, Optional import onnxscript @@ -43,15 +44,19 @@ def register(self, func: Any, name: str, *, complex: bool = False) -> None: if complex: if overloaded_function.complex: - raise ValueError( - f"Complex overload for '{name}' already registered: {overloaded_function.complex}." + warnings.warn( + f"Complex overload for '{name}' already registered: {overloaded_function.complex}.", + stacklevel=3, ) + return overloaded_function.complex.append(func) else: if overloaded_function.overloads: - raise ValueError( - f"Real overload for '{name}' already registered: {overloaded_function.overloads}." + warnings.warn( + f"Real overload for '{name}' already registered: {overloaded_function.overloads}.", + stacklevel=3, ) + return overloaded_function.overloads.append(func) def __getitem__(self, name): @@ -101,7 +106,6 @@ def torch_op( *, registry: Optional[Registry] = None, trace_only: bool = False, - private: bool = False, complex: bool = False, ) -> Callable[[Callable], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: """Register a torch op. @@ -113,8 +117,6 @@ def torch_op( i.e. "aten::relu" instead of "aten::relu.default". registry: Registry to register the function to. If None, the default registry is used. trace_only: Whether the function should only be traced and not compiled. - private: Whether the function is private (not directly exposed). It should - be true for all functions with names starting with "_". complex: Whether the function expects complex-valued inputs. """ if registry is None: @@ -134,8 +136,6 @@ def wrapper( assert registry is not None for name_ in _check_and_normalize_names(name): - if private: - continue registry.register(processed_func, name_, complex=complex) return processed_func From a7f027e25a965cfe863ee46b810d77fbf3089b03 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Dec 2025 10:16:12 -0800 Subject: [PATCH 10/16] wip Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 68 +++++++++++-------- .../function_libs/torch_lib/registration.py | 4 ++ 2 files changed, 44 insertions(+), 28 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 6b4646020c..09494f20a4 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4773,11 +4773,12 @@ def aten_index( return _aten_index_onnx(self, indices, index_ranks) -def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements +def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: index_ranks = [len(index.shape) for index in indices if index is not None] - if index_ranks[0] == 1: - # indices contains scalar only. + # Check if all non-None boolean indices are 1D + if all(rank == 1 for rank in index_ranks): + # All indices are 1D, convert boolean indices to integer indices new_indices = [ op.Transpose(op.NonZero(index), perm=[1, 0]) if index is not None else None for index in indices @@ -4786,31 +4787,42 @@ def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Ten op.Squeeze(index, axes=[1]) if index is not None else None for index in new_indices ] return _aten_index_onnx(self, new_indices, index_ranks) - else: - input_rank = len(self.shape) - # Prepare perm for transposing self tensor. - # In indices, None meaning skip the corresponding dimension, - # so we need to move this dimension to the end of the list. - # After we gathered the final results, we transpose it back. - # For example, - # self's shape is [5, 5, 5, 5], indices is [None, (5, 5)] - # the final result's shape should be [5, 16, 5]. - trans_perm = list(range(input_rank)) - trans_perm.append(trans_perm.pop(0)) - count_of_none = 0 - for index in indices: - if index is None: - self = op.Transpose(self, perm=trans_perm) - count_of_none += 1 - else: - new_indices = op.Transpose(op.NonZero(index), perm=[1, 0]) - result = op.GatherND(self, new_indices, batch_dims=0) - finla_rank = input_rank - (len(index.shape) - 1) - trans_perm = list(range(finla_rank)) - trans_perm = trans_perm[-1:] + trans_perm[:-1] - for _ in range(count_of_none): - result = op.Transpose(result, perm=trans_perm) - return result + + # Handle multi-dimensional boolean indexing + input_rank = len(self.shape) + result = self + + # Count None values before the first non-None index + none_count_before = 0 + for index in indices: + if index is not None: + break + none_count_before += 1 + + # Transpose to move the dimensions corresponding to None indices to the end + if none_count_before > 0: + # Move the first none_count_before dimensions to the end + perm = list(range(none_count_before, input_rank)) + list(range(none_count_before)) + result = op.Transpose(result, perm=perm) + + # Apply GatherND for the first non-None boolean index + for index in indices: + if index is not None: + new_indices = op.Transpose(op.NonZero(index), perm=[1, 0]) + result = op.GatherND(result, new_indices, batch_dims=0) + break + + # Transpose back to put the None dimensions in their original relative positions + if none_count_before > 0: + # After GatherND, the gathered dimension is at the beginning + # We need to move the None dimensions back to their relative positions + final_rank = len(result.shape) + # The gathered results are in dimension 0, and the None dimensions are at the end + # We want to move them back after the gathered dimension + perm = [0] + list(range(final_rank - none_count_before, final_rank)) + list(range(1, final_rank - none_count_before)) + result = op.Transpose(result, perm=perm) + + return result def aten_index_add( diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index ebc50cd190..51bc8e8eb9 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -107,6 +107,7 @@ def torch_op( registry: Optional[Registry] = None, trace_only: bool = False, complex: bool = False, + private: bool = False, ) -> Callable[[Callable], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: """Register a torch op. @@ -136,6 +137,9 @@ def wrapper( assert registry is not None for name_ in _check_and_normalize_names(name): + if private: + # Remove the private tag once all functions are no longer private. + continue registry.register(processed_func, name_, complex=complex) return processed_func From bf6500d25b7e96c9a28de28dd45de4e6ba92f5a4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Dec 2025 10:41:00 -0800 Subject: [PATCH 11/16] wip Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 09494f20a4..bae294b057 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4787,31 +4787,31 @@ def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Ten op.Squeeze(index, axes=[1]) if index is not None else None for index in new_indices ] return _aten_index_onnx(self, new_indices, index_ranks) - + # Handle multi-dimensional boolean indexing input_rank = len(self.shape) result = self - + # Count None values before the first non-None index none_count_before = 0 for index in indices: if index is not None: break none_count_before += 1 - + # Transpose to move the dimensions corresponding to None indices to the end if none_count_before > 0: # Move the first none_count_before dimensions to the end perm = list(range(none_count_before, input_rank)) + list(range(none_count_before)) result = op.Transpose(result, perm=perm) - + # Apply GatherND for the first non-None boolean index for index in indices: if index is not None: new_indices = op.Transpose(op.NonZero(index), perm=[1, 0]) result = op.GatherND(result, new_indices, batch_dims=0) break - + # Transpose back to put the None dimensions in their original relative positions if none_count_before > 0: # After GatherND, the gathered dimension is at the beginning @@ -4821,7 +4821,7 @@ def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Ten # We want to move them back after the gathered dimension perm = [0] + list(range(final_rank - none_count_before, final_rank)) + list(range(1, final_rank - none_count_before)) result = op.Transpose(result, perm=perm) - + return result From 9bdbf121a03d9470a9873e572b0ed7aa3735e1a3 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Dec 2025 11:12:24 -0800 Subject: [PATCH 12/16] Update Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 43 +++---------------- .../function_libs/torch_lib/registration.py | 5 ++- .../function_libs/torch_lib/ops_test_data.py | 7 ++- 3 files changed, 14 insertions(+), 41 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index bae294b057..997168acac 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4696,7 +4696,7 @@ def _aten_index_onnx( if _has_none_in_middle(indices): # If there is None in the middle, Advanced Indexing cannot decide where to put # the new dimensions. So it places them in the front, like GatherND does. - return op.Identity(self) + return self # When the indices are consecutive, Advanced Indexing will place the new dimensions # (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes. @@ -4776,9 +4776,8 @@ def aten_index( def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: index_ranks = [len(index.shape) for index in indices if index is not None] - # Check if all non-None boolean indices are 1D if all(rank == 1 for rank in index_ranks): - # All indices are 1D, convert boolean indices to integer indices + # indices contains scalar only. new_indices = [ op.Transpose(op.NonZero(index), perm=[1, 0]) if index is not None else None for index in indices @@ -4788,41 +4787,9 @@ def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Ten ] return _aten_index_onnx(self, new_indices, index_ranks) - # Handle multi-dimensional boolean indexing - input_rank = len(self.shape) - result = self - - # Count None values before the first non-None index - none_count_before = 0 - for index in indices: - if index is not None: - break - none_count_before += 1 - - # Transpose to move the dimensions corresponding to None indices to the end - if none_count_before > 0: - # Move the first none_count_before dimensions to the end - perm = list(range(none_count_before, input_rank)) + list(range(none_count_before)) - result = op.Transpose(result, perm=perm) - - # Apply GatherND for the first non-None boolean index - for index in indices: - if index is not None: - new_indices = op.Transpose(op.NonZero(index), perm=[1, 0]) - result = op.GatherND(result, new_indices, batch_dims=0) - break - - # Transpose back to put the None dimensions in their original relative positions - if none_count_before > 0: - # After GatherND, the gathered dimension is at the beginning - # We need to move the None dimensions back to their relative positions - final_rank = len(result.shape) - # The gathered results are in dimension 0, and the None dimensions are at the end - # We want to move them back after the gathered dimension - perm = [0] + list(range(final_rank - none_count_before, final_rank)) + list(range(1, final_rank - none_count_before)) - result = op.Transpose(result, perm=perm) - - return result + raise NotImplementedError( + "aten::index with boolean indices of rank > 1 is not supported yet." + ) def aten_index_add( diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 51bc8e8eb9..6d64859706 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -106,7 +106,6 @@ def torch_op( *, registry: Optional[Registry] = None, trace_only: bool = False, - complex: bool = False, private: bool = False, ) -> Callable[[Callable], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: """Register a torch op. @@ -118,6 +117,8 @@ def torch_op( i.e. "aten::relu" instead of "aten::relu.default". registry: Registry to register the function to. If None, the default registry is used. trace_only: Whether the function should only be traced and not compiled. + private: Whether the function is private (not directly exposed). It should + be true for all functions with names starting with "_". complex: Whether the function expects complex-valued inputs. """ if registry is None: @@ -138,7 +139,7 @@ def wrapper( assert registry is not None for name_ in _check_and_normalize_names(name): if private: - # Remove the private tag once all functions are no longer private. + # TODO: Remove the private tag once all functions are no longer private. continue registry.register(processed_func, name_, complex=complex) return processed_func diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index a4d4b33150..22bcbfcf90 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -728,7 +728,12 @@ def _where_input_wrangler( # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), - TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index), + TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index).skip( + matcher=lambda sample: any( + len(index.shape) != 1 for index in sample.args[0] if index is not None + ), + reason="fixme: aten::index with boolean indices of rank > 1 is not supported yet", + ), TorchLibOpInfo( "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler ).skip( From dcf91f6175bce2529a99c16564310f3d5c1816f9 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Dec 2025 11:13:53 -0800 Subject: [PATCH 13/16] updaa Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/registration.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/function_libs/torch_lib/registration.py b/onnxscript/function_libs/torch_lib/registration.py index 6d64859706..077391e5a1 100644 --- a/onnxscript/function_libs/torch_lib/registration.py +++ b/onnxscript/function_libs/torch_lib/registration.py @@ -107,6 +107,7 @@ def torch_op( registry: Optional[Registry] = None, trace_only: bool = False, private: bool = False, + complex: bool = False, ) -> Callable[[Callable], onnxscript.OnnxFunction | onnxscript.values.TracedOnnxFunction]: """Register a torch op. From 5f48a535e97c6096f187b2ba018e9c490214f02f Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Dec 2025 11:20:21 -0800 Subject: [PATCH 14/16] Update tests Signed-off-by: Justin Chu --- .../function_libs/torch_lib/ops/core.py | 31 +++++++++++++++++-- .../function_libs/torch_lib/ops_test_data.py | 7 +---- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 997168acac..da9a3a4e79 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4787,9 +4787,34 @@ def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Ten ] return _aten_index_onnx(self, new_indices, index_ranks) - raise NotImplementedError( - "aten::index with boolean indices of rank > 1 is not supported yet." - ) + else: + input_rank = len(self.shape) + # Prepare perm for transposing self tensor. + # In indices, None meaning skip the corresponding dimension, + # so we need to move this dimension to the end of the list. + # After we gathered the final results, we transpose it back. + # For example, + # self's shape is [5, 5, 5, 5], indices is [None, (5, 5)] + # the final result's shape should be [5, 16, 5]. + trans_perm = list(range(input_rank)) + trans_perm.append(trans_perm.pop(0)) + count_of_none = 0 + for index in indices: + if index is None: + self = op.Transpose(self, perm=trans_perm) + count_of_none += 1 + continue + + new_indices = op.Transpose(op.NonZero(index), perm=[1, 0]) + result = op.GatherND(self, new_indices, batch_dims=0) + final_rank = input_rank - (len(index.shape) - 1) + trans_perm = list(range(final_rank)) + trans_perm = trans_perm[-1:] + trans_perm[:-1] + for _ in range(count_of_none): + result = op.Transpose(result, perm=trans_perm) + # TODO(justinchuby): Even though this logic passes the tests, it still looks strange: + # why does it return early here instead of continuing to process the remaining indices? + return result def aten_index_add( diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 22bcbfcf90..a4d4b33150 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -728,12 +728,7 @@ def _where_input_wrangler( # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), - TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index).skip( - matcher=lambda sample: any( - len(index.shape) != 1 for index in sample.args[0] if index is not None - ), - reason="fixme: aten::index with boolean indices of rank > 1 is not supported yet", - ), + TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index), TorchLibOpInfo( "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler ).skip( From 4d9ca4d1b07ad3f3a8d859432eeecb8576c117d4 Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Dec 2025 11:22:05 -0800 Subject: [PATCH 15/16] notes Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index da9a3a4e79..822b31fa4b 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4814,6 +4814,7 @@ def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Ten result = op.Transpose(result, perm=trans_perm) # TODO(justinchuby): Even though this logic passes the tests, it still looks strange: # why does it return early here instead of continuing to process the remaining indices? + # I think the assumption here is that there can be only one Boolean index in the indices list? return result From a115fa9871c30941c740c5cbc2e93d729fc0a19d Mon Sep 17 00:00:00 2001 From: Justin Chu Date: Fri, 19 Dec 2025 11:22:18 -0800 Subject: [PATCH 16/16] fixme Signed-off-by: Justin Chu --- onnxscript/function_libs/torch_lib/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py index 822b31fa4b..1cbb28fbde 100644 --- a/onnxscript/function_libs/torch_lib/ops/core.py +++ b/onnxscript/function_libs/torch_lib/ops/core.py @@ -4812,7 +4812,7 @@ def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> Ten trans_perm = trans_perm[-1:] + trans_perm[:-1] for _ in range(count_of_none): result = op.Transpose(result, perm=trans_perm) - # TODO(justinchuby): Even though this logic passes the tests, it still looks strange: + # FIXME(justinchuby): Even though this logic passes the tests, it still looks strange: # why does it return early here instead of continuing to process the remaining indices? # I think the assumption here is that there can be only one Boolean index in the indices list? return result