1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

update dreambooth lora to work with IF stage II (#3560)

This commit is contained in:
Will Berman
2023-05-31 09:39:03 -07:00
committed by GitHub
parent abb89da4de
commit f751b8844e
4 changed files with 59 additions and 8 deletions

View File

@@ -60,6 +60,7 @@ from diffusers.models.attention_processor import (
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import randn_tensor
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
@@ -425,6 +426,19 @@ def parse_args(input_args=None):
required=False,
help="Whether to use attention mask for the text encoder",
)
parser.add_argument(
"--validation_images",
required=False,
default=None,
nargs="+",
help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
)
parser.add_argument(
"--class_labels_conditioning",
required=False,
default=None,
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
)
if input_args is not None:
args = parser.parse_args(input_args)
@@ -1121,7 +1135,7 @@ def main(args):
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
bsz = model_input.shape[0]
bsz, channels, height, width = model_input.shape
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
@@ -1143,8 +1157,24 @@ def main(args):
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
)
if unet.config.in_channels > channels:
needed_additional_channels = unet.config.in_channels - channels
additional_latents = randn_tensor(
(bsz, needed_additional_channels, height, width),
device=noisy_model_input.device,
dtype=noisy_model_input.dtype,
)
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
if args.class_labels_conditioning == "timesteps":
class_labels = timesteps
else:
class_labels = None
# Predict the noise residual
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample
model_pred = unet(
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
).sample
# if model predicts variance, throw away the prediction. we will only train on the
# simplified training objective. This means that all schedulers using the fine tuned
@@ -1248,9 +1278,18 @@ def main(args):
}
else:
pipeline_args = {"prompt": args.validation_prompt}
images = [
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)
]
if args.validation_images is None:
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
else:
images = []
for image in args.validation_images:
image = Image.open(image)
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
images.append(image)
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":

View File

@@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler
from ...utils import (
@@ -112,7 +113,7 @@ EXAMPLE_DOC_STRING = """
"""
class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer
text_encoder: T5EncoderModel
@@ -1047,6 +1048,9 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False

View File

@@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler
from ...utils import (
@@ -114,7 +115,7 @@ EXAMPLE_DOC_STRING = """
"""
class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer
text_encoder: T5EncoderModel
@@ -1154,6 +1155,9 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1
prev_intermediate_images = intermediate_images

View File

@@ -10,6 +10,7 @@ import torch
import torch.nn.functional as F
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
from ...loaders import LoraLoaderMixin
from ...models import UNet2DConditionModel
from ...schedulers import DDPMScheduler
from ...utils import (
@@ -70,7 +71,7 @@ EXAMPLE_DOC_STRING = """
"""
class IFSuperResolutionPipeline(DiffusionPipeline):
class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
tokenizer: T5Tokenizer
text_encoder: T5EncoderModel
@@ -903,6 +904,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline):
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
# compute the previous noisy sample x_t -> x_t-1
intermediate_images = self.scheduler.step(
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False