mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix: error on device for lpw_stable_diffusion_xl pipeline if pipe.enable_sequential_cpu_offload() enabled (#5885)
fix: set device for pipe.enable_sequential_cpu_offload()
This commit is contained in:
committed by
GitHub
parent
d72a24b790
commit
20f0cbc88f
@@ -250,6 +250,7 @@ def get_weighted_text_embeddings_sdxl(
|
||||
neg_prompt: str = "",
|
||||
neg_prompt_2: str = None,
|
||||
num_images_per_prompt: int = 1,
|
||||
device: Optional[torch.device] = None,
|
||||
):
|
||||
"""
|
||||
This function can process long prompt with weights, no length limitation
|
||||
@@ -262,10 +263,13 @@ def get_weighted_text_embeddings_sdxl(
|
||||
neg_prompt (str)
|
||||
neg_prompt_2 (str)
|
||||
num_images_per_prompt (int)
|
||||
device (torch.device)
|
||||
Returns:
|
||||
prompt_embeds (torch.Tensor)
|
||||
neg_prompt_embeds (torch.Tensor)
|
||||
"""
|
||||
device = device or pipe._execution_device
|
||||
|
||||
if prompt_2:
|
||||
prompt = f"{prompt} {prompt_2}"
|
||||
|
||||
@@ -330,17 +334,17 @@ def get_weighted_text_embeddings_sdxl(
|
||||
# get prompt embeddings one by one is not working.
|
||||
for i in range(len(prompt_token_groups)):
|
||||
# get positive prompt embeddings with weights
|
||||
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
|
||||
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
|
||||
token_tensor = torch.tensor([prompt_token_groups[i]], dtype=torch.long, device=device)
|
||||
weight_tensor = torch.tensor(prompt_weight_groups[i], dtype=torch.float16, device=device)
|
||||
|
||||
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
|
||||
token_tensor_2 = torch.tensor([prompt_token_groups_2[i]], dtype=torch.long, device=device)
|
||||
|
||||
# use first text encoder
|
||||
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(pipe.device), output_hidden_states=True)
|
||||
prompt_embeds_1 = pipe.text_encoder(token_tensor.to(device), output_hidden_states=True)
|
||||
prompt_embeds_1_hidden_states = prompt_embeds_1.hidden_states[-2]
|
||||
|
||||
# use second text encoder
|
||||
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(pipe.device), output_hidden_states=True)
|
||||
prompt_embeds_2 = pipe.text_encoder_2(token_tensor_2.to(device), output_hidden_states=True)
|
||||
prompt_embeds_2_hidden_states = prompt_embeds_2.hidden_states[-2]
|
||||
pooled_prompt_embeds = prompt_embeds_2[0]
|
||||
|
||||
@@ -357,16 +361,16 @@ def get_weighted_text_embeddings_sdxl(
|
||||
embeds.append(token_embedding)
|
||||
|
||||
# get negative prompt embeddings with weights
|
||||
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=pipe.device)
|
||||
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=pipe.device)
|
||||
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=pipe.device)
|
||||
neg_token_tensor = torch.tensor([neg_prompt_token_groups[i]], dtype=torch.long, device=device)
|
||||
neg_token_tensor_2 = torch.tensor([neg_prompt_token_groups_2[i]], dtype=torch.long, device=device)
|
||||
neg_weight_tensor = torch.tensor(neg_prompt_weight_groups[i], dtype=torch.float16, device=device)
|
||||
|
||||
# use first text encoder
|
||||
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(pipe.device), output_hidden_states=True)
|
||||
neg_prompt_embeds_1 = pipe.text_encoder(neg_token_tensor.to(device), output_hidden_states=True)
|
||||
neg_prompt_embeds_1_hidden_states = neg_prompt_embeds_1.hidden_states[-2]
|
||||
|
||||
# use second text encoder
|
||||
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(pipe.device), output_hidden_states=True)
|
||||
neg_prompt_embeds_2 = pipe.text_encoder_2(neg_token_tensor_2.to(device), output_hidden_states=True)
|
||||
neg_prompt_embeds_2_hidden_states = neg_prompt_embeds_2.hidden_states[-2]
|
||||
negative_pooled_prompt_embeds = neg_prompt_embeds_2[0]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user