-
Notifications
You must be signed in to change notification settings - Fork 751
Metal backend: SDPA metal implementation #16086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Stack from ghstack (oldest at bottom): |
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16086
Note: Links to docs will display an error until the docs builds have been completed. ⏳ No Failures, 44 PendingAs of commit 7bf5730 with merge base c00d726 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
This PR needs a
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request replaces the MPSGraph-based implementation of Scaled Dot Product Attention (SDPA) with a custom Metal kernel implementation, ported from PyTorch and influenced by MLX.
Key Changes
- Custom Metal kernel: Implements a one-pass SDPA algorithm embedded as a 200+ line inline shader with template instantiations for float, half, and bfloat types across head dimensions of 64, 96, and 128
- Enhanced Metal API: Adds new
setArgoverloads for uint32_t, float, bool, and uint3 types, plus a newdispatchThreadgroupsmethod for explicit threadgroup dispatch - Stride-aware computation: The new kernel handles transposed tensor layouts by decomposing batch and head indices and using explicit stride information
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 14 comments.
| File | Description |
|---|---|
| backends/apple/metal/runtime/shims/et_metal_ops.mm | Replaces ~400 lines of MPSGraph code with inline Metal shader source and direct kernel dispatch; adds shader library caching |
| backends/apple/metal/runtime/shims/et_metal.mm | Implements new setArg overloads for scalar types and uint3 structs; adds dispatchThreadgroups for explicit threadgroup control |
| backends/apple/metal/runtime/shims/et_metal.h | Declares new Metal kernel function methods for argument setting and threadgroup dispatch |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
mergennachin
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
Are you ignoring "is_causal" altogether?
int32_t is_causal, -
Can we compile the metal shader at build time? Isn't it jit compiling?
| if (head_dim != 64 && head_dim != 96 && head_dim != 128) { | ||
| ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim); | ||
| throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128"); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the main reason for only limiting to these head sizes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The head size is a compile-time constant. See the template instantiation lines:
#define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE) \
INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64); \
INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96); \
INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128);
This is also how MLX does it, and I am pretty sure it is done for performance reasons (loop unrolling, register optimization, etc).
Yes, for now. I added a check, returning |
We don't have yet that infrastructure in the Metal backend. I don't think that's a priority though, because AOTI generated kernels are also jit compiled, and the amount of code we are jit compiling there is much more than what we are compiling in the single SDPA kernel. So, I don't think it would make any difference right now. |
Replaces SDPA MPSGraph's implementation with Metal implementation (adapted from MLX implementation, with several modifications, to support transposed middle dimensions, and floating point attention masks).
Speeds up voxtral/whisper by 2-3x
Fixes BFloat16 issue on macOS 26.1