mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
mixed inference
This commit is contained in:
@@ -313,7 +313,7 @@ def create_vae_diffusers_config(vae_params):
|
||||
return config
|
||||
|
||||
|
||||
def create_diffusers_schedular(original_config):
|
||||
def create_diffusers_scheduler(original_config):
|
||||
schedular = DDIMScheduler(
|
||||
num_train_timesteps=original_config.model.params.timesteps,
|
||||
beta_start=original_config.model.params.linear_start,
|
||||
@@ -323,16 +323,6 @@ def create_diffusers_schedular(original_config):
|
||||
return schedular
|
||||
|
||||
|
||||
def create_ldm_bert_config(original_config):
|
||||
bert_params = original_config.model.parms.cond_stage_config.params
|
||||
config = LDMBertConfig(
|
||||
d_model=bert_params.n_embed,
|
||||
encoder_layers=bert_params.n_layer,
|
||||
encoder_ffn_dim=bert_params.n_embed * 4,
|
||||
)
|
||||
return config
|
||||
|
||||
|
||||
def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
|
||||
"""
|
||||
Takes a state dict and a config, and returns a converted checkpoint.
|
||||
@@ -503,7 +493,7 @@ def convert_vd_unet_checkpoint(checkpoint, config, unet_key, extract_ema=False):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
def convert_vd_vae_checkpoint(checkpoint, config):
|
||||
# extract state dict for VAE
|
||||
vae_state_dict = {}
|
||||
keys = list(checkpoint.keys())
|
||||
@@ -608,72 +598,6 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
||||
return new_checkpoint
|
||||
|
||||
|
||||
def convert_ldm_bert_checkpoint(checkpoint, config):
|
||||
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
||||
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
||||
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
||||
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
||||
|
||||
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
||||
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
||||
|
||||
def _copy_linear(hf_linear, pt_linear):
|
||||
hf_linear.weight = pt_linear.weight
|
||||
hf_linear.bias = pt_linear.bias
|
||||
|
||||
def _copy_layer(hf_layer, pt_layer):
|
||||
# copy layer norms
|
||||
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
||||
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
||||
|
||||
# copy attn
|
||||
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
||||
|
||||
# copy MLP
|
||||
pt_mlp = pt_layer[1][1]
|
||||
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
||||
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
||||
|
||||
def _copy_layers(hf_layers, pt_layers):
|
||||
for i, hf_layer in enumerate(hf_layers):
|
||||
if i != 0:
|
||||
i += i
|
||||
pt_layer = pt_layers[i : i + 2]
|
||||
_copy_layer(hf_layer, pt_layer)
|
||||
|
||||
hf_model = LDMBertModel(config).eval()
|
||||
|
||||
# copy embeds
|
||||
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
||||
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
||||
|
||||
# copy layer norm
|
||||
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
||||
|
||||
# copy hidden layers
|
||||
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
||||
|
||||
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
||||
|
||||
return hf_model
|
||||
|
||||
|
||||
def convert_ldm_clip_checkpoint(checkpoint):
|
||||
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
text_model_dict = {}
|
||||
|
||||
for key in keys:
|
||||
if key.startswith("cond_stage_model.transformer"):
|
||||
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
return text_model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
@@ -758,18 +682,18 @@ if __name__ == "__main__":
|
||||
image_unet = UNet2DConditionModel(**image_unet_config)
|
||||
image_unet.load_state_dict(converted_image_unet_checkpoint)
|
||||
|
||||
# text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG)
|
||||
# converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
# checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
|
||||
# )
|
||||
# text_unet = UNet2DConditionModel(**text_unet_config)
|
||||
# text_unet.load_state_dict(converted_text_unet_checkpoint)
|
||||
text_unet_config = create_unet_diffusers_config(TEXT_UNET_CONFIG)
|
||||
converted_text_unet_checkpoint = convert_vd_unet_checkpoint(
|
||||
checkpoint, text_unet_config, unet_key="model.diffusion_model.unet_text.", extract_ema=args.extract_ema
|
||||
)
|
||||
text_unet = UNet2DConditionModel(**text_unet_config)
|
||||
text_unet.load_state_dict(converted_text_unet_checkpoint, strict=False)
|
||||
|
||||
# Convert the VAE model.
|
||||
if args.vae_checkpoint_path is not None:
|
||||
vae_config = create_vae_diffusers_config(AUTOENCODER_CONFIG)
|
||||
checkpoint = torch.load(args.vae_checkpoint_path)
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
converted_vae_checkpoint = convert_vd_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
@@ -786,7 +710,7 @@ if __name__ == "__main__":
|
||||
text_encoder=text_encoder,
|
||||
image_encoder=image_encoder,
|
||||
image_unet=image_unet,
|
||||
# text_unet=text_unet,
|
||||
text_unet=text_unet,
|
||||
vae=vae,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
||||
@@ -21,7 +21,9 @@ import torch.utils.checkpoint
|
||||
|
||||
from transformers import CLIPProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel, VQModel
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...models.attention import Transformer2DModel
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
|
||||
@@ -59,6 +61,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
image_encoder: CLIPVisionModel,
|
||||
image_unet: UNet2DConditionModel,
|
||||
text_unet: UNet2DConditionModel,
|
||||
vae: Union[VQModel, AutoencoderKL],
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
):
|
||||
@@ -72,6 +75,18 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
self.image_transformer_blocks = {
|
||||
name: module for name, module in image_unet.named_modules() if isinstance(module, Transformer2DModel)
|
||||
}
|
||||
self.text_transformer_blocks = {
|
||||
name: module for name, module in text_unet.named_modules() if isinstance(module, Transformer2DModel)
|
||||
}
|
||||
|
||||
# text2img by default
|
||||
for full_name, module in image_unet.named_modules():
|
||||
if isinstance(module, Transformer2DModel):
|
||||
parent_name, name = full_name.rsplit('.', 1)
|
||||
image_unet.get_submodule(parent_name)[name] = self.text_transformer_blocks[name]
|
||||
|
||||
def _encode_prompt(self, prompt, do_classifier_free_guidance):
|
||||
r"""
|
||||
@@ -85,8 +100,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
"""
|
||||
|
||||
def _normalize_embeddings(encoder_output):
|
||||
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state)
|
||||
embeds_pooled = encoder_output.text_embeds
|
||||
embeds = self.text_encoder.text_projection(encoder_output.last_hidden_state) # sum == 19677.4570
|
||||
embeds_pooled = encoder_output.text_embeds # sum == 260.2655
|
||||
embeds = embeds / torch.norm(embeds_pooled.unsqueeze(1), dim=-1, keepdim=True)
|
||||
return embeds
|
||||
|
||||
@@ -171,9 +186,8 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
|
||||
latents = torch.randn(
|
||||
(batch_size, self.image_unet.in_channels, height // 8, width // 8),
|
||||
generator=generator,
|
||||
generator=generator, device=self.device
|
||||
)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
@@ -185,6 +199,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
extra_kwargs["eta"] = eta
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
t += 1
|
||||
if not do_classifier_free_guidance:
|
||||
latents_input = latents
|
||||
else:
|
||||
@@ -213,3 +228,115 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
|
||||
class VDMixedModelWrapper(nn.Module):
|
||||
def __init__(self, image_unet: UNet2DConditionModel, text_unet: UNet2DConditionModel):
|
||||
super().__init__()
|
||||
self.image_unet = image_unet
|
||||
self.text_unet = text_unet
|
||||
self.time_embedding = self.unet_image.time_embedding
|
||||
self.time_proj = self.unet_image.time_proj
|
||||
|
||||
def embed_imesteps(self, timesteps, sample):
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(sample.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(sample.shape[0])
|
||||
t_emb = self.time_proj(timesteps)
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might actually be running in fp16. so we need to cast here.
|
||||
# there might be better ways to encapsulate this.
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
return emb
|
||||
|
||||
def forward(self, sample, timestep, encoder_hidden_states, latents_type="image", condition_type="text", return_dict: bool = True):
|
||||
default_overall_up_factor = 2 ** self.image_unet.num_upsamplers
|
||||
|
||||
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
||||
forward_upsample_size = False
|
||||
upsample_size = None
|
||||
|
||||
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
||||
forward_upsample_size = True
|
||||
|
||||
# 1. time
|
||||
emb = self.embed_imesteps(timestep, sample)
|
||||
|
||||
# 2. pre-process
|
||||
if latents_type == "image":
|
||||
sample = self.image_unet.conv_in(sample)
|
||||
elif latents_type == "text":
|
||||
sample = self.text_unet.conv_in(sample)
|
||||
|
||||
# 3. down
|
||||
down_block_res_samples = (sample,)
|
||||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
is_final_block = i == len(self.up_blocks) - 1
|
||||
|
||||
res_samples = down_block_res_samples[-len(upsample_block.resnets):]
|
||||
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
||||
|
||||
# if we have not reached the final block and need to forward the
|
||||
# upsample size, we do it here
|
||||
if not is_final_block and forward_upsample_size:
|
||||
upsample_size = down_block_res_samples[-1].shape[2:]
|
||||
|
||||
if hasattr(upsample_block, "attentions") and upsample_block.attentions is not None:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
|
||||
)
|
||||
# 6. post-process
|
||||
sample = self.conv_norm_out(sample)
|
||||
sample = self.conv_act(sample)
|
||||
sample = self.conv_out(sample)
|
||||
|
||||
if not return_dict:
|
||||
return (sample,)
|
||||
|
||||
return UNet2DConditionOutput(sample=sample)
|
||||
|
||||
|
||||
|
||||
def mixed_forward(self, image_module, text_module, hidden_state, timesteps_emb, condition, latents_type="image", condition_type="text"):
|
||||
for ilayer, tlayer in zip(image_module, text_module):
|
||||
if isinstance(ilayer, SpatialTransformer) and condition_type == 'image':
|
||||
hidden_state = ilayer(hidden_state, condition)
|
||||
elif isinstance(ilayer, SpatialTransformer) and condition_type == 'text':
|
||||
hidden_state = tlayer(hidden_state, condition)
|
||||
elif latents_type == 'image':
|
||||
hidden_state = ilayer(hidden_state)
|
||||
elif latents_type == 'text':
|
||||
hidden_state = tlayer(hidden_state)
|
||||
else:
|
||||
raise ValueError(f"latents_type {latents_type} and condition_type {condition_type} not supported")
|
||||
return hidden_state
|
||||
|
||||
|
||||
|
||||
@@ -37,10 +37,10 @@ class VersatileDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_inference_text2img(self):
|
||||
pipe = VersatileDiffusionPipeline.from_pretrained("scripts/vd-diffusers")
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
#pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
generator = torch.manual_seed(0)
|
||||
prompt = "a dream of a village in china, by Caspar David Friedrich, matte painting trending on artstation HQ"
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
image = pipe(
|
||||
[prompt], generator=generator, guidance_scale=7.5, num_inference_steps=50, output_type="numpy"
|
||||
).images
|
||||
|
||||
Reference in New Issue
Block a user