mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add: janky draft of sdxl dreambooth lora.
This commit is contained in:
1401
examples/dreambooth/train_dreambooth_sd_xl_lora.py
Normal file
1401
examples/dreambooth/train_dreambooth_sd_xl_lora.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -131,7 +131,7 @@ if __name__ == "__main__":
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Set to a path, hub id to an already converted vae to not convert it again."
|
||||
help="Set to a path, hub id to an already converted vae to not convert it again.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -160,8 +160,8 @@ else:
|
||||
StableDiffusionPix2PixZeroPipeline,
|
||||
StableDiffusionSAGPipeline,
|
||||
StableDiffusionUpscalePipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
TextToVideoSDPipeline,
|
||||
|
||||
@@ -89,7 +89,7 @@ else:
|
||||
StableUnCLIPPipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .stable_diffusion_xl import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
|
||||
from .stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
|
||||
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline
|
||||
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
|
||||
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder
|
||||
|
||||
@@ -257,7 +257,11 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
resolution //= 2
|
||||
|
||||
if unet_params.transformer_depth is not None:
|
||||
num_transformer_blocks = unet_params.transformer_depth if isinstance(unet_params.transformer_depth, int) else list(unet_params.transformer_depth)
|
||||
num_transformer_blocks = (
|
||||
unet_params.transformer_depth
|
||||
if isinstance(unet_params.transformer_depth, int)
|
||||
else list(unet_params.transformer_depth)
|
||||
)
|
||||
else:
|
||||
num_transformer_blocks = 1
|
||||
|
||||
@@ -270,7 +274,7 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
if use_linear_projection:
|
||||
# stable diffusion 2-base-512 and 2-768
|
||||
if head_dim is None:
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim_mult = unet_params.model_channels // unet_params.num_head_channels
|
||||
head_dim = [head_dim_mult * c for c in list(unet_params.channel_mult)]
|
||||
|
||||
class_embed_type = None
|
||||
@@ -280,7 +284,9 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa
|
||||
context_dim = None
|
||||
|
||||
if unet_params.context_dim is not None:
|
||||
context_dim = unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
context_dim = (
|
||||
unet_params.context_dim if isinstance(unet_params.context_dim, int) else unet_params.context_dim[0]
|
||||
)
|
||||
|
||||
if "num_classes" in unet_params:
|
||||
if unet_params.num_classes == "sequential":
|
||||
@@ -775,7 +781,7 @@ def convert_ldm_clip_checkpoint(checkpoint, local_files_only=False):
|
||||
for key in keys:
|
||||
for prefix in remove_prefixes:
|
||||
if key.startswith(prefix):
|
||||
text_model_dict[key[len(prefix + "."):]] = checkpoint[key]
|
||||
text_model_dict[key[len(prefix + ".") :]] = checkpoint[key]
|
||||
|
||||
text_model.load_state_dict(text_model_dict)
|
||||
|
||||
@@ -787,7 +793,7 @@ textenc_conversion_lst = [
|
||||
("token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
||||
("ln_final.weight", "text_model.final_layer_norm.weight"),
|
||||
("ln_final.bias", "text_model.final_layer_norm.bias"),
|
||||
("text_projection", "text_projection.weight")
|
||||
("text_projection", "text_projection.weight"),
|
||||
]
|
||||
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
||||
|
||||
@@ -876,7 +882,9 @@ def convert_paint_by_example_checkpoint(checkpoint):
|
||||
|
||||
def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
|
||||
# text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
||||
text_model = CLIPTextModelWithProjection.from_pretrained("laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280)
|
||||
text_model = CLIPTextModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-bigG-14-laion2B-39B-b160k", projection_dim=1280
|
||||
)
|
||||
|
||||
keys = list(checkpoint.keys())
|
||||
|
||||
@@ -892,13 +900,13 @@ def convert_open_clip_checkpoint(checkpoint, prefix="cond_stage_model.model."):
|
||||
for key in keys:
|
||||
# if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
||||
# continue
|
||||
if key[len(prefix):] in textenc_conversion_map:
|
||||
if key[len(prefix) :] in textenc_conversion_map:
|
||||
if key.endswith("text_projection"):
|
||||
value = checkpoint[key].T
|
||||
else:
|
||||
value = checkpoint[key]
|
||||
|
||||
text_model_dict[textenc_conversion_map[key[len(prefix):]]] = value
|
||||
text_model_dict[textenc_conversion_map[key[len(prefix) :]]] = value
|
||||
|
||||
if key.startswith(prefix + "transformer."):
|
||||
new_key = key[len(prefix + "transformer.") :]
|
||||
@@ -1122,10 +1130,10 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
PaintByExamplePipeline,
|
||||
StableDiffusionControlNetPipeline,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableUnCLIPImg2ImgPipeline,
|
||||
StableUnCLIPPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
StableDiffusionXLImg2ImgPipeline,
|
||||
)
|
||||
|
||||
if pipeline_class is None:
|
||||
|
||||
@@ -24,7 +24,12 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, is_accelerate_version, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
|
||||
@@ -13,20 +13,22 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
@@ -35,7 +37,6 @@ from ...utils import (
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -305,7 +306,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
@@ -327,14 +330,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(
|
||||
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
@@ -405,7 +406,9 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -417,9 +420,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -652,7 +658,12 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt(
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -753,11 +764,10 @@ class StableDiffusionXLPipeline(DiffusionPipeline):
|
||||
else:
|
||||
latents = latents.float()
|
||||
|
||||
|
||||
if not output_type == "latent":
|
||||
# CHECK there is problem here (PVP)
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
# image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -13,22 +13,24 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional, Union, Tuple
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
import torch
|
||||
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...image_processor import VaeImageProcessor
|
||||
from ...loaders import FromCkptMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
LoRAXFormersAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...models.attention_processor import AttnProcessor2_0, LoRAXFormersAttnProcessor, XFormersAttnProcessor, LoRAAttnProcessor2_0
|
||||
from ...utils import (
|
||||
deprecate,
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
@@ -37,7 +39,6 @@ from ...utils import (
|
||||
)
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionXLPipelineOutput
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -307,7 +308,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
# Define tokenizers and text encoders
|
||||
tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
|
||||
text_encoders = [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
text_encoders = (
|
||||
[self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
|
||||
)
|
||||
|
||||
if prompt_embeds is None:
|
||||
# textual inversion: procecss multi-vector tokens if necessary
|
||||
@@ -329,14 +332,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
||||
text_input_ids, untruncated_ids
|
||||
):
|
||||
removed_text = tokenizer.batch_decode(
|
||||
untruncated_ids[:, tokenizer.model_max_length - 1 : -1]
|
||||
)
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=True,
|
||||
@@ -410,7 +411,9 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device)
|
||||
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
@@ -422,8 +425,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
||||
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
|
||||
bs_embed * num_images_per_prompt, -1
|
||||
)
|
||||
|
||||
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
|
||||
|
||||
@@ -521,7 +528,6 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
init_latents = image
|
||||
|
||||
else:
|
||||
|
||||
# make sure the VAE is in float32 mode, as it overflows in float16
|
||||
image = image.float()
|
||||
self.vae.to(dtype=torch.float32)
|
||||
@@ -561,7 +567,6 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
return latents
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
@@ -705,7 +710,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
text_encoder_lora_scale = (
|
||||
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
)
|
||||
prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds = self.encode_prompt(
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
@@ -737,8 +747,12 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
|
||||
if self.unet.add_embedding.linear_1.in_features == (1280 + 5 * 256):
|
||||
# refiner
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + (aesthetic_score,))], dtype=torch.long)
|
||||
neg_add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + (negative_aesthetic_score,))], dtype=torch.long)
|
||||
add_time_ids = torch.tensor(
|
||||
[list(original_size + crops_coords_top_left + (aesthetic_score,))], dtype=torch.long
|
||||
)
|
||||
neg_add_time_ids = torch.tensor(
|
||||
[list(original_size + crops_coords_top_left + (negative_aesthetic_score,))], dtype=torch.long
|
||||
)
|
||||
elif self.unet.add_embedding.linear_1.in_features == (1280 + 6 * 256):
|
||||
# SD-XL Base
|
||||
add_time_ids = torch.tensor([list(original_size + crops_coords_top_left + target_size)], dtype=torch.long)
|
||||
@@ -811,7 +825,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline):
|
||||
if not output_type == "latent":
|
||||
# CHECK there is problem here (PVP)
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
#image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
# image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
has_nsfw_concept = None
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -183,7 +183,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
step_index = (self.timesteps == timestep).nonzero().item()
|
||||
sigma = self.sigmas[step_index]
|
||||
|
||||
sample = sample / ((sigma **2 + 1) ** 0.5)
|
||||
sample = sample / ((sigma**2 + 1) ** 0.5)
|
||||
|
||||
self.is_scale_input_called = True
|
||||
return sample
|
||||
@@ -202,7 +202,9 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
# "linspace" and "leading" corresponds to annotation of Table 1. of https://arxiv.org/abs/2305.08891
|
||||
if self.config.timestep_spacing == "linspace":
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy()
|
||||
timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[
|
||||
::-1
|
||||
].copy()
|
||||
elif self.config.timestep_spacing == "leading":
|
||||
step_ratio = self.config.num_train_timesteps // self.num_inference_steps
|
||||
# creates integer timesteps by multiplying by ratio
|
||||
|
||||
Reference in New Issue
Block a user