1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Model CPU offload fix for BLIPDiffusion (#5174)

cpu offload fix for blip diffusion
This commit is contained in:
Dhruv Nair
2023-09-25 17:07:32 +05:30
committed by GitHub
parent 22b19d578e
commit 92f15f5bd4
2 changed files with 29 additions and 12 deletions

View File

@@ -98,6 +98,8 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
self,
tokenizer: CLIPTokenizer,
@@ -155,7 +157,9 @@ class BlipDiffusionPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
@@ -166,7 +170,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
).to(device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
@@ -249,11 +253,12 @@ class BlipDiffusionPipeline(DiffusionPipeline):
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
reference_image = reference_image.to(device)
if isinstance(prompt, str):
prompt = [prompt]
@@ -271,7 +276,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
max_length = self.text_encoder.text_model.config.max_position_embeddings
@@ -283,7 +288,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
@@ -300,7 +305,7 @@ class BlipDiffusionPipeline(DiffusionPipeline):
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
device=device,
)
# set timesteps
extra_set_kwargs = {}
@@ -330,9 +335,13 @@ class BlipDiffusionPipeline(DiffusionPipeline):
t,
latents,
)["prev_sample"]
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)

View File

@@ -107,6 +107,8 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Position of the context token in the text encoder.
"""
model_cpu_offload_seq = "qformer->text_encoder->unet->vae"
def __init__(
self,
tokenizer: CLIPTokenizer,
@@ -166,7 +168,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
latents = latents * self.scheduler.init_noise_sigma
return latents
def encode_prompt(self, query_embeds, prompt):
def encode_prompt(self, query_embeds, prompt, device=None):
device = device or self._execution_device
# embeddings for prompt, with query_embeds as context
max_len = self.text_encoder.text_model.config.max_position_embeddings
max_len -= self.qformer.config.num_query_tokens
@@ -177,7 +181,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
truncation=True,
max_length=max_len,
return_tensors="pt",
).to(self.device)
).to(device)
batch_size = query_embeds.shape[0]
ctx_begin_pos = [self.config.ctx_begin_pos] * batch_size
@@ -297,11 +301,12 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
Returns:
[`~pipelines.ImagePipelineOutput`] or `tuple`
"""
device = self._execution_device
reference_image = self.image_processor.preprocess(
reference_image, image_mean=self.config.mean, image_std=self.config.std, return_tensors="pt"
)["pixel_values"]
reference_image = reference_image.to(self.device)
reference_image = reference_image.to(device)
if isinstance(prompt, str):
prompt = [prompt]
@@ -319,7 +324,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
prompt_reps=prompt_reps,
)
query_embeds = self.get_query_embeddings(reference_image, source_subject_category)
text_embeddings = self.encode_prompt(query_embeds, prompt)
text_embeddings = self.encode_prompt(query_embeds, prompt, device)
# 3. unconditional embedding
do_classifier_free_guidance = guidance_scale > 1.0
if do_classifier_free_guidance:
@@ -332,7 +337,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
return_tensors="pt",
)
uncond_embeddings = self.text_encoder(
input_ids=uncond_input.input_ids.to(self.device),
input_ids=uncond_input.input_ids.to(device),
ctx_embeddings=None,
)[0]
# For classifier free guidance, we need to do two forward passes.
@@ -348,7 +353,7 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
generator=generator,
latents=latents,
dtype=self.unet.dtype,
device=self.device,
device=device,
)
# set timesteps
extra_set_kwargs = {}
@@ -399,6 +404,9 @@ class BlipDiffusionControlNetPipeline(DiffusionPipeline):
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)