Skip to content

Commit d11b2d7

Browse files
committed
Use jax.scipy.sparse.linalg.eigs instead of scipy version
1 parent a188055 commit d11b2d7

File tree

1 file changed

+20
-8
lines changed

1 file changed

+20
-8
lines changed

varipeps/ctmrg/routine.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import numpy as np
55
from scipy.sparse.linalg import LinearOperator, eigs
66

7+
import jax
78
import jax.numpy as jnp
9+
import jax.scipy as jsp
810
from jax import jit, custom_vjp, vjp, tree_util
911
from jax.lax import cond, while_loop
1012
import jax.debug as jdebug
@@ -1048,14 +1050,24 @@ def cond_func(carry):
10481050
)
10491051
else:
10501052
if config.ad_custom_fixed_point_method is Grad_Fixed_Point_Method.EIGEN_SOLVER:
1051-
env_fixed_point, arnoldi_worked = jax.pure_callback(
1052-
_ctmrg_rev_arnoldi,
1053-
jax.eval_shape(lambda x: (x, True), new_unitcell_bar),
1054-
vjp(
1055-
lambda u: do_absorption_step(peps_tensors, u, config, state),
1056-
new_unitcell,
1057-
)[1],
1058-
new_unitcell_bar,
1053+
def f_arnoldi(x):
1054+
w = vjp_env((x[0], jnp.array(0, dtype=jnp.float64)))[0]
1055+
w = jax.tree.map(lambda v1, v2: v1 + x[1] * v2, w, new_unitcell_bar)
1056+
return (w, x[1])
1057+
1058+
eigval, eigvec = jsp.sparse.linalg.eigs(
1059+
f_arnoldi, 1, (new_unitcell_bar, 1.0)
1060+
)
1061+
1062+
print_debug("Eigval: {}", eigval)
1063+
1064+
env_fixed_point = jax.tree.map(lambda v: jnp.real(v[..., 0]), eigvec[0])
1065+
1066+
env_fixed_point, arnoldi_worked = cond(
1067+
jnp.real(eigvec[1][0]) >= 1e-10,
1068+
lambda x: (jax.tree.map(lambda v: v / jnp.real(eigvec[1][0]), x), True),
1069+
lambda x: (x, False),
1070+
env_fixed_point,
10591071
)
10601072
else:
10611073
env_fixed_point = new_unitcell_bar

0 commit comments

Comments
 (0)