diff --git a/src/diffusers/models/transformers/transformer_z_image.py b/src/diffusers/models/transformers/transformer_z_image.py index 4f2d56ea8f..097672e0f7 100644 --- a/src/diffusers/models/transformers/transformer_z_image.py +++ b/src/diffusers/models/transformers/transformer_z_image.py @@ -482,21 +482,23 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr ).flatten(0, 2) all_cap_pos_ids.append(cap_padded_pos_ids) # pad mask - all_cap_pad_mask.append( - torch.cat( - [ - torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), - torch.ones((cap_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) + cap_pad_mask = torch.cat( + [ + torch.zeros((cap_ori_len,), dtype=torch.bool, device=device), + torch.ones((cap_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, ) + all_cap_pad_mask.append( + cap_pad_mask if cap_padding_len > 0 else torch.zeros((cap_ori_len,), dtype=torch.bool, device=device) + ) + # padded feature cap_padded_feat = torch.cat( [cap_feat, cap_feat[-1:].repeat(cap_padding_len, 1)], dim=0, ) - all_cap_feats_out.append(cap_padded_feat) + all_cap_feats_out.append(cap_padded_feat if cap_padding_len > 0 else cap_feat) ### Process Image C, F, H, W = image.size() @@ -515,30 +517,35 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr start=(cap_ori_len + cap_padding_len + 1, 0, 0), device=device, ).flatten(0, 2) - image_padding_pos_ids = ( - self.create_coordinate_grid( - size=(1, 1, 1), - start=(0, 0, 0), - device=device, - ) - .flatten(0, 2) - .repeat(image_padding_len, 1) + image_padded_pos_ids = torch.cat( + [ + image_ori_pos_ids, + self.create_coordinate_grid(size=(1, 1, 1), start=(0, 0, 0), device=device) + .flatten(0, 2) + .repeat(image_padding_len, 1), + ], + dim=0, ) - image_padded_pos_ids = torch.cat([image_ori_pos_ids, image_padding_pos_ids], dim=0) - all_image_pos_ids.append(image_padded_pos_ids) + all_image_pos_ids.append(image_padded_pos_ids if image_padding_len > 0 else image_ori_pos_ids) # pad mask + image_pad_mask = torch.cat( + [ + torch.zeros((image_ori_len,), dtype=torch.bool, device=device), + torch.ones((image_padding_len,), dtype=torch.bool, device=device), + ], + dim=0, + ) all_image_pad_mask.append( - torch.cat( - [ - torch.zeros((image_ori_len,), dtype=torch.bool, device=device), - torch.ones((image_padding_len,), dtype=torch.bool, device=device), - ], - dim=0, - ) + image_pad_mask + if image_padding_len > 0 + else torch.zeros((image_ori_len,), dtype=torch.bool, device=device) ) # padded feature - image_padded_feat = torch.cat([image, image[-1:].repeat(image_padding_len, 1)], dim=0) - all_image_out.append(image_padded_feat) + image_padded_feat = torch.cat( + [image, image[-1:].repeat(image_padding_len, 1)], + dim=0, + ) + all_image_out.append(image_padded_feat if image_padding_len > 0 else image) return ( all_image_out, @@ -588,10 +595,13 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr adaln_input = t.type_as(x) x[torch.cat(x_inner_pad_mask)] = self.x_pad_token x = list(x.split(x_item_seqlens, dim=0)) - x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) + x_freqs_cis = list(self.rope_embedder(torch.cat(x_pos_ids, dim=0)).split([len(_) for _ in x_pos_ids], dim=0)) x = pad_sequence(x, batch_first=True, padding_value=0.0) x_freqs_cis = pad_sequence(x_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + x_freqs_cis = x_freqs_cis[:, : x.shape[1]] + x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(x_item_seqlens): x_attn_mask[i, :seq_len] = 1 @@ -605,17 +615,21 @@ class ZImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOr # cap embed & refine cap_item_seqlens = [len(_) for _ in cap_feats] - assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) cap_max_item_seqlen = max(cap_item_seqlens) cap_feats = torch.cat(cap_feats, dim=0) cap_feats = self.cap_embedder(cap_feats) cap_feats[torch.cat(cap_inner_pad_mask)] = self.cap_pad_token cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) - cap_freqs_cis = list(self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) + cap_freqs_cis = list( + self.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split([len(_) for _ in cap_pos_ids], dim=0) + ) cap_feats = pad_sequence(cap_feats, batch_first=True, padding_value=0.0) cap_freqs_cis = pad_sequence(cap_freqs_cis, batch_first=True, padding_value=0.0) + # Clarify the length matches to satisfy Dynamo due to "Symbolic Shape Inference" to avoid compilation errors + cap_freqs_cis = cap_freqs_cis[:, : cap_feats.shape[1]] + cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) for i, seq_len in enumerate(cap_item_seqlens): cap_attn_mask[i, :seq_len] = 1