@@ -951,51 +951,6 @@ def _ctmrg_rev_while_body(carry):
951951 return vjp_env , initial_bar , bar_fixed_point , converged , count , config , state
952952
953953
954- def _ctmrg_rev_arnoldi (vjp_operator , initial_v ):
955- v , v_treedev = jax .tree .flatten (initial_v )
956- v_flat = np .concatenate ([i .reshape (- 1 ) for i in v ])
957-
958- def matvec (vec ):
959- new_vec = [None ] * len (v )
960- for i in range (len (v )):
961- i_start = sum (j .size for j in v [:i ])
962- elems = slice (i_start , i_start + v [i ].size )
963- new_vec [i ] = vec [elems ].astype (v [i ].dtype ).reshape (v [i ].shape )
964- new_vec = jax .tree .unflatten (v_treedev , new_vec )
965-
966- new_vec = vjp_operator ((new_vec , jnp .array (0 , dtype = jnp .float64 )))[0 ]
967-
968- new_vec , _ = jax .tree .flatten (new_vec )
969- new_vec = np .concatenate ([i .reshape (- 1 ) for i in new_vec ])
970-
971- return np .append (new_vec + vec [- 1 ] * v_flat , vec [- 1 ])
972-
973- lin_op = LinearOperator (
974- (v_flat .shape [0 ] + 1 , v_flat .shape [0 ] + 1 ),
975- matvec = matvec ,
976- )
977-
978- _ , vec = eigs (
979- lin_op , k = 1 , v0 = np .append (v_flat , np .array (1 , dtype = v_flat .dtype )), which = "LM"
980- )
981-
982- vec = vec .reshape (- 1 )
983-
984- if np .abs (vec [- 1 ]) >= 1e-10 :
985- vec /= vec [- 1 ]
986-
987- result = [None ] * len (v )
988- for i in range (len (v )):
989- i_start = sum (j .size for j in v [:i ])
990- elems = slice (i_start , i_start + v [i ].size )
991- result [i ] = vec [elems ].astype (v [i ].dtype ).reshape (v [i ].shape )
992-
993- if np .abs (vec [- 1 ]) < 1e-10 :
994- return jax .tree .unflatten (v_treedev , result ), False
995-
996- return jax .tree .unflatten (v_treedev , result ), True
997-
998-
999954@jit
1000955def _ctmrg_rev_workhorse (peps_tensors , new_unitcell , new_unitcell_bar , config , state ):
1001956 if new_unitcell .is_triangular_peps ():
@@ -1049,34 +1004,110 @@ def cond_func(carry):
10491004 (vjp_env , new_unitcell_bar , new_unitcell_bar , False , 0 , config , state ),
10501005 )
10511006 else :
1007+ real = jax .dtypes .result_type (
1008+ * jax .tree .leaves (new_unitcell_bar )
1009+ ) == jax .dtypes .canonicalize_dtype (jnp .float64 )
10521010 if config .ad_custom_fixed_point_method is Grad_Fixed_Point_Method .EIGEN_SOLVER :
1011+
10531012 def f_arnoldi (x ):
1054- w = vjp_env ((x [0 ], jnp .array (0 , dtype = jnp .float64 )))[0 ]
1013+ w = x [0 ]
1014+ if not real :
1015+ w = jax .tree .map (lambda x , y : x + 1j * y , w [0 ], w [1 ])
1016+
1017+ w = vjp_env ((w , jnp .array (0 , dtype = jnp .float64 )))[0 ]
10551018 w = jax .tree .map (lambda v1 , v2 : v1 + x [1 ] * v2 , w , new_unitcell_bar )
1056- return (w , x [1 ])
10571019
1058- eigval , eigvec = jsp .sparse .linalg .eigs (
1059- f_arnoldi , 1 , (new_unitcell_bar , 1.0 )
1060- )
1020+ if not real :
1021+ w = (
1022+ jax .tree .map (lambda x : jnp .real (x ), w ),
1023+ jax .tree .map (lambda x : jnp .imag (x ), w ),
1024+ )
10611025
1062- print_debug ( "Eigval: {}" , eigval )
1026+ return ( w , x [ 1 ] )
10631027
1064- env_fixed_point = jax .tree .map (lambda v : jnp .real (v [..., 0 ]), eigvec [0 ])
1028+ if real :
1029+ eigval , eigvec = jsp .sparse .linalg .eigs (
1030+ f_arnoldi , 1 , (new_unitcell_bar , 1.0 )
1031+ )
1032+ else :
1033+ eigval , eigvec = jsp .sparse .linalg .eigs (
1034+ f_arnoldi ,
1035+ 1 ,
1036+ (
1037+ (
1038+ jax .tree .map (lambda x : jnp .real (x ), new_unitcell_bar ),
1039+ jax .tree .map (lambda x : jnp .imag (x ), new_unitcell_bar ),
1040+ ),
1041+ 1.0 ,
1042+ ),
1043+ )
10651044
1066- env_fixed_point , arnoldi_worked = cond (
1067- jnp .real (eigvec [1 ][0 ]) >= 1e-10 ,
1068- lambda x : (jax .tree .map (lambda v : v / jnp .real (eigvec [1 ][0 ]), x ), True ),
1069- lambda x : (x , False ),
1070- env_fixed_point ,
1045+ converged = cond (
1046+ jnp .logical_and (
1047+ jnp .abs (jnp .real (eigval [0 ]))
1048+ < (1 + 1e-2 * config .ad_custom_convergence_eps ),
1049+ jnp .abs (jnp .imag (eigval [0 ]))
1050+ < 1e-2 * config .ad_custom_convergence_eps ,
1051+ ),
1052+ lambda : True ,
1053+ lambda : False ,
10711054 )
1055+
1056+ if config .ad_custom_verbose_output :
1057+ debug_print (
1058+ "AD: Converged: {}, Eigval: {}, Eigvec[1]: {}" ,
1059+ converged ,
1060+ eigval [0 ],
1061+ eigvec [1 ][0 ],
1062+ )
1063+
1064+ if real :
1065+ env_fixed_point = jax .tree .map (lambda v : jnp .real (v [..., 0 ]), eigvec [0 ])
1066+ env_fixed_point , arnoldi_worked = cond (
1067+ jnp .logical_and (
1068+ converged ,
1069+ jnp .abs (eigvec [1 ][0 ])
1070+ >= 1e-2 * config .ad_custom_convergence_eps ,
1071+ ),
1072+ lambda x : (
1073+ jax .tree .map (lambda v : v / jnp .real (eigvec [1 ][0 ]), x ),
1074+ True ,
1075+ ),
1076+ lambda x : (x , False ),
1077+ env_fixed_point ,
1078+ )
1079+ else :
1080+ env_fixed_point = jax .tree .map (
1081+ lambda v , w : v [..., 0 ] + 1j * w [..., 0 ], eigvec [0 ][0 ], eigvec [0 ][1 ]
1082+ )
1083+ env_fixed_point , arnoldi_worked = cond (
1084+ jnp .logical_and (
1085+ converged ,
1086+ jnp .abs (eigvec [1 ][0 ])
1087+ >= 1e-2 * config .ad_custom_convergence_eps ,
1088+ ),
1089+ lambda x : (
1090+ jax .tree .map (lambda v : v / jnp .real (eigvec [1 ][0 ]), x ),
1091+ True ,
1092+ ),
1093+ lambda x : (x , False ),
1094+ env_fixed_point ,
1095+ )
10721096 else :
10731097 env_fixed_point = new_unitcell_bar
10741098 arnoldi_worked = False
1099+ converged = True
10751100
10761101 end_count = 0
10771102
10781103 def run_gmres (v , e ):
1104+ if config .ad_custom_verbose_output :
1105+ debug_print ("AD: Computing gradient with GMRES" )
1106+
10791107 def f_gmres (w ):
1108+ if not real :
1109+ w = jax .tree .map (lambda x , y : x + 1j * y , w [0 ], w [1 ])
1110+
10801111 new_w = vjp_env ((w , jnp .array (0 , dtype = jnp .float64 )))[0 ]
10811112
10821113 new_w = new_w .replace_unique_tensors (
@@ -1090,27 +1121,47 @@ def f_gmres(w):
10901121 ]
10911122 )
10921123
1124+ if not real :
1125+ new_w = (
1126+ jax .tree .map (lambda x : jnp .real (x ), new_w ),
1127+ jax .tree .map (lambda x : jnp .imag (x ), new_w ),
1128+ )
1129+
10931130 return new_w
10941131
10951132 is_gpu = jax .default_backend () == "gpu"
10961133
1134+ if real :
1135+ v0 = new_unitcell_bar
1136+ else :
1137+ v0 = (
1138+ jax .tree .map (lambda x : jnp .real (x ), new_unitcell_bar ),
1139+ jax .tree .map (lambda x : jnp .imag (x ), new_unitcell_bar ),
1140+ )
1141+
10971142 v , e = jax .scipy .sparse .linalg .gmres (
10981143 f_gmres ,
1099- new_unitcell_bar ,
1100- new_unitcell_bar ,
1144+ v0 ,
1145+ v0 ,
11011146 solve_method = "batched" if is_gpu else "incremental" ,
11021147 atol = config .ad_custom_convergence_eps ,
1103- maxiter = config .ad_custom_max_steps ,
1148+ # maxiter=config.ad_custom_max_steps,
11041149 )
11051150
1151+ if not real :
1152+ v = jax .tree .map (lambda x , y : x + 1j * y , v [0 ], v [1 ])
1153+
11061154 return v , e
11071155
1108- env_fixed_point , end_count = jax .lax .cond (
1109- arnoldi_worked , lambda x , e : (x , e ), run_gmres , env_fixed_point , end_count
1156+ env_fixed_point , end_count , converged = jax .lax .cond (
1157+ jnp .logical_and (converged , jnp .logical_not (arnoldi_worked )),
1158+ lambda x , ec , c : (* run_gmres (x , ec ), True ),
1159+ lambda x , ec , c : (x , ec , c ),
1160+ env_fixed_point ,
1161+ end_count ,
1162+ converged ,
11101163 )
11111164
1112- converged = True
1113-
11141165 (t_bar ,) = vjp_peps_tensors ((env_fixed_point , jnp .array (0 , dtype = jnp .float64 )))
11151166
11161167 return t_bar , converged , end_count
@@ -1135,7 +1186,7 @@ def calc_ctmrg_env_rev(
11351186
11361187 varipeps_global_state .ctmrg_effective_truncation_eps = None
11371188
1138- if end_count == varipeps_config . ad_custom_max_steps and not converged :
1189+ if not converged :
11391190 raise CTMRGGradientNotConvergedError
11401191
11411192 empty_t = [t .zeros_like_self () for t in input_unitcell .get_unique_tensors ()]
0 commit comments