mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Resolve bf16 error as mentioned in this [issue](https://github.com/huggingface/diffusers/issues/4139#issuecomment-1639977304) (#4214)
* resolve bf16 error * resolve bf16 error * resolve bf16 error * resolve bf16 error * resolve bf16 error * resolve bf16 error * resolve bf16 error
This commit is contained in:
@@ -277,4 +277,4 @@ To save even more memory, pass the `--set_grads_to_none` argument to the script.
|
||||
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
|
||||
|
||||
## Experimental results
|
||||
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail.
|
||||
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail. We also released a more extensive dataset of 101 concepts for evaluating model customization methods. For more details please refer to our [dataset webpage](https://www.cs.cmu.edu/~custom-diffusion/dataset.html).
|
||||
@@ -57,7 +57,7 @@ def retrieve(class_prompt, class_data_dir, num_class_images):
|
||||
images = class_images[count]
|
||||
count += 1
|
||||
try:
|
||||
img = requests.get(images["url"])
|
||||
img = requests.get(images["url"], timeout=30)
|
||||
if img.status_code == 200:
|
||||
_ = Image.open(BytesIO(img.content))
|
||||
with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f:
|
||||
|
||||
@@ -1210,6 +1210,7 @@ def main(args):
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
tokenizer=tokenizer,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
@@ -667,9 +667,9 @@ class CustomDiffusionAttnProcessor(nn.Module):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
if self.train_q_out:
|
||||
query = self.to_q_custom_diffusion(hidden_states)
|
||||
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
crossattn = False
|
||||
@@ -680,8 +680,10 @@ class CustomDiffusionAttnProcessor(nn.Module):
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if self.train_kv:
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
|
||||
key = key.to(attn.to_q.weight.dtype)
|
||||
value = value.to(attn.to_q.weight.dtype)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
@@ -1392,9 +1394,9 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if self.train_q_out:
|
||||
query = self.to_q_custom_diffusion(hidden_states)
|
||||
query = self.to_q_custom_diffusion(hidden_states).to(attn.to_q.weight.dtype)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
query = attn.to_q(hidden_states.to(attn.to_q.weight.dtype))
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
crossattn = False
|
||||
@@ -1405,8 +1407,10 @@ class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if self.train_kv:
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states.to(self.to_k_custom_diffusion.weight.dtype))
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states.to(self.to_v_custom_diffusion.weight.dtype))
|
||||
key = key.to(attn.to_q.weight.dtype)
|
||||
value = value.to(attn.to_q.weight.dtype)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user