Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
2adcd0e
Added custom TransformersEncoder/DecoderLayer
gdewael Dec 2, 2025
ef47841
Add custom trf layer unit tests
gdewael Dec 2, 2025
576112e
enable dropout in unit tests
gdewael Dec 2, 2025
7d26c2a
Implement TransformerEncoder and TransformerDecoder
gdewael Dec 2, 2025
87b5cb7
Plug in the custom transformers into definitions
gdewael Dec 2, 2025
8e0c877
amend tests
gdewael Dec 2, 2025
6013a88
Custom SelfAttention module
gdewael Dec 2, 2025
d269438
migrate transformers to sdpa backend default
gdewael Dec 3, 2025
ae84878
adjust tests
gdewael Dec 3, 2025
a01fae7
changed defaults
gdewael Dec 3, 2025
eb13c30
changed defaults
gdewael Dec 3, 2025
609d671
typo
gdewael Dec 3, 2025
82e22ef
rework MHA
gdewael Dec 3, 2025
1386c1a
Amended and simplified tests
gdewael Dec 4, 2025
4c0ec2a
catching some special cases and expanding test coverage
gdewael Dec 4, 2025
9de0966
only allow nested with sdpa backend
gdewael Dec 4, 2025
5b88728
implemented rotary, integration into analytes and spectra pending
gdewael Dec 8, 2025
7bce19d
integrate to frontend
gdewael Dec 9, 2025
3fc273b
rollback nested tensor support
gdewael Dec 12, 2025
3ef9e4c
fixing a plethora of linting issues
gdewael Dec 12, 2025
70b6e7c
Added RotaryEmbedding to docs API
gdewael Dec 12, 2025
e64d50d
fix embarassing mistake
gdewael Dec 12, 2025
56bb8fc
fix copilot review and linting issues.
gdewael Dec 12, 2025
8695963
ruff format 0.14.9
gdewael Dec 12, 2025
ce0681d
Update precommit version
wfondrie Dec 16, 2025
0dd98df
exchange asserts for errors
gdewael Dec 18, 2025
20f00a5
testing pytorch versions
gdewael Dec 19, 2025
72525f0
Trigger workflow in main fork
gdewael Dec 19, 2025
a38da7a
move pytorch version testing to manual trigger workflow
gdewael Dec 19, 2025
fb72146
fix pytorch version testing workflow
gdewael Dec 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions .github/workflows/test-pytorch-versions.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
name: Test PyTorch Versions

on:
workflow_dispatch: # Manual trigger only

jobs:
build:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false # Don't cancel other jobs if one fails
matrix:
os: [ubuntu-latest, windows-latest, macos-latest]
torch-version: ["2.0.1", "2.1.2", "2.2.2", "2.3.1", "2.4.1", "2.5.1", "2.6.0", "2.7.1", "2.8.0", "2.9.1"]

steps:
- uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5

- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: "3.12"

- name: Create constraint file for PyTorch ${{ matrix.torch-version }}
shell: bash
run: |
echo "torch==${{ matrix.torch-version }}" > constraints.txt

- name: Install dependencies with PyTorch constraint
shell: bash
run: |
uv pip install --system -e . --group dev \
--constraint constraints.txt \
--extra-index-url https://download.pytorch.org/whl/cpu

- name: Verify PyTorch version
shell: bash
run: |
python -c "import torch; print(f'PyTorch version: {torch.__version__}')"

- name: Run transformer tests
shell: bash
run: |
uv run --no-sync pytest \
--verbose \
tests/unit_tests/test_transformers/
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
rev: v6.0.0
hooks:
- id: check-toml
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
- id: detect-private-key
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.3.1
rev: v0.14.9
hooks:
# Run the linter.
- id: ruff
Expand Down
1 change: 1 addition & 0 deletions depthcharge/encoders/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Avalailable encoders."""

from .rotary import RotaryEmbedding
from .sinusoidal import (
FloatEncoder,
PeakEncoder,
Expand Down
107 changes: 107 additions & 0 deletions depthcharge/encoders/rotary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""Rotary Position Embeddings (RoPE) for Transformers."""

import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat


class RotaryEmbedding(nn.Module):
"""Rotary embedding as in RoFormer / RoPE.

Applies rotary positional embeddings to query and key tensors in attention.
Unlike additive positional encodings, RoPE rotates the Q and K vectors
based on their position, preserving relative position information in the
attention computation.

Parameters
----------
head_dim : int
Dimension of each attention head.
min_wavelength : float, optional
The minimum wavelength to use.
max_wavelength : float, optional
The maximum wavelength to use.

"""

def __init__(
self,
head_dim: int,
min_wavelength: float = 2 * np.pi,
max_wavelength: float = 20_000 * np.pi,
) -> None:
"""Initialize RotaryEmbedding."""
super().__init__()
self.head_dim = head_dim

