Skip to content

Conversation

@abdela47
Copy link

This PR adds deterministic unit tests for pseudo_kl_divergence_loss.

The tests cover both documented spin shapes and verify gradient behavior.
They isolate the statistical structure of the loss using deterministic dummy Boltzmann
machines and do not rely on samplers or quantum hardware.

Closes #56.

@kevinchern kevinchern self-requested a review December 22, 2025 01:22
Copy link
Collaborator

@kevinchern kevinchern left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Ahmed @abdela47. Thank you for the pull request.
The tests are well-reasoned, modular, and nicely documented.
I did a quick first pass and added some minor requests.
Separately, you may find this contribution guide helpful for our conventions and best-practices.

Comment on lines 88 to 91
logits = torch.zeros(batch, n_spins)

# spins: (batch_size, n_samples, n_spins)
spins = torch.ones(batch, n_samples, n_spins)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain your rationale for using zero-valued logits and spins in this test versus nonzero values in the 2d test?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zero logits are used in the 3D shape test to keep the entropy term simple and stable (p = 0.5), allowing the test to focus purely on documented shape support; nonzero values are covered in the 2D numerical correctness test.

abdela47 and others added 3 commits December 23, 2025 17:11
Co-authored-by: Kevin Chern <32395608+kevinchern@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Missing unit tests for pseudo_kl_divergence loss function

2 participants