1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Hi Dream] follow-up (#11296)

* add
This commit is contained in:
YiYi Xu
2025-04-17 01:17:44 -10:00
committed by GitHub
parent 29d2afbfe2
commit 056793295c
3 changed files with 427 additions and 208 deletions

View File

@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
@@ -686,46 +686,108 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
x = torch.cat(x_arr, dim=0)
return x
def patchify(self, x, max_seq, img_sizes=None):
pz2 = self.config.patch_size * self.config.patch_size
if isinstance(x, torch.Tensor):
B, C = x.shape[0], x.shape[1]
device = x.device
dtype = x.dtype
else:
B, C = len(x), x[0].shape[0]
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
def patchify(self, hidden_states):
batch_size, channels, height, width = hidden_states.shape
patch_size = self.config.patch_size
patch_height, patch_width = height // patch_size, width // patch_size
device = hidden_states.device
dtype = hidden_states.dtype
if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0 : img_size[0] * img_size[1]] = 1
B, C, S, _ = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C)
elif isinstance(x, torch.Tensor):
B, C, Hp1, Wp2 = x.shape
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size)
x = x.permute(0, 2, 4, 3, 5, 1)
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C)
img_sizes = [[pH, pW]] * B
x_masks = None
# create img_sizes
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
# create hidden_states_masks
if hidden_states.shape[-2] != hidden_states.shape[-1]:
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
hidden_states_masks[:, : patch_height * patch_width] = 1.0
else:
raise NotImplementedError
return x, x_masks, img_sizes
hidden_states_masks = None
# create img_ids
img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
row_indices = torch.arange(patch_height, device=device)[:, None]
col_indices = torch.arange(patch_width, device=device)[None, :]
img_ids[..., 1] = img_ids[..., 1] + row_indices
img_ids[..., 2] = img_ids[..., 2] + col_indices
img_ids = img_ids.reshape(patch_height * patch_width, -1)
if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
img_ids_pad[: patch_height * patch_width, :] = img_ids
img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
else:
img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
# patchify hidden_states
if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
out = torch.zeros(
(batch_size, channels, self.max_seq, patch_size * patch_size),
dtype=dtype,
device=device,
)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height * patch_width, patch_size * patch_size
)
out[:, :, 0 : patch_height * patch_width] = hidden_states
hidden_states = out
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch_size, self.max_seq, patch_size * patch_size * channels
)
else:
# Handle square latents
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
hidden_states = hidden_states.reshape(
batch_size, patch_height * patch_width, patch_size * patch_size * channels
)
return hidden_states, hidden_states_masks, img_sizes, img_ids
def forward(
self,
hidden_states: torch.Tensor,
timesteps: torch.LongTensor = None,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_t5: torch.Tensor = None,
encoder_hidden_states_llama3: torch.Tensor = None,
pooled_embeds: torch.Tensor = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
img_ids: Optional[torch.Tensor] = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
hidden_states_masks: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
if encoder_hidden_states is not None:
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
encoder_hidden_states_t5 = encoder_hidden_states[0]
encoder_hidden_states_llama3 = encoder_hidden_states[1]
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
deprecation_message = (
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
)
deprecate("img_ids", "0.34.0", deprecation_message)
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
elif hidden_states_masks is not None and hidden_states.ndim != 3:
raise ValueError(
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
)
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -745,42 +807,19 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype
if hidden_states.shape[-2] != hidden_states.shape[-1]:
B, C, H, W = hidden_states.shape
patch_size = self.config.patch_size
pH, pW = H // patch_size, W // patch_size
out = torch.zeros(
(B, C, self.max_seq, patch_size * patch_size),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
out[:, :, 0 : pH * pW] = hidden_states
hidden_states = out
# Patchify the input
if hidden_states_masks is None:
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
# Embed the hidden states
hidden_states = self.x_embedder(hidden_states)
# 0. time
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder
hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
if hidden_states_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = (
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
hidden_states = self.x_embedder(hidden_states)
T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states[-1]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
if self.caption_projection is not None:
new_encoder_hidden_states = []
@@ -789,9 +828,9 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states)
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(encoder_hidden_states_t5)
txt_ids = torch.zeros(
batch_size,

View File

@@ -15,7 +15,7 @@ from transformers import (
from ...image_processor import VaeImageProcessor
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_output import HiDreamImagePipelineOutput
@@ -38,9 +38,6 @@ EXAMPLE_DOC_STRING = """
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
>>> scheduler = UniPCMultistepScheduler(
... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True
... )
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
@@ -52,7 +49,6 @@ EXAMPLE_DOC_STRING = """
>>> pipe = HiDreamImagePipeline.from_pretrained(
... "HiDream-ai/HiDream-I1-Full",
... scheduler=scheduler,
... tokenizer_4=tokenizer_4,
... text_encoder_4=text_encoder_4,
... torch_dtype=torch.bfloat16,
@@ -148,7 +144,7 @@ def retrieve_timesteps(
class HiDreamImagePipeline(DiffusionPipeline):
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
_callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"]
def __init__(
self,
@@ -309,10 +305,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
def encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
prompt_4: Union[str, List[str]],
prompt: Optional[Union[str, List[str]]] = None,
prompt_2: Optional[Union[str, List[str]]] = None,
prompt_3: Optional[Union[str, List[str]]] = None,
prompt_4: Optional[Union[str, List[str]]] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
@@ -321,8 +317,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
negative_prompt_2: Optional[Union[str, List[str]]] = None,
negative_prompt_3: Optional[Union[str, List[str]]] = None,
negative_prompt_4: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
@@ -332,120 +330,177 @@ class HiDreamImagePipeline(DiffusionPipeline):
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
batch_size = pooled_prompt_embeds.shape[0]
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
prompt=prompt,
prompt_2=prompt_2,
prompt_3=prompt_3,
prompt_4=prompt_4,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance and negative_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt_2 = negative_prompt_2 or negative_prompt
negative_prompt_3 = negative_prompt_3 or negative_prompt
negative_prompt_4 = negative_prompt_4 or negative_prompt
# normalize str to list
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_2 = (
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
)
negative_prompt_3 = (
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
)
negative_prompt_4 = (
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
)
if prompt is not None and type(prompt) is not type(negative_prompt):
raise TypeError(
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
f" {type(prompt)}."
)
elif batch_size != len(negative_prompt):
raise ValueError(
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
" the batch size of `prompt`."
)
negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
prompt=negative_prompt,
prompt_2=negative_prompt_2,
prompt_3=negative_prompt_3,
prompt_4=negative_prompt_4,
device=device,
dtype=dtype,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=negative_pooled_prompt_embeds,
max_sequence_length=max_sequence_length,
)
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
def _encode_prompt(
self,
prompt: Union[str, List[str]],
prompt_2: Union[str, List[str]],
prompt_3: Union[str, List[str]],
prompt_4: Union[str, List[str]],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
num_images_per_prompt: int = 1,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 128,
):
device = device or self._execution_device
if prompt is not None:
batch_size = len(prompt)
else:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
if pooled_prompt_embeds is None:
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
)
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
negative_prompt = negative_prompt or ""
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
if len(negative_prompt) > 1 and len(negative_prompt) != batch_size:
raise ValueError(f"negative_prompt must be of length 1 or {batch_size}")
negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype
)
if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1:
negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1)
if pooled_prompt_embeds is None:
prompt_2 = prompt_2 or prompt
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
)
if len(prompt_2) > 1 and len(prompt_2) != batch_size:
raise ValueError(f"prompt_2 must be of length 1 or {batch_size}")
pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype
)
if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1)
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
negative_prompt_2 = negative_prompt_2 or negative_prompt
negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size:
raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}")
negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype
)
if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1)
if pooled_prompt_embeds is None:
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
negative_pooled_prompt_embeds = torch.cat(
[negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1
)
if prompt_embeds is None:
if prompt_embeds_t5 is None:
prompt_3 = prompt_3 or prompt
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
if len(prompt_3) > 1 and len(prompt_3) != batch_size:
raise ValueError(f"prompt_3 must be of length 1 or {batch_size}")
prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
if prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
if do_classifier_free_guidance and negative_prompt_embeds_t5 is None:
negative_prompt_3 = negative_prompt_3 or negative_prompt
negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size:
raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}")
negative_prompt_embeds_t5 = self._get_t5_prompt_embeds(
negative_prompt_3, max_sequence_length, device, dtype
)
if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
if prompt_embeds_llama3 is None:
prompt_4 = prompt_4 or prompt
prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
if len(prompt_4) > 1 and len(prompt_4) != batch_size:
raise ValueError(f"prompt_4 must be of length 1 or {batch_size}")
_, seq_len, _ = t5_prompt_embeds.shape
t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1)
t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
_, _, seq_len, dim = llama3_prompt_embeds.shape
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None:
negative_prompt_4 = negative_prompt_4 or negative_prompt
negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
return prompt_embeds, pooled_prompt_embeds
if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size:
raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}")
negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds(
negative_prompt_4, max_sequence_length, device, dtype
)
if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
# duplicate pooled_prompt_embeds for each generation per prompt
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
# duplicate t5_prompt_embeds for batch_size and num_images_per_prompt
bs_embed, seq_len, _ = prompt_embeds_t5.shape
if bs_embed == 1 and batch_size > 1:
prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}")
prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
# duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt
_, bs_embed, seq_len, dim = prompt_embeds_llama3.shape
if bs_embed == 1 and batch_size > 1:
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}")
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
if do_classifier_free_guidance:
# duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt
bs_embed, seq_len = negative_pooled_prompt_embeds.shape
if bs_embed == 1 and batch_size > 1:
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}")
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
# duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt
bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape
if bs_embed == 1 and batch_size > 1:
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}")
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
# duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt
_, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape
if bs_embed == 1 and batch_size > 1:
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
elif bs_embed > 1 and bs_embed != batch_size:
raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}")
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view(
-1, batch_size * num_images_per_prompt, seq_len, dim
)
return (
prompt_embeds_t5,
negative_prompt_embeds_t5,
prompt_embeds_llama3,
negative_prompt_embeds_llama3,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
)
def enable_vae_slicing(self):
r"""
@@ -476,6 +531,115 @@ class HiDreamImagePipeline(DiffusionPipeline):
"""
self.vae.disable_tiling()
def check_inputs(
self,
prompt,
prompt_2,
prompt_3,
prompt_4,
negative_prompt=None,
negative_prompt_2=None,
negative_prompt_3=None,
negative_prompt_4=None,
prompt_embeds_t5=None,
prompt_embeds_llama3=None,
negative_prompt_embeds_t5=None,
negative_prompt_embeds_llama3=None,
pooled_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
callback_on_step_end_tensor_inputs=None,
):
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
if prompt is not None and pooled_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_2 is not None and pooled_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
" only forward one of the two."
)
elif prompt_3 is not None and prompt_embeds_t5 is not None:
raise ValueError(
f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to"
" only forward one of the two."
)
elif prompt_4 is not None and prompt_embeds_llama3 is not None:
raise ValueError(
f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to"
" only forward one of the two."
)
elif prompt is None and pooled_prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined."
)
elif prompt is None and prompt_embeds_t5 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined."
)
elif prompt is None and prompt_embeds_llama3 is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined."
)
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)):
raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}")
if negative_prompt is not None and negative_pooled_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:"
f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:"
f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
)
elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:"
f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two."
)
elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None:
raise ValueError(
f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:"
f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two."
)
if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None:
if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape:
raise ValueError(
"`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but"
f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`"
f" {negative_pooled_prompt_embeds.shape}."
)
if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None:
if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape:
raise ValueError(
"`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but"
f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`"
f" {negative_prompt_embeds_t5.shape}."
)
if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None:
if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape:
raise ValueError(
"`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but"
f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`"
f" {negative_prompt_embeds_llama3.shape}."
)
def prepare_latents(
self,
batch_size,
@@ -542,8 +706,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_embeds_t5: Optional[torch.FloatTensor] = None,
prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
negative_prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
@@ -552,6 +718,7 @@ class HiDreamImagePipeline(DiffusionPipeline):
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 128,
**kwargs,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -649,6 +816,22 @@ class HiDreamImagePipeline(DiffusionPipeline):
[`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is a list with the generated. images.
"""
prompt_embeds = kwargs.get("prompt_embeds", None)
negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None)
if prompt_embeds is not None:
deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead."
deprecate("prompt_embeds", "0.34.0", deprecation_message)
prompt_embeds_t5 = prompt_embeds[0]
prompt_embeds_llama3 = prompt_embeds[1]
if negative_prompt_embeds is not None:
deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead."
deprecate("negative_prompt_embeds", "0.34.0", deprecation_message)
negative_prompt_embeds_t5 = negative_prompt_embeds[0]
negative_prompt_embeds_llama3 = negative_prompt_embeds[1]
height = height or self.default_sample_size * self.vae_scale_factor
width = width or self.default_sample_size * self.vae_scale_factor
@@ -658,6 +841,25 @@ class HiDreamImagePipeline(DiffusionPipeline):
scale = math.sqrt(scale)
width, height = int(width * scale // division * division), int(height * scale // division * division)
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
prompt_3,
prompt_4,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
negative_prompt_3=negative_prompt_3,
negative_prompt_4=negative_prompt_4,
prompt_embeds_t5=prompt_embeds_t5,
prompt_embeds_llama3=prompt_embeds_llama3,
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)
self._guidance_scale = guidance_scale
self._attention_kwargs = attention_kwargs
self._interrupt = False
@@ -667,17 +869,18 @@ class HiDreamImagePipeline(DiffusionPipeline):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
elif prompt_embeds is not None:
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
else:
batch_size = 1
elif pooled_prompt_embeds is not None:
batch_size = pooled_prompt_embeds.shape[0]
device = self._execution_device
# 3. Encode prompt
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
(
prompt_embeds,
negative_prompt_embeds,
prompt_embeds_t5,
negative_prompt_embeds_t5,
prompt_embeds_llama3,
negative_prompt_embeds_llama3,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.encode_prompt(
@@ -690,8 +893,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
negative_prompt_3=negative_prompt_3,
negative_prompt_4=negative_prompt_4,
do_classifier_free_guidance=self.do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_embeds_t5=prompt_embeds_t5,
prompt_embeds_llama3=prompt_embeds_llama3,
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
device=device,
@@ -701,13 +906,8 @@ class HiDreamImagePipeline(DiffusionPipeline):
)
if self.do_classifier_free_guidance:
prompt_embeds_arr = []
for n, p in zip(negative_prompt_embeds, prompt_embeds):
if len(n.shape) == 3:
prompt_embeds_arr.append(torch.cat([n, p], dim=0))
else:
prompt_embeds_arr.append(torch.cat([n, p], dim=1))
prompt_embeds = prompt_embeds_arr
prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5], dim=0)
prompt_embeds_llama3 = torch.cat([negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
# 4. Prepare latent variables
@@ -723,26 +923,6 @@ class HiDreamImagePipeline(DiffusionPipeline):
latents,
)
if latents.shape[-2] != latents.shape[-1]:
B, C, H, W = latents.shape
pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
img_ids = torch.zeros(pH, pW, 3)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
img_ids = img_ids.reshape(pH * pW, -1)
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
img_ids_pad[: pH * pW, :] = img_ids
img_sizes = img_sizes.unsqueeze(0).to(latents.device)
img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
if self.do_classifier_free_guidance:
img_sizes = img_sizes.repeat(2 * B, 1)
img_ids = img_ids.repeat(2 * B, 1, 1)
else:
img_sizes = img_ids = None
# 5. Prepare timesteps
mu = calculate_shift(self.transformer.max_seq)
scheduler_kwargs = {"mu": mu}
@@ -774,10 +954,9 @@ class HiDreamImagePipeline(DiffusionPipeline):
noise_pred = self.transformer(
hidden_states=latent_model_input,
timesteps=timestep,
encoder_hidden_states=prompt_embeds,
encoder_hidden_states_t5=prompt_embeds_t5,
encoder_hidden_states_llama3=prompt_embeds_llama3,
pooled_embeds=pooled_prompt_embeds,
img_sizes=img_sizes,
img_ids=img_ids,
return_dict=False,
)[0]
noise_pred = -noise_pred
@@ -803,8 +982,9 @@ class HiDreamImagePipeline(DiffusionPipeline):
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", prompt_embeds_t5)
prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", prompt_embeds_llama3)
pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):

View File

@@ -43,7 +43,7 @@ enable_full_determinism()
class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = HiDreamImagePipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS