Skip to content

Commit e301cd8

Browse files
committed
Fix calculation of fixed point gradient for linear and eigensolver if complex tensors
1 parent d11b2d7 commit e301cd8

File tree

1 file changed

+116
-65
lines changed

1 file changed

+116
-65
lines changed

varipeps/ctmrg/routine.py

Lines changed: 116 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -951,51 +951,6 @@ def _ctmrg_rev_while_body(carry):
951951
return vjp_env, initial_bar, bar_fixed_point, converged, count, config, state
952952

953953

954-
def _ctmrg_rev_arnoldi(vjp_operator, initial_v):
955-
v, v_treedev = jax.tree.flatten(initial_v)
956-
v_flat = np.concatenate([i.reshape(-1) for i in v])
957-
958-
def matvec(vec):
959-
new_vec = [None] * len(v)
960-
for i in range(len(v)):
961-
i_start = sum(j.size for j in v[:i])
962-
elems = slice(i_start, i_start + v[i].size)
963-
new_vec[i] = vec[elems].astype(v[i].dtype).reshape(v[i].shape)
964-
new_vec = jax.tree.unflatten(v_treedev, new_vec)
965-
966-
new_vec = vjp_operator((new_vec, jnp.array(0, dtype=jnp.float64)))[0]
967-
968-
new_vec, _ = jax.tree.flatten(new_vec)
969-
new_vec = np.concatenate([i.reshape(-1) for i in new_vec])
970-
971-
return np.append(new_vec + vec[-1] * v_flat, vec[-1])
972-
973-
lin_op = LinearOperator(
974-
(v_flat.shape[0] + 1, v_flat.shape[0] + 1),
975-
matvec=matvec,
976-
)
977-
978-
_, vec = eigs(
979-
lin_op, k=1, v0=np.append(v_flat, np.array(1, dtype=v_flat.dtype)), which="LM"
980-
)
981-
982-
vec = vec.reshape(-1)
983-
984-
if np.abs(vec[-1]) >= 1e-10:
985-
vec /= vec[-1]
986-
987-
result = [None] * len(v)
988-
for i in range(len(v)):
989-
i_start = sum(j.size for j in v[:i])
990-
elems = slice(i_start, i_start + v[i].size)
991-
result[i] = vec[elems].astype(v[i].dtype).reshape(v[i].shape)
992-
993-
if np.abs(vec[-1]) < 1e-10:
994-
return jax.tree.unflatten(v_treedev, result), False
995-
996-
return jax.tree.unflatten(v_treedev, result), True
997-
998-
999954
@jit
1000955
def _ctmrg_rev_workhorse(peps_tensors, new_unitcell, new_unitcell_bar, config, state):
1001956
if new_unitcell.is_triangular_peps():
@@ -1049,34 +1004,110 @@ def cond_func(carry):
10491004
(vjp_env, new_unitcell_bar, new_unitcell_bar, False, 0, config, state),
10501005
)
10511006
else:
1007+
real = jax.dtypes.result_type(
1008+
*jax.tree.leaves(new_unitcell_bar)
1009+
) == jax.dtypes.canonicalize_dtype(jnp.float64)
10521010
if config.ad_custom_fixed_point_method is Grad_Fixed_Point_Method.EIGEN_SOLVER:
1011+
10531012
def f_arnoldi(x):
1054-
w = vjp_env((x[0], jnp.array(0, dtype=jnp.float64)))[0]
1013+
w = x[0]
1014+
if not real:
1015+
w = jax.tree.map(lambda x, y: x + 1j * y, w[0], w[1])
1016+
1017+
w = vjp_env((w, jnp.array(0, dtype=jnp.float64)))[0]
10551018
w = jax.tree.map(lambda v1, v2: v1 + x[1] * v2, w, new_unitcell_bar)
1056-
return (w, x[1])
10571019

1058-
eigval, eigvec = jsp.sparse.linalg.eigs(
1059-
f_arnoldi, 1, (new_unitcell_bar, 1.0)
1060-
)
1020+
if not real:
1021+
w = (
1022+
jax.tree.map(lambda x: jnp.real(x), w),
1023+
jax.tree.map(lambda x: jnp.imag(x), w),
1024+
)
10611025

1062-
print_debug("Eigval: {}", eigval)
1026+
return (w, x[1])
10631027

1064-
env_fixed_point = jax.tree.map(lambda v: jnp.real(v[..., 0]), eigvec[0])
1028+
if real:
1029+
eigval, eigvec = jsp.sparse.linalg.eigs(
1030+
f_arnoldi, 1, (new_unitcell_bar, 1.0)
1031+
)
1032+
else:
1033+
eigval, eigvec = jsp.sparse.linalg.eigs(
1034+
f_arnoldi,
1035+
1,
1036+
(
1037+
(
1038+
jax.tree.map(lambda x: jnp.real(x), new_unitcell_bar),
1039+
jax.tree.map(lambda x: jnp.imag(x), new_unitcell_bar),
1040+
),
1041+
1.0,
1042+
),
1043+
)
10651044

1066-
env_fixed_point, arnoldi_worked = cond(
1067-
jnp.real(eigvec[1][0]) >= 1e-10,
1068-
lambda x: (jax.tree.map(lambda v: v / jnp.real(eigvec[1][0]), x), True),
1069-
lambda x: (x, False),
1070-
env_fixed_point,
1045+
converged = cond(
1046+
jnp.logical_and(
1047+
jnp.abs(jnp.real(eigval[0]))
1048+
< (1 + 1e-2 * config.ad_custom_convergence_eps),
1049+
jnp.abs(jnp.imag(eigval[0]))
1050+
< 1e-2 * config.ad_custom_convergence_eps,
1051+
),
1052+
lambda: True,
1053+
lambda: False,
10711054
)
1055+
1056+
if config.ad_custom_verbose_output:
1057+
debug_print(
1058+
"AD: Converged: {}, Eigval: {}, Eigvec[1]: {}",
1059+
converged,
1060+
eigval[0],
1061+
eigvec[1][0],
1062+
)
1063+
1064+
if real:
1065+
env_fixed_point = jax.tree.map(lambda v: jnp.real(v[..., 0]), eigvec[0])
1066+
env_fixed_point, arnoldi_worked = cond(
1067+
jnp.logical_and(
1068+
converged,
1069+
jnp.abs(eigvec[1][0])
1070+
>= 1e-2 * config.ad_custom_convergence_eps,
1071+
),
1072+
lambda x: (
1073+
jax.tree.map(lambda v: v / jnp.real(eigvec[1][0]), x),
1074+
True,
1075+
),
1076+
lambda x: (x, False),
1077+
env_fixed_point,
1078+
)
1079+
else:
1080+
env_fixed_point = jax.tree.map(
1081+
lambda v, w: v[..., 0] + 1j * w[..., 0], eigvec[0][0], eigvec[0][1]
1082+
)
1083+
env_fixed_point, arnoldi_worked = cond(
1084+
jnp.logical_and(
1085+
converged,
1086+
jnp.abs(eigvec[1][0])
1087+
>= 1e-2 * config.ad_custom_convergence_eps,
1088+
),
1089+
lambda x: (
1090+
jax.tree.map(lambda v: v / jnp.real(eigvec[1][0]), x),
1091+
True,
1092+
),
1093+
lambda x: (x, False),
1094+
env_fixed_point,
1095+
)
10721096
else:
10731097
env_fixed_point = new_unitcell_bar
10741098
arnoldi_worked = False
1099+
converged = True
10751100

10761101
end_count = 0
10771102

10781103
def run_gmres(v, e):
1104+
if config.ad_custom_verbose_output:
1105+
debug_print("AD: Computing gradient with GMRES")
1106+
10791107
def f_gmres(w):
1108+
if not real:
1109+
w = jax.tree.map(lambda x, y: x + 1j * y, w[0], w[1])
1110+
10801111
new_w = vjp_env((w, jnp.array(0, dtype=jnp.float64)))[0]
10811112

10821113
new_w = new_w.replace_unique_tensors(
@@ -1090,27 +1121,47 @@ def f_gmres(w):
10901121
]
10911122
)
10921123

1124+
if not real:
1125+
new_w = (
1126+
jax.tree.map(lambda x: jnp.real(x), new_w),
1127+
jax.tree.map(lambda x: jnp.imag(x), new_w),
1128+
)
1129+
10931130
return new_w
10941131

10951132
is_gpu = jax.default_backend() == "gpu"
10961133

1134+
if real:
1135+
v0 = new_unitcell_bar
1136+
else:
1137+
v0 = (
1138+
jax.tree.map(lambda x: jnp.real(x), new_unitcell_bar),
1139+
jax.tree.map(lambda x: jnp.imag(x), new_unitcell_bar),
1140+
)
1141+
10971142
v, e = jax.scipy.sparse.linalg.gmres(
10981143
f_gmres,
1099-
new_unitcell_bar,
1100-
new_unitcell_bar,
1144+
v0,
1145+
v0,
11011146
solve_method="batched" if is_gpu else "incremental",
11021147
atol=config.ad_custom_convergence_eps,
1103-
maxiter=config.ad_custom_max_steps,
1148+
# maxiter=config.ad_custom_max_steps,
11041149
)
11051150

1151+
if not real:
1152+
v = jax.tree.map(lambda x, y: x + 1j * y, v[0], v[1])
1153+
11061154
return v, e
11071155

1108-
env_fixed_point, end_count = jax.lax.cond(
1109-
arnoldi_worked, lambda x, e: (x, e), run_gmres, env_fixed_point, end_count
1156+
env_fixed_point, end_count, converged = jax.lax.cond(
1157+
jnp.logical_and(converged, jnp.logical_not(arnoldi_worked)),
1158+
lambda x, ec, c: (*run_gmres(x, ec), True),
1159+
lambda x, ec, c: (x, ec, c),
1160+
env_fixed_point,
1161+
end_count,
1162+
converged,
11101163
)
11111164

1112-
converged = True
1113-
11141165
(t_bar,) = vjp_peps_tensors((env_fixed_point, jnp.array(0, dtype=jnp.float64)))
11151166

11161167
return t_bar, converged, end_count
@@ -1135,7 +1186,7 @@ def calc_ctmrg_env_rev(
11351186

11361187
varipeps_global_state.ctmrg_effective_truncation_eps = None
11371188

1138-
if end_count == varipeps_config.ad_custom_max_steps and not converged:
1189+
if not converged:
11391190
raise CTMRGGradientNotConvergedError
11401191

11411192
empty_t = [t.zeros_like_self() for t in input_unitcell.get_unique_tensors()]

0 commit comments

Comments
 (0)