1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix: error on device for lpw_stable_diffusion_xl pipeline if pipe.enable_sequential_cpu_offload() enabled (#5885)

fix: set device for pipe.enable_sequential_cpu_offload()
This commit is contained in:
Viktor Grygorchuk
2023-11-27 14:47:47 +02:00
committed by GitHub
parent d72a24b790
commit 20f0cbc88f

View File

@@ -250,6 +250,7 @@ def get_weighted_text_embeddings_sdxl(
neg_prompt: str = "",
neg_prompt_2: str = None,
num_images_per_prompt: int = 1,
device: Optional[torch.device] = None,
):
"""
This function can process long prompt with weights, no length limitation
@@ -262,10 +263,13 @@ def get_weighted_text_embeddings_sdxl(
neg_prompt (str)
neg_prompt_2 (str)
num_images_per_prompt (int)
device (torch.device)
Returns:
prompt_embeds (torch.Tensor)
neg_prompt_embeds (torch.Tensor)
"""
device = device or pipe._execution_device
if prompt_2:
prompt = f"{prompt} {prompt_2}"
@@ -330,17 +334,17 @@ def get_weighted_text_embeddings_sdxl(
# get prompt embeddings one by one is not working.
for i in range(len(prompt_token_groups)):
# get positive prompt embeddings with weights
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
# use first text encoder
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
# use second text encoder
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
pooled_prompt_embeds = prompt_embeds_2[0]
@@ -357,16 +361,16 @@ def get_weighted_text_embeddings_sdxl(
embeds.append(token_embedding)
# get negative prompt embeddings with weights
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
# use first text encoder
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
# use second text encoder
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]