From caf27e83f2a2d846e15ff0a0fe4639d9ca2fe7c4 Mon Sep 17 00:00:00 2001 From: Gaetan De Waele Date: Mon, 22 Dec 2025 09:54:53 +0100 Subject: [PATCH] bfloat16 support --- depthcharge/encoders/sinusoidal.py | 71 +++++++++++++--- .../test_encoders/test_sinusoidal.py | 82 +++++++++++++++++++ 2 files changed, 142 insertions(+), 11 deletions(-) diff --git a/depthcharge/encoders/sinusoidal.py b/depthcharge/encoders/sinusoidal.py index e5f38ed..8ac1af9 100644 --- a/depthcharge/encoders/sinusoidal.py +++ b/depthcharge/encoders/sinusoidal.py @@ -1,6 +1,7 @@ """Simple encoders for input into Transformers and the like.""" import math +import warnings import einops import numpy as np @@ -43,16 +44,16 @@ def __init__( self.learnable_wavelengths = learnable_wavelengths - # Get dimensions for equations: - d_sin = math.ceil(d_model / 2) - d_cos = d_model - d_sin + self._min_wavelength = min_wavelength + self._max_wavelength = max_wavelength + self._d_model = d_model - base = min_wavelength / (2 * np.pi) - scale = max_wavelength / min_wavelength - sin_exp = torch.arange(0, d_sin).float() / (d_sin - 1) - cos_exp = (torch.arange(d_sin, d_model).float() - d_sin) / (d_cos - 1) - sin_term = base * (scale**sin_exp) - cos_term = base * (scale**cos_exp) + # Dummy buffer to track the model's dtype for output conversion + self.register_buffer( + "_dtype_tracker", torch.zeros(1).float(), persistent=False + ) + + sin_term, cos_term = self._compute_wavelength_terms() if not self.learnable_wavelengths: self.register_buffer("sin_term", sin_term) @@ -75,9 +76,46 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: The encoded features for the floating point numbers. """ + if X.dtype.itemsize < torch.float32.itemsize: + warnings.warn( + f"Input tensor has dtype {X.dtype} which is lower precision " + f"than float32. This may lead to numerical precision issues " + f"with large input values (e.g. mz arrays). " + f"It is highly recommended to use float32 inputs.", + ) + sin_mz = torch.sin(X[:, :, None] / self.sin_term) cos_mz = torch.cos(X[:, :, None] / self.cos_term) - return torch.cat([sin_mz, cos_mz], axis=-1) + encoded = torch.cat([sin_mz, cos_mz], axis=-1) + + return encoded.to(self._dtype_tracker.dtype) + + def _compute_wavelength_terms(self) -> tuple[torch.Tensor, torch.Tensor]: + """Compute sin and cos wavelength terms for encoding.""" + d_sin = math.ceil(self._d_model / 2) + d_cos = self._d_model - d_sin + base = self._min_wavelength / (2 * np.pi) + scale = self._max_wavelength / self._min_wavelength + sin_exp = torch.arange(0, d_sin).float() / (d_sin - 1) + cos_exp = (torch.arange(d_sin, self._d_model).float() - d_sin) / ( + d_cos - 1 + ) + sin_term = base * (scale**sin_exp) + cos_term = base * (scale**cos_exp) + return sin_term, cos_term + + def _apply(self, fn): + """Override _apply to keep encoding buffers at float32 precision.""" + super()._apply(fn) + + # if not learnable, reconstruct the buffers at float32 precision + if not self.learnable_wavelengths: + device = self.sin_term.device + sin_term, cos_term = self._compute_wavelength_terms() + self.sin_term = sin_term.to(device) + self.cos_term = cos_term.to(device) + + return self class PeakEncoder(torch.nn.Module): @@ -213,5 +251,16 @@ def forward(self, X: torch.Tensor) -> torch.Tensor: sin_pos = torch.sin(sin_in / self.sin_term) cos_pos = torch.cos(cos_in / self.cos_term) - encoded = torch.cat([sin_pos, cos_pos], axis=2) + encoded = torch.cat([sin_pos, cos_pos], axis=2).to( + self._dtype_tracker.dtype + ) + + # Warn if input dtype doesn't match model dtype + if X.dtype != encoded.dtype: + warnings.warn( + f"Input tensor dtype ({X.dtype}) does not match model dtype " + f"({encoded.dtype}). The addition will implicitly cast to a " + f"common dtype, which may cause unexpected behavior.", + ) + return encoded + X diff --git a/tests/unit_tests/test_encoders/test_sinusoidal.py b/tests/unit_tests/test_encoders/test_sinusoidal.py index bebdb63..d7ece76 100644 --- a/tests/unit_tests/test_encoders/test_sinusoidal.py +++ b/tests/unit_tests/test_encoders/test_sinusoidal.py @@ -1,5 +1,7 @@ """Test the encoders.""" +import warnings + import numpy as np import pytest import torch @@ -96,3 +98,83 @@ def test_both_sinusoid(): assert isinstance(enc.int_encoder, FloatEncoder) assert isinstance(enc.int_encoder.sin_term, torch.nn.Parameter) assert isinstance(enc.mz_encoder.sin_term, torch.nn.Parameter) + + +def test_float_encoder_dtype_conversion(): + """Test FloatEncoder buffers stay at float32 after dtype conversion.""" + enc_bf16 = FloatEncoder(8, 0.1, 10).bfloat16() + X = torch.tensor([[0.0, 0.1, 10, 0.256]]) + + # In the case of static wavelens, sin/cos_term will always be float32 + assert enc_bf16.sin_term.dtype == torch.float32 + assert enc_bf16.cos_term.dtype == torch.float32 + assert enc_bf16._dtype_tracker.dtype == torch.bfloat16 + Y = enc_bf16(X.bfloat16()) + assert Y.dtype == torch.bfloat16 + + +def test_float_encoder_learnable_dtype_conversion(): + """Test learnable parameters follow dtype conversion.""" + enc_bf16 = FloatEncoder(8, 0.1, 10, learnable_wavelengths=True).bfloat16() + X = torch.tensor([[0.0, 0.1, 10, 0.256]]) + + # In the case of learnable wavelengths, sin/cos_term will be cast + assert enc_bf16.sin_term.dtype == torch.bfloat16 + assert enc_bf16.cos_term.dtype == torch.bfloat16 + assert enc_bf16._dtype_tracker.dtype == torch.bfloat16 + Y = enc_bf16(X.bfloat16()) + assert Y.dtype == torch.bfloat16 + + +def test_float_encoder_precision_warning(): + """Test warning is issued for lower precision inputs.""" + enc = FloatEncoder(8, 0.1, 10) + X = torch.tensor([[0.0, 0.1, 10, 0.256]]) + + # Should warn for bfloat16 + with pytest.warns(UserWarning, match="lower precision than float32"): + enc(X.bfloat16()) + + # Should warn for float16 + with pytest.warns(UserWarning, match="lower precision than float32"): + enc(X.half()) + + # Should not warn for float32 + with warnings.catch_warnings(): + warnings.simplefilter("error") + enc(X.float()) + + +def test_positional_encoder_dtype_mismatch_warning(): + """Test PositionalEncoder warns on dtype mismatch.""" + enc_bf16 = PositionalEncoder(8, 1, 8).bfloat16() + X = torch.zeros(2, 5, 8) + + assert enc_bf16.sin_term.dtype == torch.float32 + + with pytest.warns(UserWarning, match="does not match model dtype"): + enc_bf16(X.float()) + + with warnings.catch_warnings(): + warnings.simplefilter("error") + enc_bf16(X.bfloat16()) + + +def test_float_encoder_numerical_stability(): + """Test that internal computation uses float32 for numerical stability.""" + enc = FloatEncoder(8, 0.001, 10000) + + # Use moderate values to test stability + X = torch.tensor([[10.0, 50.0, 100.0]]) + + Y_f32 = enc(X.float()) + + enc_bf16 = enc.bfloat16() + Y_bf16_f32_input = enc_bf16(X.float()) + + torch.testing.assert_close( + Y_f32, + Y_bf16_f32_input.float(), + rtol=1e-3, + atol=1e-3, + )