-
Notifications
You must be signed in to change notification settings - Fork 96
[torchlib] Consolidate all overloads and prevent new ones from being created #2621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
ce76722
Consolidate all overloads and prevent new ones from being created
justinchuby 477d9fb
Remove registration of private functions
justinchuby f704a3a
fix typing
justinchuby 8f69c8d
test
justinchuby 15d3c45
msg
justinchuby a1fc13e
Merge branch 'main' into justinchu/consolidate-index
justinchuby 8515502
skip pytest
justinchuby 90c3d02
index_bool is wrong
justinchuby eb9ea10
Merge branch 'main' into justinchu/consolidate-index
justinchuby 0b5f8f9
Merge branch 'main' into justinchu/consolidate-index
justinchuby 98ceb0a
Bump version from 0.5.7 to 0.6.0
justinchuby 257e583
Consolidate index
justinchuby 10065a4
warn
justinchuby a7f027e
wip
justinchuby bf6500d
wip
justinchuby 9bdbf12
Update
justinchuby dcf91f6
updaa
justinchuby 5f48a53
Update tests
justinchuby 4d9ca4d
notes
justinchuby a115fa9
fixme
justinchuby e6c4551
Merge branch 'main' into justinchu/consolidate-index
justinchuby File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -328,7 +328,6 @@ def aten_col2im( | |
| else: # assert len(padding) == 4, already [w, x, y, z] | ||
| pads = padding | ||
|
|
||
| # Only one ONNX op here so didn't write a private function | ||
| return op.Col2Im( | ||
| self, | ||
| output_size, | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| from __future__ import annotations | ||
|
|
||
| import re | ||
| import warnings | ||
| from typing import Any, Callable, Generator, Optional | ||
|
|
||
| import onnxscript | ||
|
|
@@ -22,14 +23,12 @@ class OverloadedFunction: | |
| Attributes: | ||
| name: Name of the op. E.g. "aten::add". | ||
| overloads: Overloads function. | ||
| privates: Private functions not exposed to users. | ||
| complex: Support complex functions. | ||
| """ | ||
|
|
||
| def __init__(self, name: str): | ||
| self.name = name | ||
| self.overloads: list[Any] = [] | ||
| self.privates: list[Any] = [] | ||
| self.complex: list[Any] = [] | ||
|
|
||
|
|
||
|
|
@@ -39,17 +38,26 @@ class Registry: | |
| def __init__(self): | ||
| self._registry: dict[str, OverloadedFunction] = {} | ||
|
|
||
| def register( | ||
| self, func: Any, name: str, *, private: bool = False, complex: bool = False | ||
| ) -> None: | ||
| def register(self, func: Any, name: str, *, complex: bool = False) -> None: | ||
| """Register a function.""" | ||
|
|
||
| if private: | ||
| self._registry.setdefault(name, OverloadedFunction(name)).privates.append(func) | ||
| elif complex: | ||
| self._registry.setdefault(name, OverloadedFunction(name)).complex.append(func) | ||
| overloaded_function = self._registry.setdefault(name, OverloadedFunction(name)) | ||
|
|
||
| if complex: | ||
| if overloaded_function.complex: | ||
| warnings.warn( | ||
| f"Complex overload for '{name}' already registered: {overloaded_function.complex}.", | ||
| stacklevel=3, | ||
| ) | ||
| return | ||
| overloaded_function.complex.append(func) | ||
| else: | ||
| self._registry.setdefault(name, OverloadedFunction(name)).overloads.append(func) | ||
| if overloaded_function.overloads: | ||
| warnings.warn( | ||
| f"Real overload for '{name}' already registered: {overloaded_function.overloads}.", | ||
| stacklevel=3, | ||
| ) | ||
| return | ||
| overloaded_function.overloads.append(func) | ||
|
|
||
| def __getitem__(self, name): | ||
| return self._registry[name] | ||
|
|
@@ -131,7 +139,10 @@ def wrapper( | |
|
|
||
| assert registry is not None | ||
| for name_ in _check_and_normalize_names(name): | ||
| registry.register(processed_func, name_, private=private, complex=complex) | ||
| if private: | ||
| # TODO: Remove the private tag once all functions are no longer private. | ||
| continue | ||
| registry.register(processed_func, name_, complex=complex) | ||
| return processed_func | ||
|
|
||
| return wrapper | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -43,6 +43,11 @@ onnx = ["py.typed"] | |
|
|
||
| [tool.pytest.ini_options] | ||
| addopts = "-rsfEX --tb=short --color=yes" | ||
| norecursedirs = [ | ||
| # Skip test collection because pytest will try to import the modules twice, | ||
| # causing the torchlib registry to complain that functions are redefined. | ||
| "onnxscript/function_libs/torch_lib/ops", | ||
| ] | ||
|
|
||
| [tool.mypy] | ||
| # TODO disallow_incomplete_defs = true | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -728,23 +728,10 @@ def _where_input_wrangler( | |
| # TorchLibOpInfo("is_same_size", core_ops.aten_is_same_size), # no test case in OPS_DB | ||
| # TorchLibOpInfo("is_nonzero", core_ops.aten_is_nonzero), # no test case in OPS_DB | ||
| TorchLibOpInfo("ops.aten.index.Tensor", core_ops.aten_index), | ||
| TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index_bool), | ||
| TorchLibOpInfo( | ||
| "index_put_bool", | ||
| core_ops.aten_index_put_bool, | ||
| input_wrangler=_index_put_input_wrangler, | ||
| ).skip( | ||
| matcher=lambda sample: sample.args[0][0].dtype != torch.bool, | ||
| reason="this Aten overload only supports tensor(bool) as indices", | ||
| ), | ||
| TorchLibOpInfo("ops.aten.index.Tensor.bool", core_ops.aten_index), | ||
| TorchLibOpInfo( | ||
| "index_put", core_ops.aten_index_put, input_wrangler=_index_put_input_wrangler | ||
| ) | ||
| .skip( | ||
| matcher=lambda sample: sample.args[0][0].dtype != torch.int64, | ||
| reason="this Aten overload only supports tensor(int) as indices", | ||
| ) | ||
| .xfail( | ||
| ).skip( | ||
| dtypes=(torch.float16,), | ||
| matcher=lambda sample: sample.kwargs.get("accumulate") is True, | ||
| reason="fixme: ORT only supports float32 when accumulate is True: MLFloat16 data type is not supported with ScatterND when reduction is 'add'", | ||
|
|
@@ -1871,7 +1858,6 @@ def _where_input_wrangler( | |
| ops_test_common.duplicate_opinfo(OPS_DB, "cat", ("concat", "concatenate")) | ||
| ops_test_common.duplicate_opinfo(OPS_DB, "clone", ("lift_fresh_copy",)) | ||
| ops_test_common.duplicate_opinfo(OPS_DB, "div", ("div_mode",)) | ||
| ops_test_common.duplicate_opinfo(OPS_DB, "index_put", ("index_put_bool",)) | ||
| ops_test_common.duplicate_opinfo(OPS_DB, "max", ("max_dim",)) | ||
| ops_test_common.duplicate_opinfo(OPS_DB, "mean", ("mean_dim",)) | ||
| ops_test_common.duplicate_opinfo(OPS_DB, "min", ("min_dim",)) | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.