From 37d9fd1b26f747f81436891541a9d78bcc0ed388 Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Sun, 20 Apr 2025 08:40:51 +0530 Subject: [PATCH 1/2] Add Flux support --- src/ras/model_inference/flux_inference.py | 38 +++ .../lumina_next_t2i_inference.py | 4 + src/ras/modules/attention_processor.py | 162 ++++++++++- src/ras/modules/flux/__init__.py | 0 src/ras/modules/flux/transformer_forward.py | 198 +++++++++++++ src/ras/schedulers/__init__.py | 1 + src/ras/schedulers/ras_flux_flow_matching.py | 274 ++++++++++++++++++ ...as_scheduling_flow_match_euler_discrete.py | 48 ++- src/ras/utils/flux/__init__.py | 0 src/ras/utils/flux/update_pipeline_flux.py | 13 + src/ras/utils/ras_manager.py | 4 +- 11 files changed, 733 insertions(+), 9 deletions(-) create mode 100644 src/ras/model_inference/flux_inference.py create mode 100644 src/ras/modules/flux/__init__.py create mode 100644 src/ras/modules/flux/transformer_forward.py create mode 100644 src/ras/schedulers/ras_flux_flow_matching.py create mode 100644 src/ras/utils/flux/__init__.py create mode 100644 src/ras/utils/flux/update_pipeline_flux.py diff --git a/src/ras/model_inference/flux_inference.py b/src/ras/model_inference/flux_inference.py new file mode 100644 index 0000000..566d942 --- /dev/null +++ b/src/ras/model_inference/flux_inference.py @@ -0,0 +1,38 @@ +import argparse +import torch +import sys +import time + +sys.path.append('/workspace/RAS') +sys.path.append('/workspace/RAS/src') +from diffusers import FluxPipeline +from ras.utils.flux.update_pipeline_flux import update_flux_pipeline +from ras.utils import ras_manager +from ras.utils.ras_argparser import parse_args + +def flux_inf(args): + pipeline = FluxPipeline.from_pretrained( + "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 + ).to("cuda") + pipeline = update_flux_pipeline(pipeline) + pipeline.transformer.to(memory_format=torch.channels_last) + pipeline.vae.to(memory_format=torch.channels_last) + print("Vae", pipeline.vae_scale_factor) + generator = torch.Generator("cuda").manual_seed(args.seed) if args.seed is not None else None + numsteps = args.num_inference_steps + start = time.time() + image = pipeline( + generator=generator, + num_inference_steps=numsteps, + prompt=args.prompt, + height=args.height, + width=args.width, + ).images[0] + print(f"Pipeline time {time.time()-start}") + image.save(args.output) + + +if __name__ == "__main__": + args = parse_args() + ras_manager.MANAGER.set_parameters(args) + flux_inf(args) diff --git a/src/ras/model_inference/lumina_next_t2i_inference.py b/src/ras/model_inference/lumina_next_t2i_inference.py index b0fb6c8..33e6588 100644 --- a/src/ras/model_inference/lumina_next_t2i_inference.py +++ b/src/ras/model_inference/lumina_next_t2i_inference.py @@ -1,5 +1,9 @@ import argparse import torch +import sys + +sys.path.append('/workspace/RAS') +sys.path.append('/workspace/RAS/src') from diffusers import LuminaText2ImgPipeline from ras.utils.lumina_next_t2i.update_pipeline_lumina import update_lumina_pipeline from ras.utils import ras_manager diff --git a/src/ras/modules/attention_processor.py b/src/ras/modules/attention_processor.py index 545ea8e..d856620 100644 --- a/src/ras/modules/attention_processor.py +++ b/src/ras/modules/attention_processor.py @@ -44,7 +44,6 @@ def __call__( base_sequence_length: Optional[int] = None, ) -> torch.Tensor: from diffusers.models.embeddings import apply_rotary_emb - is_self_attention = True if hidden_states.shape == encoder_hidden_states.shape else False input_ndim = hidden_states.ndim @@ -343,3 +342,164 @@ def __call__( return hidden_states, encoder_hidden_states else: return hidden_states + + +class RASFluxAttnProcessor2_0: + """Attention processor used typically in processing the SD3-like self-attention projections.""" + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") + """ + Case 1: + No encoder states passed => hidden_states = [encoder_hidden_states, hidden_states] + aka indices in ras cache need to be + len(encoder_hidden_states) aka 512 + """ + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + image_rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + is_self_attention = encoder_hidden_states is None + # `sample` projections. + query = attn.to_q(hidden_states) #[encoder_hidden_states, hidden_states] i.e token indices are offset by 512 + k_fuse_linear = ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step and \ + self.k_cache is not None and ras_manager.MANAGER.enable_index_fusion + v_fuse_linear = k_fuse_linear + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step: + if is_self_attention: + other_patchified_index = 512 + ras_manager.MANAGER.other_patchified_index + padding_indices = torch.arange(512, device=other_patchified_index.device) + other_patchified_index = torch.cat([padding_indices, other_patchified_index]) + else: + other_patchified_index = ras_manager.MANAGER.other_patchified_index + print(other_patchified_index.shape) + + if k_fuse_linear: + from .fused_kernels_sd3 import _partially_linear + _partially_linear( + hidden_states, + attn.to_k.weight, + attn.to_k.bias, + other_patchified_index, + self.k_cache.view(batch_size, self.k_cache.shape[1], -1) + ) + else: + key = attn.to_k(hidden_states) + if v_fuse_linear: + _partially_linear( + hidden_states, + attn.to_v.weight, + attn.to_v.bias, + other_patchified_index, + self.v_cache.view(batch_size, self.v_cache.shape[1], -1) + ) + else: + value = attn.to_v(hidden_states) + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.current_step == 0: + self.k_cache = None + self.v_cache = None + + if ras_manager.MANAGER.sample_ratio < 1.0 and (ras_manager.MANAGER.current_step == ras_manager.MANAGER.scheduler_start_step - 1 or ras_manager.MANAGER.current_step in ras_manager.MANAGER.error_reset_steps): + self.k_cache = key + self.v_cache = value + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step: + if not ras_manager.MANAGER.enable_index_fusion: + self.k_cache[:, other_patchified_index] = key + self.v_cache[:, other_patchified_index] = value + key = self.k_cache + value = self.v_cache + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.current_step > ras_manager.MANAGER.scheduler_end_step: + self.k_cache = None + self.v_cache = None + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states` + if encoder_hidden_states is not None: + # `context` projections. + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view( + batch_size, -1, attn.heads, head_dim + ).transpose(1, 2) + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + # attention + query = torch.cat([encoder_hidden_states_query_proj, query], dim=2) + key = torch.cat([encoder_hidden_states_key_proj, key], dim=2) + value = torch.cat([encoder_hidden_states_value_proj, value], dim=2) + + ##TODO: Create partial rope applying kernels. + if image_rotary_emb is not None: + from diffusers.models.embeddings import apply_rotary_emb + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step: + query = apply_rotary_emb(query, ras_manager.MANAGER.image_rotary_emb_skip) + key = apply_rotary_emb(key, image_rotary_emb) + else: + query = apply_rotary_emb(query, image_rotary_emb) + key = apply_rotary_emb(key, image_rotary_emb) + + if not ras_manager.MANAGER.replace_with_flash_attn: + hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False) + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + else: + from flash_attn import flash_attn_func + hidden_states = flash_attn_func( + query, key, value, dropout_p=0.0, causal=False + ) + hidden_states = hidden_states.view(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + if encoder_hidden_states is not None: + # print(f"Before, {hidden_states.shape}") + encoder_hidden_states, hidden_states = ( + hidden_states[:, : encoder_hidden_states.shape[1]], + hidden_states[:, encoder_hidden_states.shape[1] :], + ) + # print(f"After, {hidden_states.shape}") + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + return hidden_states, encoder_hidden_states + else: + return hidden_states + + + diff --git a/src/ras/modules/flux/__init__.py b/src/ras/modules/flux/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ras/modules/flux/transformer_forward.py b/src/ras/modules/flux/transformer_forward.py new file mode 100644 index 0000000..8b67081 --- /dev/null +++ b/src/ras/modules/flux/transformer_forward.py @@ -0,0 +1,198 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +from ras.utils import ras_manager + + +def ras_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor = None, + pooled_projections: torch.Tensor = None, + timestep: torch.LongTensor = None, + img_ids: torch.Tensor = None, + txt_ids: torch.Tensor = None, + guidance: torch.Tensor = None, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_block_samples=None, + controlnet_single_block_samples=None, + return_dict: bool = True, + controlnet_blocks_repeat: bool = False, + ) -> Union[torch.Tensor, Transformer2DModelOutput]: + """ + The [`FluxTransformer2DModel`] forward method. + + Args: + hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`): + Input `hidden_states`. + encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + pooled_projections (`torch.Tensor` of shape `(batch_size, projection_dim)`): Embeddings projected + from the embeddings of input conditions. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + block_controlnet_hidden_states: (`list` of `torch.Tensor`): + A list of tensors that if specified are added to the residuals of transformer blocks. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + print(f'forward {hidden_states.shape} {encoder_hidden_states.shape}') + if joint_attention_kwargs is not None: + joint_attention_kwargs = joint_attention_kwargs.copy() + lora_scale = joint_attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." + ) + + hidden_states = self.x_embedder(hidden_states) + + timestep = timestep.to(hidden_states.dtype) * 1000 + if guidance is not None: + guidance = guidance.to(hidden_states.dtype) * 1000 + else: + guidance = None + + temb = ( + self.time_text_embed(timestep, pooled_projections) + if guidance is None + else self.time_text_embed(timestep, guidance, pooled_projections) + ) + encoder_hidden_states = self.context_embedder(encoder_hidden_states) + + if txt_ids.ndim == 3: + logger.warning( + "Passing `txt_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + txt_ids = txt_ids[0] + if img_ids.ndim == 3: + logger.warning( + "Passing `img_ids` 3d torch.Tensor is deprecated." + "Please remove the batch dimension and pass it as a 2d torch Tensor" + ) + img_ids = img_ids[0] + + ids = torch.cat((txt_ids, img_ids), dim=0) + image_rotary_emb = self.pos_embed(ids) + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step: + print("Before", hidden_states.shape) + hidden_states = hidden_states[:, ras_manager.MANAGER.other_patchified_index] + print("After", hidden_states.shape) + ras_img_ids = img_ids[ras_manager.MANAGER.other_patchified_index, :] + ras_ids = torch.cat((txt_ids, ras_img_ids), dim=0) + ras_manager.MANAGER.image_rotary_emb_skip = self.pos_embed(ras_ids) + + if joint_attention_kwargs is not None and "ip_adapter_image_embeds" in joint_attention_kwargs: + ip_adapter_image_embeds = joint_attention_kwargs.pop("ip_adapter_image_embeds") + ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds) + joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states}) + + for index_block, block in enumerate(self.transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + encoder_hidden_states, hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + encoder_hidden_states, + temb, + image_rotary_emb, + ) + + else: + encoder_hidden_states, hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_block_samples is not None: + interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) + interval_control = int(np.ceil(interval_control)) + # For Xlabs ControlNet. + if controlnet_blocks_repeat: + hidden_states = ( + hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] + ) + else: + hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + + for index_block, block in enumerate(self.single_transformer_blocks): + if torch.is_grad_enabled() and self.gradient_checkpointing: + hidden_states = self._gradient_checkpointing_func( + block, + hidden_states, + temb, + image_rotary_emb, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + temb=temb, + image_rotary_emb=image_rotary_emb, + joint_attention_kwargs=joint_attention_kwargs, + ) + + # controlnet residual + if controlnet_single_block_samples is not None: + interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) + interval_control = int(np.ceil(interval_control)) + hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( + hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + controlnet_single_block_samples[index_block // interval_control] + ) + + hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] + + # unpatchify + patch_size = ras_manager.MANAGER.patch_size #flux config has patch_size = 1 which would be wrong here. + hp = ras_manager.MANAGER.height // (patch_size * ras_manager.MANAGER.vae_size) + wp = ras_manager.MANAGER.width // (patch_size * ras_manager.MANAGER.vae_size) + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step: + final_hidden_states = torch.zeros( + (hidden_states.shape[0], hp * wp, hidden_states.shape[2]), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + final_hidden_states[:, ras_manager.MANAGER.other_patchified_index] = hidden_states + hidden_states = final_hidden_states + print('zero', {hidden_states.shape}) + + hidden_states = self.norm_out(hidden_states, temb) + output = self.proj_out(hidden_states) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/src/ras/schedulers/__init__.py b/src/ras/schedulers/__init__.py index f7cead7..0ad1307 100644 --- a/src/ras/schedulers/__init__.py +++ b/src/ras/schedulers/__init__.py @@ -1,3 +1,4 @@ from typing import TYPE_CHECKING from .ras_scheduling_flow_match_euler_discrete import RASFlowMatchEulerDiscreteScheduler +from .ras_flux_flow_matching import RASFluxFlowMatchEulerDiscreteScheduler diff --git a/src/ras/schedulers/ras_flux_flow_matching.py b/src/ras/schedulers/ras_flux_flow_matching.py new file mode 100644 index 0000000..9179f69 --- /dev/null +++ b/src/ras/schedulers/ras_flux_flow_matching.py @@ -0,0 +1,274 @@ +# This file is a modified version of the original file from the HuggingFace/diffusers library. + +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from ras.utils import ras_manager + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class RASFlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class RASFluxFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + """ + RAS Euler scheduler. + + This model inherits from ['FlowMatchEulerDiscreteScheduler']. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + ): + super().__init__(num_train_timesteps=num_train_timesteps, + shift=shift, + use_dynamic_shifting=use_dynamic_shifting, + base_shift=base_shift, + max_shift=max_shift, + base_image_seq_len=base_image_seq_len, + max_image_seq_len=max_image_seq_len, + # invert_sigmas=invert_sigmas + ) + self.drop_cnt = None + + + def _init_ras_config(self, latents): + #drop cnt stored at token level + self.drop_cnt = torch.zeros((latents.shape[1]), device=latents.device) - len(self.sigmas) + + def extract_latents_index_from_patched_latents_index(self, indices, height, width): + """ + Maps patch indices to latent indices in a row-major ordered grid for unequal H and W. + + Args: + indices (torch.Tensor): 1D tensor of patch indices from the flattened patch grid. + height (int): Height of the latent grid (H). + width (int): Width of the latent grid (W). + + Returns: + torch.Tensor: Flattened tensor of latent indices corresponding to all positions in the patches. + """ + # Access patch size (assuming it's defined in a manager object) + ps = ras_manager.MANAGER.patch_size # e.g., 2 + + # Compute patch grid dimensions + ph = height // ps # Number of patches along height + pw = width // ps # Number of patches along width + + # Compute patch coordinates from indices (row-major order) + patch_y = indices // pw # Patch row index + patch_x = indices % pw # Patch column index + + # Compute base latent index for the top-left corner of each patch + base_index = (patch_y * ps) * width + (patch_x * ps) + + # Define offsets for the four corners in row-major order + offsets = torch.tensor([0, 1, width, width + 1], dtype=indices.dtype, device=indices.device) + + # Compute all latent indices by adding offsets to base index + # Shape: (num_indices, 4) -> flatten to (num_indices * 4) + flattened_indices = (base_index[:, None] + offsets[None, :]).flatten() + + return flattened_indices + + def ras_selection(self, sample, diff): + """ + Input diff is now of shape (B, H/p*W/p, C*p*p) instead of (B, C, H, W). + Metric calculation happens in (B, C, H, W) space, but indexing is in patch space. + For 1024x1024, VAE stride 8, p=2 (packing) -> (1, 64*64, 16*4). + Assumes B=1. + """ + p = ras_manager.MANAGER.patch_size # e.g., p=2 + C = diff.shape[-1] // (p * p) # Recover original C (e.g., 16 from 16*4) + H = ras_manager.MANAGER.height // ras_manager.MANAGER.vae_size + W = ras_manager.MANAGER.width // ras_manager.MANAGER.vae_size + assert diff.shape == (1, (H // p) * (W // p), C * p * p), f"{diff.shape} != {(1, (H // p) * (W // p), C * p * p)}" + + # Convert diff to (B, C, H, W) for metric calculation + diff_unpatched = diff.view(1, H // p, W // p, C, p, p) # (1, H/p, W/p, C, p, p) + diff_unpatched = diff_unpatched.permute(0, 3, 1, 4, 2, 5) # (1, C, H/p, p, W/p, p) + diff_unpatched = diff_unpatched.reshape(1, C, H, W) # (1, C, H, W) + + # Calculate the metric for each patch + if ras_manager.MANAGER.metric == "std": + metric = torch.std(diff_unpatched, dim=1).view(H // p, p, W // p, p).transpose(1, 2).mean(-1).mean(-1).view(-1) + elif ras_manager.MANAGER.metric == "l2norm": + metric = torch.norm(diff_unpatched, p=2, dim=1).view(H // p, p, W // p, p).transpose(1, 2).mean(-1).mean(-1).view(-1) + else: + raise ValueError("Unknown metric") + + # Scale the metric with the drop count to avoid starvation + metric *= torch.exp(ras_manager.MANAGER.starvation_scale * self.drop_cnt) + current_skip_num = ras_manager.MANAGER.skip_token_num_list[self._step_index + 1] + assert ras_manager.MANAGER.high_ratio >= 0 and ras_manager.MANAGER.high_ratio <= 1, "High ratio should be in the range of [0, 1]" + + # Select indices in patch space + indices = torch.sort(metric, dim=0, descending=False).indices + low_bar = int(current_skip_num * (1 - ras_manager.MANAGER.high_ratio)) + high_bar = int(current_skip_num * ras_manager.MANAGER.high_ratio) + cached_patchified_indices = torch.cat([indices[:low_bar], indices[-high_bar:]]) + other_patchified_indices = indices[low_bar:-high_bar] + + # Update drop count + self.drop_cnt[cached_patchified_indices] += 1 + print(f"cached patchified indices {cached_patchified_indices}") + + # Convert patch indices to latent indices + # latent_cached_indices = self.extract_latents_index_from_patched_latents_index(cached_patchified_indices, H, W) + + return cached_patchified_indices, other_patchified_indices + + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[RASFlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + if self.drop_cnt is None or self._step_index == 0: + self._init_ras_config(sample) + + if self._step_index == 0: + ras_manager.MANAGER.reset_cache() + + ## B H/p * W/p C * p * p + batch, token_count, latent_dim = sample.shape + # latent_dim, height, width = sample.shape[-3:]s + + assert ras_manager.MANAGER.sample_ratio > 0.0 and ras_manager.MANAGER.sample_ratio <= 1.0 + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step: + model_output[:, ras_manager.MANAGER.cached_index, :] = ras_manager.MANAGER.cached_scaled_noise + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + diff = (sigma_next - sigma) * model_output + prev_sample = sample + diff + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_next_RAS_step: + cached_patchified_indices, other_patchified_indices = self.ras_selection(sample, diff) + ras_manager.MANAGER.cached_scaled_noise = model_output[:, cached_patchified_indices, :] + ras_manager.MANAGER.cached_index = cached_patchified_indices + ras_manager.MANAGER.other_patchified_index = other_patchified_indices + + # upon completion increase step index by one + self._step_index += 1 + ras_manager.MANAGER.increase_step() + if ras_manager.MANAGER.current_step >= ras_manager.MANAGER.num_steps: + ras_manager.MANAGER.reset_cache() + + if not return_dict: + return (prev_sample,) + + return RASFlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) \ No newline at end of file diff --git a/src/ras/schedulers/ras_scheduling_flow_match_euler_discrete.py b/src/ras/schedulers/ras_scheduling_flow_match_euler_discrete.py index 1a4e0a9..6f3a0f4 100644 --- a/src/ras/schedulers/ras_scheduling_flow_match_euler_discrete.py +++ b/src/ras/schedulers/ras_scheduling_flow_match_euler_discrete.py @@ -92,11 +92,46 @@ def __init__( def _init_ras_config(self, latents): self.drop_cnt = torch.zeros((latents.shape[-2] // ras_manager.MANAGER.patch_size * latents.shape[-1] // ras_manager.MANAGER.patch_size), device=latents.device) - len(self.sigmas) - def extract_latents_index_from_patched_latents_index(self, indices, height): - # # TODO add non-square case - # # TODO support PATCH_SIZE != 2 - flattened_indices = indices // (height // ras_manager.MANAGER.patch_size) * ras_manager.MANAGER.patch_size * height + indices % (height // ras_manager.MANAGER.patch_size) *ras_manager.MANAGER.patch_size - flattened_indices = (flattened_indices[:, None] + torch.tensor([0, height + 1, 1, height], dtype=indices.dtype, device=indices.device)[None, :]).flatten() + # def extract_latents_index_from_patched_latents_index(self, indices, height): + # # # TODO add non-square case + # # # TODO support PATCH_SIZE != 2 + # flattened_indices = indices // (height // ras_manager.MANAGER.patch_size) * ras_manager.MANAGER.patch_size * height + indices % (height // ras_manager.MANAGER.patch_size) *ras_manager.MANAGER.patch_size + # flattened_indices = (flattened_indices[:, None] + torch.tensor([0, height + 1, 1, height], dtype=indices.dtype, device=indices.device)[None, :]).flatten() + # return flattened_indices + + def extract_latents_index_from_patched_latents_index(self, indices, height, width): + """ + Maps patch indices to latent indices in a row-major ordered grid for unequal H and W. + + Args: + indices (torch.Tensor): 1D tensor of patch indices from the flattened patch grid. + height (int): Height of the latent grid (H). + width (int): Width of the latent grid (W). + + Returns: + torch.Tensor: Flattened tensor of latent indices corresponding to all positions in the patches. + """ + # Access patch size (assuming it's defined in a manager object) + ps = ras_manager.MANAGER.patch_size # e.g., 2 + + # Compute patch grid dimensions + ph = height // ps # Number of patches along height + pw = width // ps # Number of patches along width + + # Compute patch coordinates from indices (row-major order) + patch_y = indices // pw # Patch row index + patch_x = indices % pw # Patch column index + + # Compute base latent index for the top-left corner of each patch + base_index = (patch_y * ps) * width + (patch_x * ps) + + # Define offsets for the four corners in row-major order + offsets = torch.tensor([0, 1, width, width + 1], dtype=indices.dtype, device=indices.device) + + # Compute all latent indices by adding offsets to base index + # Shape: (num_indices, 4) -> flatten to (num_indices * 4) + flattened_indices = (base_index[:, None] + offsets[None, :]).flatten() + return flattened_indices def ras_selection(self, sample, diff, height, width): @@ -119,7 +154,8 @@ def ras_selection(self, sample, diff, height, width): cached_patchified_indices = torch.cat([indices[:low_bar], indices[-high_bar:]]) other_patchified_indices = indices[low_bar:-high_bar] self.drop_cnt[cached_patchified_indices] += 1 - latent_cached_indices = self.extract_latents_index_from_patched_latents_index(cached_patchified_indices, height) + # latent_cached_indices = self.extract_latents_index_from_patched_latents_index(cached_patchified_indices, height) + latent_cached_indices = self.extract_latents_index_from_patched_latents_index(cached_patchified_indices, height, width) return latent_cached_indices, other_patchified_indices diff --git a/src/ras/utils/flux/__init__.py b/src/ras/utils/flux/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ras/utils/flux/update_pipeline_flux.py b/src/ras/utils/flux/update_pipeline_flux.py new file mode 100644 index 0000000..1fe6b7f --- /dev/null +++ b/src/ras/utils/flux/update_pipeline_flux.py @@ -0,0 +1,13 @@ +from ...schedulers import RASFluxFlowMatchEulerDiscreteScheduler +from ...modules.attention_processor import RASFluxAttnProcessor2_0 +from ...modules.flux.transformer_forward import ras_forward + +def update_flux_pipeline(pipeline): + scheduler = RASFluxFlowMatchEulerDiscreteScheduler.from_config(pipeline.scheduler.config) + pipeline.scheduler = scheduler + pipeline.transformer.forward = ras_forward.__get__(pipeline.transformer, pipeline.transformer.__class__) + for block in pipeline.transformer.transformer_blocks: + block.attn.set_processor(RASFluxAttnProcessor2_0()) + for block in pipeline.transformer.single_transformer_blocks: + block.attn.set_processor(RASFluxAttnProcessor2_0()) + return pipeline \ No newline at end of file diff --git a/src/ras/utils/ras_manager.py b/src/ras/utils/ras_manager.py index 2c01585..bac4986 100644 --- a/src/ras/utils/ras_manager.py +++ b/src/ras/utils/ras_manager.py @@ -38,14 +38,14 @@ def set_parameters(self, args): self.skip_num_step = args.skip_num_step self.skip_num_step_length = args.skip_num_step_length self.height = args.height - self.weight = args.width + self.width = args.width self.high_ratio = args.high_ratio self.enable_index_fusion = args.enable_index_fusion self.generate_skip_token_list() def generate_skip_token_list(self): - avg_skip_token_num = int((1 - self.sample_ratio) * ((self.height // self.patch_size) // self.vae_size) * ((self.weight // self.patch_size) // self.vae_size)) + avg_skip_token_num = int((1 - self.sample_ratio) * ((self.height // self.patch_size) // self.vae_size) * ((self.width // self.patch_size) // self.vae_size)) if self.skip_num_step_length == 0: # static dropping self.skip_token_num_list = [avg_skip_token_num for i in range(self.num_steps)] for i in self.error_reset_steps: From 127466a17bf37e0481731936c4024e9e8114a32f Mon Sep 17 00:00:00 2001 From: Abhinay1997 Date: Wed, 23 Apr 2025 17:16:48 +0530 Subject: [PATCH 2/2] Add Wan support --- src/ras/model_inference/wan_inference.py | 38 +++ src/ras/modules/attention_processor.py | 127 ++++++++- src/ras/modules/wan/__init__.py | 0 src/ras/modules/wan/transformer_forward.py | 100 +++++++ src/ras/schedulers/__init__.py | 1 + src/ras/schedulers/ras_wan_flow_matching.py | 278 ++++++++++++++++++++ src/ras/utils/ras_argparser.py | 4 + src/ras/utils/ras_manager.py | 12 +- src/ras/utils/wan/__init__.py | 0 src/ras/utils/wan/update_pipeline_wan.py | 12 + 10 files changed, 570 insertions(+), 2 deletions(-) create mode 100644 src/ras/model_inference/wan_inference.py create mode 100644 src/ras/modules/wan/__init__.py create mode 100644 src/ras/modules/wan/transformer_forward.py create mode 100644 src/ras/schedulers/ras_wan_flow_matching.py create mode 100644 src/ras/utils/wan/__init__.py create mode 100644 src/ras/utils/wan/update_pipeline_wan.py diff --git a/src/ras/model_inference/wan_inference.py b/src/ras/model_inference/wan_inference.py new file mode 100644 index 0000000..fda20f0 --- /dev/null +++ b/src/ras/model_inference/wan_inference.py @@ -0,0 +1,38 @@ +import argparse +import torch +from diffusers import AutoencoderKLWan, WanPipeline +from diffusers.utils import export_to_video +import sys + +sys.path.append('/workspace/RAS') +sys.path.append('/workspace/RAS/src') +from ras.utils.wan.update_pipeline_wan import update_wan_pipeline +from ras.utils import ras_manager +from ras.utils.ras_argparser import parse_args + +def wan_inf(args): + # Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers + model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers" + vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32) + pipeline = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16) + pipeline.to("cuda") + pipeline = update_wan_pipeline(pipeline) + pipeline.enable_sequential_cpu_offload() + generator = torch.Generator("cuda").manual_seed(args.seed) if args.seed is not None else None + numsteps = args.num_inference_steps + video = pipeline( + generator=generator, + prompt=args.prompt, + negative_prompt=args.negative_prompt, + height=args.height, + width=args.width, + num_frames=args.num_frames, + guidance_scale=5.0, + num_inference_steps=numsteps + ).frames[0] + export_to_video(video, args.output, fps=15) + +if __name__ == "__main__": + args = parse_args() + ras_manager.MANAGER.set_parameters(args) + wan_inf(args) diff --git a/src/ras/modules/attention_processor.py b/src/ras/modules/attention_processor.py index d856620..56a50fa 100644 --- a/src/ras/modules/attention_processor.py +++ b/src/ras/modules/attention_processor.py @@ -15,7 +15,7 @@ import math import torch from ..utils import ras_manager -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union, Dict, Any from torch import nn from diffusers.models.attention_processor import Attention import torch.nn.functional as F @@ -501,5 +501,130 @@ def __call__( else: return hidden_states +class RASWanAttnProcessor2_0: + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.") + if ras_manager.MANAGER.sample_ratio < 1.0: + self.k_cache = None + self.v_cache = None + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + encoder_hidden_states_img = None + is_self_attention = False + if attn.add_k_proj is not None: + # 512 is the context length of the text encoder, hardcoded for now + image_context_length = encoder_hidden_states.shape[1] - 512 + encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length] + encoder_hidden_states = encoder_hidden_states[:, image_context_length:] + if encoder_hidden_states is None: + is_self_attention = True + encoder_hidden_states = hidden_states + + query = attn.to_q(hidden_states) + ## if self attention, key and value are going to be truncated as well. Need to use k_cache and v_cache for RAS_steps + v_fuse_linear = ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step and \ + is_self_attention and self.v_cache is not None \ + and ras_manager.MANAGER.enable_index_fusion + k_fuse_linear = v_fuse_linear and rotary_emb is None + + if v_fuse_linear: + from .fused_kernels_lumina import _partially_linear + _partially_linear( + encoder_hidden_states, + attn.to_v.weight, + attn.to_v.bias, + ras_manager.MANAGER.other_patchified_index, + self.v_cache.view(batch_size, self.v_cache.shape[1], -1) + ) + else: + value = attn.to_v(encoder_hidden_states) + if k_fuse_linear: + _partially_linear( + encoder_hidden_states, + attn.to_k.weight, + attn.to_k.bias, + ras_manager.MANAGER.other_patchified_index, + self.k_cache.view(batch_size, self.k_cache.shape[1], -1) + ) + else: + key = attn.to_k(encoder_hidden_states) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2) + key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.current_step == 0 and is_self_attention: + self.k_cache = None + self.v_cache = None + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.current_step > ras_manager.MANAGER.scheduler_end_step and is_self_attention: + self.k_cache = None + self.v_cache = None + + if rotary_emb is not None: + + def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor): + x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2))) + x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4) + return x_out.type_as(hidden_states) + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step: + query = apply_rotary_emb(query, ras_manager.MANAGER.image_rotary_emb_skip) + key = apply_rotary_emb(key, ras_manager.MANAGER.image_rotary_emb_skip) + else: + query = apply_rotary_emb(query, rotary_emb) + key = apply_rotary_emb(key, rotary_emb) + + if ras_manager.MANAGER.sample_ratio < 1.0 and (ras_manager.MANAGER.current_step == ras_manager.MANAGER.scheduler_start_step - 1 or ras_manager.MANAGER.current_step in ras_manager.MANAGER.error_reset_steps) and is_self_attention: + self.k_cache = key + self.v_cache = value + + if ras_manager.MANAGER.sample_ratio < 1.0 and is_self_attention and ras_manager.MANAGER.is_RAS_step: + if not ras_manager.MANAGER.enable_index_fusion: + self.k_cache[:, :, ras_manager.MANAGER.other_patchified_index, :] = key + self.v_cache[:, :, ras_manager.MANAGER.other_patchified_index, :] = value + key = self.k_cache + value = self.v_cache + + # I2V task + hidden_states_img = None + if encoder_hidden_states_img is not None: + key_img = attn.add_k_proj(encoder_hidden_states_img) + key_img = attn.norm_added_k(key_img) + value_img = attn.add_v_proj(encoder_hidden_states_img) + + key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2) + + hidden_states_img = F.scaled_dot_product_attention( + query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False + ) + hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3) + hidden_states_img = hidden_states_img.type_as(query) + + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + hidden_states = hidden_states.transpose(1, 2).flatten(2, 3) + hidden_states = hidden_states.type_as(query) + + if hidden_states_img is not None: + hidden_states = hidden_states + hidden_states_img + + hidden_states = attn.to_out[0](hidden_states) + hidden_states = attn.to_out[1](hidden_states) + return hidden_states \ No newline at end of file diff --git a/src/ras/modules/wan/__init__.py b/src/ras/modules/wan/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ras/modules/wan/transformer_forward.py b/src/ras/modules/wan/transformer_forward.py new file mode 100644 index 0000000..6423934 --- /dev/null +++ b/src/ras/modules/wan/transformer_forward.py @@ -0,0 +1,100 @@ +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils.torch_utils import maybe_allow_in_graph +from ras.utils import ras_manager + +@torch.no_grad() +def ras_forward( + self, + hidden_states: torch.Tensor, + timestep: torch.LongTensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, + return_dict: bool = True, + attention_kwargs: Optional[Dict[str, Any]] = None, +) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + batch_size, num_channels, num_frames, height, width = hidden_states.shape + p_t, p_h, p_w = self.config.patch_size + post_patch_num_frames = num_frames // p_t + post_patch_height = height // p_h + post_patch_width = width // p_w + + rotary_emb = self.rope(hidden_states) + + hidden_states = self.patch_embedding(hidden_states) + hidden_states = hidden_states.flatten(2).transpose(1, 2) + + temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image + ) + timestep_proj = timestep_proj.unflatten(1, (6, -1)) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1) + + # 4. Transformer blocks + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step: + hidden_states = hidden_states[:, ras_manager.MANAGER.other_patchified_index] + ras_manager.MANAGER.image_rotary_emb_skip = rotary_emb[:, :, ras_manager.MANAGER.other_patchified_index, :] + + if torch.is_grad_enabled() and self.gradient_checkpointing: + for block in self.blocks: + hidden_states = self._gradient_checkpointing_func( + block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb + ) + else: + for block in self.blocks: + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + + # 5. Output norm, projection & unpatchify + shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1) + + # Move the shift and scale tensors to the same device as hidden_states. + # When using multi-GPU inference via accelerate these will be on the + # first device rather than the last device, which hidden_states ends up + # on. + shift = shift.to(hidden_states.device) + scale = scale.to(hidden_states.device) + + hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states) + hidden_states = self.proj_out(hidden_states) + if ras_manager.MANAGER.sample_ratio < 1.0: + if ras_manager.MANAGER.is_RAS_step: + final_hidden_states = torch.zeros((hidden_states.shape[0], (num_frames // p_t ) * (height // p_h)*(width // p_w), hidden_states.shape[2]), device=hidden_states.device, dtype=hidden_states.dtype) + final_hidden_states[:, ras_manager.MANAGER.other_patchified_index] = hidden_states + hidden_states = final_hidden_states + hidden_states = hidden_states.reshape( + batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + ) + hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6) + output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) \ No newline at end of file diff --git a/src/ras/schedulers/__init__.py b/src/ras/schedulers/__init__.py index 0ad1307..3390ed1 100644 --- a/src/ras/schedulers/__init__.py +++ b/src/ras/schedulers/__init__.py @@ -2,3 +2,4 @@ from .ras_scheduling_flow_match_euler_discrete import RASFlowMatchEulerDiscreteScheduler from .ras_flux_flow_matching import RASFluxFlowMatchEulerDiscreteScheduler +from .ras_wan_flow_matching import RASWanFlowMatchEulerDiscreteScheduler diff --git a/src/ras/schedulers/ras_wan_flow_matching.py b/src/ras/schedulers/ras_wan_flow_matching.py new file mode 100644 index 0000000..a65f4f9 --- /dev/null +++ b/src/ras/schedulers/ras_wan_flow_matching.py @@ -0,0 +1,278 @@ +# This file is a modified version of the original file from the HuggingFace/diffusers library. + +# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.schedulers import FlowMatchEulerDiscreteScheduler +from ras.utils import ras_manager + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class RASFlowMatchEulerDiscreteSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's `step` function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + """ + + prev_sample: torch.FloatTensor + + +class RASWanFlowMatchEulerDiscreteScheduler(FlowMatchEulerDiscreteScheduler): + """ + RAS Euler scheduler. + + This model inherits from ['FlowMatchEulerDiscreteScheduler']. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + timestep_spacing (`str`, defaults to `"linspace"`): + The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information. + shift (`float`, defaults to 1.0): + The shift value for the timestep schedule. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + shift: float = 1.0, + use_dynamic_shifting=False, + base_shift: Optional[float] = 0.5, + max_shift: Optional[float] = 1.15, + base_image_seq_len: Optional[int] = 256, + max_image_seq_len: Optional[int] = 4096, + invert_sigmas: bool = False, + ): + super().__init__(num_train_timesteps=num_train_timesteps, + shift=shift, + use_dynamic_shifting=use_dynamic_shifting, + base_shift=base_shift, + max_shift=max_shift, + base_image_seq_len=base_image_seq_len, + max_image_seq_len=max_image_seq_len, + # invert_sigmas=invert_sigmas + ) + self.drop_cnt = None + + + def _init_ras_config(self, latents): + #drop cnt stored at token level + b, c, t, h, w = latents.shape + token_count = (t // ras_manager.MANAGER.temporal_patch_size) * (h // ras_manager.MANAGER.patch_size) * (w // ras_manager.MANAGER.patch_size) + self.drop_cnt = torch.zeros((token_count), device=latents.device) - len(self.sigmas) + + def extract_latents_index_from_patched_latents_index(self, indices, frames, height, width): + """ + Maps patch indices to latent indices in a row-major ordered grid for unequal H and W. + + Args: + indices (torch.Tensor): 1D tensor of patch indices from the flattened patch grid. + height (int): Height of the latent grid (H). + width (int): Width of the latent grid (W). + + Returns: + torch.Tensor: Flattened tensor of latent indices corresponding to all positions in the patches. + """ + # Access patch size (assuming it's defined in a manager object) + ps = ras_manager.MANAGER.patch_size # e.g., 2 + pst = ras_manager.MANAGER.temporal_patch_size # temporal patch size + + # Compute patch grid dimensions + pt = frames // pst + ph = height // ps # Number of patches along height + pw = width // ps # Number of patches along width + + #TODO: Compute for frame coordinates as well. + # Compute patch coordinates from indices (row-major order) + patch_y = indices // pw # Patch row index + patch_x = indices % pw # Patch column index + + # Compute base latent index for the top-left corner of each patch + base_index = (patch_y * ps) * width + (patch_x * ps) + + # Define offsets for the four corners in row-major order + offsets = torch.tensor([0, 1, width, width + 1], dtype=indices.dtype, device=indices.device) + + # Compute all latent indices by adding offsets to base index + # Shape: (num_indices, 4) -> flatten to (num_indices * 4) + flattened_indices = (base_index[:, None] + offsets[None, :]).flatten() + + return flattened_indices + + def ras_selection(self, sample, diff, frames, height, width): + """ + diff shape is B C T H W + """ + diff = diff.squeeze(0).permute(1, 2, 3, 0) + # calculate the metric for each patch + if ras_manager.MANAGER.metric == "std": + #T H W C -> T/p1, p1, H/p2, p2, W/p2, p2 + metric = torch.std(diff, dim=-1).view( + frames // ras_manager.MANAGER.temporal_patch_size, + ras_manager.MANAGER.temporal_patch_size, + height // ras_manager.MANAGER.patch_size, + ras_manager.MANAGER.patch_size, + width // ras_manager.MANAGER.patch_size, + ras_manager.MANAGER.patch_size + ).permute(0, 2, 4, 1, 3, 5).mean(-1).mean(-1).mean(-1).view(-1) + print(f'metric shape {metric.shape}') + elif ras_manager.MANAGER.metric == "l2norm": + metric = torch.norm(diff, p=2, dim=-1).view( + frames // ras_manager.MANAGER.temporal_patch_size, + ras_manager.MANAGER.temporal_patch_size, + height // ras_manager.MANAGER.patch_size, + ras_manager.MANAGER.patch_size, + width // ras_manager.MANAGER.patch_size, + ras_manager.MANAGER.patch_size + ).permute(0, 2, 4, 1, 3, 5).mean(-1).mean(-1).mean(-1).view(-1) + else: + raise ValueError("Unknown metric") + + # scale the metric with the drop count to avoid starvation + metric *= torch.exp(ras_manager.MANAGER.starvation_scale * self.drop_cnt) + current_skip_num = ras_manager.MANAGER.skip_token_num_list[self._step_index + 1] + assert ras_manager.MANAGER.high_ratio >= 0 and ras_manager.MANAGER.high_ratio <= 1, "High ratio should be in the range of [0, 1]" + indices = torch.sort(metric, dim=0, descending=False).indices + low_bar = int(current_skip_num * (1 - ras_manager.MANAGER.high_ratio)) + high_bar = int(current_skip_num * ras_manager.MANAGER.high_ratio) + cached_patchified_indices = torch.cat([indices[:low_bar], indices[-high_bar:]]) + other_patchified_indices = indices[low_bar:-high_bar] + self.drop_cnt[cached_patchified_indices] += 1 + # latent_cached_indices = self.extract_latents_index_from_patched_latents_index(cached_patchified_indices, height) + latent_cached_indices = self.extract_latents_index_from_patched_latents_index(cached_patchified_indices, frames, height, width) + + return latent_cached_indices, other_patchified_indices + + + def step( + self, + model_output: torch.FloatTensor, + timestep: Union[float, torch.FloatTensor], + sample: torch.FloatTensor, + s_churn: float = 0.0, + s_tmin: float = 0.0, + s_tmax: float = float("inf"), + s_noise: float = 1.0, + generator: Optional[torch.Generator] = None, + return_dict: bool = True, + ) -> Union[RASFlowMatchEulerDiscreteSchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + s_churn (`float`): + s_tmin (`float`): + s_tmax (`float`): + s_noise (`float`, defaults to 1.0): + Scaling factor for noise added to the sample. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or + tuple. + + Returns: + [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_euler_discrete.EulerDiscreteSchedulerOutput`] is + returned, otherwise a tuple is returned where the first element is the sample tensor. + """ + if ( + isinstance(timestep, int) + or isinstance(timestep, torch.IntTensor) + or isinstance(timestep, torch.LongTensor) + ): + raise ValueError( + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), + ) + + if self.step_index is None: + self._init_step_index(timestep) + + if self.drop_cnt is None or self._step_index == 0: + self._init_ras_config(sample) + + if self._step_index == 0: + ras_manager.MANAGER.reset_cache() + + ## ([1, 16, 21, 60, 106] + batch, latent_dim, frames, height, width = sample.shape + # latent_dim, height, width = sample.shape[-3:]s + + assert ras_manager.MANAGER.sample_ratio > 0.0 and ras_manager.MANAGER.sample_ratio <= 1.0 + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step: + #bcthw -> cthw -> c (thw) -> (thw) c + model_output.squeeze(0).view(latent_dim, -1)[:, ras_manager.MANAGER.cached_index] = ras_manager.MANAGER.cached_scaled_noise + model_output = model_output.transpose(0, 1).view(latent_dim, frames, height, width).unsqueeze(0) + + # Upcast to avoid precision issues when computing prev_sample + sample = sample.to(torch.float32) + + sigma = self.sigmas[self.step_index] + sigma_next = self.sigmas[self.step_index + 1] + + diff = (sigma_next - sigma) * model_output + prev_sample = sample + diff + # Cast sample back to model compatible dtype + prev_sample = prev_sample.to(model_output.dtype) + + if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_next_RAS_step: + latent_cached_indices, other_patchified_indices = self.ras_selection(sample, diff, frames, height, width) + ras_manager.MANAGER.cached_scaled_noise = model_output.squeeze(0).view(latent_dim, -1)[:, latent_cached_indices] + ras_manager.MANAGER.cached_index = latent_cached_indices + ras_manager.MANAGER.other_patchified_index = other_patchified_indices + + # upon completion increase step index by one + self._step_index += 1 + ras_manager.MANAGER.increase_step() + if ras_manager.MANAGER.current_step >= ras_manager.MANAGER.num_steps: + ras_manager.MANAGER.reset_cache() + + if not return_dict: + return (prev_sample,) + + return RASFlowMatchEulerDiscreteSchedulerOutput(prev_sample=prev_sample) \ No newline at end of file diff --git a/src/ras/utils/ras_argparser.py b/src/ras/utils/ras_argparser.py index 7407f92..ab745e8 100644 --- a/src/ras/utils/ras_argparser.py +++ b/src/ras/utils/ras_argparser.py @@ -22,4 +22,8 @@ def parse_args(): parser.add_argument("--skip_num_step_length", type=int, default=0, help="The interval to change the skip token number") parser.add_argument("--enable_index_fusion", action="store_true", help="Enable index fusion for RAS") + parser.add_argument("--num_frames", type=int, default=81, help="Num frames for video generation models.") + parser.add_argument("--temporal_patch_size", type=int, default=1, help="temporal patch size") + parser.add_argument("--is_video", action="store_true", help="Whether current model is a video model") + return parser.parse_args() diff --git a/src/ras/utils/ras_manager.py b/src/ras/utils/ras_manager.py index bac4986..5a4111d 100644 --- a/src/ras/utils/ras_manager.py +++ b/src/ras/utils/ras_manager.py @@ -1,6 +1,7 @@ class ras_manager: def __init__(self): self.patch_size = 2 + self.temporal_patch_size = 1 self.scheduler_start_step = 4 self.scheduler_end_step = 30 self.metric = "std" @@ -9,12 +10,14 @@ def __init__(self): self.sample_ratio = 0.5 self.starvation_scale = 0.1 self.vae_size = 8 + self.temporal_vae_size = 4 self.high_ratio = 1 self.num_steps = 30 self.current_step = 0 self.enable_index_fusion = False self.is_RAS_step = False self.is_next_RAS_step = False + self.is_video = False self.cached_index = None self.other_index = None @@ -27,6 +30,8 @@ def __init__(self): def set_parameters(self, args): self.patch_size = args.patch_size + self.temporal_patch_size = args.temporal_patch_size + self.is_video = args.is_video # self.scheduler_pattern = args.scheduler_pattern self.scheduler_start_step = args.scheduler_start_step self.scheduler_end_step = args.scheduler_end_step @@ -39,13 +44,18 @@ def set_parameters(self, args): self.skip_num_step_length = args.skip_num_step_length self.height = args.height self.width = args.width + self.num_frames = args.num_frames self.high_ratio = args.high_ratio self.enable_index_fusion = args.enable_index_fusion self.generate_skip_token_list() def generate_skip_token_list(self): - avg_skip_token_num = int((1 - self.sample_ratio) * ((self.height // self.patch_size) // self.vae_size) * ((self.width // self.patch_size) // self.vae_size)) + if self.is_video: + latent_frames = ((self.num_frames - 1) // self.temporal_vae_size) + 1 + avg_skip_token_num = int((1 - self.sample_ratio) * ((latent_frames // self.temporal_patch_size) * (self.height // self.patch_size) // self.vae_size) * ((self.width // self.patch_size) // self.vae_size)) + else: + avg_skip_token_num = int((1 - self.sample_ratio) * ((self.height // self.patch_size) // self.vae_size) * ((self.width // self.patch_size) // self.vae_size)) if self.skip_num_step_length == 0: # static dropping self.skip_token_num_list = [avg_skip_token_num for i in range(self.num_steps)] for i in self.error_reset_steps: diff --git a/src/ras/utils/wan/__init__.py b/src/ras/utils/wan/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ras/utils/wan/update_pipeline_wan.py b/src/ras/utils/wan/update_pipeline_wan.py new file mode 100644 index 0000000..c7362b0 --- /dev/null +++ b/src/ras/utils/wan/update_pipeline_wan.py @@ -0,0 +1,12 @@ +from ...schedulers import RASWanFlowMatchEulerDiscreteScheduler +from ...modules.attention_processor import RASWanAttnProcessor2_0 +from ...modules.wan.transformer_forward import ras_forward + +def update_wan_pipeline(pipeline): + scheduler = RASWanFlowMatchEulerDiscreteScheduler.from_config(pipeline.scheduler.config) + pipeline.scheduler = scheduler + pipeline.transformer.forward = ras_forward.__get__(pipeline.transformer, pipeline.transformer.__class__) + for block in pipeline.transformer.blocks: + block.attn1.set_processor(RASWanAttnProcessor2_0()) + block.attn2.set_processor(RASWanAttnProcessor2_0()) + return pipeline \ No newline at end of file