-
Notifications
You must be signed in to change notification settings - Fork 385
use python version agnostic binding for mxfp8 cuda kernels #3471
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3471
Note: Links to docs will display an error until the docs builds have been completed. ❌ 6 New FailuresAs of commit e576e41 with merge base 08e5e20 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
atalman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
|
spoke offline, we need to be using torchlib and not use pybind |
4e59508 to
78a8b79
Compare
| not is_cuda_version_at_least(12, 8), | ||
| reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels", | ||
| ) | ||
| def test_cuda_mx_dim1_invalid_block_size(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
deleting this test since block_size of 32 is hard coded in the python wrapper for the kernel now, since we always use this for mxfp8
78a8b79 to
e576e41
Compare
|
you should be able to use nm -D to also investigate the symbols and ensure there are none from python |
|
current CI failures will be resolved once this rollback in upstream pytorch is included in the next torch nightly: pytorch/pytorch#169985 |
Summary
torchao/ops.pyfor other CUDA C++ extensionstorchao._C_mxfp8so file (lands in build/ dir) instead of separate torchao.prototype.mxpf8_cuda`extension (landed in torchao/prototype)Context
While doing the 0.15.0 torchao release and testing the test build for cuda 12.8, and i found the "torchao.prototype.mxfp8_cuda" c++ extension cannot be found (import error, module not found). we only build the extension for cuda 12.8+, so i checked the logs and i see logs indicating it was built: https://github.com/pytorch/ao/actions/runs/20046209265/job/57498462190
so then i checked the local installation itself, and i do see a .so file for the extension in the torchao/prototype dir, so it is definitely being built.
i tried asking claude about this and it says the build for python3.10 must match the python version in the conda env due to ABI incompatibility (i'm using python 3.12). as a test, i tried a fresh conda env with python 3.10, and instead of module not found, i get an undefined symbol error, so that does seem to indicate some python ABI issue.
asking @drisspg he said we should be building with a py agnostic flag, so i looked into this and we are doing this for other c++ extensions but not mxfp8_cuda, so I am fairly certain this is the root cause and this PR will fix the issue.