mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -191,7 +191,6 @@ class FluxPipeline(
|
||||
transformer: FluxTransformer2DModel,
|
||||
image_encoder: CLIPVisionModelWithProjection = None,
|
||||
feature_extractor: CLIPImageProcessor = None,
|
||||
variant: str = "flux",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -214,17 +213,6 @@ class FluxPipeline(
|
||||
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
|
||||
)
|
||||
self.default_sample_size = 128
|
||||
if variant not in {"flux", "chroma"}:
|
||||
raise ValueError("`variant` must be `'flux' or `'chroma'`.")
|
||||
|
||||
self.variant = variant
|
||||
|
||||
def _get_chroma_attn_mask(self, length: torch.Tensor, max_sequence_length: int) -> torch.Tensor:
|
||||
attention_mask = torch.zeros((length.shape[0], max_sequence_length), dtype=torch.bool, device=length.device)
|
||||
for i, n_tokens in enumerate(length):
|
||||
n_tokens = torch.max(n_tokens + 1, max_sequence_length)
|
||||
attention_mask[i, :n_tokens] = True
|
||||
return attention_mask
|
||||
|
||||
def _get_t5_prompt_embeds(
|
||||
self,
|
||||
@@ -248,7 +236,7 @@ class FluxPipeline(
|
||||
padding="max_length",
|
||||
max_length=max_sequence_length,
|
||||
truncation=True,
|
||||
return_length=True,
|
||||
return_length=False,
|
||||
return_overflowing_tokens=False,
|
||||
return_tensors="pt",
|
||||
)
|
||||
@@ -262,10 +250,7 @@ class FluxPipeline(
|
||||
f" {max_sequence_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
text_inputs.attention_mask[:, : text_inputs.length + 1] = 1.0
|
||||
prompt_embeds = self.text_encoder_2(
|
||||
text_input_ids.to(device), output_hidden_states=False, attention_mask=text_inputs.attention_mask.to(device)
|
||||
)[0]
|
||||
prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
|
||||
|
||||
dtype = self.text_encoder_2.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
@@ -702,11 +687,11 @@ class FluxPipeline(
|
||||
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
||||
will be used.
|
||||
guidance_scale (`float`, *optional*, defaults to 3.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
Guidance scale as defined in [Classifier-Free Diffusion
|
||||
Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2.
|
||||
of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
|
||||
`guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
|
||||
the text `prompt`, usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
@@ -715,7 +700,7 @@ class FluxPipeline(
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
|
||||
Reference in New Issue
Block a user