Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"WanAnimateTransformer3DModel": {
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
"default_subfolder": "transformer",
},
"AutoencoderKLWan": {
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
"default_subfolder": "vae",
Expand Down
83 changes: 77 additions & 6 deletions src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
"wan_vae": "decoder.middle.0.residual.0.gamma",
"wan_vace": "vace_blocks.0.after_proj.bias",
"wan_animate": "motion_encoder.dec.direction.weight",
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
"cosmos-1.0": [
"net.x_embedder.proj.1.weight",
Expand Down Expand Up @@ -208,6 +209,7 @@
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
"wan-animate-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.2-Animate-14B-Diffusers"},
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
Expand Down Expand Up @@ -747,6 +749,9 @@ def infer_diffusers_model_type(checkpoint):
elif checkpoint[target_key].shape[0] == 5120:
model_type = "wan-vace-14B"

if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint:
model_type = "wan-animate-14B"

elif checkpoint[target_key].shape[0] == 1536:
model_type = "wan-t2v-1.3B"
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
Expand Down Expand Up @@ -3127,13 +3132,64 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):


def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
def generate_motion_encoder_mappings():
mappings = {
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
"motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight",
"motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias",
"motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight",
"motion_encoder.enc.fc": "motion_encoder.motion_network",
}

for i in range(7):
conv_idx = i + 1
mappings.update(
{
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.0.weight": f"motion_encoder.res_blocks.{i}.conv1.weight",
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.1.bias": f"motion_encoder.res_blocks.{i}.conv1.act_fn.bias",
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.1.weight": f"motion_encoder.res_blocks.{i}.conv2.weight",
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.2.bias": f"motion_encoder.res_blocks.{i}.conv2.act_fn.bias",
f"motion_encoder.enc.net_app.convs.{conv_idx}.skip.1.weight": f"motion_encoder.res_blocks.{i}.conv_skip.weight",
}
)

return mappings

def generate_face_adapter_mappings():
return {
"face_adapter.fuser_blocks": "face_adapter",
".k_norm.": ".norm_k.",
".q_norm.": ".norm_q.",
".linear1_q.": ".to_q.",
".linear2.": ".to_out.",
"conv1_local.conv": "conv1_local",
"conv2.conv": "conv2",
"conv3.conv": "conv3",
}

def split_tensor_handler(key, state_dict, split_pattern, target_keys):
tensor = state_dict.pop(key)
split_idx = tensor.shape[0] // 2

new_key_1 = key.replace(split_pattern, target_keys[0])
new_key_2 = key.replace(split_pattern, target_keys[1])

state_dict[new_key_1] = tensor[:split_idx]
state_dict[new_key_2] = tensor[split_idx:]

def reshape_bias_handler(key, state_dict):
if "motion_encoder.enc.net_app.convs." in key and ".bias" in key:
state_dict[key] = state_dict[key][0, :, 0, 0]

converted_state_dict = {}

# Strip model.diffusion_model prefix
keys = list(checkpoint.keys())
for k in keys:
if "model.diffusion_model." in k:
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)

# Base transformer mappings
TRANSFORMER_KEYS_RENAME_DICT = {
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
Expand All @@ -3155,28 +3211,43 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
"ffn.0": "ffn.net.0.proj",
"ffn.2": "ffn.net.2",
# Hack to swap the layer names
# The original model calls the norms in following order: norm1, norm3, norm2
# We convert it to: norm1, norm2, norm3
"norm2": "norm__placeholder",
"norm3": "norm2",
"norm__placeholder": "norm3",
# For the I2V model
# I2V model
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
# For the VACE model
# VACE model
"before_proj": "proj_in",
"after_proj": "proj_out",
}

SPECIAL_KEYS_HANDLERS = {}
if any("face_adapter" in k for k in checkpoint.keys()):
TRANSFORMER_KEYS_RENAME_DICT.update(generate_face_adapter_mappings())
SPECIAL_KEYS_HANDLERS[".linear1_kv."] = (split_tensor_handler, [".to_k.", ".to_v."])

if any("motion_encoder" in k for k in checkpoint.keys()):
TRANSFORMER_KEYS_RENAME_DICT.update(generate_motion_encoder_mappings())

for key in list(checkpoint.keys()):
new_key = key[:]
reshape_bias_handler(key, checkpoint)

for key in list(checkpoint.keys()):
new_key = key
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
new_key = new_key.replace(replace_key, rename_key)

converted_state_dict[new_key] = checkpoint.pop(key)

for key in list(converted_state_dict.keys()):
for pattern, (handler_fn, target_keys) in SPECIAL_KEYS_HANDLERS.items():
if pattern not in key:
continue
handler_fn(key, converted_state_dict, pattern, target_keys)
break

return converted_state_dict


Expand Down
12 changes: 6 additions & 6 deletions src/diffusers/models/transformers/transformer_wan_animate.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,9 +166,11 @@ def forward(self, x: torch.Tensor, channel_dim: int = 1) -> torch.Tensor:
# NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
# set to 1, which should be equivalent to a 2D convolution
expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
x = x.to(expanded_kernel.dtype)
x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)

# Main Conv2D with scaling
x = x.to(self.weight.dtype)
x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)

# Activation with fused bias, if using
Expand Down Expand Up @@ -338,8 +340,7 @@ def forward(self, face_image: torch.Tensor, channel_dim: int = 1) -> torch.Tenso
weight = self.motion_synthesis_weight + 1e-8
# Upcast the QR orthogonalization operation to FP32
original_motion_dtype = motion_feat.dtype
motion_feat = motion_feat.to(torch.float32)
weight = weight.to(torch.float32)
motion_feat = motion_feat.to(weight.dtype)

Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)

Expand Down Expand Up @@ -801,10 +802,9 @@ def forward(
if timestep_seq_len is not None:
timestep = timestep.unflatten(0, (-1, timestep_seq_len))

time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
timestep = timestep.to(time_embedder_dtype)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep = timestep.to(encoder_hidden_states.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mistake here @samadwar. The time embedder layer is kept in FP32 so direct casting like this won't work.

What we can do instead is check the dtype of the weight of the linear layer in the timestep embedder and cast if it's floating point.

        timestep = self.timesteps_proj(timestep)
        if timestep_seq_len is not None:
            timestep = timestep.unflatten(0, (-1, timestep_seq_len))

        if self.time_embedder.linear_1.weight.dtype.is_floating_point:
            time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
        else:
            time_embedder_dtype = encoder_hidden_states.dtype

        temb = self.time_embedder(timestep.to(time_embedder_dtype)).type_as(encoder_hidden_states)
        timestep_proj = self.time_proj(self.act_fn(temb))


temb = self.time_embedder(timestep)
timestep_proj = self.time_proj(self.act_fn(temb))

encoder_hidden_states = self.text_embedder(encoder_hidden_states)
Expand Down
Loading