Skip to content

Conversation

@gdewael
Copy link
Contributor

@gdewael gdewael commented Dec 12, 2025

This pull request achieves two things:

(1) All Transformer functionality is moved from the default nn.Transformer[En|De]coder to custom layers in depthcharge.transformers.layers and depthcharge.transformers.attn. I have kept the API and functionality to these custom layers as close to the PyTorch ones as possible. Importantly, numerical equivalence (both forward and grads) are ensured via tests/unit_tests/test_transformers/test_custom_layers.py. The Custom layers open the door to various additions which are currently missing from the PyTorch library. I'll name a few here:

(I can go on)

(2) To illustrate, I have added Rotary Embedding functionality to both the SpectrumTransformerEncoder and the AnalyteTransformerEncoder and AnalyteTransformerDecoder as an optional flag.

Adding both of these in a single pull request is perhaps a bit much. Please let me know if it would be preferable to split them up.

Important note: I have currently only verified whether the Transformer[En|De]coder numerical equivalence tests pass on PyTorch version 2.9.1 (current stable). I believe it is better to automate this process across a number of PyTorch versions. The torch version specifier in pyproject.toml could then be adjusted based on where it starts failing. I have little experience with CI, so I'm open to suggestions and pointers for how to integrate this.

@codecov
Copy link

codecov bot commented Dec 12, 2025

Codecov Report

❌ Patch coverage is 87.57062% with 44 lines in your changes missing coverage. Please review.
✅ Project coverage is 94.15%. Comparing base (3938944) to head (fb72146).
⚠️ Report is 4 commits behind head on main.

Files with missing lines Patch % Lines
depthcharge/transformers/layers.py 86.85% 28 Missing ⚠️
depthcharge/transformers/attn.py 85.33% 11 Missing ⚠️
depthcharge/utils.py 71.42% 4 Missing ⚠️
depthcharge/transformers/spectra.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #78      +/-   ##
==========================================
- Coverage   96.71%   94.15%   -2.56%     
==========================================
  Files          24       27       +3     
  Lines        1004     1352     +348     
==========================================
+ Hits          971     1273     +302     
- Misses         33       79      +46     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request migrates transformer functionality from PyTorch's default nn.Transformer[En|De]coder to custom implementations in depthcharge.transformers.layers and adds rotary position embedding (RoPE) support. The custom layers maintain API compatibility and numerical equivalence with PyTorch implementations while enabling advanced features like FlashAttention via SDPA, rotary embeddings, and flexible attention masking. Comprehensive unit tests verify numerical equivalence between the custom and PyTorch implementations for both forward and backward passes.

Key changes:

  • Custom TransformerEncoder, TransformerDecoder, and their layer components with switchable attention backends ("sdpa" or "native")
  • New RotaryEmbedding encoder that applies rotary positional embeddings to query and key tensors
  • Integration of rotary embedding support into SpectrumTransformerEncoder, AnalyteTransformerEncoder, and AnalyteTransformerDecoder via optional parameters

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 9 comments.

Show a summary per file
File Description
depthcharge/encoders/rotary.py New rotary position embedding implementation using RoFormer/RoPE methodology
depthcharge/encoders/__init__.py Exports RotaryEmbedding class
depthcharge/transformers/attn.py Custom multi-head attention using SDPA with rotary embedding support and configurable kernel backends
depthcharge/transformers/layers.py Custom transformer encoder/decoder layers and stacks with numerical equivalence to PyTorch defaults
depthcharge/transformers/spectra.py Updates SpectrumTransformerEncoder to use custom layers and support rotary embeddings via m/z-based positions
depthcharge/transformers/analytes.py Updates analyte transformers to use custom layers with rotary embedding support; adds unused helper methods
depthcharge/transformers/__init__.py Exports custom transformer components
depthcharge/utils.py Adds combine_key_pad_and_attn utility for merging float and binary attention masks
tests/unit_tests/test_transformers/test_custom_layers.py Comprehensive numerical equivalence tests for custom vs PyTorch implementations
tests/unit_tests/test_transformers/test_rotary_embedding.py Tests for rotary embedding behavior including shift invariance and integration with transformers
docs/api/encoders.md Adds documentation reference for RotaryEmbedding

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@bittremieux
Copy link
Collaborator

@gdewael As a first pass, can you resolve the Ruff linting issue and address these Copilot comments? Generally only some small changes needed.

gdewael and others added 3 commits December 12, 2025 14:20
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
@gdewael
Copy link
Contributor Author

gdewael commented Dec 12, 2025

The linting issue is annoying: The Github Actions workflow pulls the most recent ruff version (ruff-0.14.9), but I was using the pre-commit hooks, as suggested in the depthcharge docs, which are currently using using ruff-v0.3.1 (see here). I'm not sure if this pull request is the place to make the Github Actions workflow file compatible with the .pre-commit-config.yaml file, but I wanted to document the issue.

@wfondrie
Copy link
Owner

Ah that's annoying - I'll update the hooks.

Copy link
Owner

@wfondrie wfondrie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code looks great and is well tested, but I think we need to discuss a few conceptual things.

Implementation thoughts:

One thought I've had is that SpectrumTransformerEncoder and AnalyteTransformerEncoder should just be wrappers around any arbitrary Transformer layer - the default could be the vanilla PyTorch implementation.

There are two possible implementations that I can imagine:

  1. The wrapper has a signature that takes the initialized layer modules.
  2. The wrapper has a signature that take the uninitialized layer class and has **kwargs, which are forwarded to the specified class.

I could see either of these working 🤔

Requests:

Before merging this PR, I would love to see some data both on model performance using the rotary embeddings / custom layers vs vanilla PyTorch transformer layers. I would also like to see a speed comparison for training and inference. Thanks!

