Skip to content
Merged
84 changes: 84 additions & 0 deletions dwave/plugins/torch/nn/functional.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Functional interface."""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from dwave.plugins.torch.nn.modules.kernels import Kernel

import torch

__all__ = ["maximum_mean_discrepancy_loss"]



def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor:
"""Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``.

The `squared MMD <https://dl.acm.org/doi/abs/10.5555/2188385.2188410>`_ is defined as

.. math::
MMD^2(X, Y) = |E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] |^2,

where :math:`\varphi` is a feature map associated with the kernel function
:math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the
distributions of the samples. It follows that, in terms of the kernel function, the squared MMD
can be computed as

.. math::
E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)].

If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. This motivates the squared MMD as a loss
function for minimizing the distance between the model distribution and data distribution.

For more information, see
Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012).
A kernel two-sample test. The journal of machine learning research, 13(1), 723-773.

Args:
x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor of samples from distribution p.
y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q.
kernel (Kernel): A kernel function object.

Raises:
ValueError: If the sample size of ``x`` or ``y`` is less than two.
ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size)

Returns:
torch.Tensor: The squared maximum mean discrepancy estimate.
"""
num_x = x.shape[0]
num_y = y.shape[0]
if num_x < 2 or num_y < 2:
raise ValueError(
"Sample size of ``x`` and ``y`` must be at least two. "
f"Got, respectively, {x.shape} and {y.shape}."
)
if x.shape[1:] != y.shape[1:]:
raise ValueError(
"Input dimensions must match. You are trying to compute "
f"the kernel between tensors of shape {x.shape} and {y.shape}."
)
xy = torch.cat([x, y], dim=0)
kernel_matrix = kernel(xy, xy)
kernel_xx = kernel_matrix[:num_x, :num_x]
kernel_yy = kernel_matrix[num_x:, num_x:]
kernel_xy = kernel_matrix[:num_x, num_x:]
xx = (kernel_xx.sum() - kernel_xx.trace()) / (num_x * (num_x - 1))
yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1))
xy = kernel_xy.sum() / (num_x * num_y)
return xx + yy - 2 * xy
151 changes: 151 additions & 0 deletions dwave/plugins/torch/nn/modules/kernels.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Kernel functions."""

from abc import ABC, abstractmethod

import torch
import torch.nn as nn

from dwave.plugins.torch.nn.modules.utils import store_config

__all__ = ["Kernel", "GaussianKernel"]


class Kernel(ABC, nn.Module):
"""Base class for kernels.

`Kernels <https://en.wikipedia.org/wiki/Kernel_method>`_ are functions that compute a similarity
measure between data points. Any ``Kernel`` subclass must implement the ``_kernel`` method,
which computes the kernel matrix for a given input multi-dimensional tensor with shape
(n, f1, f2, ...), where n is the number of items and f1, f2, ... are feature dimensions, so that
the output is a tensor of shape (n, n) containing the pairwise kernel values.
"""

@abstractmethod
def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Perform a pairwise kernel evaluation over samples.

Computes the kernel matrix for inputs of shape (nx, f1, f2, ..., fk) and
(ny, f1, f2, ..., fk), whose shape is (nx, ny)
containing the pairwise kernel values.

Args:
x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor.
y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor.

Returns:
torch.Tensor: A (nx, ny) tensor.
"""

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes kernels for all pairs between and within ``x`` and ``y``.

In general, ``x`` and ``y`` are (n_x, f1, f2, ..., fk)- and (n_y, f1, f2, ..., fk)-shaped
tensors, and the output is a (n_x + n_y, n_x + n_y)-shaped tensor containing pairwise kernel
evaluations.

Args:
x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor.
y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor.

Raises:
ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size)

Returns:
torch.Tensor: A (n_x + n_y, n_x + n_y) tensor.
"""
if x.shape[1:] != y.shape[1:]:
raise ValueError(
"Input dimensions must match. You are trying to compute "
f"the kernel between tensors of shape {x.shape} and {y.shape}."
)
if x.shape[0] < 2 or y.shape[0] < 2:
raise ValueError(
"Sample size of ``x`` and ``y`` must be at least two. "
f"Got, respectively, {x.shape} and {y.shape}."
)
return self._kernel(x, y)


