Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions src/ras/model_inference/flux_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import argparse
import torch
import sys
import time

sys.path.append('/workspace/RAS')
sys.path.append('/workspace/RAS/src')
from diffusers import FluxPipeline
from ras.utils.flux.update_pipeline_flux import update_flux_pipeline
from ras.utils import ras_manager
from ras.utils.ras_argparser import parse_args

def flux_inf(args):
pipeline = FluxPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
).to("cuda")
pipeline = update_flux_pipeline(pipeline)
pipeline.transformer.to(memory_format=torch.channels_last)
pipeline.vae.to(memory_format=torch.channels_last)
print("Vae", pipeline.vae_scale_factor)
generator = torch.Generator("cuda").manual_seed(args.seed) if args.seed is not None else None
numsteps = args.num_inference_steps
start = time.time()
image = pipeline(
generator=generator,
num_inference_steps=numsteps,
prompt=args.prompt,
height=args.height,
width=args.width,
).images[0]
print(f"Pipeline time {time.time()-start}")
image.save(args.output)


if __name__ == "__main__":
args = parse_args()
ras_manager.MANAGER.set_parameters(args)
flux_inf(args)
4 changes: 4 additions & 0 deletions src/ras/model_inference/lumina_next_t2i_inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import argparse
import torch
import sys

sys.path.append('/workspace/RAS')
sys.path.append('/workspace/RAS/src')
from diffusers import LuminaText2ImgPipeline
from ras.utils.lumina_next_t2i.update_pipeline_lumina import update_lumina_pipeline
from ras.utils import ras_manager
Expand Down
38 changes: 38 additions & 0 deletions src/ras/model_inference/wan_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import argparse
import torch
from diffusers import AutoencoderKLWan, WanPipeline
from diffusers.utils import export_to_video
import sys

sys.path.append('/workspace/RAS')
sys.path.append('/workspace/RAS/src')
from ras.utils.wan.update_pipeline_wan import update_wan_pipeline
from ras.utils import ras_manager
from ras.utils.ras_argparser import parse_args

def wan_inf(args):
# Available models: Wan-AI/Wan2.1-T2V-14B-Diffusers, Wan-AI/Wan2.1-T2V-1.3B-Diffusers
model_id = "Wan-AI/Wan2.1-T2V-14B-Diffusers"
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
pipeline = WanPipeline.from_pretrained(model_id, vae=vae, torch_dtype=torch.bfloat16)
pipeline.to("cuda")
pipeline = update_wan_pipeline(pipeline)
pipeline.enable_sequential_cpu_offload()
generator = torch.Generator("cuda").manual_seed(args.seed) if args.seed is not None else None
numsteps = args.num_inference_steps
video = pipeline(
generator=generator,
prompt=args.prompt,
negative_prompt=args.negative_prompt,
height=args.height,
width=args.width,
num_frames=args.num_frames,
guidance_scale=5.0,
num_inference_steps=numsteps
).frames[0]
export_to_video(video, args.output, fps=15)

if __name__ == "__main__":
args = parse_args()
ras_manager.MANAGER.set_parameters(args)
wan_inf(args)
289 changes: 287 additions & 2 deletions src/ras/modules/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import math
import torch
from ..utils import ras_manager
from typing import Callable, List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union, Dict, Any
from torch import nn
from diffusers.models.attention_processor import Attention
import torch.nn.functional as F
Expand Down Expand Up @@ -44,7 +44,6 @@ def __call__(
base_sequence_length: Optional[int] = None,
) -> torch.Tensor:
from diffusers.models.embeddings import apply_rotary_emb

is_self_attention = True if hidden_states.shape == encoder_hidden_states.shape else False

