Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 60 additions & 11 deletions depthcharge/encoders/sinusoidal.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Simple encoders for input into Transformers and the like."""

import math
import warnings

import einops
import numpy as np
Expand Down Expand Up @@ -43,16 +44,16 @@ def __init__(

self.learnable_wavelengths = learnable_wavelengths

# Get dimensions for equations:
d_sin = math.ceil(d_model / 2)
d_cos = d_model - d_sin
self._min_wavelength = min_wavelength
self._max_wavelength = max_wavelength
self._d_model = d_model

base = min_wavelength / (2 * np.pi)
scale = max_wavelength / min_wavelength
sin_exp = torch.arange(0, d_sin).float() / (d_sin - 1)
cos_exp = (torch.arange(d_sin, d_model).float() - d_sin) / (d_cos - 1)
sin_term = base * (scale**sin_exp)
cos_term = base * (scale**cos_exp)
# Dummy buffer to track the model's dtype for output conversion
self.register_buffer(
"_dtype_tracker", torch.zeros(1).float(), persistent=False
)

sin_term, cos_term = self._compute_wavelength_terms()

if not self.learnable_wavelengths:
self.register_buffer("sin_term", sin_term)
Expand All @@ -75,9 +76,46 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:
The encoded features for the floating point numbers.

"""
if X.dtype.itemsize < torch.float32.itemsize:
warnings.warn(
f"Input tensor has dtype {X.dtype} which is lower precision "
f"than float32. This may lead to numerical precision issues "
f"with large input values (e.g. mz arrays). "
f"It is highly recommended to use float32 inputs.",
)

sin_mz = torch.sin(X[:, :, None] / self.sin_term)
cos_mz = torch.cos(X[:, :, None] / self.cos_term)
return torch.cat([sin_mz, cos_mz], axis=-1)
encoded = torch.cat([sin_mz, cos_mz], axis=-1)

return encoded.to(self._dtype_tracker.dtype)

def _compute_wavelength_terms(self) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute sin and cos wavelength terms for encoding."""
d_sin = math.ceil(self._d_model / 2)
d_cos = self._d_model - d_sin
base = self._min_wavelength / (2 * np.pi)
scale = self._max_wavelength / self._min_wavelength
sin_exp = torch.arange(0, d_sin).float() / (d_sin - 1)
cos_exp = (torch.arange(d_sin, self._d_model).float() - d_sin) / (
d_cos - 1
)
sin_term = base * (scale**sin_exp)
cos_term = base * (scale**cos_exp)
return sin_term, cos_term

def _apply(self, fn):
"""Override _apply to keep encoding buffers at float32 precision."""
super()._apply(fn)

# if not learnable, reconstruct the buffers at float32 precision
if not self.learnable_wavelengths:
device = self.sin_term.device
sin_term, cos_term = self._compute_wavelength_terms()
self.sin_term = sin_term.to(device)
self.cos_term = cos_term.to(device)

return self


class PeakEncoder(torch.nn.Module):
Expand Down Expand Up @@ -213,5 +251,16 @@ def forward(self, X: torch.Tensor) -> torch.Tensor:

sin_pos = torch.sin(sin_in / self.sin_term)
cos_pos = torch.cos(cos_in / self.cos_term)
encoded = torch.cat([sin_pos, cos_pos], axis=2)
encoded = torch.cat([sin_pos, cos_pos], axis=2).to(
self._dtype_tracker.dtype
)

# Warn if input dtype doesn't match model dtype
if X.dtype != encoded.dtype:
warnings.warn(
f"Input tensor dtype ({X.dtype}) does not match model dtype "
f"({encoded.dtype}). The addition will implicitly cast to a "
f"common dtype, which may cause unexpected behavior.",
)

return encoded + X
82 changes: 82 additions & 0 deletions tests/unit_tests/test_encoders/test_sinusoidal.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Test the encoders."""

import warnings

import numpy as np
import pytest
import torch
Expand Down Expand Up @@ -96,3 +98,83 @@ def test_both_sinusoid():
assert isinstance(enc.int_encoder, FloatEncoder)
assert isinstance(enc.int_encoder.sin_term, torch.nn.Parameter)
assert isinstance(enc.mz_encoder.sin_term, torch.nn.Parameter)


def test_float_encoder_dtype_conversion():
"""Test FloatEncoder buffers stay at float32 after dtype conversion."""
enc_bf16 = FloatEncoder(8, 0.1, 10).bfloat16()
X = torch.tensor([[0.0, 0.1, 10, 0.256]])

# In the case of static wavelens, sin/cos_term will always be float32
assert enc_bf16.sin_term.dtype == torch.float32
assert enc_bf16.cos_term.dtype == torch.float32
assert enc_bf16._dtype_tracker.dtype == torch.bfloat16
Y = enc_bf16(X.bfloat16())
assert Y.dtype == torch.bfloat16


def test_float_encoder_learnable_dtype_conversion():
"""Test learnable parameters follow dtype conversion."""
enc_bf16 = FloatEncoder(8, 0.1, 10, learnable_wavelengths=True).bfloat16()
X = torch.tensor([[0.0, 0.1, 10, 0.256]])

# In the case of learnable wavelengths, sin/cos_term will be cast
assert enc_bf16.sin_term.dtype == torch.bfloat16
assert enc_bf16.cos_term.dtype == torch.bfloat16
assert enc_bf16._dtype_tracker.dtype == torch.bfloat16
Y = enc_bf16(X.bfloat16())
assert Y.dtype == torch.bfloat16


def test_float_encoder_precision_warning():
"""Test warning is issued for lower precision inputs."""
enc = FloatEncoder(8, 0.1, 10)
X = torch.tensor([[0.0, 0.1, 10, 0.256]])

# Should warn for bfloat16
with pytest.warns(UserWarning, match="lower precision than float32"):
enc(X.bfloat16())

# Should warn for float16
with pytest.warns(UserWarning, match="lower precision than float32"):
enc(X.half())

# Should not warn for float32
with warnings.catch_warnings():
warnings.simplefilter("error")
enc(X.float())


def test_positional_encoder_dtype_mismatch_warning():
"""Test PositionalEncoder warns on dtype mismatch."""
enc_bf16 = PositionalEncoder(8, 1, 8).bfloat16()
X = torch.zeros(2, 5, 8)

assert enc_bf16.sin_term.dtype == torch.float32

with pytest.warns(UserWarning, match="does not match model dtype"):
enc_bf16(X.float())

with warnings.catch_warnings():
warnings.simplefilter("error")
enc_bf16(X.bfloat16())


def test_float_encoder_numerical_stability():
"""Test that internal computation uses float32 for numerical stability."""
enc = FloatEncoder(8, 0.001, 10000)

# Use moderate values to test stability
X = torch.tensor([[10.0, 50.0, 100.0]])

Y_f32 = enc(X.float())

enc_bf16 = enc.bfloat16()
Y_bf16_f32_input = enc_bf16(X.float())

torch.testing.assert_close(
Y_f32,
Y_bf16_f32_input.float(),
rtol=1e-3,
atol=1e-3,
)
Loading