-
Notifications
You must be signed in to change notification settings - Fork 96
[torchlib] Consolidate all overloads and prevent new ones from being created #2621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR consolidates overloaded functions in the torch_lib by removing the concept of private functions and preventing new overloads from being created for the same operation name. The changes focus on simplifying the registration system and merging boolean indexing operations with their regular counterparts.
- Removes the
privateparameter and functionality from the registration system - Consolidates boolean and regular index operations into unified functions
- Adds validation to prevent duplicate overload registrations
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| onnxscript/function_libs/torch_lib/registration.py | Removes private function support and adds overload duplication prevention |
| onnxscript/function_libs/torch_lib/ops/core.py | Consolidates index and index_put functions to handle both boolean and integer indexing |
| tests/function_libs/torch_lib/ops_test_data.py | Removes separate boolean index test entries and duplicates |
| onnxscript/function_libs/torch_lib/ops/nn.py | Removes outdated comment about private functions |
| onnxscript/function_libs/tools/torch_lib/deduce_type_constraints_test.py | Updates to exclude removed private functions from iteration |
titaiwangms
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CI is failing
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2621 +/- ##
==========================================
- Coverage 70.14% 70.13% -0.02%
==========================================
Files 226 226
Lines 27369 27379 +10
Branches 2775 2779 +4
==========================================
+ Hits 19199 19201 +2
- Misses 7218 7223 +5
- Partials 952 955 +3 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com>
|
|
||
| @torch_op(("aten::index.Tensor", "aten::_unsafe_index.Tensor"), trace_only=True) | ||
| def aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: # pylint: disable=inconsistent-return-statements | ||
| def _aten_index_bool(self: TensorType, indices: Sequence[Optional[BOOL]]) -> TensorType: |
Check notice
Code scanning / CodeQL
Explicit returns mixed with implicit (fall through) returns Note
Show autofix suggestion
Hide autofix suggestion
Copilot Autofix
AI about 20 hours ago
In general, to fix “explicit returns mixed with implicit returns” in a function that is annotated to return a non-None value, ensure that every control-flow path either explicitly returns a value of the annotated type or raises an exception. You prevent implicit fall-through by adding an explicit return at the end of the function or at the end of any branch where fall-through is possible.
For _aten_index_bool, the best fix without changing functionality is to make sure the else: branch (for non-all-scalar boolean indices) also ends with an explicit return of a TensorType. Inside this branch, the function presumably computes or delegates to _aten_index_onnx in a way analogous to the scalar-boolean case. We should reuse that result instead of inventing new behavior. Concretely, we can have the else: branch finish by returning the value from whatever computation is already there. If the omitted code builds some processed_indices and currently calls _aten_index_onnx(self, processed_indices, index_ranks) without returning it, we should change that call into return _aten_index_onnx(...). If, instead, the omitted code only prepares data and then falls through, we can add a final return _aten_index_onnx(self, indices, index_ranks) at the end of the function so that any path in the else: branch still returns a TensorType. This keeps semantics aligned with the aten_index function (which always delegates to _aten_index_onnx).
Given we must not assume unshown code structure, the minimally invasive and safe change within the shown region is to add an explicit return _aten_index_onnx(self, indices, index_ranks) as the last statement of _aten_index_bool. That guarantees that if execution reaches the bottom of the else: branch (or any other path that doesn’t otherwise return/raise), the function will still return a TensorType rather than None. No new imports or helper methods are needed; _aten_index_onnx is already used within this module.
-
Copy modified lines R4792-R4794
| @@ -4789,6 +4789,9 @@ | ||
|
|
||
| else: | ||
| input_rank = len(self.shape) | ||
| # Additional processing for non-scalar boolean indices is performed above. | ||
| # Ensure we always return a TensorType instead of implicitly returning None. | ||
| return _aten_index_onnx(self, indices, index_ranks) | ||
| # Prepare perm for transposing self tensor. | ||
| # In indices, None meaning skip the corresponding dimension, | ||
| # so we need to move this dimension to the end of the list. |
This PR implements #2580 by combining all overloads in torchlib and remove the ability to register new ones. It is done in a BC compatible fashion and should work with released versions of PyTorch.
From now on all logic for a single aten OpOverload should be implemented by a single torchlib function to ensure 1-to-1 mapping.