diff --git a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py index 3ca456c6f4..3d8fd36708 100644 --- a/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +++ b/src/diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py @@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline): Position of the context token in the text encoder. """ + model_cpu_offload_seq = "qformer->text_encoder->unet->vae" + def __init__( self, tokenizer: CLIPTokenizer, @@ -155,7 +157,9 @@ class BlipDiffusionPipeline(DiffusionPipeline): latents = latents * self.scheduler.init_noise_sigma return latents - def encode_prompt(self, query_embeds, prompt): + def encode_prompt(self, query_embeds, prompt, device=None): + device = device or self._execution_device + # embeddings for prompt, with query_embeds as context max_len = self.text_encoder.text_model.config.max_position_embeddings max_len -= self.qformer.config.num_query_tokens @@ -166,7 +170,7 @@ class BlipDiffusionPipeline(DiffusionPipeline): truncation=True, max_length=max_len, return_tensors="pt", - ).to(self.device) + ).to(device) batch_size = query_embeds.shape[0] ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size @@ -249,11 +253,12 @@ class BlipDiffusionPipeline(DiffusionPipeline): Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ + device = self._execution_device reference_image = self.image_processor.preprocess( reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" )["pixel_values"] - reference_image = reference_image.to(self.device) + reference_image = reference_image.to(device) if isinstance(prompt, str): prompt = [prompt] @@ -271,7 +276,7 @@ class BlipDiffusionPipeline(DiffusionPipeline): prompt_reps=prompt_reps, ) query_embeds = self.get_query_embeddings(reference_image, source_subject_category) - text_embeddings = self.encode_prompt(query_embeds, prompt) + text_embeddings = self.encode_prompt(query_embeds, prompt, device) do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: max_length = self.text_encoder.text_model.config.max_position_embeddings @@ -283,7 +288,7 @@ class BlipDiffusionPipeline(DiffusionPipeline): return_tensors="pt", ) uncond_embeddings = self.text_encoder( - input_ids=uncond_input.input_ids.to(self.device), + input_ids=uncond_input.input_ids.to(device), ctx_embeddings=None, )[0] # For classifier free guidance, we need to do two forward passes. @@ -300,7 +305,7 @@ class BlipDiffusionPipeline(DiffusionPipeline): generator=generator, latents=latents, dtype=self.unet.dtype, - device=self.device, + device=device, ) # set timesteps extra_set_kwargs = {} @@ -330,9 +335,13 @@ class BlipDiffusionPipeline(DiffusionPipeline): t, latents, )["prev_sample"] + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) + # Offload all models + self.maybe_free_model_hooks() + if not return_dict: return (image,) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py index 4d0eb142e8..cdf0a1dcaa 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py @@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): Position of the context token in the text encoder. """ + model_cpu_offload_seq = "qformer->text_encoder->unet->vae" + def __init__( self, tokenizer: CLIPTokenizer, @@ -166,7 +168,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): latents = latents * self.scheduler.init_noise_sigma return latents - def encode_prompt(self, query_embeds, prompt): + def encode_prompt(self, query_embeds, prompt, device=None): + device = device or self._execution_device + # embeddings for prompt, with query_embeds as context max_len = self.text_encoder.text_model.config.max_position_embeddings max_len -= self.qformer.config.num_query_tokens @@ -177,7 +181,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): truncation=True, max_length=max_len, return_tensors="pt", - ).to(self.device) + ).to(device) batch_size = query_embeds.shape[0] ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size @@ -297,11 +301,12 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): Returns: [`~pipelines.ImagePipelineOutput`] or `tuple` """ + device = self._execution_device reference_image = self.image_processor.preprocess( reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt" )["pixel_values"] - reference_image = reference_image.to(self.device) + reference_image = reference_image.to(device) if isinstance(prompt, str): prompt = [prompt] @@ -319,7 +324,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): prompt_reps=prompt_reps, ) query_embeds = self.get_query_embeddings(reference_image, source_subject_category) - text_embeddings = self.encode_prompt(query_embeds, prompt) + text_embeddings = self.encode_prompt(query_embeds, prompt, device) # 3. unconditional embedding do_classifier_free_guidance = guidance_scale > 1.0 if do_classifier_free_guidance: @@ -332,7 +337,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): return_tensors="pt", ) uncond_embeddings = self.text_encoder( - input_ids=uncond_input.input_ids.to(self.device), + input_ids=uncond_input.input_ids.to(device), ctx_embeddings=None, )[0] # For classifier free guidance, we need to do two forward passes. @@ -348,7 +353,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): generator=generator, latents=latents, dtype=self.unet.dtype, - device=self.device, + device=device, ) # set timesteps extra_set_kwargs = {} @@ -399,6 +404,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline): image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.image_processor.postprocess(image, output_type=output_type) + # Offload all models + self.maybe_free_model_hooks() + if not return_dict: return (image,)