Skip to content

Conversation

@manuelcandales
Copy link
Contributor

@manuelcandales manuelcandales commented Dec 4, 2025

Replaces SDPA MPSGraph's implementation with Metal implementation (adapted from MLX implementation, with several modifications, to support transposed middle dimensions, and floating point attention masks).

Speeds up voxtral/whisper by 2-3x

Fixes BFloat16 issue on macOS 26.1

[ghstack-poisoned]
@manuelcandales
Copy link
Contributor Author

manuelcandales commented Dec 4, 2025

Stack from ghstack (oldest at bottom):

Copilot AI review requested due to automatic review settings December 4, 2025 21:11
@pytorch-bot
Copy link

pytorch-bot bot commented Dec 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/16086

Note: Links to docs will display an error until the docs builds have been completed.

⏳ No Failures, 44 Pending

As of commit 7bf5730 with merge base c00d726 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

manuelcandales added a commit that referenced this pull request Dec 4, 2025
ghstack-source-id: aa77e4b
ghstack-comment-id: 3614336034
Pull-Request: #16086
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Dec 4, 2025
@github-actions
Copy link

github-actions bot commented Dec 4, 2025

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request replaces the MPSGraph-based implementation of Scaled Dot Product Attention (SDPA) with a custom Metal kernel implementation, ported from PyTorch and influenced by MLX.

Key Changes

  • Custom Metal kernel: Implements a one-pass SDPA algorithm embedded as a 200+ line inline shader with template instantiations for float, half, and bfloat types across head dimensions of 64, 96, and 128
  • Enhanced Metal API: Adds new setArg overloads for uint32_t, float, bool, and uint3 types, plus a new dispatchThreadgroups method for explicit threadgroup dispatch
  • Stride-aware computation: The new kernel handles transposed tensor layouts by decomposing batch and head indices and using explicit stride information

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 14 comments.

File Description
backends/apple/metal/runtime/shims/et_metal_ops.mm Replaces ~400 lines of MPSGraph code with inline Metal shader source and direct kernel dispatch; adds shader library caching
backends/apple/metal/runtime/shims/et_metal.mm Implements new setArg overloads for scalar types and uint3 structs; adds dispatchThreadgroups for explicit threadgroup control
backends/apple/metal/runtime/shims/et_metal.h Declares new Metal kernel function methods for argument setting and threadgroup dispatch

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

@mergennachin mergennachin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +1315 to +1318
if (head_dim != 64 && head_dim != 96 && head_dim != 128) {
ET_LOG(Error, "aoti_torch_mps__scaled_dot_product_attention_math_for_mps: Unsupported head_dim %lld (must be 64, 96, or 128)", head_dim);
throw std::runtime_error("Unsupported head_dim for Metal SDPA kernel - must be exactly 64, 96, or 128");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the main reason for only limiting to these head sizes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The head size is a compile-time constant. See the template instantiation lines:

#define INSTANTIATE_SDPA_VECTOR_HEADS(DTYPE)        \
  INSTANTIATE_SDPA_VECTOR(DTYPE, 64, 64);           \
  INSTANTIATE_SDPA_VECTOR(DTYPE, 96, 96);           \
  INSTANTIATE_SDPA_VECTOR(DTYPE, 128, 128);

This is also how MLX does it, and I am pretty sure it is done for performance reasons (loop unrolling, register optimization, etc).

@pytorch pytorch deleted a comment from Copilot AI Dec 8, 2025
@pytorch pytorch deleted a comment from Copilot AI Dec 8, 2025
@pytorch pytorch deleted a comment from Copilot AI Dec 8, 2025
@pytorch pytorch deleted a comment from Copilot AI Dec 8, 2025
@pytorch pytorch deleted a comment from Copilot AI Dec 8, 2025
[ghstack-poisoned]
manuelcandales added a commit that referenced this pull request Dec 8, 2025
ghstack-source-id: c614a71
ghstack-comment-id: 3614336034
Pull-Request: #16086
@manuelcandales
Copy link
Contributor Author

  • Are you ignoring "is_causal" altogether?

Yes, for now. I added a check, returning Error:NotImplemented when is_causal is True

@manuelcandales
Copy link
Contributor Author

  • Can we compile the metal shader at build time? Isn't it jit compiling?

We don't have yet that infrastructure in the Metal backend. I don't think that's a priority though, because AOTI generated kernels are also jit compiled, and the amount of code we are jit compiling there is much more than what we are compiling in the single SDPA kernel. So, I don't think it would make any difference right now.

@manuelcandales manuelcandales merged commit e793135 into main Dec 8, 2025
158 of 159 checks passed
@manuelcandales manuelcandales deleted the gh/manuelcandales/150/head branch December 8, 2025 21:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants