Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

def torch_lib_onnx_functions_from_registry() -> Generator[onnxscript.OnnxFunction, None, None]:
for op in registration.default_registry.values():
for func in (*op.overloads, *op.privates, *op.complex):
for func in (*op.overloads, *op.complex):
if isinstance(func, onnxscript.OnnxFunction):
yield func

Expand Down
47 changes: 28 additions & 19 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4382,7 +4382,6 @@
padding_mode_options = ("zeros", "border", "reflection")
padding_mode_str = padding_mode_options[padding_mode]

# Only one onnx Op so don't put into private function
return op.GridSample(
input,
grid,
Expand All @@ -4408,7 +4407,6 @@
padding_mode_options = ("zeros", "border", "reflection")
padding_mode_str = padding_mode_options[padding_mode]

# Only one onnx Op so don't put into private function
return op.GridSample(
input,
grid,
Expand Down Expand Up @@ -4698,7 +4696,7 @@
if _has_none_in_middle(indices):
# If there is None in the middle, Advanced Indexing cannot decide where to put
# the new dimensions. So it places them in the front, like GatherND does.
return op.Identity(self)
return self

# When the indices are consecutive, Advanced Indexing will place the new dimensions
# (aka. the broadcasted shape) in the middle, replacing the original [x1, ..., xk] axes.
Expand Down Expand Up @@ -4744,7 +4742,9 @@


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorType:
def aten_index(
self: TensorType, indices: Sequence[Optional[Union[INT64, BOOL]]]
) -> TensorType:
"""index.Tensor(Tensor self, Tensor?[] indices) -> Tensor

NOTE: Understanding `aten::index`
Expand All @@ -4764,17 +4764,19 @@

None in `indices` are like fillers for dimensions that cannot be removed in the process.
"""
# Handle Boolean indexing first
if any(index is not None and index.dtype == ir.DataType.BOOL for index in indices):
return _aten_index_bool(self, indices)

index_ranks = [len(index.shape) for index in indices if index is not None]

return _aten_index_onnx(self, indices, index_ranks)


@torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True)
def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements
def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType:

Check notice

Code scanning / CodeQL

Explicit returns mixed with implicit (fall through) returns Note

Mixing implicit and explicit returns may indicate an error, as implicit returns always return None.

Copilot Autofix

AI about 22 hours ago

In general, to fix “explicit returns mixed with implicit returns” in a function that is annotated to return a non-None value, ensure that every control-flow path either explicitly returns a value of the annotated type or raises an exception. You prevent implicit fall-through by adding an explicit return at the end of the function or at the end of any branch where fall-through is possible.

For _aten_index_bool, the best fix without changing functionality is to make sure the else: branch (for non-all-scalar boolean indices) also ends with an explicit return of a TensorType. Inside this branch, the function presumably computes or delegates to _aten_index_onnx in a way analogous to the scalar-boolean case. We should reuse that result instead of inventing new behavior. Concretely, we can have the else: branch finish by returning the value from whatever computation is already there. If the omitted code builds some processed_indices and currently calls _aten_index_onnx(self, processed_indices, index_ranks) without returning it, we should change that call into return _aten_index_onnx(...). If, instead, the omitted code only prepares data and then falls through, we can add a final return _aten_index_onnx(self, indices, index_ranks) at the end of the function so that any path in the else: branch still returns a TensorType. This keeps semantics aligned with the aten_index function (which always delegates to _aten_index_onnx).

Given we must not assume unshown code structure, the minimally invasive and safe change within the shown region is to add an explicit return _aten_index_onnx(self, indices, index_ranks) as the last statement of _aten_index_bool. That guarantees that if execution reaches the bottom of the else: branch (or any other path that doesn’t otherwise return/raise), the function will still return a TensorType rather than None. No new imports or helper methods are needed; _aten_index_onnx is already used within this module.

Suggested changeset 1
onnxscript/function_libs/torch_lib/ops/core.py

Autofix patch

Autofix patch
Run the following command in your local git repository to apply this patch
cat << 'EOF' | git apply
diff --git a/onnxscript/function_libs/torch_lib/ops/core.py b/onnxscript/function_libs/torch_lib/ops/core.py
--- a/onnxscript/function_libs/torch_lib/ops/core.py
+++ b/onnxscript/function_libs/torch_lib/ops/core.py
@@ -4789,6 +4789,9 @@
 
     else:
         input_rank = len(self.shape)
+        # Additional processing for non-scalar boolean indices is performed above.
+        # Ensure we always return a TensorType instead of implicitly returning None.
+        return _aten_index_onnx(self, indices, index_ranks)
         # Prepare perm for transposing self tensor.
         # In indices, None meaning skip the corresponding dimension,
         # so we need to move this dimension to the end of the list.
EOF
@@ -4789,6 +4789,9 @@

else:
input_rank = len(self.shape)
# Additional processing for non-scalar boolean indices is performed above.
# Ensure we always return a TensorType instead of implicitly returning None.
return _aten_index_onnx(self, indices, index_ranks)
# Prepare perm for transposing self tensor.
# In indices, None meaning skip the corresponding dimension,
# so we need to move this dimension to the end of the list.
Copilot is powered by AI and may make mistakes. Always verify output.
index_ranks = [len(index.shape) for index in indices if index is not None]

if index_ranks[0] == 1:
if all(rank == 1 for rank in index_ranks):
# indices contains scalar only.
new_indices = [
op.Transpose(op.NonZero(index), perm=[1, 0]) if index is not None else None
Expand All @@ -4784,6 +4786,7 @@
op.Squeeze(index, axes=[1]) if index is not None else None for index in new_indices
]
return _aten_index_onnx(self, new_indices, index_ranks)

else:
input_rank = len(self.shape)
# Prepare perm for transposing self tensor.
Expand All @@ -4800,15 +4803,19 @@
if index is None:
self = op.Transpose(self, perm=trans_perm)
count_of_none += 1
else:
new_indices = op.Transpose(op.NonZero(index), perm=[1, 0])
result = op.GatherND(self, new_indices, batch_dims=0)
finla_rank = input_rank - (len(index.shape) - 1)
trans_perm = list(range(finla_rank))
trans_perm = trans_perm[-1:] + trans_perm[:-1]
for _ in range(count_of_none):
result = op.Transpose(result, perm=trans_perm)
return result
continue

new_indices = op.Transpose(op.NonZero(index), perm=[1, 0])
result = op.GatherND(self, new_indices, batch_dims=0)
final_rank = input_rank - (len(index.shape) - 1)
trans_perm = list(range(final_rank))
trans_perm = trans_perm[-1:] + trans_perm[:-1]
for _ in range(count_of_none):
result = op.Transpose(result, perm=trans_perm)
# FIXME(justinchuby): Even though this logic passes the tests, it still looks strange:
# why does it return early here instead of continuing to process the remaining indices?
# I think the assumption here is that there can be only one Boolean index in the indices list?
return result


def aten_index_add(
Expand All @@ -4830,7 +4837,7 @@
@torch_op(("aten::index_put", "aten::_unsafe_index_put"), trace_only=True)
def aten_index_put(
self: TReal,
indices: Sequence[INT64],
indices: Sequence[Optional[Union[INT64, BOOL]]],
values: TReal,
accumulate: bool = False,
) -> TReal:
Expand All @@ -4839,6 +4846,9 @@
See implementation of `torch.onnx.symbolic_opset11.index_put
<https://github.com/pytorch/pytorch/blob/main/torch/onnx/symbolic_opset11.py#L212>`_.
"""
if any(index is not None and index.dtype == BOOL.dtype for index in indices):
return _aten_index_put_bool(self, indices, values, accumulate)

# Ensure the number of indices matches the tensor rank by appending trailing Nones.
self_rank = len(self.shape)
if len(indices) < self_rank:
Expand Down Expand Up @@ -4971,8 +4981,7 @@
return result


@torch_op("aten::index_put", trace_only=True)
def aten_index_put_bool(
def _aten_index_put_bool(
self: TReal,
indices: Sequence[BOOL],
values: TReal,
Expand Down
1 change: 0 additions & 1 deletion onnxscript/function_libs/torch_lib/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,6 @@ def aten_col2im(
else: # assert len(padding) == 4, already [w, x, y, z]
pads = padding

# Only one ONNX op here so didn't write a private function
return op.Col2Im(
self,
output_size,
Expand Down
35 changes: 23 additions & 12 deletions onnxscript/function_libs/torch_lib/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from __future__ import annotations

import re
import warnings
from typing import Any, Callable, Generator, Optional

import onnxscript
Expand All @@ -22,14 +23,12 @@ class OverloadedFunction:
Attributes:
name: Name of the op. E.g. "aten::add".
overloads: Overloads function.
privates: Private functions not exposed to users.
complex: Support complex functions.
"""

def __init__(self, name: str):
self.name = name
self.overloads: list[Any] = []
self.privates: list[Any] = []
self.complex: list[Any] = []


Expand All @@ -39,17 +38,26 @@ class Registry:
def __init__(self):
self._registry: dict[str, OverloadedFunction] = {}

def register(
self, func: Any, name: str, *, private: bool = False, complex: bool = False
) -> None:
def register(self, func: Any, name: str, *, complex: bool = False) -> None:
"""Register a function."""

if private:
self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func)
elif complex:
self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func)
overloaded_function = self._registry.setdefault(name, OverloadedFunction(name))

