mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -8,7 +8,7 @@ from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import PeftAdapterMixin
|
||||
from ...models.modeling_outputs import Transformer2DModelOutput
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
||||
from ...utils.torch_utils import maybe_allow_in_graph
|
||||
from ..attention import Attention
|
||||
from ..embeddings import TimestepEmbedding, Timesteps
|
||||
@@ -686,46 +686,108 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
x = torch.cat(x_arr, dim=0)
|
||||
return x
|
||||
|
||||
def patchify(self, x, max_seq, img_sizes=None):
|
||||
pz2 = self.config.patch_size * self.config.patch_size
|
||||
if isinstance(x, torch.Tensor):
|
||||
B, C = x.shape[0], x.shape[1]
|
||||
device = x.device
|
||||
dtype = x.dtype
|
||||
else:
|
||||
B, C = len(x), x[0].shape[0]
|
||||
device = x[0].device
|
||||
dtype = x[0].dtype
|
||||
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
|
||||
def patchify(self, hidden_states):
|
||||
batch_size, channels, height, width = hidden_states.shape
|
||||
patch_size = self.config.patch_size
|
||||
patch_height, patch_width = height // patch_size, width // patch_size
|
||||
device = hidden_states.device
|
||||
dtype = hidden_states.dtype
|
||||
|
||||
if img_sizes is not None:
|
||||
for i, img_size in enumerate(img_sizes):
|
||||
x_masks[i, 0 : img_size[0] * img_size[1]] = 1
|
||||
B, C, S, _ = x.shape
|
||||
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C)
|
||||
elif isinstance(x, torch.Tensor):
|
||||
B, C, Hp1, Wp2 = x.shape
|
||||
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
|
||||
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size)
|
||||
x = x.permute(0, 2, 4, 3, 5, 1)
|
||||
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C)
|
||||
img_sizes = [[pH, pW]] * B
|
||||
x_masks = None
|
||||
# create img_sizes
|
||||
img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
|
||||
img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
|
||||
|
||||
# create hidden_states_masks
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)
|
||||
hidden_states_masks[:, : patch_height * patch_width] = 1.0
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return x, x_masks, img_sizes
|
||||
hidden_states_masks = None
|
||||
|
||||
# create img_ids
|
||||
img_ids = torch.zeros(patch_height, patch_width, 3, device=device)
|
||||
row_indices = torch.arange(patch_height, device=device)[:, None]
|
||||
col_indices = torch.arange(patch_width, device=device)[None, :]
|
||||
img_ids[..., 1] = img_ids[..., 1] + row_indices
|
||||
img_ids[..., 2] = img_ids[..., 2] + col_indices
|
||||
img_ids = img_ids.reshape(patch_height * patch_width, -1)
|
||||
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
# Handle non-square latents
|
||||
img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
|
||||
img_ids_pad[: patch_height * patch_width, :] = img_ids
|
||||
img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
else:
|
||||
img_ids = img_ids.unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
|
||||
# patchify hidden_states
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
# Handle non-square latents
|
||||
out = torch.zeros(
|
||||
(batch_size, channels, self.max_seq, patch_size * patch_size),
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channels, patch_height, patch_size, patch_width, patch_size
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channels, patch_height * patch_width, patch_size * patch_size
|
||||
)
|
||||
out[:, :, 0 : patch_height * patch_width] = hidden_states
|
||||
hidden_states = out
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
|
||||
batch_size, self.max_seq, patch_size * patch_size * channels
|
||||
)
|
||||
|
||||
else:
|
||||
# Handle square latents
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, channels, patch_height, patch_size, patch_width, patch_size
|
||||
)
|
||||
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
|
||||
hidden_states = hidden_states.reshape(
|
||||
batch_size, patch_height * patch_width, patch_size * patch_size * channels
|
||||
)
|
||||
|
||||
return hidden_states, hidden_states_masks, img_sizes, img_ids
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timesteps: torch.LongTensor = None,
|
||||
encoder_hidden_states: torch.Tensor = None,
|
||||
encoder_hidden_states_t5: torch.Tensor = None,
|
||||
encoder_hidden_states_llama3: torch.Tensor = None,
|
||||
pooled_embeds: torch.Tensor = None,
|
||||
img_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
img_ids: Optional[torch.Tensor] = None,
|
||||
img_sizes: Optional[List[Tuple[int, int]]] = None,
|
||||
hidden_states_masks: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
|
||||
|
||||
if encoder_hidden_states is not None:
|
||||
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
|
||||
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
|
||||
encoder_hidden_states_t5 = encoder_hidden_states[0]
|
||||
encoder_hidden_states_llama3 = encoder_hidden_states[1]
|
||||
|
||||
if img_ids is not None and img_sizes is not None and hidden_states_masks is None:
|
||||
deprecation_message = (
|
||||
"Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
|
||||
)
|
||||
deprecate("img_ids", "0.34.0", deprecation_message)
|
||||
|
||||
if hidden_states_masks is not None and (img_ids is None or img_sizes is None):
|
||||
raise ValueError("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed.")
|
||||
elif hidden_states_masks is not None and hidden_states.ndim != 3:
|
||||
raise ValueError(
|
||||
"if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
|
||||
)
|
||||
|
||||
if attention_kwargs is not None:
|
||||
attention_kwargs = attention_kwargs.copy()
|
||||
lora_scale = attention_kwargs.pop("scale", 1.0)
|
||||
@@ -745,42 +807,19 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
batch_size = hidden_states.shape[0]
|
||||
hidden_states_type = hidden_states.dtype
|
||||
|
||||
if hidden_states.shape[-2] != hidden_states.shape[-1]:
|
||||
B, C, H, W = hidden_states.shape
|
||||
patch_size = self.config.patch_size
|
||||
pH, pW = H // patch_size, W // patch_size
|
||||
out = torch.zeros(
|
||||
(B, C, self.max_seq, patch_size * patch_size),
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device,
|
||||
)
|
||||
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
|
||||
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
|
||||
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
|
||||
out[:, :, 0 : pH * pW] = hidden_states
|
||||
hidden_states = out
|
||||
# Patchify the input
|
||||
if hidden_states_masks is None:
|
||||
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
|
||||
|
||||
# Embed the hidden states
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
# 0. time
|
||||
timesteps = self.t_embedder(timesteps, hidden_states_type)
|
||||
p_embedder = self.p_embedder(pooled_embeds)
|
||||
temb = timesteps + p_embedder
|
||||
|
||||
hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
|
||||
if hidden_states_masks is None:
|
||||
pH, pW = img_sizes[0]
|
||||
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
|
||||
img_ids = (
|
||||
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
|
||||
.unsqueeze(0)
|
||||
.repeat(batch_size, 1, 1)
|
||||
)
|
||||
hidden_states = self.x_embedder(hidden_states)
|
||||
|
||||
T5_encoder_hidden_states = encoder_hidden_states[0]
|
||||
encoder_hidden_states = encoder_hidden_states[-1]
|
||||
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
|
||||
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]
|
||||
|
||||
if self.caption_projection is not None:
|
||||
new_encoder_hidden_states = []
|
||||
@@ -789,9 +828,9 @@ class HiDreamImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
||||
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
|
||||
new_encoder_hidden_states.append(enc_hidden_state)
|
||||
encoder_hidden_states = new_encoder_hidden_states
|
||||
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
|
||||
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
|
||||
encoder_hidden_states.append(T5_encoder_hidden_states)
|
||||
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
|
||||
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
|
||||
encoder_hidden_states.append(encoder_hidden_states_t5)
|
||||
|
||||
txt_ids = torch.zeros(
|
||||
batch_size,
|
||||
|
||||
@@ -15,7 +15,7 @@ from transformers import (
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...models import AutoencoderKL, HiDreamImageTransformer2DModel
|
||||
from ...schedulers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler
|
||||
from ...utils import is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
||||
from ...utils.torch_utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_output import HiDreamImagePipelineOutput
|
||||
@@ -38,9 +38,6 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
|
||||
>>> from diffusers import UniPCMultistepScheduler, HiDreamImagePipeline
|
||||
|
||||
>>> scheduler = UniPCMultistepScheduler(
|
||||
... flow_shift=3.0, prediction_type="flow_prediction", use_flow_sigmas=True
|
||||
... )
|
||||
|
||||
>>> tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
|
||||
>>> text_encoder_4 = LlamaForCausalLM.from_pretrained(
|
||||
@@ -52,7 +49,6 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
>>> pipe = HiDreamImagePipeline.from_pretrained(
|
||||
... "HiDream-ai/HiDream-I1-Full",
|
||||
... scheduler=scheduler,
|
||||
... tokenizer_4=tokenizer_4,
|
||||
... text_encoder_4=text_encoder_4,
|
||||
... torch_dtype=torch.bfloat16,
|
||||
@@ -148,7 +144,7 @@ def retrieve_timesteps(
|
||||
|
||||
class HiDreamImagePipeline(DiffusionPipeline):
|
||||
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->text_encoder_4->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds_t5", "prompt_embeds_llama3", "pooled_prompt_embeds"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -309,10 +305,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_3: Union[str, List[str]],
|
||||
prompt_4: Union[str, List[str]],
|
||||
prompt: Optional[Union[str, List[str]]] = None,
|
||||
prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
prompt_4: Optional[Union[str, List[str]]] = None,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
@@ -321,8 +317,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
negative_prompt_2: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_3: Optional[Union[str, List[str]]] = None,
|
||||
negative_prompt_4: Optional[Union[str, List[str]]] = None,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
|
||||
prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds_t5: Optional[List[torch.FloatTensor]] = None,
|
||||
negative_prompt_embeds_llama3: Optional[List[torch.FloatTensor]] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
@@ -332,120 +330,177 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
|
||||
batch_size = pooled_prompt_embeds.shape[0]
|
||||
|
||||
prompt_embeds, pooled_prompt_embeds = self._encode_prompt(
|
||||
prompt=prompt,
|
||||
prompt_2=prompt_2,
|
||||
prompt_3=prompt_3,
|
||||
prompt_4=prompt_4,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
||||
negative_prompt_4 = negative_prompt_4 or negative_prompt
|
||||
|
||||
# normalize str to list
|
||||
negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_2 = (
|
||||
batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
)
|
||||
negative_prompt_3 = (
|
||||
batch_size * [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
||||
)
|
||||
negative_prompt_4 = (
|
||||
batch_size * [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
|
||||
)
|
||||
|
||||
if prompt is not None and type(prompt) is not type(negative_prompt):
|
||||
raise TypeError(
|
||||
f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
|
||||
f" {type(prompt)}."
|
||||
)
|
||||
elif batch_size != len(negative_prompt):
|
||||
raise ValueError(
|
||||
f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
|
||||
f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
|
||||
" the batch size of `prompt`."
|
||||
)
|
||||
|
||||
negative_prompt_embeds, negative_pooled_prompt_embeds = self._encode_prompt(
|
||||
prompt=negative_prompt,
|
||||
prompt_2=negative_prompt_2,
|
||||
prompt_3=negative_prompt_3,
|
||||
prompt_4=negative_prompt_4,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
max_sequence_length=max_sequence_length,
|
||||
)
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
def _encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
prompt_2: Union[str, List[str]],
|
||||
prompt_3: Union[str, List[str]],
|
||||
prompt_4: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
max_sequence_length: int = 128,
|
||||
):
|
||||
device = device or self._execution_device
|
||||
if prompt is not None:
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
|
||||
negative_prompt = negative_prompt or ""
|
||||
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
|
||||
|
||||
if len(negative_prompt) > 1 and len(negative_prompt) != batch_size:
|
||||
raise ValueError(f"negative_prompt must be of length 1 or {batch_size}")
|
||||
|
||||
negative_pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer, self.text_encoder, negative_prompt, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if negative_pooled_prompt_embeds_1.shape[0] == 1 and batch_size > 1:
|
||||
negative_pooled_prompt_embeds_1 = negative_pooled_prompt_embeds_1.repeat(batch_size, 1)
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
prompt_2 = prompt_2 or prompt
|
||||
prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
|
||||
|
||||
pooled_prompt_embeds_1 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer, self.text_encoder, prompt, max_sequence_length, device, dtype
|
||||
)
|
||||
if len(prompt_2) > 1 and len(prompt_2) != batch_size:
|
||||
raise ValueError(f"prompt_2 must be of length 1 or {batch_size}")
|
||||
|
||||
pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer_2, self.text_encoder_2, prompt_2, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
|
||||
pooled_prompt_embeds_2 = pooled_prompt_embeds_2.repeat(batch_size, 1)
|
||||
|
||||
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
|
||||
negative_prompt_2 = negative_prompt_2 or negative_prompt
|
||||
negative_prompt_2 = [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
|
||||
|
||||
if len(negative_prompt_2) > 1 and len(negative_prompt_2) != batch_size:
|
||||
raise ValueError(f"negative_prompt_2 must be of length 1 or {batch_size}")
|
||||
|
||||
negative_pooled_prompt_embeds_2 = self._get_clip_prompt_embeds(
|
||||
self.tokenizer_2, self.text_encoder_2, negative_prompt_2, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if negative_pooled_prompt_embeds_2.shape[0] == 1 and batch_size > 1:
|
||||
negative_pooled_prompt_embeds_2 = negative_pooled_prompt_embeds_2.repeat(batch_size, 1)
|
||||
|
||||
if pooled_prompt_embeds is None:
|
||||
pooled_prompt_embeds = torch.cat([pooled_prompt_embeds_1, pooled_prompt_embeds_2], dim=-1)
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
if do_classifier_free_guidance and negative_pooled_prompt_embeds is None:
|
||||
negative_pooled_prompt_embeds = torch.cat(
|
||||
[negative_pooled_prompt_embeds_1, negative_pooled_prompt_embeds_2], dim=-1
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
if prompt_embeds_t5 is None:
|
||||
prompt_3 = prompt_3 or prompt
|
||||
prompt_3 = [prompt_3] if isinstance(prompt_3, str) else prompt_3
|
||||
|
||||
if len(prompt_3) > 1 and len(prompt_3) != batch_size:
|
||||
raise ValueError(f"prompt_3 must be of length 1 or {batch_size}")
|
||||
|
||||
prompt_embeds_t5 = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
|
||||
|
||||
if prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
|
||||
prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
|
||||
|
||||
if do_classifier_free_guidance and negative_prompt_embeds_t5 is None:
|
||||
negative_prompt_3 = negative_prompt_3 or negative_prompt
|
||||
negative_prompt_3 = [negative_prompt_3] if isinstance(negative_prompt_3, str) else negative_prompt_3
|
||||
|
||||
if len(negative_prompt_3) > 1 and len(negative_prompt_3) != batch_size:
|
||||
raise ValueError(f"negative_prompt_3 must be of length 1 or {batch_size}")
|
||||
|
||||
negative_prompt_embeds_t5 = self._get_t5_prompt_embeds(
|
||||
negative_prompt_3, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if negative_prompt_embeds_t5.shape[0] == 1 and batch_size > 1:
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
|
||||
|
||||
if prompt_embeds_llama3 is None:
|
||||
prompt_4 = prompt_4 or prompt
|
||||
prompt_4 = [prompt_4] if isinstance(prompt_4, str) else prompt_4
|
||||
|
||||
t5_prompt_embeds = self._get_t5_prompt_embeds(prompt_3, max_sequence_length, device, dtype)
|
||||
llama3_prompt_embeds = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
|
||||
if len(prompt_4) > 1 and len(prompt_4) != batch_size:
|
||||
raise ValueError(f"prompt_4 must be of length 1 or {batch_size}")
|
||||
|
||||
_, seq_len, _ = t5_prompt_embeds.shape
|
||||
t5_prompt_embeds = t5_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
t5_prompt_embeds = t5_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
prompt_embeds_llama3 = self._get_llama3_prompt_embeds(prompt_4, max_sequence_length, device, dtype)
|
||||
|
||||
_, _, seq_len, dim = llama3_prompt_embeds.shape
|
||||
llama3_prompt_embeds = llama3_prompt_embeds.repeat(1, 1, num_images_per_prompt, 1)
|
||||
llama3_prompt_embeds = llama3_prompt_embeds.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
|
||||
if prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
|
||||
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
|
||||
|
||||
prompt_embeds = [t5_prompt_embeds, llama3_prompt_embeds]
|
||||
if do_classifier_free_guidance and negative_prompt_embeds_llama3 is None:
|
||||
negative_prompt_4 = negative_prompt_4 or negative_prompt
|
||||
negative_prompt_4 = [negative_prompt_4] if isinstance(negative_prompt_4, str) else negative_prompt_4
|
||||
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
if len(negative_prompt_4) > 1 and len(negative_prompt_4) != batch_size:
|
||||
raise ValueError(f"negative_prompt_4 must be of length 1 or {batch_size}")
|
||||
|
||||
negative_prompt_embeds_llama3 = self._get_llama3_prompt_embeds(
|
||||
negative_prompt_4, max_sequence_length, device, dtype
|
||||
)
|
||||
|
||||
if negative_prompt_embeds_llama3.shape[0] == 1 and batch_size > 1:
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
|
||||
|
||||
# duplicate pooled_prompt_embeds for each generation per prompt
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
# duplicate t5_prompt_embeds for batch_size and num_images_per_prompt
|
||||
bs_embed, seq_len, _ = prompt_embeds_t5.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
prompt_embeds_t5 = prompt_embeds_t5.repeat(batch_size, 1, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate prompt_embeds_t5 of batch size {bs_embed}")
|
||||
prompt_embeds_t5 = prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds_t5 = prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# duplicate llama3_prompt_embeds for batch_size and num_images_per_prompt
|
||||
_, bs_embed, seq_len, dim = prompt_embeds_llama3.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate prompt_embeds_llama3 of batch size {bs_embed}")
|
||||
prompt_embeds_llama3 = prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
|
||||
prompt_embeds_llama3 = prompt_embeds_llama3.view(-1, batch_size * num_images_per_prompt, seq_len, dim)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate negative_pooled_prompt_embeds for batch_size and num_images_per_prompt
|
||||
bs_embed, seq_len = negative_pooled_prompt_embeds.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(batch_size, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate negative_pooled_prompt_embeds of batch size {bs_embed}")
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.view(batch_size * num_images_per_prompt, -1)
|
||||
|
||||
# duplicate negative_t5_prompt_embeds for batch_size and num_images_per_prompt
|
||||
bs_embed, seq_len, _ = negative_prompt_embeds_t5.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(batch_size, 1, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate negative_prompt_embeds_t5 of batch size {bs_embed}")
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds_t5.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
# duplicate negative_prompt_embeds_llama3 for batch_size and num_images_per_prompt
|
||||
_, bs_embed, seq_len, dim = negative_prompt_embeds_llama3.shape
|
||||
if bs_embed == 1 and batch_size > 1:
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, batch_size, 1, 1)
|
||||
elif bs_embed > 1 and bs_embed != batch_size:
|
||||
raise ValueError(f"cannot duplicate negative_prompt_embeds_llama3 of batch size {bs_embed}")
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.repeat(1, 1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds_llama3.view(
|
||||
-1, batch_size * num_images_per_prompt, seq_len, dim
|
||||
)
|
||||
|
||||
return (
|
||||
prompt_embeds_t5,
|
||||
negative_prompt_embeds_t5,
|
||||
prompt_embeds_llama3,
|
||||
negative_prompt_embeds_llama3,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
@@ -476,6 +531,115 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
prompt_4,
|
||||
negative_prompt=None,
|
||||
negative_prompt_2=None,
|
||||
negative_prompt_3=None,
|
||||
negative_prompt_4=None,
|
||||
prompt_embeds_t5=None,
|
||||
prompt_embeds_llama3=None,
|
||||
negative_prompt_embeds_t5=None,
|
||||
negative_prompt_embeds_llama3=None,
|
||||
pooled_prompt_embeds=None,
|
||||
negative_pooled_prompt_embeds=None,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
):
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
if prompt is not None and pooled_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_2 is not None and pooled_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_2`: {prompt_2} and `pooled_prompt_embeds`: {pooled_prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_3 is not None and prompt_embeds_t5 is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_3`: {prompt_3} and `prompt_embeds_t5`: {prompt_embeds_t5}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt_4 is not None and prompt_embeds_llama3 is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt_4`: {prompt_4} and `prompt_embeds_llama3`: {prompt_embeds_llama3}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
elif prompt is None and pooled_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `pooled_prompt_embeds`. Cannot leave both `prompt` and `pooled_prompt_embeds` undefined."
|
||||
)
|
||||
elif prompt is None and prompt_embeds_t5 is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds_t5`. Cannot leave both `prompt` and `prompt_embeds_t5` undefined."
|
||||
)
|
||||
elif prompt is None and prompt_embeds_llama3 is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds_llama3`. Cannot leave both `prompt` and `prompt_embeds_llama3` undefined."
|
||||
)
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
|
||||
raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
|
||||
elif prompt_3 is not None and (not isinstance(prompt_3, str) and not isinstance(prompt_3, list)):
|
||||
raise ValueError(f"`prompt_3` has to be of type `str` or `list` but is {type(prompt_3)}")
|
||||
elif prompt_4 is not None and (not isinstance(prompt_4, str) and not isinstance(prompt_4, list)):
|
||||
raise ValueError(f"`prompt_4` has to be of type `str` or `list` but is {type(prompt_4)}")
|
||||
|
||||
if negative_prompt is not None and negative_pooled_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_pooled_prompt_embeds`:"
|
||||
f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_2 is not None and negative_pooled_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_pooled_prompt_embeds`:"
|
||||
f" {negative_pooled_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_3 is not None and negative_prompt_embeds_t5 is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_3`: {negative_prompt_3} and `negative_prompt_embeds_t5`:"
|
||||
f" {negative_prompt_embeds_t5}. Please make sure to only forward one of the two."
|
||||
)
|
||||
elif negative_prompt_4 is not None and negative_prompt_embeds_llama3 is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt_4`: {negative_prompt_4} and `negative_prompt_embeds_llama3`:"
|
||||
f" {negative_prompt_embeds_llama3}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if pooled_prompt_embeds is not None and negative_pooled_prompt_embeds is not None:
|
||||
if pooled_prompt_embeds.shape != negative_pooled_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`pooled_prompt_embeds` and `negative_pooled_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `pooled_prompt_embeds` {pooled_prompt_embeds.shape} != `negative_pooled_prompt_embeds`"
|
||||
f" {negative_pooled_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_embeds_t5 is not None and negative_prompt_embeds_t5 is not None:
|
||||
if prompt_embeds_t5.shape != negative_prompt_embeds_t5.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds_t5` and `negative_prompt_embeds_t5` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds_t5` {prompt_embeds_t5.shape} != `negative_prompt_embeds_t5`"
|
||||
f" {negative_prompt_embeds_t5.shape}."
|
||||
)
|
||||
if prompt_embeds_llama3 is not None and negative_prompt_embeds_llama3 is not None:
|
||||
if prompt_embeds_llama3.shape != negative_prompt_embeds_llama3.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds_llama3` and `negative_prompt_embeds_llama3` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds_llama3` {prompt_embeds_llama3.shape} != `negative_prompt_embeds_llama3`"
|
||||
f" {negative_prompt_embeds_llama3.shape}."
|
||||
)
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size,
|
||||
@@ -542,8 +706,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_t5: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds_llama3: Optional[torch.FloatTensor] = None,
|
||||
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
@@ -552,6 +718,7 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
max_sequence_length: int = 128,
|
||||
**kwargs,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -649,6 +816,22 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
[`~pipelines.hidream_image.HiDreamImagePipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is a list with the generated. images.
|
||||
"""
|
||||
|
||||
prompt_embeds = kwargs.get("prompt_embeds", None)
|
||||
negative_prompt_embeds = kwargs.get("negative_prompt_embeds", None)
|
||||
|
||||
if prompt_embeds is not None:
|
||||
deprecation_message = "The `prompt_embeds` argument is deprecated. Please use `prompt_embeds_t5` and `prompt_embeds_llama3` instead."
|
||||
deprecate("prompt_embeds", "0.34.0", deprecation_message)
|
||||
prompt_embeds_t5 = prompt_embeds[0]
|
||||
prompt_embeds_llama3 = prompt_embeds[1]
|
||||
|
||||
if negative_prompt_embeds is not None:
|
||||
deprecation_message = "The `negative_prompt_embeds` argument is deprecated. Please use `negative_prompt_embeds_t5` and `negative_prompt_embeds_llama3` instead."
|
||||
deprecate("negative_prompt_embeds", "0.34.0", deprecation_message)
|
||||
negative_prompt_embeds_t5 = negative_prompt_embeds[0]
|
||||
negative_prompt_embeds_llama3 = negative_prompt_embeds[1]
|
||||
|
||||
height = height or self.default_sample_size * self.vae_scale_factor
|
||||
width = width or self.default_sample_size * self.vae_scale_factor
|
||||
|
||||
@@ -658,6 +841,25 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
scale = math.sqrt(scale)
|
||||
width, height = int(width * scale // division * division), int(height * scale // division * division)
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
prompt_2,
|
||||
prompt_3,
|
||||
prompt_4,
|
||||
negative_prompt=negative_prompt,
|
||||
negative_prompt_2=negative_prompt_2,
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
negative_prompt_4=negative_prompt_4,
|
||||
prompt_embeds_t5=prompt_embeds_t5,
|
||||
prompt_embeds_llama3=prompt_embeds_llama3,
|
||||
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
|
||||
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
self._attention_kwargs = attention_kwargs
|
||||
self._interrupt = False
|
||||
@@ -667,17 +869,18 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
elif prompt_embeds is not None:
|
||||
batch_size = prompt_embeds[0].shape[0] if isinstance(prompt_embeds, list) else prompt_embeds.shape[0]
|
||||
else:
|
||||
batch_size = 1
|
||||
elif pooled_prompt_embeds is not None:
|
||||
batch_size = pooled_prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
|
||||
# 3. Encode prompt
|
||||
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_embeds_t5,
|
||||
negative_prompt_embeds_t5,
|
||||
prompt_embeds_llama3,
|
||||
negative_prompt_embeds_llama3,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
@@ -690,8 +893,10 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
negative_prompt_3=negative_prompt_3,
|
||||
negative_prompt_4=negative_prompt_4,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_embeds_t5=prompt_embeds_t5,
|
||||
prompt_embeds_llama3=prompt_embeds_llama3,
|
||||
negative_prompt_embeds_t5=negative_prompt_embeds_t5,
|
||||
negative_prompt_embeds_llama3=negative_prompt_embeds_llama3,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
device=device,
|
||||
@@ -701,13 +906,8 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
)
|
||||
|
||||
if self.do_classifier_free_guidance:
|
||||
prompt_embeds_arr = []
|
||||
for n, p in zip(negative_prompt_embeds, prompt_embeds):
|
||||
if len(n.shape) == 3:
|
||||
prompt_embeds_arr.append(torch.cat([n, p], dim=0))
|
||||
else:
|
||||
prompt_embeds_arr.append(torch.cat([n, p], dim=1))
|
||||
prompt_embeds = prompt_embeds_arr
|
||||
prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5], dim=0)
|
||||
prompt_embeds_llama3 = torch.cat([negative_prompt_embeds_llama3, prompt_embeds_llama3], dim=1)
|
||||
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
@@ -723,26 +923,6 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
latents,
|
||||
)
|
||||
|
||||
if latents.shape[-2] != latents.shape[-1]:
|
||||
B, C, H, W = latents.shape
|
||||
pH, pW = H // self.transformer.config.patch_size, W // self.transformer.config.patch_size
|
||||
|
||||
img_sizes = torch.tensor([pH, pW], dtype=torch.int64).reshape(-1)
|
||||
img_ids = torch.zeros(pH, pW, 3)
|
||||
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH)[:, None]
|
||||
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW)[None, :]
|
||||
img_ids = img_ids.reshape(pH * pW, -1)
|
||||
img_ids_pad = torch.zeros(self.transformer.max_seq, 3)
|
||||
img_ids_pad[: pH * pW, :] = img_ids
|
||||
|
||||
img_sizes = img_sizes.unsqueeze(0).to(latents.device)
|
||||
img_ids = img_ids_pad.unsqueeze(0).to(latents.device)
|
||||
if self.do_classifier_free_guidance:
|
||||
img_sizes = img_sizes.repeat(2 * B, 1)
|
||||
img_ids = img_ids.repeat(2 * B, 1, 1)
|
||||
else:
|
||||
img_sizes = img_ids = None
|
||||
|
||||
# 5. Prepare timesteps
|
||||
mu = calculate_shift(self.transformer.max_seq)
|
||||
scheduler_kwargs = {"mu": mu}
|
||||
@@ -774,10 +954,9 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latent_model_input,
|
||||
timesteps=timestep,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
encoder_hidden_states_t5=prompt_embeds_t5,
|
||||
encoder_hidden_states_llama3=prompt_embeds_llama3,
|
||||
pooled_embeds=pooled_prompt_embeds,
|
||||
img_sizes=img_sizes,
|
||||
img_ids=img_ids,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
noise_pred = -noise_pred
|
||||
@@ -803,8 +982,9 @@ class HiDreamImagePipeline(DiffusionPipeline):
|
||||
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
latents = callback_outputs.pop("latents", latents)
|
||||
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
||||
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
||||
prompt_embeds_t5 = callback_outputs.pop("prompt_embeds_t5", prompt_embeds_t5)
|
||||
prompt_embeds_llama3 = callback_outputs.pop("prompt_embeds_llama3", prompt_embeds_llama3)
|
||||
pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
|
||||
@@ -43,7 +43,7 @@ enable_full_determinism()
|
||||
|
||||
class HiDreamImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = HiDreamImagePipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "prompt_embeds", "negative_prompt_embeds"}
|
||||
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
|
||||
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
|
||||
|
||||
Reference in New Issue
Block a user