From dccf39f01ed5d22d3435e612121b27f2820b0f66 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?0x=E5=90=8D=E7=84=A1=E3=81=97?= Date: Tue, 15 Oct 2024 17:18:13 +0530 Subject: [PATCH] Dreambooth lora flux bug 3dtensor to 2dtensor (#9653) * fixed issue #9350, Tensor is deprecated * ran make style --- examples/dreambooth/train_dreambooth_lora_flux.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index fcc11386ab..11cba745cc 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -985,7 +985,6 @@ def encode_prompt( text_input_ids_list=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) dtype = text_encoders[0].dtype pooled_prompt_embeds = _encode_prompt_with_clip( @@ -1007,8 +1006,7 @@ def encode_prompt( text_input_ids=text_input_ids_list[1] if text_input_ids_list else None, ) - text_ids = torch.zeros(batch_size, prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - text_ids = text_ids.repeat(num_images_per_prompt, 1, 1) + text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) return prompt_embeds, pooled_prompt_embeds, text_ids