1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
yiyixuxu
2025-03-20 11:40:13 +01:00
parent 398ca0c938
commit 8070495df1
5 changed files with 279 additions and 229 deletions

View File

@@ -302,8 +302,6 @@ def main(args):
scheduler_config = {
"num_train_timesteps": 1000,
"prediction_type": "trigflow",
"max_timesteps": 1.57080,
"intermediate_timesteps": 1.3,
"sigma_data": 0.5,
}
scheduler = SCMScheduler(**scheduler_config)

View File

@@ -359,6 +359,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
norm_eps: float = 1e-6,
interpolation_scale: Optional[int] = None,
guidance_embeds: bool = False,
guidance_embeds_scale: float = 0.1,
qk_norm: Optional[str] = None,
) -> None:
super().__init__()

View File

@@ -248,6 +248,65 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
"""
self.vae.disable_tiling()
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
device: torch.device,
dtype: torch.dtype,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
the prompt.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_sequence_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
@@ -296,6 +355,13 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if device is None:
device = self._execution_device
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
@@ -320,43 +386,18 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
select_index = [0] + list(range(-max_length + 1, 0))
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0][:, select_index]
prompt_embeds = prompt_embeds[:, select_index]
prompt_attention_mask = prompt_attention_mask[:, select_index]
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -366,25 +407,15 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=negative_prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=False,
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method

View File

@@ -248,17 +248,73 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
"""
self.vae.disable_tiling()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline._get_gemma_prompt_embeds
def _get_gemma_prompt_embeds(
self,
prompt: Union[str, List[str]],
device: torch.device,
dtype: torch.dtype,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
):
r"""
Encodes the prompt into text encoder hidden states.
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
device: (`torch.device`, *optional*):
torch device to place the resulting embeddings on
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
the prompt.
"""
prompt = [prompt] if isinstance(prompt, str) else prompt
if getattr(self, "tokenizer", None) is not None:
self.tokenizer.padding_side = "right"
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_sequence_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
return prompt_embeds, prompt_attention_mask
def encode_prompt(
self,
prompt: Union[str, List[str]],
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
clean_caption: bool = False,
max_sequence_length: int = 300,
complex_human_instruction: Optional[List[str]] = None,
@@ -270,12 +326,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
Args:
prompt (`str` or `List[str]`, *optional*):
prompt to be encoded
negative_prompt (`str` or `List[str]`, *optional*):
The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds`
instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For
PixArt-Alpha, this should be "".
do_classifier_free_guidance (`bool`, *optional*, defaults to `True`):
whether to use classifier free guidance or not
num_images_per_prompt (`int`, *optional*, defaults to 1):
number of images that should be generated per prompt
device: (`torch.device`, *optional*):
@@ -283,8 +334,6 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.Tensor`, *optional*):
Pre-generated negative text embeddings. For Sana, it's should be the embeddings of the "" string.
clean_caption (`bool`, defaults to `False`):
If `True`, the function will preprocess and clean the provided caption before encoding.
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
@@ -296,6 +345,13 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if device is None:
device = self._execution_device
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
# set lora scale so that monkey patched LoRA
# function of text encoder can correctly access it
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
@@ -320,43 +376,18 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
select_index = [0] + list(range(-max_length + 1, 0))
if prompt_embeds is None:
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
# prepare complex human instruction
if not complex_human_instruction:
max_length_all = max_length
else:
chi_prompt = "\n".join(complex_human_instruction)
prompt = [chi_prompt + p for p in prompt]
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
max_length_all = num_chi_prompt_tokens + max_length - 2
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=max_length_all,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
prompt=prompt,
device=device,
dtype=dtype,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
)
text_input_ids = text_inputs.input_ids
prompt_attention_mask = text_inputs.attention_mask
prompt_attention_mask = prompt_attention_mask.to(device)
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
prompt_embeds = prompt_embeds[0][:, select_index]
prompt_embeds = prompt_embeds[:, select_index]
prompt_attention_mask = prompt_attention_mask[:, select_index]
if self.transformer is not None:
dtype = self.transformer.dtype
elif self.text_encoder is not None:
dtype = self.text_encoder.dtype
else:
dtype = None
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
bs_embed, seq_len, _ = prompt_embeds.shape
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
@@ -364,49 +395,13 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# get unconditional embeddings for classifier free guidance
if do_classifier_free_guidance and negative_prompt_embeds is None:
uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
max_length = prompt_embeds.shape[1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
negative_prompt_attention_mask = uncond_input.attention_mask
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
)
negative_prompt_embeds = negative_prompt_embeds[0]
if do_classifier_free_guidance:
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = negative_prompt_embeds.shape[1]
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
else:
negative_prompt_embeds = None
negative_prompt_attention_mask = None
if self.text_encoder is not None:
if isinstance(self, SanaLoraLoaderMixin) and USE_PEFT_BACKEND:
# Retrieve the original scale by scaling back the LoRA layers
unscale_lora_layers(self.text_encoder, lora_scale)
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
return prompt_embeds, prompt_attention_mask
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
@@ -431,12 +426,13 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
prompt,
height,
width,
num_inference_steps,
timesteps,
max_timesteps,
intermediate_timesteps,
callback_on_step_end_tensor_inputs=None,
negative_prompt=None,
prompt_embeds=None,
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
):
if height % 32 != 0 or width % 32 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 32 but are {height} and {width}.")
@@ -460,37 +456,21 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if negative_prompt is not None and negative_prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
)
if prompt_embeds is not None and prompt_attention_mask is None:
raise ValueError("Must provide `prompt_attention_mask` when specifying `prompt_embeds`.")
if negative_prompt_embeds is not None and negative_prompt_attention_mask is None:
raise ValueError("Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`.")
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
if timesteps is not None and max_timesteps is not None:
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
if timesteps is None and max_timesteps is None:
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
if intermediate_timesteps is not None and num_inference_steps != 2:
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
if prompt_embeds is not None and negative_prompt_embeds is not None:
if prompt_embeds.shape != negative_prompt_embeds.shape:
raise ValueError(
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
f" {negative_prompt_embeds.shape}."
)
if prompt_attention_mask.shape != negative_prompt_attention_mask.shape:
raise ValueError(
"`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but"
f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`"
f" {negative_prompt_attention_mask.shape}."
)
# Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing
def _text_preprocessing(self, text, clean_caption=False):
@@ -632,6 +612,7 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
return caption.strip()
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
if latents is not None:
return latents.to(device=device, dtype=dtype)
@@ -659,10 +640,6 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
def attention_kwargs(self):
return self._attention_kwargs
@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0
@property
def num_timesteps(self):
return self._num_timesteps
@@ -676,10 +653,10 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
num_inference_steps: int = 20,
num_inference_steps: int = 2,
timesteps: List[int] = None,
sigmas: List[float] = None,
max_timesteps: float = 1.57080,
intermediate_timesteps: float = 1.3,
guidance_scale: float = 4.5,
num_images_per_prompt: Optional[int] = 1,
height: int = 1024,
@@ -689,8 +666,6 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.Tensor] = None,
prompt_attention_mask: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
clean_caption: bool = False,
@@ -724,14 +699,14 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
num_inference_steps (`int`, *optional*, defaults to 20):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
max_timesteps (`float`, *optional*, defaults to 1.57080):
The maximum timestep value used in the SCM scheduler.
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
sigmas (`List[float]`, *optional*):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 4.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -822,15 +797,16 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
self.check_inputs(
prompt,
height,
width,
callback_on_step_end_tensor_inputs,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
prompt=prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
timesteps=timesteps,
max_timesteps=max_timesteps,
intermediate_timesteps=intermediate_timesteps,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
prompt_embeds=prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
)
self._guidance_scale = guidance_scale
@@ -852,29 +828,24 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
(
prompt_embeds,
prompt_attention_mask,
_,
_,
) = self.encode_prompt(
prompt,
False,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
clean_caption=clean_caption,
max_sequence_length=max_sequence_length,
complex_human_instruction=complex_human_instruction,
lora_scale=lora_scale,
)
# prompt_embeds = torch.load("/raid/yiyi/Sana-Sprint-diffusers/y.pt").to(device, dtype=prompt_embeds.dtype)
# 4. Prepare timesteps
timesteps, num_inference_steps = retrieve_timesteps(
self.scheduler, num_inference_steps, device, timesteps, sigmas
self.scheduler, num_inference_steps, device, timesteps, sigmas=None, max_timesteps=max_timesteps, intermediate_timesteps=intermediate_timesteps
)
if hasattr(self.scheduler, "set_begin_index"):
self.scheduler.set_begin_index(0)
# 5. Prepare latents.
latent_channels = self.transformer.config.in_channels
@@ -889,14 +860,11 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
latents,
)
# latents = torch.load("/raid/yiyi/Sana-Sprint-diffusers/latents.pt").to(device, dtype=latents.dtype)
latents = latents * self.scheduler.config.sigma_data
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0]).to(prompt_embeds.dtype)
# YiYi TODO: cfg_embed_scale = 0.1 (refactor this out)
guidance = guidance * 0.1
guidance = guidance * self.transformer.config.guidance_embeds_scale
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
@@ -915,11 +883,8 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(prompt_embeds.dtype)
# YiYi TODO: self.scheduler.scale_model_input?
latents_model_input = latents / self.scheduler.config.sigma_data
# YiYi TODO: refator this out
scm_timestep = torch.sin(timestep) / (torch.cos(timestep) + torch.sin(timestep))
latent_model_input = latents_model_input * torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2)
latent_model_input = latent_model_input.to(prompt_embeds.dtype)
@@ -935,17 +900,15 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
attention_kwargs=self.attention_kwargs,
)[0]
# YiYi TODO: refator this out
noise_pred = (
(1 - 2 * scm_timestep) * latent_model_input
+ (1 - 2 * scm_timestep + 2 * scm_timestep**2) * noise_pred
) / torch.sqrt(scm_timestep**2 + (1 - scm_timestep) ** 2)
# YiYi TODO: check if this can be refatored into scheduler
noise_pred = noise_pred.float() * self.scheduler.config.sigma_data
# compute previous image: x_t -> x_t-1
latents, denoised = self.scheduler.step(
noise_pred, i, timestep, latents, **extra_step_kwargs, return_dict=False
noise_pred, timestep, latents, **extra_step_kwargs, return_dict=False
)
if callback_on_step_end is not None:
@@ -965,7 +928,6 @@ class SanaSprintPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
if XLA_AVAILABLE:
xm.mark_step()
# YiYi TODO: refator this out
latents = denoised / self.scheduler.config.sigma_data
if output_type == "latent":
image = latents

View File

@@ -75,8 +75,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
self,
num_train_timesteps: int = 1000,
prediction_type: str = "trigflow",
max_timesteps: float = 1.57080,
intermediate_timesteps: Optional[float] = 1.3,
sigma_data: float = 0.5,
):
"""
@@ -87,10 +85,6 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
The number of diffusion steps to train the model.
prediction_type (`str`, defaults to `trigflow`):
Prediction type of the scheduler function. Currently only supports "trigflow".
max_timesteps (`float`, defaults to 1.57080):
The maximum timestep value used in the diffusion process.
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
The intermediate timestep value used when num_inference_steps=2.
sigma_data (`float`, defaults to 0.5):
The standard deviation of the noise added during multi-step inference.
"""
@@ -101,11 +95,35 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
self._step_index = None
self._begin_index = None
@property
def step_index(self):
return self._step_index
@property
def begin_index(self):
return self._begin_index
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
Args:
begin_index (`int`):
The begin index for the scheduler.
"""
self._begin_index = begin_index
def set_timesteps(
self,
num_inference_steps: int,
timesteps: torch.Tensor = None,
device: Union[str, torch.device] = None,
max_timesteps: float = 1.57080,
intermediate_timesteps: float = 1.3,
):
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -113,6 +131,12 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
Args:
num_inference_steps (`int`):
The number of diffusion steps used when generating samples with a pre-trained model.
timesteps (`torch.Tensor`, *optional*):
Custom timesteps to use for the denoising process.
max_timesteps (`float`, defaults to 1.57080):
The maximum timestep value used in the SCM scheduler.
intermediate_timesteps (`float`, *optional*, defaults to 1.3):
The intermediate timestep value used in SCM scheduler (only used when num_inference_steps=2).
"""
if num_inference_steps > self.config.num_train_timesteps:
raise ValueError(
@@ -121,39 +145,68 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
f" maximal {self.config.num_train_timesteps} timesteps."
)
if timesteps is not None and len(timesteps) != num_inference_steps + 1:
raise ValueError("If providing custom timesteps, `timesteps` must be of length `num_inference_steps + 1`.")
if timesteps is not None and max_timesteps is not None:
raise ValueError("If providing custom timesteps, `max_timesteps` should not be provided.")
if timesteps is None and max_timesteps is None:
raise ValueError("Should provide either `timesteps` or `max_timesteps`.")
if intermediate_timesteps is not None and num_inference_steps != 2:
raise ValueError("Intermediate timesteps for SCM is not supported when num_inference_steps != 2.")
self.num_inference_steps = num_inference_steps
if timesteps is not None and len(timesteps) == num_inference_steps + 1:
if timesteps is not None:
if isinstance(timesteps, list):
self.timesteps = torch.tensor(timesteps, device=device).float()
elif isinstance(timesteps, torch.Tensor):
self.timesteps = timesteps.to(device).float()
else:
raise ValueError(f"Unsupported timesteps type: {type(timesteps)}")
elif self.config.intermediate_timesteps and num_inference_steps == 2:
elif intermediate_timesteps is not None:
self.timesteps = torch.tensor(
[self.config.max_timesteps, self.config.intermediate_timesteps, 0], device=device
[max_timesteps, intermediate_timesteps, 0], device=device
).float()
elif self.config.intermediate_timesteps:
self.timesteps = torch.linspace(
self.config.max_timesteps, 0, num_inference_steps + 1, device=device
).float()
logger.warning(
f"Intermediate timesteps for SCM is not supported when num_inference_steps != 2. "
f"Reset timesteps to {self.timesteps} default max_timesteps"
)
else:
# max_timesteps=arctan(80/0.5)=1.56454 is the default from sCM paper, we choose a different value here
self.timesteps = torch.linspace(
self.config.max_timesteps, 0, num_inference_steps + 1, device=device
max_timesteps, 0, num_inference_steps + 1, device=device
).float()
print(f"Set timesteps: {self.timesteps}")
self._step_index = None
self._begin_index = None
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
def _init_step_index(self, timestep):
if self.begin_index is None:
if isinstance(timestep, torch.Tensor):
timestep = timestep.to(self.timesteps.device)
self._step_index = self.index_for_timestep(timestep)
else:
self._step_index = self._begin_index
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
def index_for_timestep(self, timestep, schedule_timesteps=None):
if schedule_timesteps is None:
schedule_timesteps = self.timesteps
indices = (schedule_timesteps == timestep).nonzero()
# The sigma index that is taken for the **very** first `step`
# is always the second index (or the last index if there is only 1)
# This way we can ensure we don't accidentally skip a sigma in
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
pos = 1 if len(indices) > 1 else 0
return indices[pos].item()
def step(
self,
model_output: torch.FloatTensor,
timeindex: int,
timestep: float,
sample: torch.FloatTensor,
generator: torch.Generator = None,
@@ -183,10 +236,13 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
raise ValueError(
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
if self.step_index is None:
self._init_step_index(timestep)
# 2. compute alphas, betas
t = self.timesteps[timeindex + 1]
s = self.timesteps[timeindex]
t = self.timesteps[self.step_index + 1]
s = self.timesteps[self.step_index]
# 4. Different Parameterization:
parameterization = self.config.prediction_type
@@ -206,6 +262,8 @@ class SCMScheduler(SchedulerMixin, ConfigMixin):
prev_sample = torch.cos(t) * pred_x0 + torch.sin(t) * noise
else:
prev_sample = pred_x0
self._step_index += 1
if not return_dict:
return (prev_sample, pred_x0)