base = min_wavelength / (2 * np.pi)
scale = max_wavelength / min_wavelength
thetas = (
1.0
/ base
/ (scale ** (torch.arange(0, head_dim, 2).float() / head_dim))
)
self.register_buffer("thetas", thetas)

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
positions: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""Apply rotary embeddings to query and key tensors.

Parameters
----------
q : torch.Tensor
Query tensor of shape (batch, num_heads, seq_len, head_dim)
k : torch.Tensor
Key tensor of shape (batch, num_heads, seq_len, head_dim)
positions : torch.Tensor, optional
Position values for each element in the sequence.
- If None: Use integer positions torch.arange(seq_len)
- Shape: (batch, seq_len) for per-sample positions
Examples: integer positions [0,1,2,..],
or m/z values [100.5, 150.2, ...]
Default: None

Returns
-------
q_rotated : torch.Tensor
Query with rotary embeddings applied, same shape as input
k_rotated : torch.Tensor
Key with rotary embeddings applied, same shape as input

"""
pos_to_use = self._default_pos(q) if positions is None else positions
if pos_to_use.ndim == 2:
pos_to_use = pos_to_use[:, None].expand(-1, q.size(1), -1)

sin, cos = self._get_rotations(pos_to_use, self.thetas)
q_rot = q * cos.to(q) + self._rotate_every_two(q) * sin.to(q)
k_rot = k * cos.to(k) + self._rotate_every_two(k) * sin.to(k)

return q_rot, k_rot

@staticmethod
def _default_pos(x):
return torch.arange(x.size(-2), device=x.device).expand(*x.shape[:-1])

@staticmethod
def _get_rotations(pos, thetas):
mthetas = pos[..., None] * thetas # (..., seq_len, head_dim/2)

sin, cos = (
repeat(t, "b ... h -> b ... (h j)", j=2).to(thetas)
for t in (mthetas.sin(), mthetas.cos())
)
return sin, cos

@staticmethod
def _rotate_every_two(x):
x = x.clone()
x = rearrange(x, "... (d j) -> ... d j", j=2)
x1, x2 = x.unbind(dim=-1)
x = torch.stack((-x2, x1), dim=-1)
return rearrange(x, "... d j -> ... (d j)")
7 changes: 7 additions & 0 deletions depthcharge/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,11 @@
AnalyteTransformerDecoder,
AnalyteTransformerEncoder,
)
from .attn import MultiheadAttention
from .layers import (
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)
from .spectra import SpectrumTransformerEncoder
83 changes: 56 additions & 27 deletions depthcharge/transformers/analytes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

import torch

from .. import utils
from ..encoders import PositionalEncoder
from ..mixins import ModelMixin, TransformerMixin
from ..tokenizers import Tokenizer
from .layers import (
TransformerDecoder,
TransformerDecoderLayer,
TransformerEncoder,
TransformerEncoderLayer,
)


class _AnalyteTransformer(torch.nn.Module, ModelMixin, TransformerMixin):
Expand Down Expand Up @@ -158,6 +163,12 @@ class AnalyteTransformerEncoder(_AnalyteTransformer):
padding_int : int, optional
The index that represents padding in the input sequence. Required
only if ``n_tokens`` was provided as an ``int``.
attention_backend : str, optional
Attention implementation: "sdpa" (default) or "native".
rotary_embedding : RotaryEmbedding, optional
Rotary position embedding module to apply to Q and K in self-attention.
Only compatible with `attention_backend="sdpa"`.
If ``None``, no rotary embeddings are used.

"""

