From 74666ca5e36ba9ec830eca4dd6c8c97a9e5351ae Mon Sep 17 00:00:00 2001 From: kchern Date: Fri, 14 Nov 2025 21:54:00 +0000 Subject: [PATCH] Add block-spin update sampler --- dwave/plugins/torch/nn/functional.py | 42 ++- dwave/plugins/torch/samplers/__init__.py | 15 + .../torch/samplers/block_spin_sampler.py | 332 ++++++++++++++++++ dwave/plugins/torch/tensor.py | 47 +++ dwave/plugins/torch/utils.py | 2 +- .../block-spin-sampler-b62ba4c83880c729.yaml | 10 + tests/test_block_sampler.py | 274 +++++++++++++++ tests/test_functional.py | 24 +- tests/test_tensor.py | 27 ++ 9 files changed, 768 insertions(+), 5 deletions(-) create mode 100755 dwave/plugins/torch/samplers/__init__.py create mode 100644 dwave/plugins/torch/samplers/block_spin_sampler.py create mode 100755 dwave/plugins/torch/tensor.py create mode 100755 releasenotes/notes/block-spin-sampler-b62ba4c83880c729.yaml create mode 100755 tests/test_block_sampler.py create mode 100755 tests/test_tensor.py diff --git a/dwave/plugins/torch/nn/functional.py b/dwave/plugins/torch/nn/functional.py index 8327f7a..8b26d49 100755 --- a/dwave/plugins/torch/nn/functional.py +++ b/dwave/plugins/torch/nn/functional.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. """Functional interface.""" - from __future__ import annotations from typing import TYPE_CHECKING @@ -22,8 +21,7 @@ import torch -__all__ = ["maximum_mean_discrepancy_loss"] - +__all__ = ["maximum_mean_discrepancy_loss", "bit2spin_soft", "spin2bit_soft"] def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kernel) -> torch.Tensor: @@ -82,3 +80,41 @@ def maximum_mean_discrepancy_loss(x: torch.Tensor, y: torch.Tensor, kernel: Kern yy = (kernel_yy.sum() - kernel_yy.trace()) / (num_y * (num_y - 1)) xy = kernel_xy.sum() / (num_x * num_y) return xx + yy - 2 * xy + + +def bit2spin_soft(b: torch.Tensor) -> torch.Tensor: + """Maps input :math:`b` to :math:`2b-1`. + + The mapping does not require :math:`b` to be binary, only that it is in the interval :math:`[0, 1]`. + + Args: + b (torch.Tensor): Input tensor of values in :math:`[0, 1]`. + + Raises: + ValueError: If not all ``b`` values are in :math:`[0, 1]`. + + Returns: + torch.Tensor: A tensor with values :math:`2b-1`. + """ + if not ((b >= 0) & (b <= 1)).all(): + raise ValueError(f"Not all inputs are in [0, 1]: {b}") + return b * 2 - 1 + + +def spin2bit_soft(s: torch.Tensor) -> torch.Tensor: + """Maps input :math:`s` to :math:`(s+1)/2`. + + The mapping does not require :math:`s` to be spin-valued, only that it is in the interval :math:`[-1, 1]`. + + Args: + s (torch.Tensor): Input tensor of values in :math:`[-1, 1]`. + + Raises: + ValueError: If not all ``s`` values are in `[-1, 1]`. + + Returns: + torch.Tensor: A tensor with values :math:`(s+1)/2`. + """ + if (s.abs() > 1).any(): + raise ValueError(f"Not all inputs are in [-1, 1]: {s}") + return (s + 1) / 2 diff --git a/dwave/plugins/torch/samplers/__init__.py b/dwave/plugins/torch/samplers/__init__.py new file mode 100755 index 0000000..932b865 --- /dev/null +++ b/dwave/plugins/torch/samplers/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dwave.plugins.torch.samplers.block_spin_sampler import * diff --git a/dwave/plugins/torch/samplers/block_spin_sampler.py b/dwave/plugins/torch/samplers/block_spin_sampler.py new file mode 100644 index 0000000..cadb256 --- /dev/null +++ b/dwave/plugins/torch/samplers/block_spin_sampler.py @@ -0,0 +1,332 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Iterable +from typing import TYPE_CHECKING, Callable, Hashable, Literal + +import torch +from torch import nn + +if TYPE_CHECKING: + from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM + from torch._prims_common import DeviceLikeType + +from dwave.plugins.torch.nn.functional import bit2spin_soft +from dwave.plugins.torch.tensor import randspin + +__all__ = ["BlockSampler"] + + +class BlockSampler: + """A block-spin update sampler for graph-restricted Boltzmann machines. + + Due to the sparse definition of GRBMs, some tedious indexing tricks are required to + efficiently sample in blocks of spins. Ideally, an adjacency list can be used, however, + adjacencies are ragged, making vectorization inapplicable. + + Block-Gibbs and Block-Metropolis obey detailed balance and are ergodic methods at finite nonzero + temperature which, at fixed parameters, converge upon Boltzmann distributions. Block-Metropolis + allows higher acceptance rates for proposals (faster single-step mixing), but is non-ergodic in + the limit of zero or infinite temperature. Decorrelation from an initial condition can be slower. + Block-Gibbs represents best practice for independent sampling. + + Args: + grbm (GRBM): The Graph-Restricted Boltzmann Machine to sample from. + crayon (Callable[Hashable, Hashable]): A colouring function that maps a single + node of the ``grbm`` to its colour. + num_chains (int): Number of Markov chains to run in parallel. + initial_states (torch.Tensor | None): A tensor of +/-1 values of shape + (``num_chains``, ``grbm.n_nodes``) representing the initial states of the Markov chains. + If None, initial states will be uniformly randomized with number of chains equal to + ``num_chains``. Defaults to None. + schedule (Iterable[Float]): The inverse temperature schedule. + proposal_acceptance_criteria (Literal["Gibbs", "Metropolis"]): The proposal acceptance + criterion used to accept or reject states in the Markov chain. Defaults to "Gibbs". + seed (int | None): Random seed. Defaults to None. + + Raises: + InvalidProposalAcceptanceCriteriaError: If the proposal acceptance criteria is not one of + "Gibbs" or "Metropolis". + """ + + def __init__(self, grbm: GRBM, crayon: Callable[[Hashable], Hashable], num_chains: int, + schedule: Iterable[float], + proposal_acceptance_criteria: Literal["Gibbs", "Metropolis"] = "Gibbs", + initial_states: torch.Tensor | None = None, + seed: int | None = None): + super().__init__() + + if num_chains < 1: + raise ValueError("Number of reads should be a positive integer.") + + self._proposal_acceptance_criteria = proposal_acceptance_criteria.title() + if self._proposal_acceptance_criteria not in {"Gibbs", "Metropolis"}: + raise ValueError( + 'Proposal acceptance criterion should be one of "Gibbs" or "Metropolis"' + ) + + self._grbm: GRBM = grbm + self._crayon: Callable[[Hashable], Hashable] = crayon + if not self._valid_crayon(): + raise ValueError( + "crayon is not a valid colouring of grbm. " + + "At least one edge has vertices of the same colour." + ) + + self._partition = self._get_partition() + self._padded_adjacencies, self._padded_adjacencies_weight = self._get_adjacencies() + + self._rng = torch.Generator() + if seed is not None: + self._rng = self._rng.manual_seed(seed) + + initial_states = self._prepare_initial_states(num_chains, initial_states, self._rng) + self._schedule = nn.Parameter(torch.tensor(list(schedule)), requires_grad=False) + self._x = nn.Parameter(initial_states.float(), requires_grad=False) + self._zeros = nn.Parameter(torch.zeros((num_chains, 1)), requires_grad=False) + + def to(self, device: DeviceLikeType) -> BlockSampler: + """Moves sampler components to the target device. + + If the device is "meta", then the random number generator (RNG) + will not be modified at all. For all other devices, all attributes used for performing + block-spin updates will be moved to the target device. Importantly, the RNG's device is + relayed by the following procedure: + 1. Draw a random integer between 0 (inclusive) and 2**60 (exclusive) with the current + generator as a new seed ``s``. + 2. Create a new generator on the target device. + 3. Set the new generator's seed as ``s``. + + Developer-note: Not sure the above constitutes a good practice, but I not aware of any + obvious solution for moving generators across devices. + + Args: + device (DeviceLikeType): The target device. + """ + self._x = self._x.to(device) + self._zeros = self._zeros.to(device) + self._schedule = self._schedule.to(device) + self._partition = self._partition.to(device) + self._padded_adjacencies = self._padded_adjacencies.to(device) + self._padded_adjacencies_weight = self._padded_adjacencies_weight.to(device) + if device != "meta": + rng = torch.Generator(device) + rng.manual_seed(torch.randint(0, 2**60, (1,), generator=self._rng).item()) + self._rng = rng + return self + + def _prepare_initial_states( + self, num_chains: int, initial_states: torch.Tensor | None = None, + generator: torch.Generator | None = None + ) -> torch.Tensor: + """Convert initial states to tensor or sample uniformly random spins as initial states. + + Args: + num_chains (int): Number of initial states. + initial_states (torch.Tensor | None): A tensor of shape + (``num_chains``, ``self._grbm.n_nodes``) representing the initial states of the + sampler's Markov chains. If None, then initial states are sampled uniformly from + +/-1 values. Defaults to None. + generator (torch.Generator | None): A random number generator. + + Raises: + ShapeMismatchError: If the shape of initial states do not match that of the expected + (``num_chains``, ``self._grbm.n_nodes``). + NonSpinError: If the provided initial states have nonspin-valued entries. + + Returns: + torch.Tensor: The initial states of the sampler's Markov chain. + """ + if initial_states is None: + initial_states = randspin((num_chains, self._grbm.n_nodes), generator=generator) + + if initial_states.shape != (num_chains, self._grbm.n_nodes): + raise ValueError( + "Initial states should be of shape ``num_chains, grbm.n_nodes`` " + f"{(num_chains, self._grbm.n_nodes)}, but got {tuple(initial_states.shape)} instead." + ) + + if not set(initial_states.unique().tolist()).issubset({-1, 1}): + raise ValueError("Initial states contain nonspin values.") + + return initial_states + + def _valid_crayon(self) -> bool: + """Determines whether ``crayon`` is a valid colouring of the graph-restricted Boltzmann machine. + + Returns: + bool: True if the colouring is valid and False otherwise. + """ + for u, v in self._grbm.edges: + if self._crayon(u) == self._crayon(v): + return False + return True + + def _get_partition(self) -> nn.ParameterList: + """Computes the vertex partition induced by the colouring function. + + Returns: + nn.ParameterList: The partition induced by the colouring. + """ + partition = defaultdict(list) + for node in self._grbm.nodes: + idx = self._grbm.node_to_idx[node] + c = self._crayon(node) + partition[c].append(idx) + partition = nn.ParameterList([ + nn.Parameter(torch.tensor(partition[k], requires_grad=False), requires_grad=False) + for k in sorted(partition) + ]) + return partition + + def _get_adjacencies(self) -> tuple[torch.Tensor, torch.Tensor]: + """Create two adjacency matrices, one for neighbouring indices and another for the + corresponding edge weights' indices. + + The issue begins with the adjacency lists being ragged. To address this, we pad adjacencies + with ``-1`` values. The exact values do not matter, as the way these adjacencies will be used + is by padding an input state with 0s, so when accessing ``-1``, the output will be masked out. + + For example, consider the returned adjacency matrices ``padded_adjacencies`` and + ``padded_adjacencies_weight``. + + In the first adjacency matrix, ``padded_adjacencies[0]`` is a + ``torch.Tensor`` consisting of indices of neighbouring vertices of vertex ``0``. Values of + ``-1`` in this tensor indicates no neighbour. + + In the second adjacency matrix, ``padded_adjacencies_weight[0]`` is a ``torch.Tensor`` + consisting of indices of edge weight indices corresponding to edges of vertex ``0``. + Similarly, ``-1`` values in this tensor indicates no neighbour. + + Returns: + tuple[list[torch.Tensor], list[torch.Tensor]]: The first output is a padded adjacency + matrix, the second output is an adjacency matrix of edge weight indices. + """ + max_degree = 0 + if self._grbm.n_edges: + max_degree = torch.unique(torch.cat([self._grbm.edge_idx_i, self._grbm.edge_idx_j]), + return_counts=True)[1].max().item() + adjacency = nn.Parameter( + -torch.ones(self._grbm.n_nodes, max_degree, dtype=int), requires_grad=False + ) + adjacency_weight = nn.Parameter( + -torch.ones(self._grbm.n_nodes, max_degree, dtype=int), requires_grad=False + ) + + adjacency_dict = defaultdict(list) + edge_to_idx = dict() + for idx, (u, v) in enumerate( + zip(self._grbm.edge_idx_i.tolist(), + self._grbm.edge_idx_j.tolist())): + adjacency_dict[v].append(u) + adjacency_dict[u].append(v) + edge_to_idx[u, v] = idx + edge_to_idx[v, u] = idx + for u in self._grbm.idx_to_node: + neighbours = adjacency_dict[u] + adj_weight_idxs = [edge_to_idx[u, v] for v in neighbours] + num_neighbours = len(neighbours) + adjacency[u][:num_neighbours] = torch.tensor(neighbours) + adjacency_weight[u][:num_neighbours] = torch.tensor(adj_weight_idxs) + return adjacency, adjacency_weight + + @torch.no_grad + def _compute_effective_field(self, block) -> torch.Tensor: + """Computes the effective field for all vertices in ``block``. + + Args: + block (nn.ParameterList): A list of integers (indices) corresponding to the vertices of + a colour. + + Returns: + torch.Tensor: The effective fields of each vertex in ``block``. + """ + xnbr = torch.hstack([self._x, self._zeros])[:, self._padded_adjacencies[block]] + h = self._grbm.linear[block] + J = self._grbm.quadratic[self._padded_adjacencies_weight[block]] + return (xnbr * J.unsqueeze(0)).sum(2) + h + + @torch.no_grad + def _metropolis_update(self, beta: float, block: nn.ParameterList, + effective_field: torch.Tensor) -> None: + """Performs a Metropolis update in-place. + + Args: + beta (float): The inverse temperature to sample at. + block (nn.ParameterList): A list of integers (indices) corresponding to the vertices of + a colour. + effective_field (torch.Tensor): Effective fields of each spin corresponding to indices + of the block. + """ + delta = -2 * self._x[:, block] * effective_field + prob = (-delta * beta).exp().clip(0, 1) + + # if the delta field is negative, then flipping the spin will improve the energy + prob[delta <= 0] = 1 + flip = -bit2spin_soft(prob.bernoulli(generator=self._rng)) + self._x[:, block] = flip * self._x[:, block] + + @torch.no_grad + def _gibbs_update(self, beta: torch.Tensor, block: torch.nn.ParameterList, effective_field: torch.Tensor) -> None: + """Performs a Gibbs update in-place. + + Args: + beta (torch.Tensor): The (scalar) inverse temperature to sample at. + block (nn.ParameterList): A list of integers (indices) corresponding to the vertices of + a colour. + effective_field (torch.Tensor): Effective fields of each spin corresponding to indices + of the block. + """ + prob = 1 / (1 + torch.exp(2 * beta * effective_field)) + spins = bit2spin_soft(prob.bernoulli(generator=self._rng)) + self._x[:, block] = spins + + @torch.no_grad + def _step(self, beta: torch.Tensor) -> None: + """Performs a block-spin update in-place. + + Args: + beta (torch.Tensor): Inverse temperature to sample at. + """ + for block in self._partition: + effective_field = self._compute_effective_field(block) + if self._proposal_acceptance_criteria == "Metropolis": + self._metropolis_update(beta, block, effective_field) + elif self._proposal_acceptance_criteria == "Gibbs": + self._gibbs_update(beta, block, effective_field) + else: + # NOTE: This line should never be reached because acceptance proposal criterion + # should've been checked on instantiation + raise ValueError(f"Invalid proposal acceptance criterion.") + + @torch.no_grad + def sample(self, x: torch.Tensor | None = None) -> torch.Tensor: + """Performs block updates. + + Args: + x (torch.Tensor): A tensor of shape (``batch_size``, ``dim``) or (``batch_size``, ``n_nodes``) + interpreted as a batch of partially-observed spins. Entries marked with ``torch.nan`` will + be sampled; entries with +/-1 values will remain constant. + + Returns: + torch.Tensor: A tensor of shape (batch_size, dim) of +/-1 values sampled from the model. + """ + if x is not None: + raise NotImplementedError("Support for conditional sampling has not been implemented.") + for beta in self._schedule: + self._step(beta) + return self._x diff --git a/dwave/plugins/torch/tensor.py b/dwave/plugins/torch/tensor.py new file mode 100755 index 0000000..b2e860a --- /dev/null +++ b/dwave/plugins/torch/tensor.py @@ -0,0 +1,47 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +import torch + +if TYPE_CHECKING: + from torch import Generator + from torch._prims_common import DeviceLikeType + from torch.types import _bool, _dtype, _size + +__all__ = ["randspin"] + + +def randspin(size: _size, **kwargs) -> torch.Tensor: + """Wrapper for ``torch.randint`` restricted to spin outputs (+/-1 values). + + Args: + size (torch.types._size): Shape of the output tensor. + **kwargs: Keyword arguments of ``torch.randint``. + + Raises: + ValueError: If ``low`` is supplied as a keyword argument. + ValueError: If ``high`` is supplied as a keyword argument. + + Returns: + torch.Tensor: A tensor of +/-1 values. + """ + if "low" in kwargs: + raise ValueError("Invalid keyword argument `low`.") + if "high" in kwargs: + raise ValueError("Invalid keyword argument `high`.") + b = torch.randint(0, 2, size, **kwargs) + return 2 * b - 1 diff --git a/dwave/plugins/torch/utils.py b/dwave/plugins/torch/utils.py index fd142db..b3e147e 100755 --- a/dwave/plugins/torch/utils.py +++ b/dwave/plugins/torch/utils.py @@ -13,7 +13,7 @@ # limitations under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional import torch diff --git a/releasenotes/notes/block-spin-sampler-b62ba4c83880c729.yaml b/releasenotes/notes/block-spin-sampler-b62ba4c83880c729.yaml new file mode 100755 index 0000000..14d2de6 --- /dev/null +++ b/releasenotes/notes/block-spin-sampler-b62ba4c83880c729.yaml @@ -0,0 +1,10 @@ +--- +features: + - | + Add ``BlockSampler`` for performing block Gibbs (or Metropolis) sampling of + graph-restricted Boltzmann Machines. + - | + Add functions for converting spins to bits and bits to spins. + - | + Add ``randspin`` for generating random spins. + diff --git a/tests/test_block_sampler.py b/tests/test_block_sampler.py new file mode 100755 index 0000000..0df7d43 --- /dev/null +++ b/tests/test_block_sampler.py @@ -0,0 +1,274 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import dwave_networkx as dnx +import networkx as nx +import torch +from parameterized import parameterized + +from dwave.plugins.torch.models.boltzmann_machine import GraphRestrictedBoltzmannMachine as GRBM +from dwave.plugins.torch.samplers.block_spin_sampler import BlockSampler + + +class TestBlockSampler(unittest.TestCase): + ZEPHYR = dnx.zephyr_graph(1, coordinates=True) + GRBM_ZEPHYR = GRBM(ZEPHYR.nodes, ZEPHYR.edges) + CRAYON_ZEPHYR = dnx.zephyr_four_color + + BIPARTITE = nx.complete_bipartite_graph(5, 3) + GRBM_BIPARTITE = GRBM(BIPARTITE.nodes, BIPARTITE.edges) + def CRAYON_BIPARTITE(b): return b < 5 + + GRBM_SINGLE = GRBM([0], []) + def CRAYON_SINGLE(s): 0 + + GRBM_CRAYON_TEST_CASES = [(GRBM_ZEPHYR, CRAYON_ZEPHYR), + (GRBM_BIPARTITE, CRAYON_BIPARTITE), + (GRBM_SINGLE, CRAYON_SINGLE)] + + def setUp(self) -> None: + self.crayon_veqa = lambda v: v == "a" + return super().setUp() + + @parameterized.expand(GRBM_CRAYON_TEST_CASES) + def test_sample(self, grbm, crayon): + for pac in "Metropolis", "Gibbs": + schedule = [0.0, 1.0, 2.0] + bss1 = BlockSampler(grbm, crayon, 10, schedule, pac, seed=1) + bss1.sample() + + bss2 = BlockSampler(grbm, crayon, 10, [1.0], pac, seed=1) + for beta in schedule: + bss2._step(beta) + + self.assertListEqual(bss1._x.tolist(), bss2._x.tolist()) + + def test_device(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + sample_size = 1_000_000 + bss = BlockSampler(grbm, crayon, sample_size, [1.0], "Gibbs", seed=2) + bss.to('meta') + self.assertEqual("cpu", bss._grbm.linear.device.type) + self.assertEqual("cpu", bss._grbm.quadratic.device.type) + self.assertEqual("meta", bss._x.device.type) + self.assertEqual("meta", bss._padded_adjacencies.device.type) + self.assertEqual("meta", bss._padded_adjacencies_weight.device.type) + self.assertEqual("meta", bss._zeros.device.type) + self.assertEqual("meta", bss._schedule.device.type) + self.assertEqual("meta", bss._partition[0].device.type) + self.assertEqual("meta", bss._partition[1].device.type) + # NOTE: "meta" device is not supported for torch.Generator + self.assertEqual("cpu", bss._rng.device.type) + + def test_gibbs_update(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + sample_size = 1_000_000 + bss = BlockSampler(grbm, crayon, sample_size, [1.0], "Gibbs", seed=2) + bss._x.data[:] = 1 + zero = torch.tensor(0.0) + ones = torch.ones((sample_size, 1)) + bss._gibbs_update(0.0, bss._partition[0], ones*zero) + torch.testing.assert_close(torch.tensor(0.5), bss._x.mean(), atol=1e-3, rtol=1e-3) + bss._gibbs_update(0.0, bss._partition[1], ones*zero) + torch.testing.assert_close(torch.tensor(0.0), bss._x.mean(), atol=1e-3, rtol=1e-3) + + effective_field = torch.tensor(1.2) + bss._gibbs_update(1.0, bss._partition[0], effective_field*ones) + bss._gibbs_update(1.0, bss._partition[1], effective_field*ones) + torch.testing.assert_close( + torch.tanh(-effective_field), + bss._x.mean(), + atol=1e-3, rtol=1e-3) + + def test_initial_states_respected(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + initial_states = torch.tensor([[-1, 1], [1, 1], [-1, -1], [1, 1], [-1, 1], [-1, 1], [1, 1]]) + + bss = BlockSampler(grbm, crayon, len(initial_states), [1.0], "Metropolis", + initial_states, 2) + self.assertListEqual(bss._x.tolist(), initial_states.tolist()) + + def test_metropolis_update_average(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + sample_size = 1_000_000 + bss = BlockSampler(grbm, crayon, sample_size, [1.0], "Metropolis", seed=2) + bss._x.data[:] = 1 + ones = torch.ones((sample_size, 1)) + effective_field = torch.tensor(1.2) + for i in range(10): + bss._metropolis_update(1.0, bss._partition[0], effective_field*ones) + bss._metropolis_update(1.0, bss._partition[1], effective_field*ones) + torch.testing.assert_close( + torch.tanh(-effective_field), + bss._x.mean(), + atol=1e-3, rtol=1e-3) + + def test_metropolis_update_oscillates(self): + grbm = GRBM(list("ab"), [["a", "b"]]) + + crayon = self.crayon_veqa + sample_size = 1_00 + bss = BlockSampler(grbm, crayon, sample_size, [1.0], "Metropolis", seed=2) + bss._x.data[:] = 1 + zero_effective_field = torch.zeros((sample_size, 1)) + bss._metropolis_update(0.0, bss._partition[0], zero_effective_field) + self.assertTrue((bss._x[:, 1] == -1).all()) + bss._metropolis_update(0.0, bss._partition[1], zero_effective_field) + self.assertTrue((bss._x == -1).all()) + + def test_effective_field(self): + # Create a triangle graph with an additional dangling vertex + # a + # / | \ + # b--c d + self.nodes = list("abcd") + self.edges = [["a", "b"], ["a", "c"], ["a", "d"], ["b", "c"]] + + # Manually set the parameter weights for testing + dtype = torch.float32 + grbm = GRBM(self.nodes, self.edges) + grbm._linear.data = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=dtype) + grbm._quadratic.data = torch.tensor([1.1, 2.2, 3.3, 6.6], dtype=dtype) + + def crayon(v): + if v == "a": + return 0 + if v == "b": + return 1 + if v == "c": + return 2 + if v == "d": + return 1 + bss = BlockSampler(grbm, crayon, 3, [1.0], seed=3) + bss._x.data[:] = torch.tensor([[1, 1, -1, -1], + [-1, -1, 1, -1], + [1, 1, 1, -1]]) + # effective field for a + effective_field_a = bss._compute_effective_field(bss._partition[0]) + torch.testing.assert_close( + effective_field_a, + torch.tensor([[0.0 + 1.1 - 2.2 - 3.3], + [0.0 - 1.1 + 2.2 - 3.3], + [0.0 + 1.1 + 2.2 - 3.3]]) + ) + # effective field for b, d + effective_field_bd = bss._compute_effective_field(bss._partition[1]) + torch.testing.assert_close(effective_field_bd, + torch.tensor([[1.0 + 1.1 - 6.6, 3.0 + 3.3], + [1.0 - 1.1 + 6.6, 3.0 - 3.3], + [1.0 + 1.1 + 6.6, 3.0 + 3.3]])) + # effective field for c + effective_field_c = bss._compute_effective_field(bss._partition[2]) + torch.testing.assert_close(effective_field_c, + torch.tensor([[2.0 + 2.2 + 6.6], + [2.0 - 2.2 - 6.6], + [2.0 + 2.2 + 6.6]])) + + def test_get_adjacencies(self): + # Create a triangle graph with an additional dangling vertex + # a + # / | \ + # b--c d + self.nodes = list("abcd") + self.edges = [["a", "b"], ["a", "c"], ["a", "d"], ["b", "c"]] + + # Manually set the parameter weights for testing + dtype = torch.float32 + grbm = GRBM(self.nodes, self.edges) + grbm._linear.data = torch.tensor([0.0, 1.0, 2.0, 3.0], dtype=dtype) + grbm._quadratic.data = torch.tensor([1.1, 2.2, 3.3, 6.6], dtype=dtype) + + def crayon(v): + if v == "a": + return 0 + if v == "b": + return 1 + if v == "c": + return 2 + if v == "d": + return 1 + bss = BlockSampler(grbm, crayon, 10, [1.0], seed=4) + padded_adj, padded_adj_weights = bss._get_adjacencies() + + # First, check the neighbour indices are correct + # a has neighbours b, c, d in that order, so 2, 3, 4 + self.assertListEqual(padded_adj[0].tolist(), [1, 2, 3]) + # b has neighbours a, c, in that order, so 0, 2, and padded -1 + self.assertListEqual(padded_adj[1].tolist(), [0, 2, -1]) + # c has neighbours a, b, in that order, so 0, 1, and padded -1 + self.assertListEqual(padded_adj[2].tolist(), [0, 1, -1]) + # d has neighbour a, so 0, and two padded -1 + self.assertListEqual(padded_adj[3].tolist(), [0, -1, -1]) + + # Next, check weights are correct + # a has edges 0, 1, 2 + self.assertListEqual(padded_adj_weights[0].tolist(), [0, 1, 2]) + # b has edges 0, 3, + self.assertListEqual(padded_adj_weights[1].tolist(), [0, 3, -1]) + # c has edges 0, 3, + self.assertListEqual(padded_adj_weights[2].tolist(), [1, 3, -1]) + # d has edges 2 + self.assertListEqual(padded_adj_weights[3].tolist(), [2, -1, -1]) + + @parameterized.expand(GRBM_CRAYON_TEST_CASES) + def test_get_partition(self, grbm: GRBM, crayon): + bss = BlockSampler(grbm, crayon, 10, [1.0], seed=5) + # Check every block is indeed coloured correctly + for block in bss._partition: + self.assertEqual(1, len({crayon(grbm.idx_to_node[bidx]) for bidx in block.tolist()})) + # Check every node has been included + self.assertSetEqual({idx for block in bss._partition for idx in block.tolist()}, + {bss._grbm.node_to_idx[node] for node in bss._grbm.nodes}) + + def test_invalid_crayon(self): + grbm = GRBM([0, 1], [(0, 1)]) + def crayon(n): return 1 + self.assertRaisesRegex(ValueError, "not a valid colouring", BlockSampler, grbm, crayon, 10, [1.0]) + + def test_invalid_proposal(self): + grbm = GRBM([0, 1], [(0, 1)]) + def crayon(n): return 1 + self.assertRaisesRegex(ValueError, "Proposal acceptance criterion should be one of", BlockSampler, + grbm, crayon, 10, [1.0], "abc") + + def test_prepare_initial_states(self): + grbm = GRBM([0, 1, 2], [(0, 1)]) + def crayon(n): return n + bss = BlockSampler(grbm, crayon, 1, [1.0],) + + with self.subTest("Nonspin initial states."): + self.assertRaisesRegex(ValueError, "contain nonspin values", bss._prepare_initial_states, + initial_states=torch.tensor([[0, 1, -1]]), num_chains=1) + + with self.subTest("Testing initial states with incorrect shape."): + self.assertRaisesRegex(ValueError, "Initial states should be of shape", bss._prepare_initial_states, + num_chains=10, initial_states=torch.tensor([[-1, 1, 1, 1, -1]])) + + @parameterized.expand(GRBM_CRAYON_TEST_CASES) + def test_invalid_num_reads(self, grbm, crayon): + self.assertRaisesRegex(ValueError, "should be a positive integer", BlockSampler, grbm, crayon, 0, [1.0]) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_functional.py b/tests/test_functional.py index 17f81f4..da3e770 100755 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -14,8 +14,11 @@ import unittest import torch +from parameterized import parameterized +from dwave.plugins.torch.nn.functional import bit2spin_soft from dwave.plugins.torch.nn.functional import maximum_mean_discrepancy_loss as mmd_loss +from dwave.plugins.torch.nn.functional import spin2bit_soft from dwave.plugins.torch.nn.modules.kernels import Kernel @@ -51,7 +54,9 @@ def test_mmd_loss_dim_mismatch(self): x = torch.tensor([[1], [4]], dtype=torch.float32) y = torch.tensor([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) - self.assertRaisesRegex(ValueError, "Input dimensions must match. You are trying to compute ", mmd_loss, x, y, None) + self.assertRaisesRegex(ValueError, + "Input dimensions must match. You are trying to compute ", + mmd_loss, x, y, None) def test_mmd_loss_arange(self): x = torch.tensor([[1.0], [4.0], [5.0]]) @@ -75,5 +80,22 @@ def _kernel(self, x, y): self.assertEqual(-25, mmd_loss(x, y, Constant())) +class TestFunctional(unittest.TestCase): + + def test_spin2bit_soft(self): + self.assertListEqual(spin2bit_soft(torch.tensor([-1.0, 1.0, 0.5])).tolist(), [0, 1, 0.75]) + + @parameterized.expand([([-1.1, 1.0],), ([-0.5, 1.1],)]) + def test_spin2bit_raises(self, input): + self.assertRaises(ValueError, spin2bit_soft, torch.tensor(input)) + + def test_bit2spin_soft(self): + self.assertListEqual(bit2spin_soft(torch.tensor([0.0, 1.0, 0.5])).tolist(), [-1, 1, 0]) + + @parameterized.expand([([-0.1, 1.0],), ([0.1, 1.1],)]) + def test_bit2spin_soft_raises(self, input): + self.assertRaises(ValueError, bit2spin_soft, torch.tensor(input)) + + if __name__ == "__main__": unittest.main() diff --git a/tests/test_tensor.py b/tests/test_tensor.py new file mode 100755 index 0000000..67022a4 --- /dev/null +++ b/tests/test_tensor.py @@ -0,0 +1,27 @@ +# Copyright 2025 D-Wave +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from dwave.plugins.torch.tensor import randspin + + +class TestTensor(unittest.TestCase): + + def test_rands(self): + self.assertSetEqual({-1, 1}, set(randspin((2000,)).unique().tolist())) + + +if __name__ == "__main__": + unittest.main()