Skip to content

Conversation

@kashif
Copy link
Contributor

@kashif kashif commented Nov 23, 2025

What does this PR do?

  • Removes the redundant txt_seq_lens plumbing from all QwenImage pipelines and modular steps; the transformer now infers text length from encoder inputs/masks and validates optional overrides.
  • Builds a lightweight broadcastable attention mask from encoder_hidden_states_mask inside the double-stream attention, avoiding full seq_len² masks while keeping padding tokens masked.
  • Adjusts QwenImage Transformer/ControlNet RoPE to take a single text length and documents the fallback behavior.
  • Adds regression tests to ensure short txt_seq_lens values and encoder masks are handled safely.

Fixes #12344

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@kashif kashif requested a review from sayakpaul November 23, 2025 18:03
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul requested a review from yiyixuxu November 24, 2025 01:57
@dxqb
Copy link

dxqb commented Nov 29, 2025

just a few comments, not a full review:

  • there is some overlap with Fix qwen encoder hidden states mask #12655
  • this code has the same issue mentioned in Fix qwen encoder hidden states mask #12655 (expecting boolean semantics in a FloatTensor - but float attention masks are interpreted differently)
  • Could you clarify what the purpose of this PR is?
    If the purpose is to remove the txt_seq_lens parameters, and infer the sequence lengths from the attention mask: why is it still a parameter of the transformer model?
    If the purpose is towards passing sequence lengths to the attention dispatch (see Qwen Image: txt_seq_lens is redundant and not used #12344 (comment)), the sequence lengths for each batch sample must be inferred from the mask and passed to the transformer blocks, not only the max sequence length across all batch samples for RoPE

raise ValueError(f"`txt_seq_lens` must have length {batch_size}, but got {len(txt_seq_lens)} instead.")
text_seq_len = max(text_seq_len, max(txt_seq_lens))
elif encoder_hidden_states_mask is not None:
text_seq_len = max(text_seq_len, int(encoder_hidden_states_mask.sum(dim=1).max().item()))
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works if the attention mask is in the form of [True, True, True, ..., False, False, False]. While this is the case in the most common use case of text attention masks, it doesn't have to be the case.

If the mask is [True, False, True, False, True, False], self.pos_embed receives an incorrect sequence length

@kashif
Copy link
Contributor Author

kashif commented Nov 29, 2025

thanks @dxqb the idea was to remove txt_seq_lens all together and work with any mask pattern

Copy link
Collaborator

@yiyixuxu yiyixuxu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks a ton for the PR! @kashif
I left one question, let me know!

- Remove seq_lens parameter from dispatch_attention_fn
- Update varlen backends to extract seqlens from masks
- Update QwenImage to pass 2D joint_attention_mask
- Fix native backend to handle 2D boolean masks
- Fix sage_varlen seqlens_q to match seqlens_k for self-attention

Note: sage_varlen still producing black images, needs further investigation
@sayakpaul
Copy link
Member

#12655 provides some benchmarks on the speed, as well. Possible to provide them here too? @kashif

@kashif
Copy link
Contributor Author

kashif commented Dec 8, 2025

some benchmarks with various backends:

code:
benchmark_backends_qwen.py

backend_benchmark_complete

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Qwen Image: txt_seq_lens is redundant and not used

5 participants