-
Notifications
You must be signed in to change notification settings - Fork 42
Open
Description
Hello,
I tried Dion2 with a parameter from Conv1d. Even with flatten=True, I get an error:
[rank0]: M_sel = torch.index_select(M_work, dim=-1, index=K) # [I, k]
[rank0]: RuntimeError: Index is supposed to be an empty tensor or a vector
The following fixes the error, but I am not sure if this implementation is correct:
def fractional_orthonormalize_update(
M_full: Tensor,
fraction: float,
ef_decay: Tensor,
flatten: bool,
epsilon: Tensor,
newton_schulz_func: Callable,
) -> Tensor:
# Flattening for 3D+ tensors
original_shape = M_full.shape
if flatten and M_full.ndim >= 3:
M_full = M_full.flatten(start_dim=1)
M_work, transposed = make_work_view(M_full)
I, J = M_work.size(-2), M_work.size(-1)
if fraction >= 1.0:
# Full orthonormalization
ortho_update = muon_update_newton_schulz(
M_work, newton_schulz_func, flatten=flatten, epsilon=epsilon
)
M_work.mul_(ef_decay)
else:
# Fractional orthonormalization
k = int(math.ceil(fraction * J))
ortho_update = topk_and_orthonormalize(
M_work,
ef_decay=ef_decay,
k=k,
flatten=flatten,
epsilon=epsilon,
newton_schulz_func=newton_schulz_func,
)
result = ortho_update.mT.contiguous() if transposed else ortho_update
if flatten and original_shape != result.shape:
result = result.view(original_shape)
return result
Are convolution layers supported by Dion2?
Thanks
Metadata
Metadata
Assignees
Labels
No labels