input_ndim = hidden_states.ndim
Expand Down Expand Up @@ -343,3 +342,289 @@ def __call__(
return hidden_states, encoder_hidden_states
else:
return hidden_states


class RASFluxAttnProcessor2_0:
"""Attention processor used typically in processing the SD3-like self-attention projections."""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("FluxAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
"""
Case 1:
No encoder states passed => hidden_states = [encoder_hidden_states, hidden_states]
aka indices in ras cache need to be + len(encoder_hidden_states) aka 512
"""
def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
) -> torch.FloatTensor:
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
is_self_attention = encoder_hidden_states is None
# `sample` projections.
query = attn.to_q(hidden_states) #[encoder_hidden_states, hidden_states] i.e token indices are offset by 512
k_fuse_linear = ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step and \
self.k_cache is not None and ras_manager.MANAGER.enable_index_fusion
v_fuse_linear = k_fuse_linear

if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step:
if is_self_attention:
other_patchified_index = 512 + ras_manager.MANAGER.other_patchified_index
padding_indices = torch.arange(512, device=other_patchified_index.device)
other_patchified_index = torch.cat([padding_indices, other_patchified_index])
else:
other_patchified_index = ras_manager.MANAGER.other_patchified_index
print(other_patchified_index.shape)

if k_fuse_linear:
from .fused_kernels_sd3 import _partially_linear
_partially_linear(
hidden_states,
attn.to_k.weight,
attn.to_k.bias,
other_patchified_index,
self.k_cache.view(batch_size, self.k_cache.shape[1], -1)
)
else:
key = attn.to_k(hidden_states)
if v_fuse_linear:
_partially_linear(
hidden_states,
attn.to_v.weight,
attn.to_v.bias,
other_patchified_index,
self.v_cache.view(batch_size, self.v_cache.shape[1], -1)
)
else:
value = attn.to_v(hidden_states)

if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.current_step == 0:
self.k_cache = None
self.v_cache = None

if ras_manager.MANAGER.sample_ratio < 1.0 and (ras_manager.MANAGER.current_step == ras_manager.MANAGER.scheduler_start_step - 1 or ras_manager.MANAGER.current_step in ras_manager.MANAGER.error_reset_steps):
self.k_cache = key
self.v_cache = value

if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step:
if not ras_manager.MANAGER.enable_index_fusion:
self.k_cache[:, other_patchified_index] = key
self.v_cache[:, other_patchified_index] = value
key = self.k_cache
value = self.v_cache

if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.current_step > ras_manager.MANAGER.scheduler_end_step:
self.k_cache = None
self.v_cache = None

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
if encoder_hidden_states is not None:
# `context` projections.
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)

encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
batch_size, -1, attn.heads, head_dim
).transpose(1, 2)

if attn.norm_added_q is not None:
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
if attn.norm_added_k is not None:
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)

# attention
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)

##TODO: Create partial rope applying kernels.
if image_rotary_emb is not None:
from diffusers.models.embeddings import apply_rotary_emb
if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step:
query = apply_rotary_emb(query, ras_manager.MANAGER.image_rotary_emb_skip)
key = apply_rotary_emb(key, image_rotary_emb)
else:
query = apply_rotary_emb(query, image_rotary_emb)
key = apply_rotary_emb(key, image_rotary_emb)

if not ras_manager.MANAGER.replace_with_flash_attn:
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
else:
from flash_attn import flash_attn_func
hidden_states = flash_attn_func(
query, key, value, dropout_p=0.0, causal=False
)
hidden_states = hidden_states.view(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

if encoder_hidden_states is not None:
# print(f"Before, {hidden_states.shape}")
encoder_hidden_states, hidden_states = (
hidden_states[:, : encoder_hidden_states.shape[1]],
hidden_states[:, encoder_hidden_states.shape[1] :],
)
# print(f"After, {hidden_states.shape}")

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

encoder_hidden_states = attn.to_add_out(encoder_hidden_states)

return hidden_states, encoder_hidden_states
else:
return hidden_states

class RASWanAttnProcessor2_0:
def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("WanAttnProcessor2_0 requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
if ras_manager.MANAGER.sample_ratio < 1.0:
self.k_cache = None
self.v_cache = None

def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
rotary_emb: Optional[torch.Tensor] = None,
) -> torch.Tensor:
encoder_hidden_states_img = None
is_self_attention = False
if attn.add_k_proj is not None:
# 512 is the context length of the text encoder, hardcoded for now
image_context_length = encoder_hidden_states.shape[1] - 512
encoder_hidden_states_img = encoder_hidden_states[:, :image_context_length]
encoder_hidden_states = encoder_hidden_states[:, image_context_length:]
if encoder_hidden_states is None:
is_self_attention = True
encoder_hidden_states = hidden_states

query = attn.to_q(hidden_states)
## if self attention, key and value are going to be truncated as well. Need to use k_cache and v_cache for RAS_steps
v_fuse_linear = ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step and \
is_self_attention and self.v_cache is not None \
and ras_manager.MANAGER.enable_index_fusion
k_fuse_linear = v_fuse_linear and rotary_emb is None

if v_fuse_linear:
from .fused_kernels_lumina import _partially_linear
_partially_linear(
encoder_hidden_states,
attn.to_v.weight,
attn.to_v.bias,
ras_manager.MANAGER.other_patchified_index,
self.v_cache.view(batch_size, self.v_cache.shape[1], -1)
)
else:
value = attn.to_v(encoder_hidden_states)

if k_fuse_linear:
_partially_linear(
encoder_hidden_states,
attn.to_k.weight,
attn.to_k.bias,
ras_manager.MANAGER.other_patchified_index,
self.k_cache.view(batch_size, self.k_cache.shape[1], -1)
)
else:
key = attn.to_k(encoder_hidden_states)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)

if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.current_step == 0 and is_self_attention:
self.k_cache = None
self.v_cache = None

if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.current_step > ras_manager.MANAGER.scheduler_end_step and is_self_attention:
self.k_cache = None
self.v_cache = None

if rotary_emb is not None:

def apply_rotary_emb(hidden_states: torch.Tensor, freqs: torch.Tensor):
x_rotated = torch.view_as_complex(hidden_states.to(torch.float64).unflatten(3, (-1, 2)))
x_out = torch.view_as_real(x_rotated * freqs).flatten(3, 4)
return x_out.type_as(hidden_states)

if ras_manager.MANAGER.sample_ratio < 1.0 and ras_manager.MANAGER.is_RAS_step:
query = apply_rotary_emb(query, ras_manager.MANAGER.image_rotary_emb_skip)
key = apply_rotary_emb(key, ras_manager.MANAGER.image_rotary_emb_skip)
else:
query = apply_rotary_emb(query, rotary_emb)
key = apply_rotary_emb(key, rotary_emb)

if ras_manager.MANAGER.sample_ratio < 1.0 and (ras_manager.MANAGER.current_step == ras_manager.MANAGER.scheduler_start_step - 1 or ras_manager.MANAGER.current_step in ras_manager.MANAGER.error_reset_steps) and is_self_attention:
self.k_cache = key
self.v_cache = value

if ras_manager.MANAGER.sample_ratio < 1.0 and is_self_attention and ras_manager.MANAGER.is_RAS_step:
if not ras_manager.MANAGER.enable_index_fusion:
self.k_cache[:, :, ras_manager.MANAGER.other_patchified_index, :] = key
self.v_cache[:, :, ras_manager.MANAGER.other_patchified_index, :] = value
key = self.k_cache
value = self.v_cache

# I2V task
hidden_states_img = None
if encoder_hidden_states_img is not None:
key_img = attn.add_k_proj(encoder_hidden_states_img)
key_img = attn.norm_added_k(key_img)
value_img = attn.add_v_proj(encoder_hidden_states_img)

key_img = key_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)
value_img = value_img.unflatten(2, (attn.heads, -1)).transpose(1, 2)

hidden_states_img = F.scaled_dot_product_attention(
query, key_img, value_img, attn_mask=None, dropout_p=0.0, is_causal=False
)
hidden_states_img = hidden_states_img.transpose(1, 2).flatten(2, 3)
hidden_states_img = hidden_states_img.type_as(query)

hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
hidden_states = hidden_states.type_as(query)

if hidden_states_img is not None:
hidden_states = hidden_states + hidden_states_img

hidden_states = attn.to_out[0](hidden_states)
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
Empty file.
Loading