From 5652c43f8398473702bfdc30b55700cd378b1e87 Mon Sep 17 00:00:00 2001 From: nupurkmr9 Date: Mon, 24 Jul 2023 17:11:19 -0700 Subject: [PATCH] Resolve bf16 error as mentioned in this [issue](https://github.com/huggingface/diffusers/issues/4139#issuecomment-1639977304) (#4214) * resolve bf16 error * resolve bf16 error * resolve bf16 error * resolve bf16 error * resolve bf16 error * resolve bf16 error * resolve bf16 error --- examples/custom_diffusion/README.md | 2 +- examples/custom_diffusion/retrieve.py | 2 +- .../train_custom_diffusion.py | 1 + src/diffusers/models/attention_processor.py | 20 +++++++++++-------- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/custom_diffusion/README.md b/examples/custom_diffusion/README.md index ecd972737b..9e3c387e3d 100644 --- a/examples/custom_diffusion/README.md +++ b/examples/custom_diffusion/README.md @@ -277,4 +277,4 @@ To save even more memory, pass the `--set_grads_to_none` argument to the script. More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html ## Experimental results -You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. \ No newline at end of file +You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. We also released a more extensive dataset of 101 concepts for evaluating model customization methods. For more details please refer to our [dataset webpage](https://www.cs.cmu.edu/~custom-diffusion/dataset.html). \ No newline at end of file diff --git a/examples/custom_diffusion/retrieve.py b/examples/custom_diffusion/retrieve.py index 7b7635c188..6f050c1522 100644 --- a/examples/custom_diffusion/retrieve.py +++ b/examples/custom_diffusion/retrieve.py @@ -57,7 +57,7 @@ def retrieve(class_prompt, class_data_dir, num_class_images): images = class_images[count] count += 1 try: - img = requests.get(images["url"]) + img = requests.get(images["url"], timeout=30) if img.status_code == 200: _ = Image.open(BytesIO(img.content)) with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f: diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 995dbdca48..e710921b89 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -1210,6 +1210,7 @@ def main(args): text_encoder=accelerator.unwrap_model(text_encoder), tokenizer=tokenizer, revision=args.revision, + torch_dtype=weight_dtype, ) pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) pipeline = pipeline.to(accelerator.device) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 69889337ed..8468c864fe 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -667,9 +667,9 @@ class CustomDiffusionAttnProcessor(nn.Module): batch_size, sequence_length, _ = hidden_states.shape attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: - query = self.to_q_custom_diffusion(hidden_states) + query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) else: - query = attn.to_q(hidden_states) + query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) if encoder_hidden_states is None: crossattn = False @@ -680,8 +680,10 @@ class CustomDiffusionAttnProcessor(nn.Module): encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.train_kv: - key = self.to_k_custom_diffusion(encoder_hidden_states) - value = self.to_v_custom_diffusion(encoder_hidden_states) + key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) + value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) @@ -1392,9 +1394,9 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if self.train_q_out: - query = self.to_q_custom_diffusion(hidden_states) + query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype) else: - query = attn.to_q(hidden_states) + query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype)) if encoder_hidden_states is None: crossattn = False @@ -1405,8 +1407,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module): encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) if self.train_kv: - key = self.to_k_custom_diffusion(encoder_hidden_states) - value = self.to_v_custom_diffusion(encoder_hidden_states) + key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype)) + value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype)) + key = key.to(attn.to_q.weight.dtype) + value = value.to(attn.to_q.weight.dtype) else: key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states)