if complex:
if overloaded_function.complex:
warnings.warn(
f"Complex overload for '{name}' already registered: {overloaded_function.complex}.",
stacklevel=3,
)
return
overloaded_function.complex.append(func)
else:
self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func)
if overloaded_function.overloads:
warnings.warn(
f"Real overload for '{name}' already registered: {overloaded_function.overloads}.",
stacklevel=3,
)
return
overloaded_function.overloads.append(func)

def __getitem__(self, name):
return self._registry[name]
Expand Down Expand Up @@ -131,7 +139,10 @@ def wrapper(

assert registry is not None
for name_ in _check_and_normalize_names(name):
registry.register(processed_func, name_, private=private, complex=complex)
if private:
# TODO: Remove the private tag once all functions are no longer private.
continue
registry.register(processed_func, name_, complex=complex)
return processed_func

return wrapper
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@ onnx = ["py.typed"]

[tool.pytest.ini_options]
addopts = "-rsfEX --tb=short --color=yes"
norecursedirs = [
# Skip test collection because pytest will try to import the modules twice,
# causing the torchlib registry to complain that functions are redefined.
"onnxscript/function_libs/torch_lib/ops",
]

[tool.mypy]
# TODO disallow_incomplete_defs = true
Expand Down
18 changes: 2 additions & 16 deletions tests/function_libs/torch_lib/ops_test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,23 +728,10 @@ def _where_input_wrangler(
# TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB
# TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB
TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool),
TorchLibOpInfo(
"index_put_bool",
core_ops.aten_index_put_bool,
input_wrangler=_index_put_input_wrangler,
).skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.bool,
reason="this Aten overload only supports tensor(bool) as indices",
),
TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index),
TorchLibOpInfo(
"index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler
)
.skip(
matcher=lambda sample: sample.args[0][0].dtype != torch.int64,
reason="this Aten overload only supports tensor(int) as indices",
)
.xfail(
).skip(
dtypes=(torch.float16,),
matcher=lambda sample: sample.kwargs.get("accumulate") is True,
reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'",
Expand Down Expand Up @@ -1871,7 +1858,6 @@ def _where_input_wrangler(
ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate"))
ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",))
ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",))
ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",))
ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",))
ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",))
Expand Down
Loading