-
Notifications
You must be signed in to change notification settings - Fork 43
Open
Description
I start to use dion2 in a model trained with fsdp2, torch==2.9.1+cu129. It raises an error when start to save a distributed checkpoint.
883 [rank0]: Traceback (most recent call last):
…
892 [rank0]: File "/app/foundation-model/libraries/fm_utils/fm_utils/checkpointing/fsdp.py", line 194, in save
893 [rank0]: model_state, optim_state = get_state_dict(
894 [rank0]: ^^^^^^^^^^^^^^^
895 [rank0]: File "/app/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py", line 1197, in get_state_dict
896 [rank0]: optim_state_dict = _get_optim_state_dict(model, optimizers, info)
897 [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
898 [rank0]: File "/app/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
899 [rank0]: return func(*args, **kwargs)
900 [rank0]: ^^^^^^^^^^^^^^^^^^^^^
901 [rank0]: File "/app/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py", line 832, in _get_optim_state_dict
902 [rank0]: OptimizerStateType, _flatten_optim_state_dict(optim_state_dict)
903 [rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
904 [rank0]: File "/app/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py", line 700, in _flatten_optim_state_dict
905 [rank0]: _raise_if_type_not_supported(v)
906 [rank0]: File "/app/.venv/lib/python3.12/site-packages/torch/distributed/checkpoint/state_dict.py", line 691, in _raise_if_type_not_supported
907 [rank0]: raise NotImplementedError(
908 [rank0]: NotImplementedError: Flattening optimizer state_dict only supports tensor, int, float states now. Type is <class 'NoneType'>.
It seems that torch starts to check the values in the optimizer state since version 2.9. I have made a patch here, jiasen-aignx@c7fa0e2. It works fine for my model but not sure if it is fully correct.
Metadata
Metadata
Assignees
Labels
No labels