mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Model CPU offload fix for BLIPDiffusion (#5174)
cpu offload fix for blip diffusion
This commit is contained in:
@@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
@@ -155,7 +157,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, query_embeds, prompt):
|
||||
def encode_prompt(self, query_embeds, prompt, device=None):
|
||||
device = device or self._execution_device
|
||||
|
||||
# embeddings for prompt, with query_embeds as context
|
||||
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
||||
max_len -= self.qformer.config.num_query_tokens
|
||||
@@ -166,7 +170,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
max_length=max_len,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
).to(device)
|
||||
|
||||
batch_size = query_embeds.shape[0]
|
||||
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
||||
@@ -249,11 +253,12 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
device = self._execution_device
|
||||
|
||||
reference_image = self.image_processor.preprocess(
|
||||
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
||||
)["pixel_values"]
|
||||
reference_image = reference_image.to(self.device)
|
||||
reference_image = reference_image.to(device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
@@ -271,7 +276,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
prompt_reps=prompt_reps,
|
||||
)
|
||||
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
if do_classifier_free_guidance:
|
||||
max_length = self.text_encoder.text_model.config.max_position_embeddings
|
||||
@@ -283,7 +288,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.to(self.device),
|
||||
input_ids=uncond_input.input_ids.to(device),
|
||||
ctx_embeddings=None,
|
||||
)[0]
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
@@ -300,7 +305,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
dtype=self.unet.dtype,
|
||||
device=self.device,
|
||||
device=device,
|
||||
)
|
||||
# set timesteps
|
||||
extra_set_kwargs = {}
|
||||
@@ -330,9 +335,13 @@ class BlipDiffusionPipeline(DiffusionPipeline):
|
||||
t,
|
||||
latents,
|
||||
)["prev_sample"]
|
||||
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
|
||||
@@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
Position of the context token in the text encoder.
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tokenizer: CLIPTokenizer,
|
||||
@@ -166,7 +168,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def encode_prompt(self, query_embeds, prompt):
|
||||
def encode_prompt(self, query_embeds, prompt, device=None):
|
||||
device = device or self._execution_device
|
||||
|
||||
# embeddings for prompt, with query_embeds as context
|
||||
max_len = self.text_encoder.text_model.config.max_position_embeddings
|
||||
max_len -= self.qformer.config.num_query_tokens
|
||||
@@ -177,7 +181,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
truncation=True,
|
||||
max_length=max_len,
|
||||
return_tensors="pt",
|
||||
).to(self.device)
|
||||
).to(device)
|
||||
|
||||
batch_size = query_embeds.shape[0]
|
||||
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
|
||||
@@ -297,11 +301,12 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
Returns:
|
||||
[`~pipelines.ImagePipelineOutput`] or `tuple`
|
||||
"""
|
||||
device = self._execution_device
|
||||
|
||||
reference_image = self.image_processor.preprocess(
|
||||
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
|
||||
)["pixel_values"]
|
||||
reference_image = reference_image.to(self.device)
|
||||
reference_image = reference_image.to(device)
|
||||
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
@@ -319,7 +324,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
prompt_reps=prompt_reps,
|
||||
)
|
||||
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt)
|
||||
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
|
||||
# 3. unconditional embedding
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
if do_classifier_free_guidance:
|
||||
@@ -332,7 +337,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_embeddings = self.text_encoder(
|
||||
input_ids=uncond_input.input_ids.to(self.device),
|
||||
input_ids=uncond_input.input_ids.to(device),
|
||||
ctx_embeddings=None,
|
||||
)[0]
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
@@ -348,7 +353,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
generator=generator,
|
||||
latents=latents,
|
||||
dtype=self.unet.dtype,
|
||||
device=self.device,
|
||||
device=device,
|
||||
)
|
||||
# set timesteps
|
||||
extra_set_kwargs = {}
|
||||
@@ -399,6 +404,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user