class GaussianKernel(Kernel):
"""The Gaussian kernel.

This kernel between two data points x and y is defined as
:math:`k(x, y) = exp(-||x-y||^2 / (2 * \sigma))`, where :math:`\sigma` is the bandwidth
parameter.

This implementation considers aggregating multiple Gaussian kernels with different
bandwidths. The bandwidths are determined by multiplying a base bandwidth with a set of
multipliers. The base bandwidth can be provided directly or estimated from the data using the
average distance between samples.

Args:
n_kernels (int): Number of kernel bandwidths to use.
factor (int | float): Multiplicative factor to generate bandwidths. The bandwidths are
computed as :math:`\sigma_i = \sigma * factor^{i - n\_kernels // 2}` for
:math:`i` in ``[0, n\_kernels - 1]``. Defaults to 2.0.
bandwidth (float | None): Base bandwidth parameter. If ``None``, the bandwidth is computed
from the data (without gradients). Defaults to ``None``.
"""

@store_config
def __init__(
self, n_kernels: int, factor: int | float = 2.0, bandwidth: float | None = None
):
super().__init__()
factors = factor ** (torch.arange(n_kernels) - n_kernels // 2)
self.register_buffer("factors", factors)
self.bandwidth = bandwidth

@torch.no_grad()
def _get_bandwidth(self, distance_matrix: torch.Tensor) -> torch.Tensor | float:
"""Heuristically determine a bandwidth parameter as the average distance between samples.

Computes the base bandwidth parameter as the average distance between samples if the
bandwidth is not provided during initialization. Otherwise, returns the provided bandwidth.
See https://arxiv.org/abs/1707.07269 for more details about the motivation behind taking
the average distance as the bandwidth.

Args:
distance_matrix (torch.Tensor): A (n, n) tensor representing the pairwise
L2 distances between samples. If it is ``None`` and the bandwidth is not provided,
an error will be raised. Defaults to ``None``.

Returns:
torch.Tensor | float: The base bandwidth parameter.
"""
if self.bandwidth is None:
num_samples = distance_matrix.shape[0]
return distance_matrix.sum() / (num_samples**2 - num_samples)
return self.bandwidth

def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Compute the Gaussian kernel between ``x`` and ``y``.

.. math::
k(x, y) = \sum_{i=1}^{num\_features} exp(-||x-y||^2 / (2 * \sigma_i)),

where :math:`\sigma_i` are the bandwidths.

Args:
x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor.
y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor.

Returns:
torch.Tensor: A (nx, ny) tensor representing the kernel matrix.
"""
distance_matrix = torch.cdist(x.flatten(1), y.flatten(1), p=2)
bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.factors
return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0)
56 changes: 56 additions & 0 deletions dwave/plugins/torch/nn/modules/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright 2025 D-Wave
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING

import torch
import torch.nn as nn

from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss
from dwave.plugins.torch.nn.modules.utils import store_config

if TYPE_CHECKING:
from dwave.plugins.torch.nn.modules.kernels import Kernel

__all__ = ["MaximumMeanDiscrepancyLoss"]


class MaximumMeanDiscrepancyLoss(nn.Module):
"""An unbiased estimator for the squared maximum mean discrepancy (MMD) as a loss function.

This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to
compute the loss.

Args:
kernel (Kernel): A kernel function object.
"""

@store_config
def __init__(self, kernel: Kernel) -> None:
super().__init__()
self.kernel = kernel

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
"""Computes the MMD loss between two sets of samples ``x`` and ``y``.

Args:
x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p.
y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q.

Returns:
torch.Tensor: The computed MMD loss.
"""
return mmd_loss(x, y, self.kernel)
10 changes: 10 additions & 0 deletions releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
---
features:
- |
Add a ``MaximumMeanDiscrepancyLoss`` in ``dwave.plugins.torch.nn.loss`` for estimating the
squared maximum mean discrepancy (MMD) for a given kernel and two samples.
Its functional counterpart ``maximum_mean_discrepancy_loss`` is in
``dwave.plugins.torch.nn.functional``.
Kernels reside in ``dwave.plugins.torch.nn.modules.kernels``. This enables, for example,
training discrete autoencoders to match the distribution of a target distribution (e.g., prior).

1 change: 1 addition & 0 deletions tests/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
coverage
codecov
parameterized
einops
Loading