Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Dec 9, 2025

Summary

  • Migrate from PyBind to TORCH_LIBRARY API, which is python version agnostic
  • Update setup.py to use python limited (version agnostic) tag
  • Use pattern seen in torchao/ops.py for other CUDA C++ extensions
    • build torchao._C_mxfp8 so file (lands in build/ dir) instead of separate torchao.prototype.mxpf8_cuda`extension (landed in torchao/prototype)
    • Define new op schema for 2d quantization and 3d quantization kernels
    • Update custom ops, meta functions, and custom sharding registrations to wrap new custom op

Context

While doing the 0.15.0 torchao release and testing the test build for cuda 12.8, and i found the "torchao.prototype.mxfp8_cuda" c++ extension cannot be found (import error, module not found). we only build the extension for cuda 12.8+, so i checked the logs and i see logs indicating it was built: https://github.com/pytorch/ao/actions/runs/20046209265/job/57498462190

so then i checked the local installation itself, and i do see a .so file for the extension in the torchao/prototype dir, so it is definitely being built.

i tried asking claude about this and it says the build for python3.10 must match the python version in the conda env due to ABI incompatibility (i'm using python 3.12). as a test, i tried a fresh conda env with python 3.10, and instead of module not found, i get an undefined symbol error, so that does seem to indicate some python ABI issue.

asking @drisspg he said we should be building with a py agnostic flag, so i looked into this and we are doing this for other c++ extensions but not mxfp8_cuda, so I am fairly certain this is the root cause and this PR will fix the issue.

@pytorch-bot
Copy link

pytorch-bot bot commented Dec 9, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/3471

Note: Links to docs will display an error until the docs builds have been completed.

❌ 6 New Failures

As of commit e576e41 with merge base 08e5e20 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 9, 2025
@danielvegamyhre danielvegamyhre added topic: bug fix Use this tag for PRs that fix bugs and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Dec 9, 2025
Copy link
Contributor

@atalman atalman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@drisspg
Copy link
Contributor

drisspg commented Dec 9, 2025

spoke offline, we need to be using torchlib and not use pybind

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 9, 2025
@danielvegamyhre danielvegamyhre changed the title use python version agnostic c++ extension for mxfp8_cuda use python version agnostic c++ binding for mxfp8 cuda kernels Dec 10, 2025
@danielvegamyhre danielvegamyhre changed the title use python version agnostic c++ binding for mxfp8 cuda kernels use python version agnostic python binding for mxfp8 cuda kernels Dec 10, 2025
@danielvegamyhre danielvegamyhre changed the title use python version agnostic python binding for mxfp8 cuda kernels use python version agnostic binding for mxfp8 cuda kernels Dec 10, 2025
not is_cuda_version_at_least(12, 8),
reason="CUDA version >= 12.8 required for MXFP8 CUDA kernels",
)
def test_cuda_mx_dim1_invalid_block_size():
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

deleting this test since block_size of 32 is hard coded in the python wrapper for the kernel now, since we always use this for mxfp8

@drisspg
Copy link
Contributor

drisspg commented Dec 10, 2025

you should be able to use nm -D to also investigate the symbols and ensure there are none from python

@danielvegamyhre
Copy link
Contributor Author

current CI failures will be resolved once this rollback in upstream pytorch is included in the next torch nightly: pytorch/pytorch#169985

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: bug fix Use this tag for PRs that fix bugs

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants