diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index d9504e95c5..f4bb8267aa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -553,13 +553,15 @@ class StableDiffusionInstructPix2PixPipeline( else: attention_mask = None - prompt_embeds = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, - ) + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + else: + prompt_embeds_dtype = self.unet.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -615,7 +617,7 @@ class StableDiffusionInstructPix2PixPipeline( # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)