From 8a5fe45f1ecda2af4fa7a97ee59c0c6c82a9163d Mon Sep 17 00:00:00 2001 From: quattro Date: Wed, 26 Jun 2024 16:45:22 -0700 Subject: [PATCH 01/21] this branch isn't fully functional yet, but let's get this in place for students to take over --- src/traceax/__init__.py | 13 +- src/traceax/_estimators.py | 277 ++-------------------- src/traceax/_solution.py | 35 +++ src/traceax/_trace.py | 461 +++++++++++++++++++++++++++++++++++++ src/traceax/_utils.py | 83 +++++++ 5 files changed, 602 insertions(+), 267 deletions(-) create mode 100644 src/traceax/_solution.py create mode 100644 src/traceax/_trace.py create mode 100644 src/traceax/_utils.py diff --git a/src/traceax/__init__.py b/src/traceax/__init__.py index 1fed170..f16823e 100644 --- a/src/traceax/__init__.py +++ b/src/traceax/__init__.py @@ -15,11 +15,7 @@ from importlib.metadata import PackageNotFoundError, version # pragma: no cover from ._estimators import ( - AbstractTraceEstimator as AbstractTraceEstimator, - HutchinsonEstimator as HutchinsonEstimator, - HutchPlusPlusEstimator as HutchPlusPlusEstimator, - XNysTraceEstimator as XNysTraceEstimator, - XTraceEstimator as XTraceEstimator, + AbstractEstimator as AbstractEstimator, ) from ._samplers import ( AbstractSampler as AbstractSampler, @@ -27,6 +23,13 @@ RademacherSampler as RademacherSampler, SphereSampler as SphereSampler, ) +from ._trace import ( + HutchinsonEstimator as HutchinsonEstimator, + HutchPlusPlusEstimator as HutchPlusPlusEstimator, + trace as trace, + XNysTraceEstimator as XNysTraceEstimator, + XTraceEstimator as XTraceEstimator, +) try: diff --git a/src/traceax/_estimators.py b/src/traceax/_estimators.py index fcd7a67..6e74fc6 100644 --- a/src/traceax/_estimators.py +++ b/src/traceax/_estimators.py @@ -13,44 +13,32 @@ # limitations under the License. from abc import abstractmethod -from typing import Any +from typing import Any, Generic, TypeVar import equinox as eqx -import jax -import jax.numpy as jnp -import jax.scipy as jsp from equinox import AbstractVar -from jax.numpy.linalg import norm -from jaxtyping import Array, PRNGKeyArray -from lineax import AbstractLinearOperator, is_negative_semidefinite, is_positive_semidefinite +from jaxtyping import Array, PRNGKeyArray, PyTree +from lineax import AbstractLinearOperator -from ._samplers import AbstractSampler, RademacherSampler, SphereSampler +from ._samplers import AbstractSampler -def _check_shapes(operator: AbstractLinearOperator, k: int) -> tuple[int, int]: - n_in = operator.in_size() - n_out = operator.out_size() - if n_in != n_out: - raise ValueError(f"Trace estimation requires square linear operator. Found {(n_out, n_in)}.") +_EstimatorState = TypeVar("_EstimatorState") - if k < 1: - raise ValueError(f"Trace estimation requires positive number of matvecs. Found {k}.") - return n_in, k - - -def _get_scale(W: Array, D: Array, n: int, k: int) -> Array: - return (n - k + 1) / (n - norm(W, axis=0) ** 2 + jnp.abs(D) ** 2) - - -class AbstractTraceEstimator(eqx.Module, strict=True): +class AbstractEstimator(eqx.Module, Generic[_EstimatorState], strict=True): r"""Abstract base class for all trace estimators.""" sampler: AbstractVar[AbstractSampler] @abstractmethod - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: + def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _EstimatorState: + """ """ + ... + + @abstractmethod + def estimate(self, state: _EstimatorState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: """Estimate the trace of `operator`. !!! Example @@ -59,7 +47,7 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) key = jax.random.PRNGKey(...) operator = lx.MatrixLinearOperator(...) hutch = tx.HutchinsonEstimator() - result = hutch.compute(key, operator, k=10) + result = hutch.estimate(key, operator, k=10) # or result = hutch(key, operator, k=10) ``` @@ -79,241 +67,6 @@ def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) """ ... - def __call__(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: + def __call__(self, state: _EstimatorState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: """An alias for `estimate`.""" - return self.estimate(key, operator, k) - - -class HutchinsonEstimator(AbstractTraceEstimator): - r"""Girard-Hutchinson Trace Estimator: - - $\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A})$, - where $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. - - """ - - sampler: AbstractSampler = RademacherSampler() - - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: - n, k = _check_shapes(operator, k) - # sample from proposed distribution - samples = self.sampler(key, n, k) - - # project to k-dim space - projected = jax.vmap(operator.mv, (1,), 1)(samples) - - # take the mean across estimates - trace_est = jnp.sum(projected * samples) / k - - return trace_est, {} - - -HutchinsonEstimator.__init__.__doc__ = r"""**Arguments:** - -- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. -""" - - -class HutchPlusPlusEstimator(AbstractTraceEstimator): - r"""Hutch++ Trace Estimator: - - Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the a _low-rank approximation_ - to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for - $\Omega = [\omega_1, \dotsc, \omega_k]$. - - Hutch++ improves upon Girard-Hutchinson estimator by including the trace of the residuals. Namely, - Hutch++ estimates $\text{trace}(\mathbf{A})$ as - $\text{trace}(\hat{\mathbf{A}}) - \text{trace}(\mathbf{A} - \hat{\mathbf{A}})$. - - As with the Girard-Hutchinson estimator, it requires - $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. - - """ - - sampler: AbstractSampler = RademacherSampler() - - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: - # generate an n, k matrix X - n, k = _check_shapes(operator, k) - m = k // 3 - - # some operators work fine with matrices in mv, some dont; this ensures they all do - mv = jax.vmap(operator.mv, (1,), 1) - - # split X into 2 Xs; X1 and X2, where X1 has shape 2m, where m = k/3 - samples = self.sampler(key, n, 2 * m) - X1 = samples[:, :m] - X2 = samples[:, m:] - - Y = mv(X1) - - # compute Q, _ = QR(Y) (orthogonal matrix) - Q, _ = jnp.linalg.qr(Y) - - # compute G = X2 - Q @ (Q.T @ X2) - G = X2 - Q @ (Q.T @ X2) - - # estimate trace = tr(Q.T @ A @ Q) + tr(G.T @ A @ G) / k - AQ = mv(Q) - AG = mv(G) - trace_est = jnp.sum(AQ * Q) + jnp.sum(AG * G) / (G.shape[1]) - - return trace_est, {} - - -HutchPlusPlusEstimator.__init__.__doc__ = r"""**Arguments:** - -- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. -""" - - -class XTraceEstimator(AbstractTraceEstimator): - r"""XTrace Trace Estimator: - - Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the the _low-rank approximation_ - to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for - $\Omega = [\omega_1, \dotsc, \omega_k]$. - - XTrace improves upon Hutch++ estimator by enforcing *exchangeability* of sampled test-vectors, - to construct a symmetric estimation function with lower variance. - - Additionally, the *improved* XTrace algorithm (i.e. `improved = True`), ensures that test-vectors - are orthogonalized against the low rank approximation $\mathbf{Q}\mathbf{Q}^* \mathbf{A}$ and - renormalized. This improved XTrace approach may provide better empirical results compared with - the non-orthogonalized version. - - As with the Girard-Hutchinson estimator, it requires - $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. - - """ - - sampler: AbstractSampler = SphereSampler() - improved: bool = True - - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: - n, k = _check_shapes(operator, k) - m = k // 2 - - # some operators work fine with matrices in mv, some dont; this ensures they all do - mv = jax.vmap(operator.mv, (1,), 1) - - samples = self.sampler(key, n, m) - Y = mv(samples) - Q, R = jnp.linalg.qr(Y) - - # solve and rescale - S = jnp.linalg.inv(R).T - s = norm(S, axis=0) - S = S / s - - # working variables - Z = mv(Q) - H = Q.T @ Z - W = Q.T @ samples - T = Z.T @ samples - HW = H @ W - - SW_d = jnp.sum(S * W, axis=0) - TW_d = jnp.sum(T * W, axis=0) - SHS_d = jnp.sum(S * (H @ S), axis=0) - WHW_d = jnp.sum(W * HW, axis=0) - - term1 = SW_d * jnp.sum((T - H.T @ W) * S, axis=0) - term2 = (jnp.abs(SW_d) ** 2) * SHS_d - term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0) - - if self.improved: - scale = _get_scale(W, SW_d, n, k) - else: - scale = 1 - - estimates = jnp.trace(H) * jnp.ones(m) - SHS_d + (WHW_d - TW_d + term1 + term2 + term3) * scale - trace_est = jnp.mean(estimates) - std_err = jnp.std(estimates) / jnp.sqrt(m) - - return trace_est, {"std.err": std_err} - - -XTraceEstimator.__init__.__doc__ = r"""**Arguments:** - -- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. -- `improved`: whether to use the *improved* XTrace estimator, which rescales predicted samples. - Default is `True` (see Notes). -""" - - -class XNysTraceEstimator(AbstractTraceEstimator): - r"""XNysTrace Trace Estimator: - - XNysTrace improves upon XTrace estimator when $\mathbf{A}$ is (negative-) positive-semidefinite, by - performing a [Nyström approximation](https://en.wikipedia.org/wiki/Low-rank_matrix_approximations#Nystr%C3%B6m_approximation), - rather than a randomized SVD (i.e., random projection followed by QR decomposition). - - Like, [`traceax.XTraceEstimator`][], the *improved* XNysTrace algorithm (i.e. `improved = True`), ensures - that test-vectors are orthogonalized against the low rank approximation and renormalized. - This improved XNysTrace approach may provide better empirical results compared with the non-orthogonalized version. - - As with the Girard-Hutchinson estimator, it requires - $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. - - """ - - sampler: AbstractSampler = SphereSampler() - improved: bool = True - - def estimate(self, key: PRNGKeyArray, operator: AbstractLinearOperator, k: int) -> tuple[Array, dict[str, Any]]: - is_nsd = is_negative_semidefinite(operator) - if not (is_positive_semidefinite(operator) | is_nsd): - raise ValueError("`XNysTraceEstimator` may only be used for positive or negative definite linear operators") - if is_nsd: - operator = -operator - - n, k = _check_shapes(operator, k) - - # some operators work fine with matrices in mv, some dont; this ensures they all do - mv = jax.vmap(operator.mv, (1,), 1) - - samples = self.sampler(key, n, k) - Y = mv(samples) - - # shift for numerical issues - nu = jnp.finfo(Y.dtype).eps * norm(Y, "fro") / jnp.sqrt(n) - Y = Y + samples * nu - Q, R = jnp.linalg.qr(Y) - - # compute and symmetrize H, then take cholesky factor - H = samples.T @ Y - C = jnp.linalg.cholesky(0.5 * (H + H.T)).T - B = jsp.linalg.solve_triangular(C.T, R.T, lower=True).T - - # if improved == True - Qs, Rs = jnp.linalg.qr(samples) - Ws = Qs.T @ samples - - # solve and rescale - if self.improved: - S = jnp.linalg.inv(Rs).T - s = norm(S, axis=0) - S = S / s - scale = _get_scale(Ws, jnp.sum(S * Ws, axis=0), n, k) - else: - scale = 1 - - W = Q.T @ samples - S = jsp.linalg.solve_triangular(C, B.T).T / jnp.sqrt(jnp.diag(jnp.linalg.inv(H))) - dSW = jnp.sum(S * W, axis=0) - - estimates = norm(B, "fro") ** 2 - norm(S, axis=0) ** 2 + (jnp.abs(dSW) ** 2) * scale - nu * n - trace_est = jnp.mean(estimates) - std_err = jnp.std(estimates) / jnp.sqrt(k) - trace_est = jnp.where(is_nsd, -trace_est, trace_est) - - return trace_est, {"std.err": std_err} - - -XNysTraceEstimator.__init__.__doc__ = r"""**Arguments:** - -- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. -- `improved`: whether to use the *improved* XNysTrace estimator, which rescales predicted samples. - Default is `True` (see Notes). -""" + return self.estimate(state, k) diff --git a/src/traceax/_solution.py b/src/traceax/_solution.py new file mode 100644 index 0000000..e4a4e1c --- /dev/null +++ b/src/traceax/_solution.py @@ -0,0 +1,35 @@ +# Copyright (c) 2024 MancusoLab. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any + +import equinox as eqx + +from jaxtyping import ArrayLike, PyTree + + +class Solution(eqx.Module, strict=True): + """The solution to a stochastic estimation problem. + + **Attributes:** + + - `value`: The estimated value. + - `stats`: A dictionary containing statistics about the solution (e.g., standard error). + This may be empty if individual estimators cannot provide this information (i.e. `{}`) + - `state`: The internal state for the estimator. + """ + + value: PyTree[Any] + stats: dict[str, PyTree[ArrayLike]] + state: PyTree[Any] diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py new file mode 100644 index 0000000..99c88fe --- /dev/null +++ b/src/traceax/_trace.py @@ -0,0 +1,461 @@ +# Copyright (c) 2024 MancusoLab. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import functools as ft + +from typing import Any +from typing_extensions import TypeAlias + +import equinox as eqx +import equinox.internal as eqxi +import jax.lax as lax +import jax.numpy as jnp +import jax.random as rdm +import jax.scipy as jsp +import jax.tree_util as jtu + +from jax.interpreters import ad as ad +from jax.numpy.linalg import norm +from jaxtyping import Array, PRNGKeyArray, PyTree +from lineax import ( + AbstractLinearOperator, + DiagonalLinearOperator, + IdentityLinearOperator, + is_negative_semidefinite, + is_positive_semidefinite, +) + +from ._estimators import AbstractEstimator +from ._samplers import AbstractSampler, RademacherSampler, SphereSampler +from ._solution import Solution +from ._utils import ( + _assert_false, + _check_operator, + _clip_k, + _to_shapedarray, + _to_struct, + _vmap_mv, + sentinel, +) + + +def _get_scale(W: Array, D: Array, n: int, k: int) -> Array: + return (n - k + 1) / (n - norm(W, axis=0) ** 2 + jnp.abs(D) ** 2) + + +_BasicTraceState: TypeAlias = tuple[PRNGKeyArray, AbstractLinearOperator, int] +_PSDTraceState: TypeAlias = tuple[PRNGKeyArray, AbstractLinearOperator, int, bool] + + +class HutchinsonEstimator(AbstractEstimator[_BasicTraceState], strict=True): + r"""Girard-Hutchinson Trace Estimator: + + $\mathbb{E}[\omega^T \mathbf{A} \omega] = \text{trace}(\mathbf{A})$, + where $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. + + """ + + sampler: AbstractSampler = RademacherSampler() + + def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _BasicTraceState: + n = _check_operator(operator) + return (key, operator, n) + + def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: + key, operator, n = state + + k = _clip_k(k, n) + + # sample from proposed distribution + samples = self.sampler(key, n, k) + + # project to k-dim space + projected = _vmap_mv(operator)(samples) + + # take the mean across estimates + trace_est = jnp.sum(projected * samples) / k + + return trace_est, {} + + +HutchinsonEstimator.__init__.__doc__ = r"""**Arguments:** + +- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. +""" + + +class HutchPlusPlusEstimator(AbstractEstimator[_BasicTraceState], strict=True): + r"""Hutch++ Trace Estimator: + + Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the a _low-rank approximation_ + to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for + $\Omega = [\omega_1, \dotsc, \omega_k]$. + + Hutch++ improves upon Girard-Hutchinson estimator by including the trace of the residuals. Namely, + Hutch++ estimates $\text{trace}(\mathbf{A})$ as + $\text{trace}(\hat{\mathbf{A}}) - \text{trace}(\mathbf{A} - \hat{\mathbf{A}})$. + + As with the Girard-Hutchinson estimator, it requires + $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. + + """ + + sampler: AbstractSampler = RademacherSampler() + + def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _BasicTraceState: + n = _check_operator(operator) + return (key, operator, n) + + def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: + key, operator, n = state + + k = _clip_k(k, n) + m = k // 3 + + # some operators work fine with matrices in mv, some dont; this ensures they all do + mv = _vmap_mv(operator) + + # split X into 2 Xs; X1 and X2, where X1 has shape 2m, where m = k/3 + samples = self.sampler(key, n, 2 * m) + X1 = samples[:, :m] + X2 = samples[:, m:] + + Y = mv(X1) + + # compute Q, _ = QR(Y) (orthogonal matrix) + Q, _ = jnp.linalg.qr(Y) + + # compute G = X2 - Q @ (Q.T @ X2) + G = X2 - Q @ (Q.T @ X2) + + # estimate trace = tr(Q.T @ A @ Q) + tr(G.T @ A @ G) / k + AQ = mv(Q) + AG = mv(G) + trace_est = jnp.sum(AQ * Q) + jnp.sum(AG * G) / (G.shape[1]) + + return trace_est, {} + + +HutchPlusPlusEstimator.__init__.__doc__ = r"""**Arguments:** + +- `sampler`: The sampling distribution for $\omega$. Default is [`traceax.RademacherSampler`][]. +""" + + +class XTraceEstimator(AbstractEstimator[_BasicTraceState], strict=True): + r"""XTrace Trace Estimator: + + Let $\hat{\mathbf{A}} := \mathbf{Q}\mathbf{Q}^* \mathbf{A}$ be the the _low-rank approximation_ + to $\mathbf{A}$, where $\mathbf{Q}$ is the orthonormal basis of $\mathbf{A} \Omega$, for + $\Omega = [\omega_1, \dotsc, \omega_k]$. + + XTrace improves upon Hutch++ estimator by enforcing *exchangeability* of sampled test-vectors, + to construct a symmetric estimation function with lower variance. + + Additionally, the *improved* XTrace algorithm (i.e. `improved = True`), ensures that test-vectors + are orthogonalized against the low rank approximation $\mathbf{Q}\mathbf{Q}^* \mathbf{A}$ and + renormalized. This improved XTrace approach may provide better empirical results compared with + the non-orthogonalized version. + + As with the Girard-Hutchinson estimator, it requires + $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. + + """ + + sampler: AbstractSampler = SphereSampler() + improved: bool = True + + def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _BasicTraceState: + n = _check_operator(operator) + return (key, operator, n) + + def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: + key, operator, n = state + + k = _clip_k(k, n) + m = k // 2 + + # some operators work fine with matrices in mv, some dont; this ensures they all do + mv = _vmap_mv(operator) + + samples = self.sampler(key, n, m) + Y = mv(samples) + Q, R = jnp.linalg.qr(Y) + + # solve and rescale + S = jnp.linalg.inv(R).T + s = norm(S, axis=0) + S = S / s + + # working variables + Z = mv(Q) + H = Q.T @ Z + W = Q.T @ samples + T = Z.T @ samples + HW = H @ W + + SW_d = jnp.sum(S * W, axis=0) + TW_d = jnp.sum(T * W, axis=0) + SHS_d = jnp.sum(S * (H @ S), axis=0) + WHW_d = jnp.sum(W * HW, axis=0) + + term1 = SW_d * jnp.sum((T - H.T @ W) * S, axis=0) + term2 = (jnp.abs(SW_d) ** 2) * SHS_d + term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0) + + if self.improved: + scale = _get_scale(W, SW_d, n, k) + else: + scale = 1 + + estimates = jnp.trace(H) * jnp.ones(m) - SHS_d + (WHW_d - TW_d + term1 + term2 + term3) * scale + trace_est = jnp.mean(estimates) + std_err = jnp.std(estimates) / jnp.sqrt(m) + + return trace_est, {"std.err": std_err} + + +XTraceEstimator.__init__.__doc__ = r"""**Arguments:** + +- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. +- `improved`: whether to use the *improved* XTrace estimator, which rescales predicted samples. + Default is `True` (see Notes). +""" + + +class XNysTraceEstimator(AbstractEstimator[_PSDTraceState], strict=True): + r"""XNysTrace Trace Estimator: + + XNysTrace improves upon XTrace estimator when $\mathbf{A}$ is (negative-) positive-semidefinite, by + performing a [Nyström approximation](https://en.wikipedia.org/wiki/Low-rank_matrix_approximations#Nystr%C3%B6m_approximation), + rather than a randomized SVD (i.e., random projection followed by QR decomposition). + + Like, [`traceax.XTraceEstimator`][], the *improved* XNysTrace algorithm (i.e. `improved = True`), ensures + that test-vectors are orthogonalized against the low rank approximation and renormalized. + This improved XNysTrace approach may provide better empirical results compared with the non-orthogonalized version. + + As with the Girard-Hutchinson estimator, it requires + $\mathbb{E}[\omega] = 0$ and $\mathbb{E}[\omega \omega^T] = \mathbf{I}$. + + """ + + sampler: AbstractSampler = SphereSampler() + improved: bool = True + + def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _PSDTraceState: + n = _check_operator(operator) + is_nsd = is_negative_semidefinite(operator) + if not (is_positive_semidefinite(operator) | is_nsd): + raise ValueError("`XNysTraceEstimator` may only be used for positive or negative definite linear operators") + if is_nsd: + operator = -operator + + return (key, operator, n, is_nsd) + + def estimate(self, state: _PSDTraceState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: + key, operator, n, is_nsd = state + + k = _clip_k(k, n) + + # some operators work fine with matrices in mv, some dont; this ensures they all do + mv = _vmap_mv(operator) + + samples = self.sampler(key, n, k) + Y = mv(samples) + + # shift for numerical issues + nu = jnp.finfo(Y.dtype).eps * norm(Y, "fro") / jnp.sqrt(n) + Y = Y + samples * nu + Q, R = jnp.linalg.qr(Y) + + # compute and symmetrize H, then take cholesky factor + H = samples.T @ Y + C = jnp.linalg.cholesky(0.5 * (H + H.T)).T + B = jsp.linalg.solve_triangular(C.T, R.T, lower=True).T + + # if improved == True + Qs, Rs = jnp.linalg.qr(samples) + Ws = Qs.T @ samples + + # solve and rescale + if self.improved: + S = jnp.linalg.inv(Rs).T + s = norm(S, axis=0) + S = S / s + scale = _get_scale(Ws, jnp.sum(S * Ws, axis=0), n, k) + else: + scale = 1 + + W = Q.T @ samples + S = jsp.linalg.solve_triangular(C, B.T).T / jnp.sqrt(jnp.diag(jnp.linalg.inv(H))) + dSW = jnp.sum(S * W, axis=0) + + estimates = norm(B, "fro") ** 2 - norm(S, axis=0) ** 2 + (jnp.abs(dSW) ** 2) * scale - nu * n + trace_est = jnp.mean(estimates) + std_err = jnp.std(estimates) / jnp.sqrt(k) + trace_est = jnp.where(is_nsd, -trace_est, trace_est) + + return trace_est, {"std.err": std_err} + + +XNysTraceEstimator.__init__.__doc__ = r"""**Arguments:** + +- `sampler`: the sampling distribution for $\omega$. Default is [`traceax.SphereSampler`][]. +- `improved`: whether to use the *improved* XNysTrace estimator, which rescales predicted samples. + Default is `True` (see Notes). +""" + + +def _estimate_trace_impl(key, operator, state, k, estimator, *, check_closure): + out = estimator.estimate(state, k) + if check_closure: + out = eqxi.nontraceable(out, name="`traceax.trace` with respect to a closed-over value") + result, stats = out + + return result, stats + + +_to_struct_tr = ft.partial(_to_struct, name="traceax.trace") + + +@eqxi.filter_primitive_def +def _estimate_trace_abstract_eval(key, operator, state, k, estimator): + key, state, k, estimator = jtu.tree_map(_to_struct_tr, (key, state, k, estimator)) + out = eqx.filter_eval_shape( + _estimate_trace_impl, + key, + operator, + state, + k, + estimator, + check_closure=False, + ) + out = jtu.tree_map(_to_shapedarray, out) + + return out + + +@eqxi.filter_primitive_jvp +def _estimate_trace_jvp(primals, tangents): + key, operator, state, k, estimator = primals + # t_operator := V + t_key, t_operator, t_state, t_k, t_estimator = tangents + jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator)) + + # primal problem of t = tr(A) + result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator) + out = result, stats + + # inner prodct in matrix space => = tr(A @ B) + # d tr(A) / dA = I + # t' = = tr(I @ V) = tr(V) + # tangent problem => tr(V) + t_state = estimator.init(key, t_operator) + t_result, t_stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, t_operator, t_state, k, estimator) + + t_out = ( + t_result, + jtu.tree_map(lambda _: None, stats), + ) + + return out, t_out + + +@eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore +def _estimate_trace_transpose(inputs, cts_out): + key, operator, state, k, estimator = inputs + cts_result, cts_stats = cts_out + # jtu.tree_map(_assert_defined, (key, operator, state, k, estimator), is_leaf=_is_undefined) + # op_t = cts_result * IdentityLinearOperator(operator.in_structure()) + op_t = cts_result * operator.transpose() + state_none = jtu.tree_map(lambda _: None, state) + k_none = None + estimator_none = jtu.tree_map(lambda _: None, estimator) + + return op_t, state_none, k_none, estimator_none + + +_estimate_trace_p = eqxi.create_vprim( + "trace", + eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=False)), + _estimate_trace_abstract_eval, + _estimate_trace_jvp, + _estimate_trace_transpose, +) +_estimate_trace_p.def_impl( + eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=True)), +) +eqxi.register_impl_finalisation(_estimate_trace_p) + + +# @eqx.filter_jit +def trace( + key: PRNGKeyArray, + operator: AbstractLinearOperator, + k: int, + estimator: AbstractEstimator = XTraceEstimator(), + *, + state: PyTree[Any] = sentinel, +) -> Solution: + r""" """ + if eqx.is_array(operator): + raise ValueError( + "`traceax.trace(..., operator=...)` should be an " + "`lineax.AbstractLinearOperator`, not a raw JAX array. If you are trying to pass " + "a matrix then this should be passed as " + "`lineax.MatrixLinearOperator(matrix)`." + ) + + in_size = operator.in_size() + out_size = operator.out_size() + if in_size != out_size: + raise ValueError( + "`traceax.trace(..., operator=...)` should be a square `lineax.AbstractLinearOperator`. " + f"Found shape {out_size}x{in_size}." + ) + + # if identity op, then just shortcircuit and return dimension size + if isinstance(operator, IdentityLinearOperator): + return Solution( + value=float(in_size), + stats={}, + state=state, + ) + # if diagonal op, then just shortcircuit and sum diagonal + if isinstance(operator, DiagonalLinearOperator): + return Solution( + value=jnp.sum(operator.diagonal), + stats={}, + state=state, + ) + + # set up state if necessary + if state == sentinel: + key, s_key = rdm.split(key) + state = estimator.init(s_key, operator) + dynamic_state, static_state = eqx.partition(state, eqx.is_array) + dynamic_state = lax.stop_gradient(dynamic_state) + state = eqx.combine(dynamic_state, static_state) + + # cannot differentiate through key, state, or estimator + key = eqxi.nondifferentiable(key, name="`trace(key, ...)`") + state = eqxi.nondifferentiable(state, name="`trace(..., state=...)`") + estimator = eqxi.nondifferentiable(estimator, name="`trace(..., estimator=...)`") + + # estimate trace and compute stats if any + result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator) + + # cannot differentiate backwards through stats + stats = eqxi.nondifferentiable_backward(stats, name="_, stats = trace(...)") + + return Solution(value=result, stats=stats, state=state) diff --git a/src/traceax/_utils.py b/src/traceax/_utils.py new file mode 100644 index 0000000..ffae138 --- /dev/null +++ b/src/traceax/_utils.py @@ -0,0 +1,83 @@ +# Copyright (c) 2024 MancusoLab. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any + +import equinox.internal as eqxi +import jax +import jax.core + +from jax import vmap +from jax.interpreters import ad as ad +from lineax import AbstractLinearOperator + + +sentinel: Any = eqxi.doc_repr(object(), "sentinel") + + +def _check_operator(operator: AbstractLinearOperator) -> int: + n_in = operator.in_size() + n_out = operator.out_size() + if n_in != n_out: + raise ValueError(f"Estimation requires square linear operator. Found {(n_out, n_in)}.") + + return n_in + + +def _clip_k(k: int, n: int) -> int: + return min(max(k, 1), n) + + +def _vmap_mv(operator: AbstractLinearOperator): + return vmap(operator.mv, (1,), 1) + + +def _is_none(x): + return x is None + + +def _to_shapedarray(x): + if isinstance(x, jax.ShapeDtypeStruct): + return jax.core.ShapedArray(x.shape, x.dtype) + else: + return x + + +def _to_struct(x, name): + if isinstance(x, jax.core.ShapedArray): + return jax.ShapeDtypeStruct(x.shape, x.dtype) + elif isinstance(x, jax.core.AbstractValue): + raise NotImplementedError( + f"`{name}` only supports working with JAX arrays; not " f"other abstract values. Got abstract value {x}." + ) + else: + return x + + +def _assert_false(x): + assert False + + +def _is_undefined(x): + return isinstance(x, ad.UndefinedPrimal) + + +def _assert_defined(x): + assert not _is_undefined(x) + + +def _keep_undefined(v, ct): + if _is_undefined(v): + return ct + else: + return None From 35be0f427311ec6a55ca86ebec05730cf740f401 Mon Sep 17 00:00:00 2001 From: quattro Date: Wed, 26 Jun 2024 16:46:18 -0700 Subject: [PATCH 02/21] this branch isn't fully functional yet, but let's get this in place for students to take over --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 1e77bf4..f3688c6 100644 --- a/README.md +++ b/README.md @@ -70,20 +70,20 @@ key, key1, key2, key3, key4 = rdm.split(key, 5) # Hutchinson estimator; default samples Rademacher {-1,+1} hutch = tx.HutchinsonEstimator() -print(hutch.estimate(key1, operator, k)) # (Array(3.6007538, dtype=float32), {}) +print(tx.trace(key1, operator, k, hutch)) # (Array(3.6007538, dtype=float32), {}) # Hutch++ estimator; default samples Rademacher {-1,+1} hpp = tx.HutchPlusPlusEstimator() -print(hpp.estimate(key2, operator, k)) # (Array(3.4094956, dtype=float32), {}) +print(tx.trace(key2, operator, k, hpp)) # (Array(3.4094956, dtype=float32), {}) # XTrace estimator; default samples uniformly on n-Sphere xt = tx.XTraceEstimator() -print(xt.estimate(key3, operator, k)) # (Array(3.3030486, dtype=float32), {'std.err': Array(0.01238528, dtype=float32)}) +print(tx.trace(key3, operator, k, xt)) # (Array(3.3030486, dtype=float32), {'std.err': Array(0.01238528, dtype=float32)}) # XNysTrace estimator; Improved performance for NSD/PSD trace estimates operator = lx.TaggedLinearOperator(operator, lx.positive_semidefinite_tag) nt = tx.XNysTraceEstimator() -print(nt.estimate(key4, operator, k)) # (Array(3.3314352, dtype=float32), {'std.err': Array(0.0006521, dtype=float32)}) +print(tx.trace(key4, operator, k, nt)) # (Array(3.3314352, dtype=float32), {'std.err': Array(0.0006521, dtype=float32)}) ``` ## Documentation From d71b5a4aae0f2b513aae93304f753af515de8595 Mon Sep 17 00:00:00 2001 From: quattro Date: Wed, 26 Jun 2024 16:50:17 -0700 Subject: [PATCH 03/21] this branch isn't fully functional yet, but let's get this in place for students to take over --- tests/test_trace.py | 38 +++++++++++++++++++++++++------------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/tests/test_trace.py b/tests/test_trace.py index 1eef35e..eaabe19 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -17,14 +17,17 @@ import jax.numpy as jnp import lineax as lx -import traceax as tr +import traceax as tx from .helpers import ( construct_matrix, ) -@pytest.mark.parametrize("estimator", (tr.HutchinsonEstimator(), tr.HutchPlusPlusEstimator(), tr.XTraceEstimator())) +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) @pytest.mark.parametrize("k", (5, 10, 50)) @pytest.mark.parametrize( "tags", @@ -44,14 +47,17 @@ def test_matrix_linop(getkey, estimator, k, tags, size, dtype): k = min(k, size) matrix = construct_matrix(getkey, tags, size, dtype) operator = lx.MatrixLinearOperator(matrix, tags=tags) - result = estimator.estimate(getkey(), operator, k) + result = tx.trace(getkey(), operator, k, estimator) assert result is not None - assert result[0] is not None - assert jnp.isfinite(result[0]) + assert result.value is not None + assert jnp.isfinite(result.value) -@pytest.mark.parametrize("estimator", (tr.HutchinsonEstimator(), tr.HutchPlusPlusEstimator(), tr.XTraceEstimator())) +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) @pytest.mark.parametrize("k", (5, 10, 50)) @pytest.mark.parametrize("size", (5, 50, 500)) @pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) @@ -59,15 +65,21 @@ def test_diag_linop(getkey, estimator, k, size, dtype): k = min(k, size) matrix = construct_matrix(getkey, lx.diagonal_tag, size, dtype) operator = lx.DiagonalLinearOperator(jnp.diag(matrix)) - result = estimator.estimate(getkey(), operator, k) + result = tx.trace(getkey(), operator, k, estimator) assert result is not None - assert result[0] is not None - assert jnp.isfinite(result[0]) + assert result.value is not None + assert jnp.isfinite(result.value) @pytest.mark.parametrize( - "estimator", (tr.HutchinsonEstimator(), tr.HutchPlusPlusEstimator(), tr.XTraceEstimator(), tr.XNysTraceEstimator()) + "estimator", + ( + tx.HutchinsonEstimator(), + tx.HutchPlusPlusEstimator(), + tx.XTraceEstimator(), + tx.XNysTraceEstimator(), + ), ) @pytest.mark.parametrize("k", (5, 10, 50)) @pytest.mark.parametrize("tags", (lx.positive_semidefinite_tag, lx.negative_semidefinite_tag)) @@ -77,8 +89,8 @@ def test_nsd_psd_matrix_linop(getkey, estimator, k, tags, size, dtype): k = min(k, size) matrix = construct_matrix(getkey, tags, size, dtype) operator = lx.MatrixLinearOperator(matrix, tags=tags) - result = estimator.estimate(getkey(), operator, k) + result = tx.trace(getkey(), operator, k, estimator) assert result is not None - assert result[0] is not None - assert jnp.isfinite(result[0]) + assert result.value is not None + assert jnp.isfinite(result.value) From 3b58cd190425ec0b93218edcee4c2e27e5732043 Mon Sep 17 00:00:00 2001 From: quattro Date: Fri, 20 Sep 2024 12:58:17 -0700 Subject: [PATCH 04/21] checkpoint for handoff --- src/traceax/_estimators.py | 3 + src/traceax/_trace.py | 138 +++++++++++++++++++++++++++---------- 2 files changed, 103 insertions(+), 38 deletions(-) diff --git a/src/traceax/_estimators.py b/src/traceax/_estimators.py index 6e74fc6..35d5b9c 100644 --- a/src/traceax/_estimators.py +++ b/src/traceax/_estimators.py @@ -70,3 +70,6 @@ def estimate(self, state: _EstimatorState, k: int) -> tuple[PyTree[Array], dict[ def __call__(self, state: _EstimatorState, k: int) -> tuple[PyTree[Array], dict[str, Any]]: """An alias for `estimate`.""" return self.estimate(state, k) + + @abstractmethod + def transpose(self, state: _EstimatorState) -> _EstimatorState: ... diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index 99c88fe..8dd2f45 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -18,22 +18,17 @@ import equinox as eqx import equinox.internal as eqxi +import jax import jax.lax as lax import jax.numpy as jnp import jax.random as rdm import jax.scipy as jsp import jax.tree_util as jtu +import lineax as lx -from jax.interpreters import ad as ad +from jax.interpreters import ad as ad, mlir as mlir from jax.numpy.linalg import norm from jaxtyping import Array, PRNGKeyArray, PyTree -from lineax import ( - AbstractLinearOperator, - DiagonalLinearOperator, - IdentityLinearOperator, - is_negative_semidefinite, - is_positive_semidefinite, -) from ._estimators import AbstractEstimator from ._samplers import AbstractSampler, RademacherSampler, SphereSampler @@ -53,8 +48,8 @@ def _get_scale(W: Array, D: Array, n: int, k: int) -> Array: return (n - k + 1) / (n - norm(W, axis=0) ** 2 + jnp.abs(D) ** 2) -_BasicTraceState: TypeAlias = tuple[PRNGKeyArray, AbstractLinearOperator, int] -_PSDTraceState: TypeAlias = tuple[PRNGKeyArray, AbstractLinearOperator, int, bool] +_BasicTraceState: TypeAlias = tuple[PRNGKeyArray, lx.AbstractLinearOperator, int] +_PSDTraceState: TypeAlias = tuple[PRNGKeyArray, lx.AbstractLinearOperator, int, bool] class HutchinsonEstimator(AbstractEstimator[_BasicTraceState], strict=True): @@ -67,7 +62,7 @@ class HutchinsonEstimator(AbstractEstimator[_BasicTraceState], strict=True): sampler: AbstractSampler = RademacherSampler() - def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _BasicTraceState: + def init(self, key: PRNGKeyArray, operator: lx.AbstractLinearOperator) -> _BasicTraceState: n = _check_operator(operator) return (key, operator, n) @@ -87,6 +82,10 @@ def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict return trace_est, {} + def transpose(self, state: _BasicTraceState) -> _BasicTraceState: + key, operator, n = state + return key, operator.transpose(), n + HutchinsonEstimator.__init__.__doc__ = r"""**Arguments:** @@ -112,7 +111,7 @@ class HutchPlusPlusEstimator(AbstractEstimator[_BasicTraceState], strict=True): sampler: AbstractSampler = RademacherSampler() - def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _BasicTraceState: + def init(self, key: PRNGKeyArray, operator: lx.AbstractLinearOperator) -> _BasicTraceState: n = _check_operator(operator) return (key, operator, n) @@ -145,6 +144,10 @@ def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict return trace_est, {} + def transpose(self, state: _BasicTraceState) -> _BasicTraceState: + key, operator, n = state + return key, operator.transpose(), n + HutchPlusPlusEstimator.__init__.__doc__ = r"""**Arguments:** @@ -175,7 +178,7 @@ class XTraceEstimator(AbstractEstimator[_BasicTraceState], strict=True): sampler: AbstractSampler = SphereSampler() improved: bool = True - def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _BasicTraceState: + def init(self, key: PRNGKeyArray, operator: lx.AbstractLinearOperator) -> _BasicTraceState: n = _check_operator(operator) return (key, operator, n) @@ -224,6 +227,10 @@ def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict return trace_est, {"std.err": std_err} + def transpose(self, state: _BasicTraceState) -> _BasicTraceState: + key, operator, n = state + return key, operator.transpose(), n + XTraceEstimator.__init__.__doc__ = r"""**Arguments:** @@ -252,10 +259,10 @@ class XNysTraceEstimator(AbstractEstimator[_PSDTraceState], strict=True): sampler: AbstractSampler = SphereSampler() improved: bool = True - def init(self, key: PRNGKeyArray, operator: AbstractLinearOperator) -> _PSDTraceState: + def init(self, key: PRNGKeyArray, operator: lx.AbstractLinearOperator) -> _PSDTraceState: n = _check_operator(operator) - is_nsd = is_negative_semidefinite(operator) - if not (is_positive_semidefinite(operator) | is_nsd): + is_nsd = lx.is_negative_semidefinite(operator) + if not (lx.is_positive_semidefinite(operator) | is_nsd): raise ValueError("`XNysTraceEstimator` may only be used for positive or negative definite linear operators") if is_nsd: operator = -operator @@ -307,6 +314,10 @@ def estimate(self, state: _PSDTraceState, k: int) -> tuple[PyTree[Array], dict[s return trace_est, {"std.err": std_err} + def transpose(self, state: _PSDTraceState) -> _PSDTraceState: + key, operator, n, is_nsd = state + return key, operator.transpose(), n, is_nsd + XNysTraceEstimator.__init__.__doc__ = r"""**Arguments:** @@ -351,17 +362,20 @@ def _estimate_trace_jvp(primals, tangents): # t_operator := V t_key, t_operator, t_state, t_k, t_estimator = tangents jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator)) + del t_key, t_state, t_k, t_estimator # primal problem of t = tr(A) result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator) out = result, stats - # inner prodct in matrix space => = tr(A @ B) + # inner prodct in linear operator space => = tr(A @ B) # d tr(A) / dA = I # t' = = tr(I @ V) = tr(V) # tangent problem => tr(V) - t_state = estimator.init(key, t_operator) - t_result, t_stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, t_operator, t_state, k, estimator) + # TODO: should we reuse key or split? both seem confusing options + key, t_key = rdm.split(key) + t_state = estimator.init(t_key, t_operator) + t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) t_out = ( t_result, @@ -371,37 +385,84 @@ def _estimate_trace_jvp(primals, tangents): return out, t_out +def _is_undefined(x): + return isinstance(x, ad.UndefinedPrimal) + + +def _assert_defined(x): + assert not _is_undefined(x) + + +def _remove_undefined_primal(x): + if _is_undefined(x): + return x.aval + else: + return + + +def _build_diagonal(ct_result: float, op: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + if isinstance(op, lx.MatrixLinearOperator): + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = ct_result * jnp.ones(in_size) + return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) + elif isinstance(op, lx.DiagonalLinearOperator): + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = ct_result * jnp.ones(in_size) + return lx.DiagonalLinearOperator(diag) + elif isinstance(op, lx.MulLinearOperator): + inner_op = _build_diagonal(ct_result, op.operator) + scalar = op.scalar + return scalar * inner_op # type: ignore + else: + raise ValueError("Unsupported type!") + + @eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore def _estimate_trace_transpose(inputs, cts_out): - key, operator, state, k, estimator = inputs - cts_result, cts_stats = cts_out - # jtu.tree_map(_assert_defined, (key, operator, state, k, estimator), is_leaf=_is_undefined) - # op_t = cts_result * IdentityLinearOperator(operator.in_structure()) - op_t = cts_result * operator.transpose() - state_none = jtu.tree_map(lambda _: None, state) + # the jacobian, for the trace is just the identity matrix, i.e. J = I + # so J'v = I v = v + + # primal inputs; operator should have UndefinedPrimal leaves + key, operator, state, _, estimator = inputs + + # co-tangent of the trace approximation and the stats (None) + cts_result, _ = cts_out + + # the internals of the operator are UndefinedPrimal leaves so + # we need to rely on abstract values to pull structure info + op_t = _build_diagonal(cts_result, operator) + + key_none = jtu.tree_map(lambda _: None, key) + # state_none = jtu.tree_map(lambda _: None, state) + state_none = (None, op_t, None) k_none = None estimator_none = jtu.tree_map(lambda _: None, estimator) - return op_t, state_none, k_none, estimator_none + return key_none, op_t, state_none, k_none, estimator_none -_estimate_trace_p = eqxi.create_vprim( - "trace", - eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=False)), +_noclosure_check_impl = (eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=False)),) +_estimate_trace_p = jax.core.Primitive("trace") # type: ignore +_estimate_trace_p.multiple_results = True +_estimate_trace_p.def_impl(_noclosure_check_impl) +_estimate_trace_p.def_abstract_eval( _estimate_trace_abstract_eval, - _estimate_trace_jvp, - _estimate_trace_transpose, ) +ad.primitive_jvps[_estimate_trace_p] = _estimate_trace_jvp +ad.primitive_transposes[_estimate_trace_p] = _estimate_trace_transpose +mlir.register_lowering(_estimate_trace_p, mlir.lower_fun(_noclosure_check_impl, multiple_results=True)) # type: ignore + +# rebind here to allow closure checks _estimate_trace_p.def_impl( eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=True)), ) -eqxi.register_impl_finalisation(_estimate_trace_p) # @eqx.filter_jit def trace( key: PRNGKeyArray, - operator: AbstractLinearOperator, + operator: lx.AbstractLinearOperator, k: int, estimator: AbstractEstimator = XTraceEstimator(), *, @@ -425,14 +486,14 @@ def trace( ) # if identity op, then just shortcircuit and return dimension size - if isinstance(operator, IdentityLinearOperator): + if isinstance(operator, lx.IdentityLinearOperator): return Solution( - value=float(in_size), + value=jnp.asarray(in_size, dtype=float), stats={}, state=state, ) # if diagonal op, then just shortcircuit and sum diagonal - if isinstance(operator, DiagonalLinearOperator): + if isinstance(operator, lx.DiagonalLinearOperator): return Solution( value=jnp.sum(operator.diagonal), stats={}, @@ -441,8 +502,9 @@ def trace( # set up state if necessary if state == sentinel: - key, s_key = rdm.split(key) - state = estimator.init(s_key, operator) + state = estimator.init(key, operator) + # we don't want to allow differntiate through trace-alg state, which likely contains the operator + # or by-products of the operator dynamic_state, static_state = eqx.partition(state, eqx.is_array) dynamic_state = lax.stop_gradient(dynamic_state) state = eqx.combine(dynamic_state, static_state) From 06a57defa73e79431fe14979a745418f01404aa9 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Mon, 23 Sep 2024 09:52:34 -0700 Subject: [PATCH 05/21] test: tridiagonal linear operator --- tests/test_trace.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_trace.py b/tests/test_trace.py index eaabe19..d180a36 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -15,6 +15,7 @@ import pytest import jax.numpy as jnp +import jax.random as rdm import lineax as lx import traceax as tx @@ -94,3 +95,24 @@ def test_nsd_psd_matrix_linop(getkey, estimator, k, tags, size, dtype): assert result is not None assert result.value is not None assert jnp.isfinite(result.value) + + +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) +@pytest.mark.parametrize("k", (5, 10, 50)) +@pytest.mark.parametrize("size", (5, 50, 500)) +@pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) +def test_tridiagonal_linop(getkey, estimator, k, size, dtype): + k = min(k, size) + matrix = construct_matrix(getkey, lx.tridiagonal_tag, size, dtype) + main_diag = jnp.diag(matrix) + lower_diag = jnp.diag(matrix, k=-1) + upper_diag = jnp.diag(matrix, k=1) + operator = lx.TridiagonalLinearOperator(main_diag, lower_diag, upper_diag) + result = tx.trace(getkey(), operator, k, estimator) + + assert result is not None + assert result.value is not None + assert jnp.isfinite(result.value) From cba2d0d4c09a3a91f1652f3fd4f82dbe10cad09e Mon Sep 17 00:00:00 2001 From: nahid18 Date: Mon, 23 Sep 2024 09:58:48 -0700 Subject: [PATCH 06/21] chore: remove import --- tests/test_trace.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_trace.py b/tests/test_trace.py index d180a36..fff78ee 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -15,7 +15,6 @@ import pytest import jax.numpy as jnp -import jax.random as rdm import lineax as lx import traceax as tx From a691c30705e6085d12a92e829c64c10c07e61b71 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Mon, 30 Sep 2024 13:51:16 -0700 Subject: [PATCH 07/21] test: identity linear operator --- tests/test_trace.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/tests/test_trace.py b/tests/test_trace.py index fff78ee..d0c2339 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -14,6 +14,7 @@ import pytest +import jax import jax.numpy as jnp import lineax as lx @@ -115,3 +116,21 @@ def test_tridiagonal_linop(getkey, estimator, k, size, dtype): assert result is not None assert result.value is not None assert jnp.isfinite(result.value) + + +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) +@pytest.mark.parametrize("k", (5, 10, 50)) +@pytest.mark.parametrize("size", (5, 50, 500)) +@pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) +def test_identity_linop(getkey, estimator, k, size, dtype): + k = min(k, size) + input_structure = jax.ShapeDtypeStruct((size,), dtype) + operator = lx.IdentityLinearOperator(input_structure) + result = tx.trace(getkey(), operator, k, estimator) + + assert result is not None + assert result.value is not None + assert jnp.isfinite(result.value) From d964667faf04d2f093bbf12657d4777619575f44 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Mon, 30 Sep 2024 15:42:23 -0700 Subject: [PATCH 08/21] test: tagged linear operator --- tests/test_trace.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/tests/test_trace.py b/tests/test_trace.py index d0c2339..14dfd5f 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -134,3 +134,35 @@ def test_identity_linop(getkey, estimator, k, size, dtype): assert result is not None assert result.value is not None assert jnp.isfinite(result.value) + + +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) +@pytest.mark.parametrize("k", (5, 10, 50)) +@pytest.mark.parametrize( + "tags", + ( + lx.diagonal_tag, + lx.symmetric_tag, + lx.lower_triangular_tag, + lx.upper_triangular_tag, + lx.tridiagonal_tag, + lx.unit_diagonal_tag, + lx.positive_semidefinite_tag, + lx.negative_semidefinite_tag, + ), +) +@pytest.mark.parametrize("size", (5, 50, 500)) +@pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) +def test_tagged_linear_operator(getkey, estimator, k, tags, size, dtype): + k = min(k, size) + matrix = construct_matrix(getkey, tags, size, dtype) + operator = lx.MatrixLinearOperator(matrix, tags=tags) + tagged_operator = lx.TaggedLinearOperator(operator, tags=tags) + result = tx.trace(getkey(), tagged_operator, k, estimator) + + assert result is not None + assert result.value is not None + assert jnp.isfinite(result.value) \ No newline at end of file From 66a36d3810bbd7c39cdbce3b962a1b3abdc541ff Mon Sep 17 00:00:00 2001 From: nahid18 Date: Mon, 30 Sep 2024 15:49:01 -0700 Subject: [PATCH 09/21] chore: added trailing space --- tests/test_trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_trace.py b/tests/test_trace.py index 14dfd5f..3adb1cb 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -165,4 +165,4 @@ def test_tagged_linear_operator(getkey, estimator, k, tags, size, dtype): assert result is not None assert result.value is not None - assert jnp.isfinite(result.value) \ No newline at end of file + assert jnp.isfinite(result.value) From 5775168b78887d47ceb79b8ef16c6bf1bc33b7b7 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Wed, 2 Oct 2024 13:29:11 -0700 Subject: [PATCH 10/21] test: compound operators --- tests/test_trace.py | 75 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 72 insertions(+), 3 deletions(-) diff --git a/tests/test_trace.py b/tests/test_trace.py index 3adb1cb..1c81656 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -112,7 +112,7 @@ def test_tridiagonal_linop(getkey, estimator, k, size, dtype): upper_diag = jnp.diag(matrix, k=1) operator = lx.TridiagonalLinearOperator(main_diag, lower_diag, upper_diag) result = tx.trace(getkey(), operator, k, estimator) - + assert result is not None assert result.value is not None assert jnp.isfinite(result.value) @@ -130,7 +130,7 @@ def test_identity_linop(getkey, estimator, k, size, dtype): input_structure = jax.ShapeDtypeStruct((size,), dtype) operator = lx.IdentityLinearOperator(input_structure) result = tx.trace(getkey(), operator, k, estimator) - + assert result is not None assert result.value is not None assert jnp.isfinite(result.value) @@ -156,7 +156,7 @@ def test_identity_linop(getkey, estimator, k, size, dtype): ) @pytest.mark.parametrize("size", (5, 50, 500)) @pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) -def test_tagged_linear_operator(getkey, estimator, k, tags, size, dtype): +def test_tagged_linop(getkey, estimator, k, tags, size, dtype): k = min(k, size) matrix = construct_matrix(getkey, tags, size, dtype) operator = lx.MatrixLinearOperator(matrix, tags=tags) @@ -166,3 +166,72 @@ def test_tagged_linear_operator(getkey, estimator, k, tags, size, dtype): assert result is not None assert result.value is not None assert jnp.isfinite(result.value) + + +@pytest.mark.parametrize( + "estimator", + (tx.HutchinsonEstimator(), tx.HutchPlusPlusEstimator(), tx.XTraceEstimator()), +) +@pytest.mark.parametrize( + "tags", + ( + None, + lx.diagonal_tag, + lx.symmetric_tag, + lx.lower_triangular_tag, + lx.upper_triangular_tag, + lx.tridiagonal_tag, + lx.unit_diagonal_tag, + lx.positive_semidefinite_tag, + lx.negative_semidefinite_tag, + ), +) +@pytest.mark.parametrize("k", (5, 10, 50)) +@pytest.mark.parametrize("size", (5, 50, 500)) +@pytest.mark.parametrize("dtype", (jnp.float32, jnp.float64)) +def test_compound_op(getkey, estimator, k, tags, size, dtype): + k = min(k, size) + matrix_a = construct_matrix(getkey, tags, size, dtype) + matrix_b = construct_matrix(getkey, tags, size, dtype) + op_a = lx.MatrixLinearOperator(matrix_a, tags=tags) + op_b = lx.MatrixLinearOperator(matrix_b, tags=tags) + + """AddLinearOperator""" + add_op = op_a + op_b + add_result = tx.trace(getkey(), add_op, k, estimator) + + assert add_result is not None + assert add_result.value is not None + assert jnp.isfinite(add_result.value) + + """ComposedLinearOperator""" + composed_op = op_a @ op_b + composed_result = tx.trace(getkey(), composed_op, k, estimator) + + assert composed_result is not None + assert composed_result.value is not None + assert jnp.isfinite(composed_result.value) + + """MulLinearOperator""" + scalar = 0.5 # random value + mul_op = scalar * op_a + mul_result = tx.trace(getkey(), mul_op, k, estimator) + + assert mul_result is not None + assert mul_result.value is not None + assert jnp.isfinite(mul_result.value) + + """NegLinearOperator""" + neg_result = tx.trace(getkey(), -op_a, k, estimator) + + assert neg_result is not None + assert neg_result.value is not None + assert jnp.isfinite(neg_result.value) + + """DivLinearOperator""" + denom = 0.5 # random value + div_result = tx.trace(getkey(), op_b/denom, k, estimator) + + assert div_result is not None + assert div_result.value is not None + assert jnp.isfinite(div_result.value) From 2b67f58096908427dff0d14ee8fe2bfca435941b Mon Sep 17 00:00:00 2001 From: nahid18 Date: Wed, 2 Oct 2024 13:39:28 -0700 Subject: [PATCH 11/21] test: tagged operator --- tests/test_trace.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_trace.py b/tests/test_trace.py index 1c81656..92398fc 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -167,6 +167,15 @@ def test_tagged_linop(getkey, estimator, k, tags, size, dtype): assert result.value is not None assert jnp.isfinite(result.value) + sym_operator = operator + operator.T + sym_operator = lx.TaggedLinearOperator(sym_operator, lx.symmetric_tag) + sym_result = tx.trace(getkey(), sym_operator, k, estimator) + + assert lx.is_symmetric(sym_operator) + assert sym_result is not None + assert sym_result.value is not None + assert jnp.isfinite(sym_result.value) + @pytest.mark.parametrize( "estimator", From f71250e71f9e19f1795d6976426f9b08df1be320 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Thu, 3 Oct 2024 19:48:29 -0700 Subject: [PATCH 12/21] refactor build diagonal to a dispatcher --- src/traceax/_trace.py | 71 +++++++++++++++++++++++++++++++++---------- 1 file changed, 55 insertions(+), 16 deletions(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index 8dd2f45..6950b37 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -400,22 +400,60 @@ def _remove_undefined_primal(x): return -def _build_diagonal(ct_result: float, op: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator: +# def _build_diagonal(ct_result: float, op: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator: +# operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) +# if isinstance(op, lx.MatrixLinearOperator): +# in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) +# diag = ct_result * jnp.ones(in_size) +# return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) +# elif isinstance(op, lx.DiagonalLinearOperator): +# in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) +# diag = ct_result * jnp.ones(in_size) +# return lx.DiagonalLinearOperator(diag) +# elif isinstance(op, lx.MulLinearOperator): +# inner_op = _build_diagonal(ct_result, op.operator) +# scalar = op.scalar +# return scalar * inner_op # type: ignore +# else: +# raise ValueError("Unsupported type!") + + +# replaces _build_diagonal +@ft.singledispatch +def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + raise ValueError("Unsupported type!") + + +@_make_identity.register +def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) - if isinstance(op, lx.MatrixLinearOperator): - in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) - diag = ct_result * jnp.ones(in_size) - return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) - elif isinstance(op, lx.DiagonalLinearOperator): - in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) - diag = ct_result * jnp.ones(in_size) - return lx.DiagonalLinearOperator(diag) - elif isinstance(op, lx.MulLinearOperator): - inner_op = _build_diagonal(ct_result, op.operator) - scalar = op.scalar - return scalar * inner_op # type: ignore - else: - raise ValueError("Unsupported type!") + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = ct_result * jnp.ones(in_size) + return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) + + +@_make_identity.register +def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = ct_result * jnp.ones(in_size) + return lx.DiagonalLinearOperator(diag) + + +@_make_identity.register +def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + scalar = op.scalar + return scalar * inner_op # type: ignore + + +@_make_identity.register +def _(op: lx.TridiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) + diag = ct_result * jnp.ones(in_size) + off_diag = jnp.zeros(in_size - 1) + return lx.TridiagonalLinearOperator(diag, off_diag, off_diag) @eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore @@ -431,7 +469,8 @@ def _estimate_trace_transpose(inputs, cts_out): # the internals of the operator are UndefinedPrimal leaves so # we need to rely on abstract values to pull structure info - op_t = _build_diagonal(cts_result, operator) + # op_t = _build_diagonal(cts_result, operator) + op_t = _make_identity(operator, cts_result) key_none = jtu.tree_map(lambda _: None, key) # state_none = jtu.tree_map(lambda _: None, state) From 4ac0977eaf599875309e6381c806827d671312d2 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Sat, 5 Oct 2024 00:41:46 -0700 Subject: [PATCH 13/21] feat: add, neg, div, composed --- src/traceax/_trace.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index 6950b37..2b73b2d 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -444,7 +444,7 @@ def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOpera def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: inner_op = _make_identity(op.operator, ct_result) scalar = op.scalar - return scalar * inner_op # type: ignore + return lx.MulLinearOperator(inner_op, scalar) @_make_identity.register @@ -456,6 +456,33 @@ def _(op: lx.TridiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOp return lx.TridiagonalLinearOperator(diag, off_diag, off_diag) +@_make_identity.register +def _(op: lx.AddLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op1 = _make_identity(op.operator1, ct_result) + inner_op2 = _make_identity(op.operator2, ct_result) + return lx.AddLinearOperator(inner_op1, inner_op2) + + +@_make_identity.register +def _(op: lx.NegLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + return lx.NegLinearOperator(inner_op) + + +@_make_identity.register +def _(op: lx.DivLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + scalar = op.scalar + return lx.DivLinearOperator(inner_op, scalar) + + +@_make_identity.register +def _(op: lx.ComposedLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op1 = _make_identity(op.operator1, ct_result) + inner_op2 = _make_identity(op.operator2, ct_result) + return lx.ComposedLinearOperator(inner_op1, inner_op2) + + @eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore def _estimate_trace_transpose(inputs, cts_out): # the jacobian, for the trace is just the identity matrix, i.e. J = I From 445a1353a1b3a516bb974ec143139f450c77138d Mon Sep 17 00:00:00 2001 From: quattro Date: Tue, 8 Oct 2024 14:19:23 -0700 Subject: [PATCH 14/21] switch back to eqxi for primitive construction. simplified few ident ops --- src/traceax/_trace.py | 48 ++++++++++--------------------------------- 1 file changed, 11 insertions(+), 37 deletions(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index 2b73b2d..71cb14e 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -18,7 +18,6 @@ import equinox as eqx import equinox.internal as eqxi -import jax import jax.lax as lax import jax.numpy as jnp import jax.random as rdm @@ -376,7 +375,6 @@ def _estimate_trace_jvp(primals, tangents): key, t_key = rdm.split(key) t_state = estimator.init(t_key, t_operator) t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) - t_out = ( t_result, jtu.tree_map(lambda _: None, stats), @@ -400,25 +398,6 @@ def _remove_undefined_primal(x): return -# def _build_diagonal(ct_result: float, op: lx.AbstractLinearOperator) -> lx.AbstractLinearOperator: -# operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) -# if isinstance(op, lx.MatrixLinearOperator): -# in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) -# diag = ct_result * jnp.ones(in_size) -# return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) -# elif isinstance(op, lx.DiagonalLinearOperator): -# in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) -# diag = ct_result * jnp.ones(in_size) -# return lx.DiagonalLinearOperator(diag) -# elif isinstance(op, lx.MulLinearOperator): -# inner_op = _build_diagonal(ct_result, op.operator) -# scalar = op.scalar -# return scalar * inner_op # type: ignore -# else: -# raise ValueError("Unsupported type!") - - -# replaces _build_diagonal @ft.singledispatch def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: raise ValueError("Unsupported type!") @@ -428,7 +407,7 @@ def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.Abstra def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) - diag = ct_result * jnp.ones(in_size) + diag = jnp.full(in_size, ct_result) return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) @@ -436,14 +415,14 @@ def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperato def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) - diag = ct_result * jnp.ones(in_size) + diag = jnp.full(in_size, ct_result) return lx.DiagonalLinearOperator(diag) @_make_identity.register def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: inner_op = _make_identity(op.operator, ct_result) - scalar = op.scalar + scalar = jnp.array(1.0) return lx.MulLinearOperator(inner_op, scalar) @@ -451,7 +430,7 @@ def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: def _(op: lx.TridiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) - diag = ct_result * jnp.ones(in_size) + diag = jnp.full(in_size, ct_result) off_diag = jnp.zeros(in_size - 1) return lx.TridiagonalLinearOperator(diag, off_diag, off_diag) @@ -496,11 +475,9 @@ def _estimate_trace_transpose(inputs, cts_out): # the internals of the operator are UndefinedPrimal leaves so # we need to rely on abstract values to pull structure info - # op_t = _build_diagonal(cts_result, operator) op_t = _make_identity(operator, cts_result) key_none = jtu.tree_map(lambda _: None, key) - # state_none = jtu.tree_map(lambda _: None, state) state_none = (None, op_t, None) k_none = None estimator_none = jtu.tree_map(lambda _: None, estimator) @@ -508,24 +485,21 @@ def _estimate_trace_transpose(inputs, cts_out): return key_none, op_t, state_none, k_none, estimator_none -_noclosure_check_impl = (eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=False)),) -_estimate_trace_p = jax.core.Primitive("trace") # type: ignore -_estimate_trace_p.multiple_results = True -_estimate_trace_p.def_impl(_noclosure_check_impl) -_estimate_trace_p.def_abstract_eval( +_estimate_trace_p = eqxi.create_vprim( + "trace", + eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=False)), _estimate_trace_abstract_eval, + _estimate_trace_jvp, + _estimate_trace_transpose, ) -ad.primitive_jvps[_estimate_trace_p] = _estimate_trace_jvp -ad.primitive_transposes[_estimate_trace_p] = _estimate_trace_transpose -mlir.register_lowering(_estimate_trace_p, mlir.lower_fun(_noclosure_check_impl, multiple_results=True)) # type: ignore - # rebind here to allow closure checks _estimate_trace_p.def_impl( eqxi.filter_primitive_def(ft.partial(_estimate_trace_impl, check_closure=True)), ) +eqxi.register_impl_finalisation(_estimate_trace_p) -# @eqx.filter_jit +@eqx.filter_jit def trace( key: PRNGKeyArray, operator: lx.AbstractLinearOperator, From a391d9b2fc5f5614fdf97935a7d49001d84e9399 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Wed, 9 Oct 2024 15:51:24 -0700 Subject: [PATCH 15/21] comment out filter jit --- src/traceax/_trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index 71cb14e..46be46b 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -499,7 +499,7 @@ def _estimate_trace_transpose(inputs, cts_out): eqxi.register_impl_finalisation(_estimate_trace_p) -@eqx.filter_jit +# @eqx.filter_jit def trace( key: PRNGKeyArray, operator: lx.AbstractLinearOperator, From 8a8fdacc01d212730698ec808cb59a674205b283 Mon Sep 17 00:00:00 2001 From: quattro Date: Wed, 9 Oct 2024 16:52:22 -0700 Subject: [PATCH 16/21] fixed dtype issues where compiled and run dtypes mismatch --- tests/test_trace.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_trace.py b/tests/test_trace.py index 92398fc..d2d6e21 100644 --- a/tests/test_trace.py +++ b/tests/test_trace.py @@ -222,9 +222,9 @@ def test_compound_op(getkey, estimator, k, tags, size, dtype): assert jnp.isfinite(composed_result.value) """MulLinearOperator""" - scalar = 0.5 # random value + scalar = jnp.asarray(0.5, dtype=dtype) # random value; make sure precision matches specified! mul_op = scalar * op_a - mul_result = tx.trace(getkey(), mul_op, k, estimator) + mul_result = tx.trace(getkey(), mul_op, k, estimator) # pyright: ignore assert mul_result is not None assert mul_result.value is not None @@ -238,8 +238,8 @@ def test_compound_op(getkey, estimator, k, tags, size, dtype): assert jnp.isfinite(neg_result.value) """DivLinearOperator""" - denom = 0.5 # random value - div_result = tx.trace(getkey(), op_b/denom, k, estimator) + denom = jnp.asarray(0.5, dtype=dtype) # random value; make sure precision matches specified! + div_result = tx.trace(getkey(), op_b / denom, k, estimator) assert div_result is not None assert div_result.value is not None From e65a14aa2768bdd36587beb3fe857d2f028634a4 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Thu, 10 Oct 2024 18:45:03 -0700 Subject: [PATCH 17/21] tangent linear operator and materialise zeros --- src/traceax/_trace.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index 46be46b..6390d6c 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -18,6 +18,7 @@ import equinox as eqx import equinox.internal as eqxi +import jax.debug as jdg import jax.lax as lax import jax.numpy as jnp import jax.random as rdm @@ -355,6 +356,10 @@ def _estimate_trace_abstract_eval(key, operator, state, k, estimator): return out +def _is_none(x): + return x is None + + @eqxi.filter_primitive_jvp def _estimate_trace_jvp(primals, tangents): key, operator, state, k, estimator = primals @@ -373,6 +378,10 @@ def _estimate_trace_jvp(primals, tangents): # tangent problem => tr(V) # TODO: should we reuse key or split? both seem confusing options key, t_key = rdm.split(key) + if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)): + t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) + t_operator = lx.TangentLinearOperator(operator, t_operator) + t_state = estimator.init(t_key, t_operator) t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) t_out = ( @@ -395,7 +404,7 @@ def _remove_undefined_primal(x): if _is_undefined(x): return x.aval else: - return + return x @ft.singledispatch @@ -403,6 +412,13 @@ def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.Abstra raise ValueError("Unsupported type!") +@_make_identity.register +def _(op: lx.TangentLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + p_op = op.primal + t_op = op.tangent + return lx.TangentLinearOperator(p_op, _make_identity(t_op, ct_result)) + + @_make_identity.register def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) @@ -499,7 +515,7 @@ def _estimate_trace_transpose(inputs, cts_out): eqxi.register_impl_finalisation(_estimate_trace_p) -# @eqx.filter_jit +@eqx.filter_jit def trace( key: PRNGKeyArray, operator: lx.AbstractLinearOperator, From 803a2d97e105c4d2c0db307d2f56960ee369cb13 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Thu, 10 Oct 2024 18:46:54 -0700 Subject: [PATCH 18/21] chore: remove import --- src/traceax/_trace.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index 6390d6c..f6296aa 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -18,7 +18,6 @@ import equinox as eqx import equinox.internal as eqxi -import jax.debug as jdg import jax.lax as lax import jax.numpy as jnp import jax.random as rdm From dc5890b01a29528c8bb6ff52b60772c5459a2d83 Mon Sep 17 00:00:00 2001 From: nahid18 Date: Mon, 14 Oct 2024 15:05:47 -0700 Subject: [PATCH 19/21] chore: organize --- src/traceax/_trace.py | 86 ++++++++++++++++++------------------------- src/traceax/_utils.py | 7 ++++ 2 files changed, 42 insertions(+), 51 deletions(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index f6296aa..ca6fe9a 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -36,6 +36,9 @@ _assert_false, _check_operator, _clip_k, + _is_none, + _is_undefined, + _remove_undefined_primal, _to_shapedarray, _to_struct, _vmap_mv, @@ -355,57 +358,6 @@ def _estimate_trace_abstract_eval(key, operator, state, k, estimator): return out -def _is_none(x): - return x is None - - -@eqxi.filter_primitive_jvp -def _estimate_trace_jvp(primals, tangents): - key, operator, state, k, estimator = primals - # t_operator := V - t_key, t_operator, t_state, t_k, t_estimator = tangents - jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator)) - del t_key, t_state, t_k, t_estimator - - # primal problem of t = tr(A) - result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator) - out = result, stats - - # inner prodct in linear operator space => = tr(A @ B) - # d tr(A) / dA = I - # t' = = tr(I @ V) = tr(V) - # tangent problem => tr(V) - # TODO: should we reuse key or split? both seem confusing options - key, t_key = rdm.split(key) - if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)): - t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) - t_operator = lx.TangentLinearOperator(operator, t_operator) - - t_state = estimator.init(t_key, t_operator) - t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) - t_out = ( - t_result, - jtu.tree_map(lambda _: None, stats), - ) - - return out, t_out - - -def _is_undefined(x): - return isinstance(x, ad.UndefinedPrimal) - - -def _assert_defined(x): - assert not _is_undefined(x) - - -def _remove_undefined_primal(x): - if _is_undefined(x): - return x.aval - else: - return x - - @ft.singledispatch def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: raise ValueError("Unsupported type!") @@ -477,6 +429,38 @@ def _(op: lx.ComposedLinearOperator, ct_result: float) -> lx.AbstractLinearOpera return lx.ComposedLinearOperator(inner_op1, inner_op2) +@eqxi.filter_primitive_jvp +def _estimate_trace_jvp(primals, tangents): + key, operator, state, k, estimator = primals + # t_operator := V + t_key, t_operator, t_state, t_k, t_estimator = tangents + jtu.tree_map(_assert_false, (t_key, t_state, t_k, t_estimator)) + del t_key, t_state, t_k, t_estimator + + # primal problem of t = tr(A) + result, stats = eqxi.filter_primitive_bind(_estimate_trace_p, key, operator, state, k, estimator) + out = result, stats + + # inner prodct in linear operator space => = tr(A @ B) + # d tr(A) / dA = I + # t' = = tr(I @ V) = tr(V) + # tangent problem => tr(V) + # TODO: should we reuse key or split? both seem confusing options + key, t_key = rdm.split(key) + if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)): + t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) + t_operator = lx.TangentLinearOperator(operator, t_operator) + + t_state = estimator.init(t_key, t_operator) + t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) + t_out = ( + t_result, + jtu.tree_map(lambda _: None, stats), + ) + + return out, t_out + + @eqxi.filter_primitive_transpose(materialise_zeros=True) # pyright: ignore def _estimate_trace_transpose(inputs, cts_out): # the jacobian, for the trace is just the identity matrix, i.e. J = I diff --git a/src/traceax/_utils.py b/src/traceax/_utils.py index ffae138..b5e50b3 100644 --- a/src/traceax/_utils.py +++ b/src/traceax/_utils.py @@ -81,3 +81,10 @@ def _keep_undefined(v, ct): return ct else: return None + + +def _remove_undefined_primal(x): + if _is_undefined(x): + return x.aval + else: + return x \ No newline at end of file From 7f32881997925f8814cccfa5a95a2d5e662aac2e Mon Sep 17 00:00:00 2001 From: nahid18 Date: Thu, 14 Nov 2024 13:22:56 -0800 Subject: [PATCH 20/21] chore: rearrange code --- src/traceax/_trace.py | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index ca6fe9a..a396500 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -363,6 +363,24 @@ def _make_identity(op: lx.AbstractLinearOperator, ct_result: float) -> lx.Abstra raise ValueError("Unsupported type!") +@_make_identity.register +def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) + in_size, out_size = eqx.filter_eval_shape(lambda o: (o.in_size(), o.out_size()), operator_struct) + if in_size != out_size: + raise ValueError("`_make_identity` only supports square matrices.") + diag = jnp.full(in_size, ct_result) + out = lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) + return out + + +@_make_identity.register +def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: + inner_op = _make_identity(op.operator, ct_result) + scalar = jnp.array(1.0) + return lx.MulLinearOperator(inner_op, scalar*ct_result) + + @_make_identity.register def _(op: lx.TangentLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: p_op = op.primal @@ -370,14 +388,6 @@ def _(op: lx.TangentLinearOperator, ct_result: float) -> lx.AbstractLinearOperat return lx.TangentLinearOperator(p_op, _make_identity(t_op, ct_result)) -@_make_identity.register -def _(op: lx.MatrixLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: - operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) - in_size = eqx.filter_eval_shape(lambda o: o.in_size(), operator_struct) - diag = jnp.full(in_size, ct_result) - return lx.MatrixLinearOperator(jnp.diag(diag), tags=operator_struct.tags) - - @_make_identity.register def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) @@ -386,13 +396,6 @@ def _(op: lx.DiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOpera return lx.DiagonalLinearOperator(diag) -@_make_identity.register -def _(op: lx.MulLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: - inner_op = _make_identity(op.operator, ct_result) - scalar = jnp.array(1.0) - return lx.MulLinearOperator(inner_op, scalar) - - @_make_identity.register def _(op: lx.TridiagonalLinearOperator, ct_result: float) -> lx.AbstractLinearOperator: operator_struct = jtu.tree_map(_remove_undefined_primal, op, is_leaf=_is_undefined) @@ -447,12 +450,13 @@ def _estimate_trace_jvp(primals, tangents): # tangent problem => tr(V) # TODO: should we reuse key or split? both seem confusing options key, t_key = rdm.split(key) - if any(t is not None for t in jtu.tree_leaves(t_operator, is_leaf=_is_none)): - t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) - t_operator = lx.TangentLinearOperator(operator, t_operator) + + t_operator = jtu.tree_map(eqxi.materialise_zeros, operator, t_operator, is_leaf=_is_none) + t_operator = lx.TangentLinearOperator(operator, t_operator) t_state = estimator.init(t_key, t_operator) t_result, _ = eqxi.filter_primitive_bind(_estimate_trace_p, t_key, t_operator, t_state, k, estimator) + # t_result = jnp.trace(t_operator.as_matrix()) t_out = ( t_result, jtu.tree_map(lambda _: None, stats), From ee1b3d7c1e467a23891de9d389a4c6bbcfabe8da Mon Sep 17 00:00:00 2001 From: nahid18 Date: Thu, 13 Mar 2025 02:58:35 -0700 Subject: [PATCH 21/21] fix: scale bug in XTrace --- src/traceax/_trace.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/traceax/_trace.py b/src/traceax/_trace.py index a396500..7e5688e 100644 --- a/src/traceax/_trace.py +++ b/src/traceax/_trace.py @@ -219,7 +219,7 @@ def estimate(self, state: _BasicTraceState, k: int) -> tuple[PyTree[Array], dict term3 = jnp.conjugate(SW_d) * jnp.sum(S * (R - HW), axis=0) if self.improved: - scale = _get_scale(W, SW_d, n, k) + scale = _get_scale(W, SW_d, n, m) else: scale = 1