Expand All @@ -171,6 +182,8 @@ def __init__(
dropout: float = 0,
positional_encoder: PositionalEncoder | bool = True,
padding_int: int | None = None,
attention_backend: str = "sdpa",
rotary_embedding: torch.nn.Module | None = None,
) -> None:
"""Initialize an AnalyteEncoder."""
super().__init__(
Expand All @@ -185,15 +198,20 @@ def __init__(
)

# The Transformer layers:
layer = torch.nn.TransformerEncoderLayer(
layer = TransformerEncoderLayer(
d_model=self.d_model,
nhead=self.nhead,
dim_feedforward=self.dim_feedforward,
batch_first=True,
dropout=self.dropout,
attention_backend=attention_backend,
rotary_embedding=rotary_embedding,
enable_sdpa_math=True,
enable_sdpa_mem_efficient=True,
enable_sdpa_flash_attention=True,
)

self.transformer_encoder = torch.nn.TransformerEncoder(
self.transformer_encoder = TransformerEncoder(
layer,
num_layers=n_layers,
)
Expand All @@ -216,8 +234,8 @@ def forward(
Additional data. These may be used by overwriting the
`global_token_hook()` method in a subclass.
mask : torch.Tensor
Passed to `torch.nn.TransformerEncoder.forward()`. The mask
for the sequence.
Passed to `depthcharge.transformers.TransformerEncoder.forward()`.
The mask for the sequence.
**kwargs : dict
Additional data fields. These may be used by overwriting
the `global_token_hook()` method in a subclass.
Expand Down Expand Up @@ -278,6 +296,12 @@ class AnalyteTransformerDecoder(_AnalyteTransformer):
padding_int : int, optional
The index that represents padding in the input sequence. Required
only if ``n_tokens`` was provided as an ``int``.
attention_backend : str, optional
Attention implementation: "sdpa" (default) or "native".
rotary_embedding : RotaryEmbedding, optional
Rotary position embedding module to apply to Q and K in self-attention.
Only compatible with `attention_backend="sdpa"`.
If ``None``, no rotary embeddings are used.

"""

Expand All @@ -291,6 +315,8 @@ def __init__(
dropout: float = 0,
positional_encoder: PositionalEncoder | bool = True,
padding_int: int | None = None,
attention_backend: str = "sdpa",
rotary_embedding: torch.nn.Module | None = None,
) -> None:
"""Initialize a AnalyteDecoder."""
super().__init__(
Expand All @@ -305,15 +331,20 @@ def __init__(
)

# Additional model components
layer = torch.nn.TransformerDecoderLayer(
layer = TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=dim_feedforward,
batch_first=True,
dropout=dropout,
attention_backend=attention_backend,
rotary_embedding=rotary_embedding,
enable_sdpa_math=True,
enable_sdpa_mem_efficient=True,
enable_sdpa_flash_attention=True,
)

self.transformer_decoder = torch.nn.TransformerDecoder(
self.transformer_decoder = TransformerDecoder(
layer,
num_layers=n_layers,
)
Expand Down Expand Up @@ -348,15 +379,15 @@ def embed(
The representations from a ``TransformerEncoder``, such as a
``SpectrumTransformerEncoder``.
memory_key_padding_mask : torch.Tensor of shape (batch_size, len_seq)
Passed to `torch.nn.TransformerEncoder.forward()`. The mask that
indicates which elements of ``memory`` are padding.
Passed to `depthcharge.transformers.TransformerEncoder.forward()`.
The mask that indicates which elements of ``memory`` are padding.
memory_mask : torch.Tensor
Passed to `torch.nn.TransformerEncoder.forward()`. The mask
for the memory sequence.
Passed to `depthcharge.transformers.TransformerEncoder.forward()`.
The mask for the memory sequence.
tgt_mask : torch.Tensor or None
Passed to `torch.nn.TransformerEncoder.forward()`. The default
is a mask that is suitable for predicting the next element in
the sequence.
Passed to `depthcharge.transformers.TransformerEncoder.forward()`.
The default is a mask that is suitable for predicting the next
element in the sequence.
**kwargs : dict
Additional data fields. These may be used by overwriting
the `global_token_hook()` method in a subclass.
Expand Down Expand Up @@ -388,20 +419,18 @@ def embed(
# Feed through model:
encoded = self.positional_encoder(encoded)

if tgt_mask is None:
tgt_mask = utils.generate_tgt_mask(encoded.shape[1]).to(
self.device
)

return self.transformer_decoder(
output = self.transformer_decoder(
tgt=encoded,
memory=memory,
tgt_mask=tgt_mask,
tgt_is_causal=True,
tgt_key_padding_mask=tgt_key_padding_mask,
memory_key_padding_mask=memory_key_padding_mask,
memory_mask=memory_mask,
)

return output

def score_embeddings(self, embeddings: torch.Tensor) -> torch.Tensor:
"""Score the embeddings to find the most confident tokens.

Expand Down Expand Up @@ -445,15 +474,15 @@ def forward(
The representations from a ``TransformerEncoder``, such as a
``SpectrumTransformerEncoder``.
memory_key_padding_mask : torch.Tensor of shape (batch_size, len_seq)
Passed to `torch.nn.TransformerEncoder.forward()`. The mask that
indicates which elements of ``memory`` are padding.
Passed to `depthcharge.transformers.TransformerEncoder.forward()`.
The mask that indicates which elements of ``memory`` are padding.
memory_mask : torch.Tensor
Passed to `torch.nn.TransformerEncoder.forward()`. The mask
for the memory sequence.
Passed to `depthcharge.transformers.TransformerEncoder.forward()`.
The mask for the memory sequence.
tgt_mask : torch.Tensor or None
Passed to `torch.nn.TransformerEncoder.forward()`. The default
is a mask that is suitable for predicting the next element in
the sequence.
Passed to `depthcharge.transformers.TransformerEncoder.forward()`.
The default is a mask that is suitable for predicting the next
element in the sequence.
**kwargs : dict
Additional data fields. These may be used by overwriting
the `global_token_hook()` method in a subclass.
Expand Down
Loading
Loading