-
Notifications
You must be signed in to change notification settings - Fork 11
Add maximum mean discrepancy and radial basis #35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@VolodyaCO what's the motivation for the following changes?
|
Addressed in meeting. For posterity:
I will review ASAP |
kevinchern
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Vlad!! Nicely implemented and documented.
Let's add the MaximumMeanDiscrepancy a a module and this should be good to go.
thisac
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @VolodyaCO and @kevinchern! Test fail due to Python 3.9 tests being run (removing 3.9 support in #49, which should fix it).
Unit tests should be expanded. There should at least be test classes for kernels (RBF) and more unit tests for the mmd function. Otherwise, looks good. Just a few, mostly minor, comments.
| __all__ = ["Kernel", "RBFKernel", "mmd_loss"] | ||
|
|
||
|
|
||
| class Kernel(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should Kernels (and RBF) be in a kernels.py instead of mmd.py?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For now, all kernels are for computing MMD losses. I don't know if that'll change in the future.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we expect it might change, it seems better to put kernels somewhere else. Alternatively remove Kernel and RBFKernel from __all__ since they're only (?) used within this module.
It seems to me like it's a more general concept and thus, if ever used outside of the mmd module, should be e.g., in torch.model.kernels or torch.kernels even if they're currently only used for calculating MMD losses. Even just putting them in a kernels.py, separating them from the mmd_loss() function, makes more sense to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point Theo. I am in favour of organizing them in torch.kernels for reasons you mentioned, i.e., a kernel is a more general object that isn't limited to applications in MMD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, I will move kernels to torch.kernels. Should we have torch.kernels and torch.functional.kernels? To store the function and the modules separately?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can add this^
| soft = logits | ||
| result = hard - soft.detach() + soft | ||
| # Now we need to repeat the result n_samples times along a new dimension | ||
| return repeat(result, "b ... -> b n ...", n=n_samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we absolutely need repeat here? Seems a bit cumbersome to add einops as a test dependency just for this test. 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, but it's more readable than the pytorch-only version, which requires unsqueezing, inferring the number of feature dimensions and then repeating.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
chiming in to say einops is generally useful, e.g., #39
e4e652d to
39e9eec
Compare
964390c to
948e93e
Compare
Co-authored-by: Vladimir Vargas Calderón <vvargasc@dwavesys.com>
948e93e to
36117c4
Compare
|
Just did another pass. Main changes are:
Food for thought (will add an issue): the Kernel base class does not enforce that a kernel is PSD. Edit RE PSD guarantee: Had a brief exchange with Vlad and figured the onus is on developer to correctly define a kernel |
VolodyaCO
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to merge
thisac
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall, looks good! Just a few comments/suggestions.
Co-Authored-By: Theodor Isacsson <theodor@isacsson.ca>
Co-authored-by: Theodor Isacsson <theodor@isacsson.ca>
TODOs:
store_config