mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
up
This commit is contained in:
@@ -302,8 +302,6 @@ def main(args):
|
||||
scheduler_config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"prediction_type": "trigflow",
|
||||
"max_timesteps": 1.57080,
|
||||
"intermediate_timesteps": 1.3,
|
||||
"sigma_data": 0.5,
|
||||
}
|
||||
scheduler = SCMScheduler(**scheduler_config)
|
||||
|
||||
@@ -359,6 +359,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
norm_eps: float = 1e-6,
|
||||
interpolation_scale: Optional[int] = None,
|
||||
guidance_embeds: bool = False,
|
||||
guidance_embeds_scale: float = 0.1,
|
||||
qk_norm: Optional[str] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
@@ -248,6 +248,65 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
||||
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
|
||||
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
|
||||
the prompt.
|
||||
"""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_sequence_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
@@ -296,6 +355,13 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
elif self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
|
||||
@@ -320,43 +386,18 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
select_index = [0] + list(range(-max_length + 1, 0))
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=complex_human_instruction,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0][:, select_index]
|
||||
prompt_embeds = prompt_embeds[:, select_index]
|
||||
prompt_attention_mask = prompt_attention_mask[:, select_index]
|
||||
|
||||
if self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
elif self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -366,25 +407,15 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=negative_prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=False,
|
||||
)
|
||||
negative_prompt_attention_mask = uncond_input.attention_mask
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
@@ -248,17 +248,73 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
"""
|
||||
self.vae.disable_tiling()
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
|
||||
def _get_gemma_prompt_embeds(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: Optional[List[str]] = None,
|
||||
):
|
||||
r"""
|
||||
Encodes the prompt into text encoder hidden states.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
device: (`torch.device`, *optional*):
|
||||
torch device to place the resulting embeddings on
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
||||
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
|
||||
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
|
||||
the prompt.
|
||||
"""
|
||||
prompt = [prompt] if isinstance(prompt, str) else prompt
|
||||
|
||||
if getattr(self, "tokenizer", None) is not None:
|
||||
self.tokenizer.padding_side = "right"
|
||||
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_sequence_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
clean_caption: bool = False,
|
||||
max_sequence_length: int = 300,
|
||||
complex_human_instruction: Optional[List[str]] = None,
|
||||
@@ -270,12 +326,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
prompt to be encoded
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
|
||||
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
|
||||
PixArt-Alpha, this should be "".
|
||||
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
|
||||
whether to use classifier free guidance or not
|
||||
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
number of images that should be generated per prompt
|
||||
device: (`torch.device`, *optional*):
|
||||
@@ -283,8 +334,6 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.Tensor`, *optional*):
|
||||
Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
|
||||
clean_caption (`bool`, defaults to `False`):
|
||||
If `True`, the function will preprocess and clean the provided caption before encoding.
|
||||
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
||||
@@ -296,6 +345,13 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
elif self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
# set lora scale so that monkey patched LoRA
|
||||
# function of text encoder can correctly access it
|
||||
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
|
||||
@@ -320,43 +376,18 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
select_index = [0] + list(range(-max_length + 1, 0))
|
||||
|
||||
if prompt_embeds is None:
|
||||
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
||||
|
||||
# prepare complex human instruction
|
||||
if not complex_human_instruction:
|
||||
max_length_all = max_length
|
||||
else:
|
||||
chi_prompt = "\n".join(complex_human_instruction)
|
||||
prompt = [chi_prompt + p for p in prompt]
|
||||
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
||||
max_length_all = num_chi_prompt_tokens + max_length - 2
|
||||
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=max_length_all,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
||||
prompt=prompt,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=complex_human_instruction,
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
prompt_attention_mask = text_inputs.attention_mask
|
||||
prompt_attention_mask = prompt_attention_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
||||
prompt_embeds = prompt_embeds[0][:, select_index]
|
||||
prompt_embeds = prompt_embeds[:, select_index]
|
||||
prompt_attention_mask = prompt_attention_mask[:, select_index]
|
||||
|
||||
if self.transformer is not None:
|
||||
dtype = self.transformer.dtype
|
||||
elif self.text_encoder is not None:
|
||||
dtype = self.text_encoder.dtype
|
||||
else:
|
||||
dtype = None
|
||||
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
@@ -364,49 +395,13 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
# get unconditional embeddings for classifier free guidance
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
||||
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
||||
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
||||
max_length = prompt_embeds.shape[1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
negative_prompt_attention_mask = uncond_input.attention_mask
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
||||
|
||||
negative_prompt_embeds = self.text_encoder(
|
||||
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
||||
)
|
||||
negative_prompt_embeds = negative_prompt_embeds[0]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
seq_len = negative_prompt_embeds.shape[1]
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
else:
|
||||
negative_prompt_embeds = None
|
||||
negative_prompt_attention_mask = None
|
||||
|
||||
if self.text_encoder is not None:
|
||||
if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
|
||||
# Retrieve the original scale by scaling back the LoRA layers
|
||||
unscale_lora_layers(self.text_encoder, lora_scale)
|
||||
|
||||
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
|
||||
return prompt_embeds, prompt_attention_mask
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
@@ -431,12 +426,13 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
num_inference_steps,
|
||||
timesteps,
|
||||
max_timesteps,
|
||||
intermediate_timesteps,
|
||||
callback_on_step_end_tensor_inputs=None,
|
||||
negative_prompt=None,
|
||||
prompt_embeds=None,
|
||||
negative_prompt_embeds=None,
|
||||
prompt_attention_mask=None,
|
||||
negative_prompt_attention_mask=None,
|
||||
):
|
||||
if height % 32 != 0 or width % 32 != 0:
|
||||
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
|
||||
@@ -460,37 +456,21 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if negative_prompt is not None and negative_prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
|
||||
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt_embeds is not None and prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
|
||||
|
||||
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
|
||||
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
|
||||
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
|
||||
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
|
||||
|
||||
if timesteps is not None and max_timesteps is not None:
|
||||
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
|
||||
|
||||
if timesteps is None and max_timesteps is None:
|
||||
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
|
||||
|
||||
if intermediate_timesteps is not None and num_inference_steps != 2:
|
||||
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
|
||||
|
||||
if prompt_embeds is not None and negative_prompt_embeds is not None:
|
||||
if prompt_embeds.shape != negative_prompt_embeds.shape:
|
||||
raise ValueError(
|
||||
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
|
||||
f" {negative_prompt_embeds.shape}."
|
||||
)
|
||||
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
|
||||
raise ValueError(
|
||||
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
|
||||
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
|
||||
f" {negative_prompt_attention_mask.shape}."
|
||||
)
|
||||
|
||||
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
|
||||
def _text_preprocessing(self, text, clean_caption=False):
|
||||
@@ -632,6 +612,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
return caption.strip()
|
||||
|
||||
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
|
||||
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
|
||||
if latents is not None:
|
||||
return latents.to(device=device, dtype=dtype)
|
||||
@@ -659,10 +640,6 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
def attention_kwargs(self):
|
||||
return self._attention_kwargs
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def num_timesteps(self):
|
||||
return self._num_timesteps
|
||||
@@ -676,10 +653,10 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: str = "",
|
||||
num_inference_steps: int = 20,
|
||||
num_inference_steps: int = 2,
|
||||
timesteps: List[int] = None,
|
||||
sigmas: List[float] = None,
|
||||
max_timesteps: float = 1.57080,
|
||||
intermediate_timesteps: float = 1.3,
|
||||
guidance_scale: float = 4.5,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
height: int = 1024,
|
||||
@@ -689,8 +666,6 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.Tensor] = None,
|
||||
prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
clean_caption: bool = False,
|
||||
@@ -724,14 +699,14 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
num_inference_steps (`int`, *optional*, defaults to 20):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
max_timesteps (`float`, *optional*, defaults to 1.57080):
|
||||
The maximum timestep value used in the SCM scheduler.
|
||||
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
|
||||
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
sigmas (`List[float]`, *optional*):
|
||||
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
@@ -822,15 +797,16 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_attention_mask,
|
||||
prompt=prompt,
|
||||
height=height,
|
||||
width=width,
|
||||
num_inference_steps=num_inference_steps,
|
||||
timesteps=timesteps,
|
||||
max_timesteps=max_timesteps,
|
||||
intermediate_timesteps=intermediate_timesteps,
|
||||
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds=prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
)
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
@@ -852,29 +828,24 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
(
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
_,
|
||||
_,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
False,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
clean_caption=clean_caption,
|
||||
max_sequence_length=max_sequence_length,
|
||||
complex_human_instruction=complex_human_instruction,
|
||||
lora_scale=lora_scale,
|
||||
)
|
||||
# prompt_embeds = torch.load("/raid/yiyi/Sana-Sprint-diffusers/y.pt").to(device, dtype=prompt_embeds.dtype)
|
||||
|
||||
# 4. Prepare timesteps
|
||||
timesteps, num_inference_steps = retrieve_timesteps(
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
||||
self.scheduler, num_inference_steps, device, timesteps, sigmas=None, max_timesteps=max_timesteps, intermediate_timesteps=intermediate_timesteps
|
||||
)
|
||||
if hasattr(self.scheduler, "set_begin_index"):
|
||||
self.scheduler.set_begin_index(0)
|
||||
|
||||
# 5. Prepare latents.
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
@@ -889,14 +860,11 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
latents,
|
||||
)
|
||||
|
||||
# latents = torch.load("/raid/yiyi/Sana-Sprint-diffusers/latents.pt").to(device, dtype=latents.dtype)
|
||||
|
||||
latents = latents * self.scheduler.config.sigma_data
|
||||
|
||||
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
|
||||
guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
|
||||
# YiYi TODO: cfg_embed_scale = 0.1 (refactor this out)
|
||||
guidance = guidance * 0.1
|
||||
guidance = guidance * self.transformer.config.guidance_embeds_scale
|
||||
|
||||
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
@@ -915,11 +883,8 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
|
||||
|
||||
# YiYi TODO: self.scheduler.scale_model_input?
|
||||
latents_model_input = latents / self.scheduler.config.sigma_data
|
||||
|
||||
# YiYi TODO: refator this out
|
||||
scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
|
||||
latent_model_input = latents_model_input * torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2)
|
||||
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
|
||||
@@ -935,17 +900,15 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
attention_kwargs=self.attention_kwargs,
|
||||
)[0]
|
||||
|
||||
# YiYi TODO: refator this out
|
||||
noise_pred = (
|
||||
(1 - 2 * scm_timestep) * latent_model_input
|
||||
+ (1 - 2 * scm_timestep + 2 * scm_timestep**2) * noise_pred
|
||||
) / torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2)
|
||||
# YiYi TODO: check if this can be refatored into scheduler
|
||||
noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
latents, denoised = self.scheduler.step(
|
||||
noise_pred, i, timestep, latents, **extra_step_kwargs, return_dict=False
|
||||
noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
|
||||
)
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
@@ -965,7 +928,6 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
# YiYi TODO: refator this out
|
||||
latents = denoised / self.scheduler.config.sigma_data
|
||||
if output_type == "latent":
|
||||
image = latents
|
||||
|
||||
@@ -75,8 +75,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
prediction_type: str = "trigflow",
|
||||
max_timesteps: float = 1.57080,
|
||||
intermediate_timesteps: Optional[float] = 1.3,
|
||||
sigma_data: float = 0.5,
|
||||
):
|
||||
"""
|
||||
@@ -87,10 +85,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
The number of diffusion steps to train the model.
|
||||
prediction_type (`str`, defaults to `trigflow`):
|
||||
Prediction type of the scheduler function. Currently only supports "trigflow".
|
||||
max_timesteps (`float`, defaults to 1.57080):
|
||||
The maximum timestep value used in the diffusion process.
|
||||
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
|
||||
The intermediate timestep value used when num_inference_steps=2.
|
||||
sigma_data (`float`, defaults to 0.5):
|
||||
The standard deviation of the noise added during multi-step inference.
|
||||
"""
|
||||
@@ -101,11 +95,35 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
@property
|
||||
def step_index(self):
|
||||
return self._step_index
|
||||
|
||||
@property
|
||||
def begin_index(self):
|
||||
return self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
||||
def set_begin_index(self, begin_index: int = 0):
|
||||
"""
|
||||
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
||||
|
||||
Args:
|
||||
begin_index (`int`):
|
||||
The begin index for the scheduler.
|
||||
"""
|
||||
self._begin_index = begin_index
|
||||
|
||||
def set_timesteps(
|
||||
self,
|
||||
num_inference_steps: int,
|
||||
timesteps: torch.Tensor = None,
|
||||
device: Union[str, torch.device] = None,
|
||||
max_timesteps: float = 1.57080,
|
||||
intermediate_timesteps: float = 1.3,
|
||||
):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
||||
@@ -113,6 +131,12 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
The number of diffusion steps used when generating samples with a pre-trained model.
|
||||
timesteps (`torch.Tensor`, *optional*):
|
||||
Custom timesteps to use for the denoising process.
|
||||
max_timesteps (`float`, defaults to 1.57080):
|
||||
The maximum timestep value used in the SCM scheduler.
|
||||
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
|
||||
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
|
||||
"""
|
||||
if num_inference_steps > self.config.num_train_timesteps:
|
||||
raise ValueError(
|
||||
@@ -121,39 +145,68 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
f" maximal {self.config.num_train_timesteps} timesteps."
|
||||
)
|
||||
|
||||
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
|
||||
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
|
||||
|
||||
if timesteps is not None and max_timesteps is not None:
|
||||
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
|
||||
|
||||
if timesteps is None and max_timesteps is None:
|
||||
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
|
||||
|
||||
if intermediate_timesteps is not None and num_inference_steps != 2:
|
||||
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
|
||||
|
||||
self.num_inference_steps = num_inference_steps
|
||||
|
||||
if timesteps is not None and len(timesteps) == num_inference_steps + 1:
|
||||
if timesteps is not None:
|
||||
if isinstance(timesteps, list):
|
||||
self.timesteps = torch.tensor(timesteps, device=device).float()
|
||||
elif isinstance(timesteps, torch.Tensor):
|
||||
self.timesteps = timesteps.to(device).float()
|
||||
else:
|
||||
raise ValueError(f"Unsupported timesteps type: {type(timesteps)}")
|
||||
elif self.config.intermediate_timesteps and num_inference_steps == 2:
|
||||
elif intermediate_timesteps is not None:
|
||||
self.timesteps = torch.tensor(
|
||||
[self.config.max_timesteps, self.config.intermediate_timesteps, 0], device=device
|
||||
[max_timesteps, intermediate_timesteps, 0], device=device
|
||||
).float()
|
||||
elif self.config.intermediate_timesteps:
|
||||
self.timesteps = torch.linspace(
|
||||
self.config.max_timesteps, 0, num_inference_steps + 1, device=device
|
||||
).float()
|
||||
logger.warning(
|
||||
f"Intermediate timesteps for SCM is not supported when num_inference_steps != 2. "
|
||||
f"Reset timesteps to {self.timesteps} default max_timesteps"
|
||||
)
|
||||
else:
|
||||
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
|
||||
self.timesteps = torch.linspace(
|
||||
self.config.max_timesteps, 0, num_inference_steps + 1, device=device
|
||||
max_timesteps, 0, num_inference_steps + 1, device=device
|
||||
).float()
|
||||
|
||||
print(f"Set timesteps: {self.timesteps}")
|
||||
|
||||
self._step_index = None
|
||||
self._begin_index = None
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
||||
def _init_step_index(self, timestep):
|
||||
if self.begin_index is None:
|
||||
if isinstance(timestep, torch.Tensor):
|
||||
timestep = timestep.to(self.timesteps.device)
|
||||
self._step_index = self.index_for_timestep(timestep)
|
||||
else:
|
||||
self._step_index = self._begin_index
|
||||
|
||||
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
|
||||
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
||||
if schedule_timesteps is None:
|
||||
schedule_timesteps = self.timesteps
|
||||
|
||||
indices = (schedule_timesteps == timestep).nonzero()
|
||||
|
||||
# The sigma index that is taken for the **very** first `step`
|
||||
# is always the second index (or the last index if there is only 1)
|
||||
# This way we can ensure we don't accidentally skip a sigma in
|
||||
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
||||
pos = 1 if len(indices) > 1 else 0
|
||||
|
||||
return indices[pos].item()
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timeindex: int,
|
||||
timestep: float,
|
||||
sample: torch.FloatTensor,
|
||||
generator: torch.Generator = None,
|
||||
@@ -183,10 +236,13 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
raise ValueError(
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
|
||||
if self.step_index is None:
|
||||
self._init_step_index(timestep)
|
||||
|
||||
# 2. compute alphas, betas
|
||||
t = self.timesteps[timeindex + 1]
|
||||
s = self.timesteps[timeindex]
|
||||
t = self.timesteps[self.step_index + 1]
|
||||
s = self.timesteps[self.step_index]
|
||||
|
||||
# 4. Different Parameterization:
|
||||
parameterization = self.config.prediction_type
|
||||
@@ -206,6 +262,8 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
|
||||
prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise
|
||||
else:
|
||||
prev_sample = pred_x0
|
||||
|
||||
self._step_index += 1
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, pred_x0)
|
||||
|
||||
Reference in New Issue
Block a user