Skip to content

Z-Image text sequence length issue #12893

@dxqb

Description

@dxqb

Describe the bug

I think there might be an issue with calculating the sequence length of cap_feat (which is the text encoder output), and masking it accordingly.

I'm going to use code links from before the omni commit, because it's easier to read - but the issue seems exist in both, before and after the omni commit.

cap_item_seqlens = [len(_) for _ in cap_feats]
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = self.cap_embedder(cap_feats)
cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(
self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0)
)
cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0)
cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0)
# Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors
cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]]
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1

In these lines, the sequence length for each cap_feats sample is put into cap_item_seqlens and is then used to create an attention mask in line 611.

The problem with this is that cap_feats has been overwritten before in these lines:

(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = self.patchify_and_embed(x, cap_feats, patch_size, f_patch_size)

cap_feats therefore isn't anymore what the caller has passed. instead, cap_feats has already been padded by patchify_and_embed. Therefore, all cap_item_seqlens are identical to cap_max_item_seqlen.

This leads to unmasked text tokens during attention, which probably wasn't the intention here.

Reproduction

put a breakpoint in this line:

all cap_items_seqlens are identical, cap_attn_mask is all True - even if multiple different text lengths were passed to the transformer.

Logs

System Info

Who can help?

@JerryWu-code @RuoyiDu @yiyixuxu

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions