|
4 | 4 | import numpy as np |
5 | 5 | from scipy.sparse.linalg import LinearOperator, eigs |
6 | 6 |
|
| 7 | +import jax |
7 | 8 | import jax.numpy as jnp |
| 9 | +import jax.scipy as jsp |
8 | 10 | from jax import jit, custom_vjp, vjp, tree_util |
9 | 11 | from jax.lax import cond, while_loop |
10 | 12 | import jax.debug as jdebug |
@@ -1048,14 +1050,24 @@ def cond_func(carry): |
1048 | 1050 | ) |
1049 | 1051 | else: |
1050 | 1052 | if config.ad_custom_fixed_point_method is Grad_Fixed_Point_Method.EIGEN_SOLVER: |
1051 | | - env_fixed_point, arnoldi_worked = jax.pure_callback( |
1052 | | - _ctmrg_rev_arnoldi, |
1053 | | - jax.eval_shape(lambda x: (x, True), new_unitcell_bar), |
1054 | | - vjp( |
1055 | | - lambda u: do_absorption_step(peps_tensors, u, config, state), |
1056 | | - new_unitcell, |
1057 | | - )[1], |
1058 | | - new_unitcell_bar, |
| 1053 | + def f_arnoldi(x): |
| 1054 | + w = vjp_env((x[0], jnp.array(0, dtype=jnp.float64)))[0] |
| 1055 | + w = jax.tree.map(lambda v1, v2: v1 + x[1] * v2, w, new_unitcell_bar) |
| 1056 | + return (w, x[1]) |
| 1057 | + |
| 1058 | + eigval, eigvec = jsp.sparse.linalg.eigs( |
| 1059 | + f_arnoldi, 1, (new_unitcell_bar, 1.0) |
| 1060 | + ) |
| 1061 | + |
| 1062 | + print_debug("Eigval: {}", eigval) |
| 1063 | + |
| 1064 | + env_fixed_point = jax.tree.map(lambda v: jnp.real(v[..., 0]), eigvec[0]) |
| 1065 | + |
| 1066 | + env_fixed_point, arnoldi_worked = cond( |
| 1067 | + jnp.real(eigvec[1][0]) >= 1e-10, |
| 1068 | + lambda x: (jax.tree.map(lambda v: v / jnp.real(eigvec[1][0]), x), True), |
| 1069 | + lambda x: (x, False), |
| 1070 | + env_fixed_point, |
1059 | 1071 | ) |
1060 | 1072 | else: |
1061 | 1073 | env_fixed_point = new_unitcell_bar |
|
0 commit comments