Comment on lines 176 to 202
if src_mask is not None:
assert (
(list(src_mask.size()) == [src.size(1), src.size(1)])
or (
list(src_mask.size())
== [src.size(0), src.size(1), src.size(1)]
)
or (
list(src_mask.size())
== [src.size(0), self.nhead, src.size(1), src.size(1)]
)
), (
"`src_mask` should have size (seq_len, seq_len), "
"(batch_size, seq_len, seq_len), or "
"(batch_size, nhead, seq_len, seq_len)"
)
assert src_mask.dtype == src.dtype, (
"`src_mask` should have same dtype as `src`"
)
if src_key_padding_mask is not None:
assert list(src_key_padding_mask.size()) == [
src.size(0),
src.size(1),
], "`src_key_padding_mask` should have size (batch_size, seq_len)"
assert src_key_padding_mask.dtype == torch.bool, (
"`src_key_padding_mask` should be bool"
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: Are these size assertions necessary? Won't an error be raised anyway if they are the wrong size?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd recommend a try...except pattern instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The asserts are resolved as part of 0dd98df (changed into ValueErrors and TypeErrors).

The size checks are not strictly necessary, the wrong mask sizes do raise an error in F.scaled_dot_product_attention, but I find them not so informative for end-users: RuntimeError: The size of tensor a (...) must match the size of tensor b (...) at non-singleton dimension ....
Seeing as torch.nn.Transformer[En|De]coder also perform similar checks to raise errors, e.g.: RuntimeError: The shape of the 2D attn_mask is torch.Size([65, 65]), but should be (64, 64)., I would prefer to keep them in.

@gdewael
Copy link
Contributor Author

gdewael commented Dec 22, 2025

About wrapping arbitrary Transformer[En|De]coder layer objects:

I think this change would be easy to elegantly implement. Some extra thoughts/reservations about this change:

  • Currently, the forward pass documentation of all Transformer objects documents what kind of masks/input args are expected/supported. I will note there are already some slight differences with the default torch.nn.Transformer[En|De]coder APIs (listed as part of a drop-down below). If a custom layer is passed, instead of providing detailed documentation on expected inputs, a link has to be provided to the respective documentations of the layers instead. This decreases accessibility and usability a bit for end users who have to start "digging" more to find out what they should pass to the [Spectrum|Analyte]Transformers. An alternative would be to make sure the API is completely consistent between both types of layers, but that would constitute maintenance/upkeep hell.
torch vs proposed layer API
  • attn_mask: torch version expects (b*nh, s1, s2) or (s1, s2), the proposed implementation expects (b, nh, s1, s2), (b, s1, s2) or (s1, s2) (This is what is expected by scaled_dot_product_attention and I personally feel that the collapsed batch_size and number of heads dimension is confusion).
  • The proposed implementation assumes attn_mask should be float (instead of flexibly float or binary).
  • is_causal: in the torch version, this argument just provides a hint that the provided attn_mask is a causal one (this is done to be able to use a more efficient implementation path), in the proposed implementation, passing is_causal=True actually implements a causal mask (in this way, passing both floating point attn_mask with is_causal=True becomes possible).
  • torch version does not support passing a floating point attn_mask with a binary key_padding_mask, whereas the proposed implementation does.
  • If both layers are numerically equivalent in both forwards and gradients, the only actual problem this approach solves is that users can now additionally provide their own TransformerLayers (instead of the ones provided in the proposed implementation and by torch). I'm not familiar enough with the userbase of depthcharge to know if this is a realistic use-case.

Model performance (predictive and speed)

I ran a test using metabolomics MS/MS (MassSpecGym) molecular retrieval (transformer encodes spectrum + contrastive loss on similarity with embedded molecular fingerprints). I'll give you the results first and at the bottom provide experiment details.
I'll be a bit lazy here in immediately just paste in tensorboard graphs:

Training loss image

Orange - current depthcharge main / Dark blue - PR version without Rotary / Light blue - PR version with Rotary.

These results are expected because the unit tests ensure forward passes and gradients are equal. The remaining differences are up to initialization and batch stochasticity.

Validation performance (hit rate @ 20) These should not be taken too seriously as I intentionally chose the model size ridiculously large in order to accentuate run-time speed differences, so there is a lot of instability, but they indicate that the implementation does not break the models. image

Orange - current depthcharge main / Dark blue - PR version without Rotary / Light blue - PR version with Rotary

Runtime (4 epochs on a RTX 3090)
Model Runtime
depthcharge main 1.842 hr
PR version without Rotary 1.838 hr
PR version with Rotary 1.931 hr

A note on runtimes: F.scaled_dot_product_attention allows Flash Attention but it requires float16 or bfloat16 inputs (not currently supported by depthcharge, see #76) and does not allow key_padding_masks (currently passed by default in all depthcharge Transformers). The proposed transformer implementation hence does permit some additional performance gains in the future given some code changes.

Implementation details:

This test uses a quickly spun up training script building on some research code (apologies for the state of it). It uses my own dataloader to load MassSpecGym and couples it with depthcharge to encode spectra. It was run on an RTX 3090 using torch v2.9.1. It uses a 20 layer d_model=512 transformer encoder (42.6M parameters).

Reproduction steps

First, download test_script.py. Then:

git clone https://github.com/gdewael/ms-mole.git
conda create --name "msmole" python==3.11
conda activate msmole
pip install -e ms-mole/
pip install -U torch
git clone https://github.com/gdewael/depthcharge.git
cd depthcharge
git fetch
git switch migrate_to_sdpa
pip install -e .
git switch main
cd ..

Then, run depthcharge main version:

python test_script.py

Then, run PR version:

cd depthcharge
git switch migrate_to_sdpa
cd ..
python test_script.py

Then, uncomment RotaryEmbedding from line 139 in test_script.py and run:

python test_script.py

Note that this does not separately test training/inference, but it is sort of implicitly in there because of the validation epochs.

Extra changes

I included a manual-trigger GitHub Actions Workflow that tests out different PyTorch versions. My proposed scaled_dot_product Transformer Layer variant is only supported from torch 2.3.0 due to the usage of a torch.nn.attention context manager, but other than that, numerical (forward+grad) equivalence is guaranteed for all PyTorch versions (see https://github.com/gdewael/depthcharge/actions/runs/20366769005). This new Actions Workflow should make it easier to periodically test equivalence as new PyTorch versions are released.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants