mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script * propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script * propagate fixes from https://github.com/huggingface/diffusers/pull/11873/ to flux script * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -1399,6 +1399,7 @@ def main(args):
|
||||
torch_dtype = torch.float16
|
||||
elif args.prior_generation_precision == "bf16":
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
@@ -1419,7 +1420,8 @@ def main(args):
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
|
||||
images = pipeline(prompt=example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
|
||||
|
||||
@@ -1131,6 +1131,7 @@ def main(args):
|
||||
torch_dtype = torch.float16
|
||||
elif args.prior_generation_precision == "bf16":
|
||||
torch_dtype = torch.bfloat16
|
||||
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
@@ -1151,7 +1152,8 @@ def main(args):
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
|
||||
images = pipeline(prompt=example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
|
||||
@@ -1159,8 +1161,7 @@ def main(args):
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
free_memory()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
@@ -1728,6 +1729,10 @@ def main(args):
|
||||
device=accelerator.device,
|
||||
prompt=args.instance_prompt,
|
||||
)
|
||||
else:
|
||||
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
# Convert images to latent space
|
||||
if args.cache_latents:
|
||||
|
||||
Reference in New Issue
Block a user