mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update dreambooth lora to work with IF stage II (#3560)
This commit is contained in:
@@ -60,6 +60,7 @@ from diffusers.models.attention_processor import (
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
@@ -425,6 +426,19 @@ def parse_args(input_args=None):
|
||||
required=False,
|
||||
help="Whether to use attention mask for the text encoder",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_images",
|
||||
required=False,
|
||||
default=None,
|
||||
nargs="+",
|
||||
help="Optional set of images to use for validation. Used when the target pipeline takes an initial image as input such as when training image variation or superresolution.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_labels_conditioning",
|
||||
required=False,
|
||||
default=None,
|
||||
help="The optional `class_label` conditioning to pass to the unet, available values are `timesteps`.",
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
@@ -1121,7 +1135,7 @@ def main(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
bsz = model_input.shape[0]
|
||||
bsz, channels, height, width = model_input.shape
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
|
||||
@@ -1143,8 +1157,24 @@ def main(args):
|
||||
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
|
||||
)
|
||||
|
||||
if unet.config.in_channels > channels:
|
||||
needed_additional_channels = unet.config.in_channels - channels
|
||||
additional_latents = randn_tensor(
|
||||
(bsz, needed_additional_channels, height, width),
|
||||
device=noisy_model_input.device,
|
||||
dtype=noisy_model_input.dtype,
|
||||
)
|
||||
noisy_model_input = torch.cat([additional_latents, noisy_model_input], dim=1)
|
||||
|
||||
if args.class_labels_conditioning == "timesteps":
|
||||
class_labels = timesteps
|
||||
else:
|
||||
class_labels = None
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, encoder_hidden_states, class_labels=class_labels
|
||||
).sample
|
||||
|
||||
# if model predicts variance, throw away the prediction. we will only train on the
|
||||
# simplified training objective. This means that all schedulers using the fine tuned
|
||||
@@ -1248,9 +1278,18 @@ def main(args):
|
||||
}
|
||||
else:
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)
|
||||
]
|
||||
|
||||
if args.validation_images is None:
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
]
|
||||
else:
|
||||
images = []
|
||||
for image in args.validation_images:
|
||||
image = Image.open(image)
|
||||
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...loaders import LoraLoaderMixin
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
@@ -112,7 +113,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
|
||||
class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
@@ -1047,6 +1048,9 @@ class IFImg2ImgSuperResolutionPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
|
||||
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...loaders import LoraLoaderMixin
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
@@ -114,7 +115,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
|
||||
class IFInpaintingSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
@@ -1154,6 +1155,9 @@ class IFInpaintingSuperResolutionPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
|
||||
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
prev_intermediate_images = intermediate_images
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPImageProcessor, T5EncoderModel, T5Tokenizer
|
||||
|
||||
from ...loaders import LoraLoaderMixin
|
||||
from ...models import UNet2DConditionModel
|
||||
from ...schedulers import DDPMScheduler
|
||||
from ...utils import (
|
||||
@@ -70,7 +71,7 @@ EXAMPLE_DOC_STRING = """
|
||||
"""
|
||||
|
||||
|
||||
class IFSuperResolutionPipeline(DiffusionPipeline):
|
||||
class IFSuperResolutionPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
tokenizer: T5Tokenizer
|
||||
text_encoder: T5EncoderModel
|
||||
|
||||
@@ -903,6 +904,9 @@ class IFSuperResolutionPipeline(DiffusionPipeline):
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if self.scheduler.config.variance_type not in ["learned", "learned_range"]:
|
||||
noise_pred, _ = noise_pred.split(intermediate_images.shape[1], dim=1)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
intermediate_images = self.scheduler.step(
|
||||
noise_pred, t, intermediate_images, **extra_step_kwargs, return_dict=False
|
||||
|
||||
Reference in New Issue
Block a user