-
Notifications
You must be signed in to change notification settings - Fork 11
Add maximum mean discrepancy and radial basis #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
36117c4
Add maximum mean discrepancy and radial basis
kevinchern 56b423f
Rename acronyms and fix first-line in docstring
kevinchern 63b15f8
Define kernel as function of two inputs
kevinchern e37dc2c
Refactor MMD into kernels, functional, and loss
kevinchern 506829d
Add unit tests for kernels
kevinchern cd815a0
Add tests for functional and loss add errors
kevinchern 901f333
Update release note
kevinchern 5d4d473
Rename RBF to GaussianKernel
kevinchern 4c0d233
Renme RBF to GaussianKernel
kevinchern 285b5c1
Remove custom errors and fix docstrings
kevinchern 261cd4d
Fix a docstring
kevinchern 9752758
Fix minor code aesthetics
kevinchern File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,84 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Functional interface.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| if TYPE_CHECKING: | ||
| from dwave.plugins.torch.nn.modules.kernels import Kernel | ||
|
|
||
| import torch | ||
|
|
||
| __all__ = ["maximum_mean_discrepancy_loss"] | ||
|
|
||
|
|
||
|
|
||
| def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: | ||
| """Estimates the squared maximum mean discrepancy (MMD) given two samples ``x`` and ``y``. | ||
|
|
||
| The `squared MMD <https://dl.acm.org/doi/abs/10.5555/2188385.2188410>`_ is defined as | ||
|
|
||
| .. math:: | ||
| MMD^2(X, Y) = |E_{x\sim p}[\varphi(x)] - E_{y\sim q}[\varphi(y)] |^2, | ||
|
|
||
| where :math:`\varphi` is a feature map associated with the kernel function | ||
| :math:`k(x, y) = \langle \varphi(x), \varphi(y) \rangle`, and :math:`p` and :math:`q` are the | ||
| distributions of the samples. It follows that, in terms of the kernel function, the squared MMD | ||
| can be computed as | ||
|
|
||
| .. math:: | ||
| E_{x, x'\sim p}[k(x, x')] + E_{y, y'\sim q}[k(y, y')] - 2E_{x\sim p, y\sim q}[k(x, y)]. | ||
|
|
||
| If :math:`p = q`, then :math:`MMD^2(X, Y) = 0`. This motivates the squared MMD as a loss | ||
| function for minimizing the distance between the model distribution and data distribution. | ||
|
|
||
| For more information, see | ||
| Gretton, A., Borgwardt, K. M., Rasch, M. J., Schölkopf, B., & Smola, A. (2012). | ||
| A kernel two-sample test. The journal of machine learning research, 13(1), 723-773. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor of samples from distribution p. | ||
| y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor of samples from distribution q. | ||
| kernel (Kernel): A kernel function object. | ||
|
|
||
| Raises: | ||
| ValueError: If the sample size of ``x`` or ``y`` is less than two. | ||
| ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size) | ||
|
|
||
| Returns: | ||
| torch.Tensor: The squared maximum mean discrepancy estimate. | ||
| """ | ||
| num_x = x.shape[0] | ||
| num_y = y.shape[0] | ||
| if num_x < 2 or num_y < 2: | ||
| raise ValueError( | ||
| "Sample size of ``x`` and ``y`` must be at least two. " | ||
| f"Got, respectively, {x.shape} and {y.shape}." | ||
| ) | ||
| if x.shape[1:] != y.shape[1:]: | ||
| raise ValueError( | ||
| "Input dimensions must match. You are trying to compute " | ||
| f"the kernel between tensors of shape {x.shape} and {y.shape}." | ||
| ) | ||
| xy = torch.cat([x, y], dim=0) | ||
| kernel_matrix = kernel(xy, xy) | ||
| kernel_xx = kernel_matrix[:num_x, :num_x] | ||
| kernel_yy = kernel_matrix[num_x:, num_x:] | ||
| kernel_xy = kernel_matrix[:num_x, num_x:] | ||
| xx = (kernel_xx.sum() - kernel_xx.trace()) / (num_x * (num_x - 1)) | ||
| yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) | ||
| xy = kernel_xy.sum() / (num_x * num_y) | ||
| return xx + yy - 2 * xy | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,151 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
| """Kernel functions.""" | ||
|
|
||
| from abc import ABC, abstractmethod | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from dwave.plugins.torch.nn.modules.utils import store_config | ||
|
|
||
| __all__ = ["Kernel", "GaussianKernel"] | ||
|
|
||
|
|
||
| class Kernel(ABC, nn.Module): | ||
| """Base class for kernels. | ||
|
|
||
| `Kernels <https://en.wikipedia.org/wiki/Kernel_method>`_ are functions that compute a similarity | ||
| measure between data points. Any ``Kernel`` subclass must implement the ``_kernel`` method, | ||
| which computes the kernel matrix for a given input multi-dimensional tensor with shape | ||
| (n, f1, f2, ...), where n is the number of items and f1, f2, ... are feature dimensions, so that | ||
| the output is a tensor of shape (n, n) containing the pairwise kernel values. | ||
| """ | ||
|
|
||
| @abstractmethod | ||
kevinchern marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """Perform a pairwise kernel evaluation over samples. | ||
|
|
||
| Computes the kernel matrix for inputs of shape (nx, f1, f2, ..., fk) and | ||
| (ny, f1, f2, ..., fk), whose shape is (nx, ny) | ||
| containing the pairwise kernel values. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor. | ||
| y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor. | ||
|
|
||
| Returns: | ||
| torch.Tensor: A (nx, ny) tensor. | ||
| """ | ||
|
|
||
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """Computes kernels for all pairs between and within ``x`` and ``y``. | ||
|
|
||
| In general, ``x`` and ``y`` are (n_x, f1, f2, ..., fk)- and (n_y, f1, f2, ..., fk)-shaped | ||
| tensors, and the output is a (n_x + n_y, n_x + n_y)-shaped tensor containing pairwise kernel | ||
| evaluations. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A (n_x, f1, f2, ..., fk) tensor. | ||
| y (torch.Tensor): A (n_y, f1, f2, ..., fk) tensor. | ||
|
|
||
| Raises: | ||
| ValueError: If shape of ``x`` and ``y`` mismatch (excluding batch size) | ||
|
|
||
| Returns: | ||
| torch.Tensor: A (n_x + n_y, n_x + n_y) tensor. | ||
| """ | ||
| if x.shape[1:] != y.shape[1:]: | ||
| raise ValueError( | ||
| "Input dimensions must match. You are trying to compute " | ||
| f"the kernel between tensors of shape {x.shape} and {y.shape}." | ||
| ) | ||
| if x.shape[0] < 2 or y.shape[0] < 2: | ||
| raise ValueError( | ||
| "Sample size of ``x`` and ``y`` must be at least two. " | ||
| f"Got, respectively, {x.shape} and {y.shape}." | ||
| ) | ||
| return self._kernel(x, y) | ||
|
|
||
|
|
||
| class GaussianKernel(Kernel): | ||
| """The Gaussian kernel. | ||
|
|
||
| This kernel between two data points x and y is defined as | ||
| :math:`k(x, y) = exp(-||x-y||^2 / (2 * \sigma))`, where :math:`\sigma` is the bandwidth | ||
| parameter. | ||
|
|
||
| This implementation considers aggregating multiple Gaussian kernels with different | ||
| bandwidths. The bandwidths are determined by multiplying a base bandwidth with a set of | ||
| multipliers. The base bandwidth can be provided directly or estimated from the data using the | ||
| average distance between samples. | ||
|
|
||
| Args: | ||
| n_kernels (int): Number of kernel bandwidths to use. | ||
| factor (int | float): Multiplicative factor to generate bandwidths. The bandwidths are | ||
| computed as :math:`\sigma_i = \sigma * factor^{i - n\_kernels // 2}` for | ||
| :math:`i` in ``[0, n\_kernels - 1]``. Defaults to 2.0. | ||
| bandwidth (float | None): Base bandwidth parameter. If ``None``, the bandwidth is computed | ||
| from the data (without gradients). Defaults to ``None``. | ||
| """ | ||
|
|
||
| @store_config | ||
| def __init__( | ||
| self, n_kernels: int, factor: int | float = 2.0, bandwidth: float | None = None | ||
| ): | ||
| super().__init__() | ||
| factors = factor ** (torch.arange(n_kernels) - n_kernels // 2) | ||
| self.register_buffer("factors", factors) | ||
| self.bandwidth = bandwidth | ||
|
|
||
| @torch.no_grad() | ||
| def _get_bandwidth(self, distance_matrix: torch.Tensor) -> torch.Tensor | float: | ||
| """Heuristically determine a bandwidth parameter as the average distance between samples. | ||
|
|
||
| Computes the base bandwidth parameter as the average distance between samples if the | ||
| bandwidth is not provided during initialization. Otherwise, returns the provided bandwidth. | ||
| See https://arxiv.org/abs/1707.07269 for more details about the motivation behind taking | ||
| the average distance as the bandwidth. | ||
|
|
||
| Args: | ||
| distance_matrix (torch.Tensor): A (n, n) tensor representing the pairwise | ||
| L2 distances between samples. If it is ``None`` and the bandwidth is not provided, | ||
| an error will be raised. Defaults to ``None``. | ||
|
|
||
| Returns: | ||
| torch.Tensor | float: The base bandwidth parameter. | ||
| """ | ||
| if self.bandwidth is None: | ||
| num_samples = distance_matrix.shape[0] | ||
| return distance_matrix.sum() / (num_samples**2 - num_samples) | ||
| return self.bandwidth | ||
|
|
||
| def _kernel(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """Compute the Gaussian kernel between ``x`` and ``y``. | ||
|
|
||
| .. math:: | ||
| k(x, y) = \sum_{i=1}^{num\_features} exp(-||x-y||^2 / (2 * \sigma_i)), | ||
|
|
||
| where :math:`\sigma_i` are the bandwidths. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A (nx, f1, f2, ..., fk) tensor. | ||
| y (torch.Tensor): A (ny, f1, f2, ..., fk) tensor. | ||
|
|
||
| Returns: | ||
| torch.Tensor: A (nx, ny) tensor representing the kernel matrix. | ||
| """ | ||
| distance_matrix = torch.cdist(x.flatten(1), y.flatten(1), p=2) | ||
| bandwidth = self._get_bandwidth(distance_matrix.detach()) * self.factors | ||
| return torch.exp(-distance_matrix.unsqueeze(0) / bandwidth.reshape(-1, 1, 1)).sum(dim=0) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,56 @@ | ||
| # Copyright 2025 D-Wave | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import TYPE_CHECKING | ||
|
|
||
| import torch | ||
| import torch.nn as nn | ||
|
|
||
| from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss | ||
| from dwave.plugins.torch.nn.modules.utils import store_config | ||
|
|
||
| if TYPE_CHECKING: | ||
| from dwave.plugins.torch.nn.modules.kernels import Kernel | ||
|
|
||
| __all__ = ["MaximumMeanDiscrepancyLoss"] | ||
|
|
||
|
|
||
| class MaximumMeanDiscrepancyLoss(nn.Module): | ||
| """An unbiased estimator for the squared maximum mean discrepancy (MMD) as a loss function. | ||
|
|
||
| This uses the ``dwave.plugins.torch.nn.functional.maximum_mean_discrepancy_loss`` function to | ||
| compute the loss. | ||
|
|
||
| Args: | ||
| kernel (Kernel): A kernel function object. | ||
| """ | ||
|
|
||
| @store_config | ||
| def __init__(self, kernel: Kernel) -> None: | ||
| super().__init__() | ||
| self.kernel = kernel | ||
|
|
||
| def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: | ||
| """Computes the MMD loss between two sets of samples ``x`` and ``y``. | ||
|
|
||
| Args: | ||
| x (torch.Tensor): A (n_x, f1, f2, ...) tensor of samples from distribution p. | ||
| y (torch.Tensor): A (n_y, f1, f2, ...) tensor of samples from distribution q. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The computed MMD loss. | ||
| """ | ||
| return mmd_loss(x, y, self.kernel) |
10 changes: 10 additions & 0 deletions
10
releasenotes/notes/add-mmd-loss-function-3fa9e9a2cb452391.yaml
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| --- | ||
| features: | ||
| - | | ||
| Add a ``MaximumMeanDiscrepancyLoss`` in ``dwave.plugins.torch.nn.loss`` for estimating the | ||
| squared maximum mean discrepancy (MMD) for a given kernel and two samples. | ||
| Its functional counterpart ``maximum_mean_discrepancy_loss`` is in | ||
| ``dwave.plugins.torch.nn.functional``. | ||
| Kernels reside in ``dwave.plugins.torch.nn.modules.kernels``. This enables, for example, | ||
| training discrete autoencoders to match the distribution of a target distribution (e.g., prior). | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,4 @@ | ||
| coverage | ||
| codecov | ||
| parameterized | ||
| einops |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.