diff --git a/src/diffusers/pipelines/paint_by_example/image_encoder.py b/src/diffusers/pipelines/paint_by_example/image_encoder.py index f79e79266e..e83f638c60 100644 --- a/src/diffusers/pipelines/paint_by_example/image_encoder.py +++ b/src/diffusers/pipelines/paint_by_example/image_encoder.py @@ -36,12 +36,15 @@ class PaintByExampleImageEncoder(CLIPPreTrainedModel): # uncondition for scaling self.uncond_vector = nn.Parameter(torch.randn((1, 1, self.proj_size))) - def forward(self, pixel_values): + def forward(self, pixel_values, return_uncond_vector=False): clip_output = self.model(pixel_values=pixel_values) latent_states = clip_output.pooler_output latent_states = self.mapper(latent_states[:, None]) latent_states = self.final_layer_norm(latent_states) latent_states = self.proj_out(latent_states) + if return_uncond_vector: + return latent_states, self.uncond_vector + return latent_states diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index 5b3ccba40c..1fb4521d2c 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -201,14 +201,11 @@ class PaintByExamplePipeline(DiffusionPipeline): device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.unet, self.vae]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) + for cpu_offloaded_model in [self.unet, self.vae, self.image_encoder]: + cpu_offload(cpu_offloaded_model, execution_device=device) if self.safety_checker is not None: - # TODO(Patrick) - there is currently a bug with cpu offload of nn.Parameter in accelerate - # fix by only offloading self.safety_checker for now - cpu_offload(self.safety_checker.vision_model, device) + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) @property # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device @@ -367,7 +364,7 @@ class PaintByExamplePipeline(DiffusionPipeline): image = self.feature_extractor(images=image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) - image_embeddings = self.image_encoder(image) + image_embeddings, uncond_embeddings = self.image_encoder(image, return_uncond_vector=True) # duplicate image embeddings for each generation per prompt, using mps friendly method bs_embed, seq_len, _ = image_embeddings.shape @@ -375,7 +372,6 @@ class PaintByExamplePipeline(DiffusionPipeline): image_embeddings = image_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1) if do_classifier_free_guidance: - uncond_embeddings = self.image_encoder.uncond_vector uncond_embeddings = uncond_embeddings.repeat(1, image_embeddings.shape[0], 1) uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, 1, -1)