1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Qwen-Image] adding validation for guidance_scale, true_cfg_scale and negative_prompt (#12223)

* up
This commit is contained in:
YiYi Xu
2025-08-27 01:04:33 -10:00
committed by GitHub
parent 552c127c05
commit 865ba102b3
5 changed files with 180 additions and 70 deletions

View File

@@ -435,7 +435,7 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
guidance_scale: Optional[float] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -462,7 +462,12 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
generate images that are closely linked to the text `prompt`, usually at the expense of lower image
quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -474,17 +479,16 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -564,6 +568,16 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
)
elif true_cfg_scale <= 1 and has_neg_prompt:
logger.warning(
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
@@ -618,10 +632,17 @@ class QwenImagePipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
if self.transformer.config.guidance_embeds and guidance_scale is None:
raise ValueError("guidance_scale is required for guidance-distilled model.")
elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
logger.warning(
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
)
guidance = None
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None
if self.attention_kwargs is None:

View File

@@ -535,7 +535,7 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
guidance_scale: Optional[float] = None,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
control_image: PipelineImageInput = None,
@@ -566,7 +566,12 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
generate images that are closely linked to the text `prompt`, usually at the expense of lower image
quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -578,12 +583,16 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -674,6 +683,16 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
)
elif true_cfg_scale <= 1 and has_neg_prompt:
logger.warning(
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
@@ -822,10 +841,17 @@ class QwenImageControlNetPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
controlnet_keep.append(keeps[0] if isinstance(self.controlnet, QwenImageControlNetModel) else keeps)
# handle guidance
if self.transformer.config.guidance_embeds:
if self.transformer.config.guidance_embeds and guidance_scale is None:
raise ValueError("guidance_scale is required for guidance-distilled model.")
elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
logger.warning(
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
)
guidance = None
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None
if self.attention_kwargs is None:

View File

@@ -532,7 +532,7 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
width: Optional[int] = None,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
guidance_scale: Optional[float] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -559,7 +559,12 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `true_cfg_scale` is
not greater than `1`).
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
true_cfg_scale (`float`, *optional*, defaults to 1.0): Guidance scale as defined in [Classifier-Free
Diffusion Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of
equation 2. of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is
enabled by setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale
encourages to generate images that are closely linked to the text `prompt`, usually at the expense of
lower image quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -571,17 +576,16 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -672,6 +676,16 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
)
elif true_cfg_scale <= 1 and has_neg_prompt:
logger.warning(
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
image=prompt_image,
@@ -734,10 +748,17 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
if self.transformer.config.guidance_embeds and guidance_scale is None:
raise ValueError("guidance_scale is required for guidance-distilled model.")
elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
logger.warning(
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
)
guidance = None
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None
if self.attention_kwargs is None:

View File

@@ -511,7 +511,7 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
strength: float = 0.6,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
guidance_scale: Optional[float] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -544,7 +544,12 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
generate images that are closely linked to the text `prompt`, usually at the expense of lower image
quality.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image. This is set to 1024 by default for the best results.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
@@ -562,17 +567,16 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -657,6 +661,16 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
)
elif true_cfg_scale <= 1 and has_neg_prompt:
logger.warning(
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
@@ -721,10 +735,17 @@ class QwenImageImg2ImgPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
if self.transformer.config.guidance_embeds and guidance_scale is None:
raise ValueError("guidance_scale is required for guidance-distilled model.")
elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
logger.warning(
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
)
guidance = None
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None
if self.attention_kwargs is None:

View File

@@ -624,7 +624,7 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
strength: float = 0.6,
num_inference_steps: int = 50,
sigmas: Optional[List[float]] = None,
guidance_scale: float = 1.0,
guidance_scale: Optional[float] = None,
num_images_per_prompt: int = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
@@ -657,7 +657,12 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
true_cfg_scale (`float`, *optional*, defaults to 1.0):
When > 1.0 and a provided `negative_prompt`, enables true classifier-free guidance.
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `true_cfg_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Classifier-free guidance is enabled by
setting `true_cfg_scale > 1` and a provided `negative_prompt`. Higher guidance scale encourages to
generate images that are closely linked to the text `prompt`, usually at the expense of lower image
quality.
mask_image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask
are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a
@@ -692,17 +697,16 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
will be used.
guidance_scale (`float`, *optional*, defaults to 3.5):
Guidance scale as defined in [Classifier-Free Diffusion
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
the text `prompt`, usually at the expense of lower image quality.
This parameter in the pipeline is there to support future guidance-distilled models when they come up.
Note that passing `guidance_scale` to the pipeline is ineffective. To enable classifier-free guidance,
please pass `true_cfg_scale` and `negative_prompt` (even an empty negative prompt like " ") should
enable classifier-free guidance computations.
guidance_scale (`float`, *optional*, defaults to None):
A guidance scale value for guidance distilled models. Unlike the traditional classifier-free guidance
where the guidance scale is applied during inference through noise prediction rescaling, guidance
distilled models take the guidance scale directly as an input parameter during forward pass. Guidance
scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images
that are closely linked to the text `prompt`, usually at the expense of lower image quality. This
parameter in the pipeline is there to support future guidance-distilled models when they come up. It is
ignored when not using guidance distilled models. To enable traditional classifier-free guidance,
please pass `true_cfg_scale > 1.0` and `negative_prompt` (even an empty negative prompt like " " should
enable classifier-free guidance computations).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -801,6 +805,16 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
has_neg_prompt = negative_prompt is not None or (
negative_prompt_embeds is not None and negative_prompt_embeds_mask is not None
)
if true_cfg_scale > 1 and not has_neg_prompt:
logger.warning(
f"true_cfg_scale is passed as {true_cfg_scale}, but classifier-free guidance is not enabled since no negative_prompt is provided."
)
elif true_cfg_scale <= 1 and has_neg_prompt:
logger.warning(
" negative_prompt is passed but classifier-free guidance is not enabled since true_cfg_scale <= 1"
)
do_true_cfg = true_cfg_scale > 1 and has_neg_prompt
prompt_embeds, prompt_embeds_mask = self.encode_prompt(
prompt=prompt,
@@ -890,10 +904,17 @@ class QwenImageInpaintPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
self._num_timesteps = len(timesteps)
# handle guidance
if self.transformer.config.guidance_embeds:
if self.transformer.config.guidance_embeds and guidance_scale is None:
raise ValueError("guidance_scale is required for guidance-distilled model.")
elif self.transformer.config.guidance_embeds:
guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
guidance = guidance.expand(latents.shape[0])
else:
elif not self.transformer.config.guidance_embeds and guidance_scale is not None:
logger.warning(
f"guidance_scale is passed as {guidance_scale}, but ignored since the model is not guidance-distilled."
)
guidance = None
elif not self.transformer.config.guidance_embeds and guidance_scale is None:
guidance = None
if self.attention_kwargs is None: