Skip to content

Conversation

@kevinchern
Copy link
Collaborator

@kevinchern kevinchern commented Nov 12, 2025

TODOs:

  • document
  • test
  • release note
  • wrap with store_config

@kevinchern kevinchern marked this pull request as draft November 17, 2025 21:25
@VolodyaCO VolodyaCO marked this pull request as ready for review November 28, 2025 16:28
@kevinchern
Copy link
Collaborator Author

kevinchern commented Nov 29, 2025

@VolodyaCO what's the motivation for the following changes?

  1. Removal of MMD as a module, and
  2. The addition of gradient-tracking in the get_bandwidth function.

@kevinchern
Copy link
Collaborator Author

kevinchern commented Dec 1, 2025

@VolodyaCO what's the motivation for the following changes?

1. Removal of MMD as a module, and

2. The addition of gradient-tracking in the `get_bandwidth` function.

Addressed in meeting. For posterity:

  • module was removed for consistency with pseudo_kl_divergence as functions.
  • get_bandwidth was accidentally removed (but the l2 matrix is detached in kernel function)

I will review ASAP

Copy link
Collaborator Author

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Vlad!! Nicely implemented and documented.
Let's add the MaximumMeanDiscrepancy a a module and this should be good to go.

Copy link
Contributor

@thisac thisac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @VolodyaCO and @kevinchern! Test fail due to Python 3.9 tests being run (removing 3.9 support in #49, which should fix it).

Unit tests should be expanded. There should at least be test classes for kernels (RBF) and more unit tests for the mmd function. Otherwise, looks good. Just a few, mostly minor, comments.

__all__ = ["Kernel", "RBFKernel", "mmd_loss"]


class Kernel(nn.Module):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should Kernels (and RBF) be in a kernels.py instead of mmd.py?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, all kernels are for computing MMD losses. I don't know if that'll change in the future.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we expect it might change, it seems better to put kernels somewhere else. Alternatively remove Kernel and RBFKernel from __all__ since they're only (?) used within this module.

It seems to me like it's a more general concept and thus, if ever used outside of the mmd module, should be e.g., in torch.model.kernels or torch.kernels even if they're currently only used for calculating MMD losses. Even just putting them in a kernels.py, separating them from the mmd_loss() function, makes more sense to me.

Copy link
Collaborator Author

@kevinchern kevinchern Dec 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point Theo. I am in favour of organizing them in torch.kernels for reasons you mentioned, i.e., a kernel is a more general object that isn't limited to applications in MMD.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, I will move kernels to torch.kernels. Should we have torch.kernels and torch.functional.kernels? To store the function and the modules separately?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can add this^

soft = logits
result = hard - soft.detach() + soft
# Now we need to repeat the result n_samples times along a new dimension
return repeat(result, "b ... -> b n ...", n=n_samples)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we absolutely need repeat here? Seems a bit cumbersome to add einops as a test dependency just for this test. 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, but it's more readable than the pytorch-only version, which requires unsqueezing, inferring the number of feature dimensions and then repeating.

Copy link
Collaborator Author

@kevinchern kevinchern Dec 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chiming in to say einops is generally useful, e.g., #39

@kevinchern kevinchern force-pushed the feature/mmd branch 4 times, most recently from 964390c to 948e93e Compare December 16, 2025 01:07
Co-authored-by: Vladimir Vargas Calderón <vvargasc@dwavesys.com>
@kevinchern
Copy link
Collaborator Author

kevinchern commented Dec 16, 2025

Just did another pass. Main changes are:

  1. Refactor MMD into kernels, functional, and loss modules.
  2. Rename acronyms to be consistent with package standards (I'm torn between adhering to our no-acronym standard versus conventions used in pytorch, i.e., pytorch uses MSELoss, but we define MaximuMeanDiscrepancyLoss)
  3. Fixed bug with RBF (need to flatten before torch.cdist)
  4. Wrote unit tests
  5. _kernel(x) changed to _kernel(x, y).
  6. Added specific errors: class SampleSizeError(ValueError): ... and class DimensionMismatchError(ValueError): ...

Food for thought (will add an issue): the Kernel base class does not enforce that a kernel is PSD.

Edit RE PSD guarantee: Had a brief exchange with Vlad and figured the onus is on developer to correctly define a kernel

@kevinchern kevinchern requested a review from thisac December 16, 2025 23:28
Copy link
Collaborator

@VolodyaCO VolodyaCO left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to merge

Copy link
Contributor

@thisac thisac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall, looks good! Just a few comments/suggestions.

Co-Authored-By: Theodor Isacsson <theodor@isacsson.ca>
@kevinchern kevinchern requested a review from thisac January 5, 2026 18:11
Co-authored-by: Theodor Isacsson <theodor@isacsson.ca>
@kevinchern kevinchern requested a review from thisac January 6, 2026 18:23
@kevinchern kevinchern merged commit fb756c2 into dwavesystems:main Jan 7, 2026
3 of 15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants