-
Notifications
You must be signed in to change notification settings - Fork 96
Added padding_idx=None option and new test cases for aten_embedding_bag #2549
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@microsoft-github-policy-service agree |
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #2549 +/- ##
==========================================
- Coverage 70.07% 70.03% -0.04%
==========================================
Files 226 226
Lines 27276 27285 +9
Branches 2754 2756 +2
==========================================
- Hits 19113 19110 -3
- Misses 7213 7224 +11
- Partials 950 951 +1 ☔ View full report in Codecov by Sentry. |
|
@justinchuby I can handle the linting issues, but I’m confused about the other CI failures — could you help? |
|
@justinchuby updated fixed the issues in the code. can you review? |
|
@crypto-a I would love to see this PR go in! Are you still planning on contributing it? |
|
Changes LGTM. CI is reporting
|
|
@crypto-a could you take a look at the CI errors and rebase from main? Thanks |
|
Will take a look at it |
|
The CI is failing because my refactored masking-based approach is producing different numerical results than the original algorithm for mean/max aggregation modes, causing test failures. I am working on fixing the implementation |
|
@justinchuby , I’ve fixed the CI issues. I had to remove the test cases I added myself. Turns out the scenarios I needed for my code to work were already covered in the OpInfo tests. All tests are now passing on my end. Could you please rerun CI? Thanks! |
|
Tests are passing, the branch is up to date, and it’s ready to be merged |
| # Modify this section ########################################################## | ||
|
|
||
|
|
||
| def _embedding_bag_input_wrangler( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Final round of reivews: Is this necessary? I think it should accept a None input?
| sparse: bool = False, | ||
| per_sample_weights: Optional[TFloat] = None, | ||
| include_last_offset: bool = False, | ||
| padding_idx: Optional[int] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I checked https://github.com/pytorch/pytorch/blob/8656dea039bd0a31952a6a9792566e70b07429dc/aten/src/ATen/native/native_functions.yaml#L2372-L2378 and found that embedding_bag does not have padding_idx as a parameter. I think you only need to update aten_embedding_bag_padding_idx
Fix Issues #2219, #2385 and the first part of #2489
This commit adds new test cases and the necessary implementation changes to correctly support the
padding_idx=Noneoption in theaten_embedding_bagoperator. This aligns the ONNX Script operator with PyTorch's native behavior and expands test coverage for this feature.Key Changes:
core.py: Theaten_embedding_bag_padding_idxfunction has been updated to handlepadding_idx=None. This new code routes the operation to the standardaten_embedding_bagimplementation when no padding indices are specified.extra_opinfo.py: Two newOpInfodefinitions,test_embedding_bag_with_padding_idx_noneandtest_embedding_bag_with_padding_idx_int, have been added to theOP_DBlist. These provide input samples to test the new and existingpadding_idxfunctionality.ops_test_data.py: TheTESTED_TORCHLIB_OPStuple has been updated to include the new tests, ensuring they are discovered and executed by the test runner.