-
Notifications
You must be signed in to change notification settings - Fork 23
Migrate Transformer functionality to custom layers and add Rotary Embedding support #78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #78 +/- ##
==========================================
- Coverage 96.71% 94.15% -2.56%
==========================================
Files 24 27 +3
Lines 1004 1352 +348
==========================================
+ Hits 971 1273 +302
- Misses 33 79 +46 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This pull request migrates transformer functionality from PyTorch's default nn.Transformer[En|De]coder to custom implementations in depthcharge.transformers.layers and adds rotary position embedding (RoPE) support. The custom layers maintain API compatibility and numerical equivalence with PyTorch implementations while enabling advanced features like FlashAttention via SDPA, rotary embeddings, and flexible attention masking. Comprehensive unit tests verify numerical equivalence between the custom and PyTorch implementations for both forward and backward passes.
Key changes:
- Custom
TransformerEncoder,TransformerDecoder, and their layer components with switchable attention backends ("sdpa" or "native") - New
RotaryEmbeddingencoder that applies rotary positional embeddings to query and key tensors - Integration of rotary embedding support into
SpectrumTransformerEncoder,AnalyteTransformerEncoder, andAnalyteTransformerDecodervia optional parameters
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
depthcharge/encoders/rotary.py |
New rotary position embedding implementation using RoFormer/RoPE methodology |
depthcharge/encoders/__init__.py |
Exports RotaryEmbedding class |
depthcharge/transformers/attn.py |
Custom multi-head attention using SDPA with rotary embedding support and configurable kernel backends |
depthcharge/transformers/layers.py |
Custom transformer encoder/decoder layers and stacks with numerical equivalence to PyTorch defaults |
depthcharge/transformers/spectra.py |
Updates SpectrumTransformerEncoder to use custom layers and support rotary embeddings via m/z-based positions |
depthcharge/transformers/analytes.py |
Updates analyte transformers to use custom layers with rotary embedding support; adds unused helper methods |
depthcharge/transformers/__init__.py |
Exports custom transformer components |
depthcharge/utils.py |
Adds combine_key_pad_and_attn utility for merging float and binary attention masks |
tests/unit_tests/test_transformers/test_custom_layers.py |
Comprehensive numerical equivalence tests for custom vs PyTorch implementations |
tests/unit_tests/test_transformers/test_rotary_embedding.py |
Tests for rotary embedding behavior including shift invariance and integration with transformers |
docs/api/encoders.md |
Adds documentation reference for RotaryEmbedding |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
@gdewael As a first pass, can you resolve the Ruff linting issue and address these Copilot comments? Generally only some small changes needed. |
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
|
The linting issue is annoying: The Github Actions workflow pulls the most recent ruff version ( |
|
Ah that's annoying - I'll update the hooks. |
wfondrie
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The code looks great and is well tested, but I think we need to discuss a few conceptual things.
Implementation thoughts:
One thought I've had is that SpectrumTransformerEncoder and AnalyteTransformerEncoder should just be wrappers around any arbitrary Transformer layer - the default could be the vanilla PyTorch implementation.
There are two possible implementations that I can imagine:
- The wrapper has a signature that takes the initialized layer modules.
- The wrapper has a signature that take the uninitialized layer class and has
**kwargs, which are forwarded to the specified class.
I could see either of these working 🤔
Requests:
Before merging this PR, I would love to see some data both on model performance using the rotary embeddings / custom layers vs vanilla PyTorch transformer layers. I would also like to see a speed comparison for training and inference. Thanks!
| if src_mask is not None: | ||
| assert ( | ||
| (list(src_mask.size()) == [src.size(1), src.size(1)]) | ||
| or ( | ||
| list(src_mask.size()) | ||
| == [src.size(0), src.size(1), src.size(1)] | ||
| ) | ||
| or ( | ||
| list(src_mask.size()) | ||
| == [src.size(0), self.nhead, src.size(1), src.size(1)] | ||
| ) | ||
| ), ( | ||
| "`src_mask` should have size (seq_len, seq_len), " | ||
| "(batch_size, seq_len, seq_len), or " | ||
| "(batch_size, nhead, seq_len, seq_len)" | ||
| ) | ||
| assert src_mask.dtype == src.dtype, ( | ||
| "`src_mask` should have same dtype as `src`" | ||
| ) | ||
| if src_key_padding_mask is not None: | ||
| assert list(src_key_padding_mask.size()) == [ | ||
| src.size(0), | ||
| src.size(1), | ||
| ], "`src_key_padding_mask` should have size (batch_size, seq_len)" | ||
| assert src_key_padding_mask.dtype == torch.bool, ( | ||
| "`src_key_padding_mask` should be bool" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
question: Are these size assertions necessary? Won't an error be raised anyway if they are the wrong size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd recommend a try...except pattern instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The asserts are resolved as part of 0dd98df (changed into ValueErrors and TypeErrors).
The size checks are not strictly necessary, the wrong mask sizes do raise an error in F.scaled_dot_product_attention, but I find them not so informative for end-users: RuntimeError: The size of tensor a (...) must match the size of tensor b (...) at non-singleton dimension ....
Seeing as torch.nn.Transformer[En|De]coder also perform similar checks to raise errors, e.g.: RuntimeError: The shape of the 2D attn_mask is torch.Size([65, 65]), but should be (64, 64)., I would prefer to keep them in.
About wrapping arbitrary Transformer[En|De]coder layer objects:I think this change would be easy to elegantly implement. Some extra thoughts/reservations about this change:
torch vs proposed layer API
Model performance (predictive and speed)I ran a test using metabolomics MS/MS (MassSpecGym) molecular retrieval (transformer encodes spectrum + contrastive loss on similarity with embedded molecular fingerprints). I'll give you the results first and at the bottom provide experiment details. Training loss
Orange - current These results are expected because the unit tests ensure forward passes and gradients are equal. The remaining differences are up to initialization and batch stochasticity. Validation performance (hit rate @ 20)These should not be taken too seriously as I intentionally chose the model size ridiculously large in order to accentuate run-time speed differences, so there is a lot of instability, but they indicate that the implementation does not break the models.
Orange - current Runtime (4 epochs on a RTX 3090)
A note on runtimes: Implementation details: This test uses a quickly spun up training script building on some research code (apologies for the state of it). It uses my own dataloader to load MassSpecGym and couples it with depthcharge to encode spectra. It was run on an RTX 3090 using Reproduction stepsFirst, download test_script.py. Then: Then, run depthcharge main version: Then, run PR version: Then, uncomment RotaryEmbedding from line 139 in Note that this does not separately test training/inference, but it is sort of implicitly in there because of the validation epochs. Extra changesI included a manual-trigger GitHub Actions Workflow that tests out different PyTorch versions. My proposed scaled_dot_product Transformer Layer variant is only supported from |


This pull request achieves two things:
(1) All Transformer functionality is moved from the default
nn.Transformer[En|De]coderto custom layers indepthcharge.transformers.layersanddepthcharge.transformers.attn. I have kept the API and functionality to these custom layers as close to the PyTorch ones as possible. Importantly, numerical equivalence (both forward and grads) are ensured viatests/unit_tests/test_transformers/test_custom_layers.py. The Custom layers open the door to various additions which are currently missing from the PyTorch library. I'll name a few here:attn_maskwhenkey_padding_maskis not passed (see here).(I can go on)
(2) To illustrate, I have added Rotary Embedding functionality to both the
SpectrumTransformerEncoderand theAnalyteTransformerEncoderandAnalyteTransformerDecoderas an optional flag.Adding both of these in a single pull request is perhaps a bit much. Please let me know if it would be preferable to split them up.
Important note: I have currently only verified whether the
Transformer[En|De]codernumerical equivalence tests pass on PyTorch version 2.9.1 (current stable). I believe it is better to automate this process across a number of PyTorch versions. The torch version specifier inpyproject.tomlcould then be adjusted based on where it starts failing. I have little experience with CI, so I'm open to suggestions and pointers for how to integrate this.