From 7c5fef81e0aecff65c041a9dfb23aff22bf64f4b Mon Sep 17 00:00:00 2001 From: Nathan Lambert Date: Mon, 14 Nov 2022 13:48:48 -0800 Subject: [PATCH 1/6] Add UNet 1d for RL model for planning + colab (#105) * re-add RL model code * match model forward api * add register_to_config, pass training tests * fix tests, update forward outputs * remove unused code, some comments * add to docs * remove extra embedding code * unify time embedding * remove conv1d output sequential * remove sequential from conv1dblock * style and deleting duplicated code * clean files * remove unused variables * clean variables * add 1d resnet block structure for downsample * rename as unet1d * fix renaming * rename files * add get_block(...) api * unify args for model1d like model2d * minor cleaning * fix docs * improve 1d resnet blocks * fix tests, remove permuts * fix style * add output activation * rename flax blocks file * Add Value Function and corresponding example script to Diffuser implementation (#884) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review Co-authored-by: Nathan Lambert * update post merge of scripts * add mdiblock / outblock architecture * Pipeline cleanup (#947) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src Co-authored-by: Nathan Lambert * Update src/diffusers/models/unet_1d_blocks.py * Update tests/test_models_unet.py * RL Cleanup v2 (#965) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src * add specific vf block and update tests * style * Update tests/test_models_unet.py Co-authored-by: Nathan Lambert * fix quality in tests * fix quality style, split test file * fix checks / tests * make timesteps closer to main * unify block API * unify forward api * delete lines in examples * style * examples style * all tests pass * make style * make dance_diff test pass * Refactoring RL PR (#1200) * init file changes * add import utils * finish cleaning files, imports * remove import flags * clean examples * fix imports, tests for merge * update readmes * hotfix for tests * quality * fix some tests * change defaults * more mps test fixes * unet1d defaults * do not default import experimental * defaults for tests * fix tests * fix-copies * fix * changes per Patrik's comments (#1285) * changes per Patrik's comments * update conversion script * fix renaming * skip more mps tests * last test fix * Update examples/rl/README.md Co-authored-by: Ben Glickenhaus --- .gitignore | 4 +- docs/source/api/models.mdx | 9 +- examples/README.md | 2 +- examples/rl/README.md | 19 + examples/rl/run_diffuser_gen_trajectories.py | 57 +++ examples/rl/run_diffuser_locomotion.py | 57 +++ .../convert_models_diffuser_to_diffusers.py | 100 +++++ src/diffusers/experimental/README.md | 5 + src/diffusers/experimental/__init__.py | 1 + src/diffusers/experimental/rl/__init__.py | 1 + .../experimental/rl/value_guided_sampling.py | 129 +++++++ src/diffusers/models/embeddings.py | 13 +- src/diffusers/models/resnet.py | 138 ++++++- src/diffusers/models/unet_1d.py | 115 ++++-- src/diffusers/models/unet_1d_blocks.py | 346 ++++++++++++++++-- src/diffusers/schedulers/scheduling_ddpm.py | 6 +- tests/models/test_models_unet_1d.py | 235 +++++++++++- .../dance_diffusion/test_dance_diffusion.py | 4 + 18 files changed, 1176 insertions(+), 65 deletions(-) create mode 100644 examples/rl/README.md create mode 100644 examples/rl/run_diffuser_gen_trajectories.py create mode 100644 examples/rl/run_diffuser_locomotion.py create mode 100644 scripts/convert_models_diffuser_to_diffusers.py create mode 100644 src/diffusers/experimental/README.md create mode 100644 src/diffusers/experimental/__init__.py create mode 100644 src/diffusers/experimental/rl/__init__.py create mode 100644 src/diffusers/experimental/rl/value_guided_sampling.py diff --git a/.gitignore b/.gitignore index cf81834636..f018a111ea 100644 --- a/.gitignore +++ b/.gitignore @@ -163,4 +163,6 @@ tags *.lock # DS_Store (MacOS) -.DS_Store \ No newline at end of file +.DS_Store +# RL pipelines may produce mp4 outputs +*.mp4 \ No newline at end of file diff --git a/docs/source/api/models.mdx b/docs/source/api/models.mdx index 2e1e8798a7..7c1faa8474 100644 --- a/docs/source/api/models.mdx +++ b/docs/source/api/models.mdx @@ -22,12 +22,15 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module ## UNet2DOutput [[autodoc]] models.unet_2d.UNet2DOutput -## UNet1DModel -[[autodoc]] UNet1DModel - ## UNet2DModel [[autodoc]] UNet2DModel +## UNet1DOutput +[[autodoc]] models.unet_1d.UNet1DOutput + +## UNet1DModel +[[autodoc]] UNet1DModel + ## UNet2DConditionOutput [[autodoc]] models.unet_2d_condition.UNet2DConditionOutput diff --git a/examples/README.md b/examples/README.md index 29872a7a16..06ce06b9e3 100644 --- a/examples/README.md +++ b/examples/README.md @@ -42,7 +42,7 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie | [**Text-to-Image fine-tuning**](./text_to_image) | โœ… | โœ… | | [**Textual Inversion**](./textual_inversion) | โœ… | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_textual_inversion_training.ipynb) | [**Dreambooth**](./dreambooth) | โœ… | - | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb) - +| [**Reinforcement Learning for Control**](https://github.com/huggingface/diffusers/blob/main/examples/rl/run_diffusers_locomotion.py) | - | - | coming soon. ## Community diff --git a/examples/rl/README.md b/examples/rl/README.md new file mode 100644 index 0000000000..d68f2bf780 --- /dev/null +++ b/examples/rl/README.md @@ -0,0 +1,19 @@ +# Overview + +These examples show how to run (Diffuser)[https://arxiv.org/abs/2205.09991] in Diffusers. +There are four scripts, +1. `run_diffuser_locomotion.py` to sample actions and run them in the environment, +2. and `run_diffuser_gen_trajectories.py` to just sample actions from the pre-trained diffusion model. + +You will need some RL specific requirements to run the examples: + +``` +pip install -f https://download.pytorch.org/whl/torch_stable.html \ + free-mujoco-py \ + einops \ + gym==0.24.1 \ + protobuf==3.20.1 \ + git+https://github.com/rail-berkeley/d4rl.git \ + mediapy \ + Pillow==9.0.0 +``` diff --git a/examples/rl/run_diffuser_gen_trajectories.py b/examples/rl/run_diffuser_gen_trajectories.py new file mode 100644 index 0000000000..5bb068cc9f --- /dev/null +++ b/examples/rl/run_diffuser_gen_trajectories.py @@ -0,0 +1,57 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers.experimental import ValueGuidedRLPipeline + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=0, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +if __name__ == "__main__": + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + pipeline = ValueGuidedRLPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + env=env, + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # Call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") diff --git a/examples/rl/run_diffuser_locomotion.py b/examples/rl/run_diffuser_locomotion.py new file mode 100644 index 0000000000..e89181610b --- /dev/null +++ b/examples/rl/run_diffuser_locomotion.py @@ -0,0 +1,57 @@ +import d4rl # noqa +import gym +import tqdm +from diffusers.experimental import ValueGuidedRLPipeline + + +config = dict( + n_samples=64, + horizon=32, + num_inference_steps=20, + n_guide_steps=2, + scale_grad_by_std=True, + scale=0.1, + eta=0.0, + t_grad_cutoff=2, + device="cpu", +) + + +if __name__ == "__main__": + env_name = "hopper-medium-v2" + env = gym.make(env_name) + + pipeline = ValueGuidedRLPipeline.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", + env=env, + ) + + env.seed(0) + obs = env.reset() + total_reward = 0 + total_score = 0 + T = 1000 + rollout = [obs.copy()] + try: + for t in tqdm.tqdm(range(T)): + # call the policy + denorm_actions = pipeline(obs, planning_horizon=32) + + # execute action in environment + next_observation, reward, terminal, _ = env.step(denorm_actions) + score = env.get_normalized_score(total_reward) + # update return + total_reward += reward + total_score += score + print( + f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}, Score: {score}, Total Score:" + f" {total_score}" + ) + # save observations for rendering + rollout.append(next_observation.copy()) + + obs = next_observation + except KeyboardInterrupt: + pass + + print(f"Total reward: {total_reward}") diff --git a/scripts/convert_models_diffuser_to_diffusers.py b/scripts/convert_models_diffuser_to_diffusers.py new file mode 100644 index 0000000000..9475f7da93 --- /dev/null +++ b/scripts/convert_models_diffuser_to_diffusers.py @@ -0,0 +1,100 @@ +import json +import os + +import torch + +from diffusers import UNet1DModel + + +os.makedirs("hub/hopper-medium-v2/unet/hor32", exist_ok=True) +os.makedirs("hub/hopper-medium-v2/unet/hor128", exist_ok=True) + +os.makedirs("hub/hopper-medium-v2/value_function", exist_ok=True) + + +def unet(hor): + if hor == 128: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D") + + elif hor == 32: + down_block_types = ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D") + block_out_channels = (32, 64, 128, 256) + up_block_types = ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D") + model = torch.load(f"/Users/bglickenhaus/Documents/diffuser/temporal_unet-hopper-mediumv2-hor{hor}.torch") + state_dict = model.state_dict() + config = dict( + down_block_types=down_block_types, + block_out_channels=block_out_channels, + up_block_types=up_block_types, + layers_per_block=1, + use_timestep_embedding=True, + out_block_type="OutConv1DBlock", + norm_num_groups=8, + downsample_each_block=False, + in_channels=14, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + flip_sin_to_cos=False, + freq_shift=1, + sample_size=65536, + mid_block_type="MidResTemporalBlock1D", + act_fn="mish", + ) + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + mapping = dict((k, hfk) for k, hfk in zip(model.state_dict().keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), f"hub/hopper-medium-v2/unet/hor{hor}/diffusion_pytorch_model.bin") + with open(f"hub/hopper-medium-v2/unet/hor{hor}/config.json", "w") as f: + json.dump(config, f) + + +def value_function(): + config = dict( + in_channels=14, + down_block_types=("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + up_block_types=(), + out_block_type="ValueFunction", + mid_block_type="ValueFunctionMidBlock1D", + block_out_channels=(32, 64, 128, 256), + layers_per_block=1, + downsample_each_block=True, + sample_size=65536, + out_channels=14, + extra_in_channels=0, + time_embedding_type="positional", + use_timestep_embedding=True, + flip_sin_to_cos=False, + freq_shift=1, + norm_num_groups=8, + act_fn="mish", + ) + + model = torch.load("/Users/bglickenhaus/Documents/diffuser/value_function-hopper-mediumv2-hor32.torch") + state_dict = model + hf_value_function = UNet1DModel(**config) + print(f"length of state dict: {len(state_dict.keys())}") + print(f"length of value function dict: {len(hf_value_function.state_dict().keys())}") + + mapping = dict((k, hfk) for k, hfk in zip(state_dict.keys(), hf_value_function.state_dict().keys())) + for k, v in mapping.items(): + state_dict[v] = state_dict.pop(k) + + hf_value_function.load_state_dict(state_dict) + + torch.save(hf_value_function.state_dict(), "hub/hopper-medium-v2/value_function/diffusion_pytorch_model.bin") + with open("hub/hopper-medium-v2/value_function/config.json", "w") as f: + json.dump(config, f) + + +if __name__ == "__main__": + unet(32) + # unet(128) + value_function() diff --git a/src/diffusers/experimental/README.md b/src/diffusers/experimental/README.md new file mode 100644 index 0000000000..81a9de81c7 --- /dev/null +++ b/src/diffusers/experimental/README.md @@ -0,0 +1,5 @@ +# ๐Ÿงจ Diffusers Experimental + +We are adding experimental code to support novel applications and usages of the Diffusers library. +Currently, the following experiments are supported: +* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model. \ No newline at end of file diff --git a/src/diffusers/experimental/__init__.py b/src/diffusers/experimental/__init__.py new file mode 100644 index 0000000000..ebc8155403 --- /dev/null +++ b/src/diffusers/experimental/__init__.py @@ -0,0 +1 @@ +from .rl import ValueGuidedRLPipeline diff --git a/src/diffusers/experimental/rl/__init__.py b/src/diffusers/experimental/rl/__init__.py new file mode 100644 index 0000000000..7b338d3173 --- /dev/null +++ b/src/diffusers/experimental/rl/__init__.py @@ -0,0 +1 @@ +from .value_guided_sampling import ValueGuidedRLPipeline diff --git a/src/diffusers/experimental/rl/value_guided_sampling.py b/src/diffusers/experimental/rl/value_guided_sampling.py new file mode 100644 index 0000000000..8d5062e3d4 --- /dev/null +++ b/src/diffusers/experimental/rl/value_guided_sampling.py @@ -0,0 +1,129 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch + +import tqdm + +from ...models.unet_1d import UNet1DModel +from ...pipeline_utils import DiffusionPipeline +from ...utils.dummy_pt_objects import DDPMScheduler + + +class ValueGuidedRLPipeline(DiffusionPipeline): + def __init__( + self, + value_function: UNet1DModel, + unet: UNet1DModel, + scheduler: DDPMScheduler, + env, + ): + super().__init__() + self.value_function = value_function + self.unet = unet + self.scheduler = scheduler + self.env = env + self.data = env.get_dataset() + self.means = dict() + for key in self.data.keys(): + try: + self.means[key] = self.data[key].mean() + except: + pass + self.stds = dict() + for key in self.data.keys(): + try: + self.stds[key] = self.data[key].std() + except: + pass + self.state_dim = env.observation_space.shape[0] + self.action_dim = env.action_space.shape[0] + + def normalize(self, x_in, key): + return (x_in - self.means[key]) / self.stds[key] + + def de_normalize(self, x_in, key): + return x_in * self.stds[key] + self.means[key] + + def to_torch(self, x_in): + if type(x_in) is dict: + return {k: self.to_torch(v) for k, v in x_in.items()} + elif torch.is_tensor(x_in): + return x_in.to(self.unet.device) + return torch.tensor(x_in, device=self.unet.device) + + def reset_x0(self, x_in, cond, act_dim): + for key, val in cond.items(): + x_in[:, key, act_dim:] = val.clone() + return x_in + + def run_diffusion(self, x, conditions, n_guide_steps, scale): + batch_size = x.shape[0] + y = None + for i in tqdm.tqdm(self.scheduler.timesteps): + # create batch of timesteps to pass into model + timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long) + for _ in range(n_guide_steps): + with torch.enable_grad(): + x.requires_grad_() + y = self.value_function(x.permute(0, 2, 1), timesteps).sample + grad = torch.autograd.grad([y.sum()], [x])[0] + + posterior_variance = self.scheduler._get_variance(i) + model_std = torch.exp(0.5 * posterior_variance) + grad = model_std * grad + grad[timesteps < 2] = 0 + x = x.detach() + x = x + scale * grad + x = self.reset_x0(x, conditions, self.action_dim) + prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1) + x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"] + + # apply conditions to the trajectory + x = self.reset_x0(x, conditions, self.action_dim) + x = self.to_torch(x) + return x, y + + def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1): + # normalize the observations and create batch dimension + obs = self.normalize(obs, "observations") + obs = obs[None].repeat(batch_size, axis=0) + + conditions = {0: self.to_torch(obs)} + shape = (batch_size, planning_horizon, self.state_dim + self.action_dim) + + # generate initial noise and apply our conditions (to make the trajectories start at current state) + x1 = torch.randn(shape, device=self.unet.device) + x = self.reset_x0(x1, conditions, self.action_dim) + x = self.to_torch(x) + + # run the diffusion process + x, y = self.run_diffusion(x, conditions, n_guide_steps, scale) + + # sort output trajectories by value + sorted_idx = y.argsort(0, descending=True).squeeze() + sorted_values = x[sorted_idx] + actions = sorted_values[:, :, : self.action_dim] + actions = actions.detach().cpu().numpy() + denorm_actions = self.de_normalize(actions, key="actions") + + # select the action with the highest value + if y is not None: + selected_index = 0 + else: + # if we didn't run value guiding, select a random action + selected_index = np.random.randint(0, batch_size) + denorm_actions = denorm_actions[selected_index, 0] + return denorm_actions diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index b09d43fc2e..0221d891f1 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -62,14 +62,21 @@ def get_timestep_embedding( class TimestepEmbedding(nn.Module): - def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"): + def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None): super().__init__() - self.linear_1 = nn.Linear(channel, time_embed_dim) + self.linear_1 = nn.Linear(in_channels, time_embed_dim) self.act = None if act_fn == "silu": self.act = nn.SiLU() - self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim) + elif act_fn == "mish": + self.act = nn.Mish() + + if out_dim is not None: + time_embed_dim_out = out_dim + else: + time_embed_dim_out = time_embed_dim + self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out) def forward(self, sample): sample = self.linear_1(sample) diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index 7bb5416adf..52d056ae96 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -5,6 +5,75 @@ import torch.nn as nn import torch.nn.functional as F +class Upsample1D(nn.Module): + """ + An upsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + use_conv_transpose: + out_channels: + """ + + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + self.conv = None + if use_conv_transpose: + self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1) + elif use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.use_conv_transpose: + return self.conv(x) + + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + + if self.use_conv: + x = self.conv(x) + + return x + + +class Downsample1D(nn.Module): + """ + A downsampling layer with an optional convolution. + + Parameters: + channels: channels in the inputs and outputs. + use_conv: a bool determining if a convolution is applied. + out_channels: + padding: + """ + + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + assert self.channels == self.out_channels + self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.conv(x) + + class Upsample2D(nn.Module): """ An upsampling layer with an optional convolution. @@ -12,7 +81,8 @@ class Upsample2D(nn.Module): Parameters: channels: channels in the inputs and outputs. use_conv: a bool determining if a convolution is applied. - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions. + use_conv_transpose: + out_channels: """ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): @@ -80,7 +150,8 @@ class Downsample2D(nn.Module): Parameters: channels: channels in the inputs and outputs. use_conv: a bool determining if a convolution is applied. - dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions. + out_channels: + padding: """ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): @@ -415,6 +486,69 @@ class Mish(torch.nn.Module): return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) +# unet_rl.py +def rearrange_dims(tensor): + if len(tensor.shape) == 2: + return tensor[:, :, None] + if len(tensor.shape) == 3: + return tensor[:, :, None, :] + elif len(tensor.shape) == 4: + return tensor[:, :, 0, :] + else: + raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.") + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8): + super().__init__() + + self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.mish = nn.Mish() + + def forward(self, x): + x = self.conv1d(x) + x = rearrange_dims(x) + x = self.group_norm(x) + x = rearrange_dims(x) + x = self.mish(x) + return x + + +# unet_rl.py +class ResidualTemporalBlock1D(nn.Module): + def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5): + super().__init__() + self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size) + self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size) + + self.time_emb_act = nn.Mish() + self.time_emb = nn.Linear(embed_dim, out_channels) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity() + ) + + def forward(self, x, t): + """ + Args: + x : [ batch_size x inp_channels x horizon ] + t : [ batch_size x embed_dim ] + + returns: + out : [ batch_size x out_channels x horizon ] + """ + t = self.time_emb_act(t) + t = self.time_emb(t) + out = self.conv_in(x) + rearrange_dims(t) + out = self.conv_out(out) + return out + self.residual_conv(x) + + def upsample_2d(hidden_states, kernel=None, factor=2, gain=1): r"""Upsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given diff --git a/src/diffusers/models/unet_1d.py b/src/diffusers/models/unet_1d.py index cc0685deb9..29d1d707f5 100644 --- a/src/diffusers/models/unet_1d.py +++ b/src/diffusers/models/unet_1d.py @@ -1,3 +1,17 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import dataclass from typing import Optional, Tuple, Union @@ -8,7 +22,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps -from .unet_1d_blocks import get_down_block, get_mid_block, get_up_block +from .unet_1d_blocks import get_down_block, get_mid_block, get_out_block, get_up_block @dataclass @@ -30,11 +44,11 @@ class UNet1DModel(ModelMixin, ConfigMixin): implements for all the model (such as downloading or saving, etc.) Parameters: - sample_size (`int`, *optionl*): Default length of sample. Should be adaptable at runtime. + sample_size (`int`, *optional*): Default length of sample. Should be adaptable at runtime. in_channels (`int`, *optional*, defaults to 2): Number of channels in the input sample. out_channels (`int`, *optional*, defaults to 2): Number of channels in the output. time_embedding_type (`str`, *optional*, defaults to `"fourier"`): Type of time embedding to use. - freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. + freq_shift (`float`, *optional*, defaults to 0.0): Frequency shift for fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to : obj:`False`): Whether to flip sin to cos for fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to : @@ -43,6 +57,13 @@ class UNet1DModel(ModelMixin, ConfigMixin): obj:`("UpBlock1D", "UpBlock1DNoSkip", "AttnUpBlock1D")`): Tuple of upsample block types. block_out_channels (`Tuple[int]`, *optional*, defaults to : obj:`(32, 32, 64)`): Tuple of block output channels. + mid_block_type (`str`, *optional*, defaults to "UNetMidBlock1D"): block type for middle of UNet. + out_block_type (`str`, *optional*, defaults to `None`): optional output processing of UNet. + act_fn (`str`, *optional*, defaults to None): optional activitation function in UNet blocks. + norm_num_groups (`int`, *optional*, defaults to 8): group norm member count in UNet blocks. + layers_per_block (`int`, *optional*, defaults to 1): added number of layers in a UNet block. + downsample_each_block (`int`, *optional*, defaults to False: + experimental feature for using a UNet without upsampling. """ @register_to_config @@ -54,16 +75,20 @@ class UNet1DModel(ModelMixin, ConfigMixin): out_channels: int = 2, extra_in_channels: int = 0, time_embedding_type: str = "fourier", - freq_shift: int = 0, flip_sin_to_cos: bool = True, use_timestep_embedding: bool = False, + freq_shift: float = 0.0, down_block_types: Tuple[str] = ("DownBlock1DNoSkip", "DownBlock1D", "AttnDownBlock1D"), - mid_block_type: str = "UNetMidBlock1D", up_block_types: Tuple[str] = ("AttnUpBlock1D", "UpBlock1D", "UpBlock1DNoSkip"), + mid_block_type: Tuple[str] = "UNetMidBlock1D", + out_block_type: str = None, block_out_channels: Tuple[int] = (32, 32, 64), + act_fn: str = None, + norm_num_groups: int = 8, + layers_per_block: int = 1, + downsample_each_block: bool = False, ): super().__init__() - self.sample_size = sample_size # time @@ -73,12 +98,19 @@ class UNet1DModel(ModelMixin, ConfigMixin): ) timestep_input_dim = 2 * block_out_channels[0] elif time_embedding_type == "positional": - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + self.time_proj = Timesteps( + block_out_channels[0], flip_sin_to_cos=flip_sin_to_cos, downscale_freq_shift=freq_shift + ) timestep_input_dim = block_out_channels[0] if use_timestep_embedding: time_embed_dim = block_out_channels[0] * 4 - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.time_mlp = TimestepEmbedding( + in_channels=timestep_input_dim, + time_embed_dim=time_embed_dim, + act_fn=act_fn, + out_dim=block_out_channels[0], + ) self.down_blocks = nn.ModuleList([]) self.mid_block = None @@ -94,38 +126,66 @@ class UNet1DModel(ModelMixin, ConfigMixin): if i == 0: input_channel += extra_in_channels + is_final_block = i == len(block_out_channels) - 1 + down_block = get_down_block( down_block_type, + num_layers=layers_per_block, in_channels=input_channel, out_channels=output_channel, + temb_channels=block_out_channels[0], + add_downsample=not is_final_block or downsample_each_block, ) self.down_blocks.append(down_block) # mid self.mid_block = get_mid_block( - mid_block_type=mid_block_type, - mid_channels=block_out_channels[-1], + mid_block_type, in_channels=block_out_channels[-1], - out_channels=None, + mid_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + embed_dim=block_out_channels[0], + num_layers=layers_per_block, + add_downsample=downsample_each_block, ) # up reversed_block_out_channels = list(reversed(block_out_channels)) output_channel = reversed_block_out_channels[0] + if out_block_type is None: + final_upsample_channels = out_channels + else: + final_upsample_channels = block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else out_channels + output_channel = ( + reversed_block_out_channels[i + 1] if i < len(up_block_types) - 1 else final_upsample_channels + ) + + is_final_block = i == len(block_out_channels) - 1 up_block = get_up_block( up_block_type, + num_layers=layers_per_block, in_channels=prev_output_channel, out_channels=output_channel, + temb_channels=block_out_channels[0], + add_upsample=not is_final_block, ) self.up_blocks.append(up_block) prev_output_channel = output_channel - # TODO(PVP, Nathan) placeholder for RL application to be merged shortly - # Totally fine to add another layer with a if statement - no need for nn.Identity here + # out + num_groups_out = norm_num_groups if norm_num_groups is not None else min(block_out_channels[0] // 4, 32) + self.out_block = get_out_block( + out_block_type=out_block_type, + num_groups_out=num_groups_out, + embed_dim=block_out_channels[0], + out_channels=out_channels, + act_fn=act_fn, + fc_dim=block_out_channels[-1] // 4, + ) def forward( self, @@ -144,12 +204,20 @@ class UNet1DModel(ModelMixin, ConfigMixin): [`~models.unet_1d.UNet1DOutput`] or `tuple`: [`~models.unet_1d.UNet1DOutput`] if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. """ - # 1. time - if len(timestep.shape) == 0: - timestep = timestep[None] - timestep_embed = self.time_proj(timestep)[..., None] - timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + timesteps = torch.tensor([timesteps], dtype=torch.long, device=sample.device) + elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + timestep_embed = self.time_proj(timesteps) + if self.config.use_timestep_embedding: + timestep_embed = self.time_mlp(timestep_embed) + else: + timestep_embed = timestep_embed[..., None] + timestep_embed = timestep_embed.repeat([1, 1, sample.shape[2]]).to(sample.dtype) # 2. down down_block_res_samples = () @@ -158,13 +226,18 @@ class UNet1DModel(ModelMixin, ConfigMixin): down_block_res_samples += res_samples # 3. mid - sample = self.mid_block(sample) + if self.mid_block: + sample = self.mid_block(sample, timestep_embed) # 4. up for i, upsample_block in enumerate(self.up_blocks): res_samples = down_block_res_samples[-1:] down_block_res_samples = down_block_res_samples[:-1] - sample = upsample_block(sample, res_samples) + sample = upsample_block(sample, res_hidden_states_tuple=res_samples, temb=timestep_embed) + + # 5. post-process + if self.out_block: + sample = self.out_block(sample, timestep_embed) if not return_dict: return (sample,) diff --git a/src/diffusers/models/unet_1d_blocks.py b/src/diffusers/models/unet_1d_blocks.py index 9009071d1e..fc758ebbb0 100644 --- a/src/diffusers/models/unet_1d_blocks.py +++ b/src/diffusers/models/unet_1d_blocks.py @@ -17,6 +17,256 @@ import torch import torch.nn.functional as F from torch import nn +from .resnet import Downsample1D, ResidualTemporalBlock1D, Upsample1D, rearrange_dims + + +class DownResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + conv_shortcut=False, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_downsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.add_downsample = add_downsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True, padding=1) + + def forward(self, hidden_states, temb=None): + output_states = () + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.downsample is not None: + hidden_states = self.downsample(hidden_states) + + return hidden_states, output_states + + +class UpResnetBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels=None, + num_layers=1, + temb_channels=32, + groups=32, + groups_out=None, + non_linearity=None, + time_embedding_norm="default", + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.time_embedding_norm = time_embedding_norm + self.add_upsample = add_upsample + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(2 * in_channels, out_channels, embed_dim=temb_channels)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=temb_channels)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.upsample = None + if add_upsample: + self.upsample = Upsample1D(out_channels, use_conv_transpose=True) + + def forward(self, hidden_states, res_hidden_states_tuple=None, temb=None): + if res_hidden_states_tuple is not None: + res_hidden_states = res_hidden_states_tuple[-1] + hidden_states = torch.cat((hidden_states, res_hidden_states), dim=1) + + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.nonlinearity is not None: + hidden_states = self.nonlinearity(hidden_states) + + if self.upsample is not None: + hidden_states = self.upsample(hidden_states) + + return hidden_states + + +class ValueFunctionMidBlock1D(nn.Module): + def __init__(self, in_channels, out_channels, embed_dim): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.embed_dim = embed_dim + + self.res1 = ResidualTemporalBlock1D(in_channels, in_channels // 2, embed_dim=embed_dim) + self.down1 = Downsample1D(out_channels // 2, use_conv=True) + self.res2 = ResidualTemporalBlock1D(in_channels // 2, in_channels // 4, embed_dim=embed_dim) + self.down2 = Downsample1D(out_channels // 4, use_conv=True) + + def forward(self, x, temb=None): + x = self.res1(x, temb) + x = self.down1(x) + x = self.res2(x, temb) + x = self.down2(x) + return x + + +class MidResTemporalBlock1D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + embed_dim, + num_layers: int = 1, + add_downsample: bool = False, + add_upsample: bool = False, + non_linearity=None, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.add_downsample = add_downsample + + # there will always be at least one resnet + resnets = [ResidualTemporalBlock1D(in_channels, out_channels, embed_dim=embed_dim)] + + for _ in range(num_layers): + resnets.append(ResidualTemporalBlock1D(out_channels, out_channels, embed_dim=embed_dim)) + + self.resnets = nn.ModuleList(resnets) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = nn.Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + else: + self.nonlinearity = None + + self.upsample = None + if add_upsample: + self.upsample = Downsample1D(out_channels, use_conv=True) + + self.downsample = None + if add_downsample: + self.downsample = Downsample1D(out_channels, use_conv=True) + + if self.upsample and self.downsample: + raise ValueError("Block cannot downsample and upsample") + + def forward(self, hidden_states, temb): + hidden_states = self.resnets[0](hidden_states, temb) + for resnet in self.resnets[1:]: + hidden_states = resnet(hidden_states, temb) + + if self.upsample: + hidden_states = self.upsample(hidden_states) + if self.downsample: + self.downsample = self.downsample(hidden_states) + + return hidden_states + + +class OutConv1DBlock(nn.Module): + def __init__(self, num_groups_out, out_channels, embed_dim, act_fn): + super().__init__() + self.final_conv1d_1 = nn.Conv1d(embed_dim, embed_dim, 5, padding=2) + self.final_conv1d_gn = nn.GroupNorm(num_groups_out, embed_dim) + if act_fn == "silu": + self.final_conv1d_act = nn.SiLU() + if act_fn == "mish": + self.final_conv1d_act = nn.Mish() + self.final_conv1d_2 = nn.Conv1d(embed_dim, out_channels, 1) + + def forward(self, hidden_states, temb=None): + hidden_states = self.final_conv1d_1(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_gn(hidden_states) + hidden_states = rearrange_dims(hidden_states) + hidden_states = self.final_conv1d_act(hidden_states) + hidden_states = self.final_conv1d_2(hidden_states) + return hidden_states + + +class OutValueFunctionBlock(nn.Module): + def __init__(self, fc_dim, embed_dim): + super().__init__() + self.final_block = nn.ModuleList( + [ + nn.Linear(fc_dim + embed_dim, fc_dim // 2), + nn.Mish(), + nn.Linear(fc_dim // 2, 1), + ] + ) + + def forward(self, hidden_states, temb): + hidden_states = hidden_states.view(hidden_states.shape[0], -1) + hidden_states = torch.cat((hidden_states, temb), dim=-1) + for layer in self.final_block: + hidden_states = layer(hidden_states) + + return hidden_states + _kernels = { "linear": [1 / 8, 3 / 8, 3 / 8, 1 / 8], @@ -62,7 +312,7 @@ class Upsample1d(nn.Module): self.pad = kernel_1d.shape[0] // 2 - 1 self.register_buffer("kernel", kernel_1d) - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): hidden_states = F.pad(hidden_states, ((self.pad + 1) // 2,) * 2, self.pad_mode) weight = hidden_states.new_zeros([hidden_states.shape[1], hidden_states.shape[1], self.kernel.shape[0]]) indices = torch.arange(hidden_states.shape[1], device=hidden_states.device) @@ -162,32 +412,6 @@ class ResConvBlock(nn.Module): return output -def get_down_block(down_block_type, out_channels, in_channels): - if down_block_type == "DownBlock1D": - return DownBlock1D(out_channels=out_channels, in_channels=in_channels) - elif down_block_type == "AttnDownBlock1D": - return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) - elif down_block_type == "DownBlock1DNoSkip": - return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) - raise ValueError(f"{down_block_type} does not exist.") - - -def get_up_block(up_block_type, in_channels, out_channels): - if up_block_type == "UpBlock1D": - return UpBlock1D(in_channels=in_channels, out_channels=out_channels) - elif up_block_type == "AttnUpBlock1D": - return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) - elif up_block_type == "UpBlock1DNoSkip": - return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) - raise ValueError(f"{up_block_type} does not exist.") - - -def get_mid_block(mid_block_type, in_channels, mid_channels, out_channels): - if mid_block_type == "UNetMidBlock1D": - return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) - raise ValueError(f"{mid_block_type} does not exist.") - - class UNetMidBlock1D(nn.Module): def __init__(self, mid_channels, in_channels, out_channels=None): super().__init__() @@ -217,7 +441,7 @@ class UNetMidBlock1D(nn.Module): self.attentions = nn.ModuleList(attentions) self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states): + def forward(self, hidden_states, temb=None): hidden_states = self.down(hidden_states) for attn, resnet in zip(self.attentions, self.resnets): hidden_states = resnet(hidden_states) @@ -322,7 +546,7 @@ class AttnUpBlock1D(nn.Module): self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") - def forward(self, hidden_states, res_hidden_states_tuple): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -349,7 +573,7 @@ class UpBlock1D(nn.Module): self.resnets = nn.ModuleList(resnets) self.up = Upsample1d(kernel="cubic") - def forward(self, hidden_states, res_hidden_states_tuple): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -374,7 +598,7 @@ class UpBlock1DNoSkip(nn.Module): self.resnets = nn.ModuleList(resnets) - def forward(self, hidden_states, res_hidden_states_tuple): + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): res_hidden_states = res_hidden_states_tuple[-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) @@ -382,3 +606,63 @@ class UpBlock1DNoSkip(nn.Module): hidden_states = resnet(hidden_states) return hidden_states + + +def get_down_block(down_block_type, num_layers, in_channels, out_channels, temb_channels, add_downsample): + if down_block_type == "DownResnetBlock1D": + return DownResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "DownBlock1D": + return DownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "AttnDownBlock1D": + return AttnDownBlock1D(out_channels=out_channels, in_channels=in_channels) + elif down_block_type == "DownBlock1DNoSkip": + return DownBlock1DNoSkip(out_channels=out_channels, in_channels=in_channels) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block(up_block_type, num_layers, in_channels, out_channels, temb_channels, add_upsample): + if up_block_type == "UpResnetBlock1D": + return UpResnetBlock1D( + in_channels=in_channels, + num_layers=num_layers, + out_channels=out_channels, + temb_channels=temb_channels, + add_upsample=add_upsample, + ) + elif up_block_type == "UpBlock1D": + return UpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "AttnUpBlock1D": + return AttnUpBlock1D(in_channels=in_channels, out_channels=out_channels) + elif up_block_type == "UpBlock1DNoSkip": + return UpBlock1DNoSkip(in_channels=in_channels, out_channels=out_channels) + raise ValueError(f"{up_block_type} does not exist.") + + +def get_mid_block(mid_block_type, num_layers, in_channels, mid_channels, out_channels, embed_dim, add_downsample): + if mid_block_type == "MidResTemporalBlock1D": + return MidResTemporalBlock1D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + embed_dim=embed_dim, + add_downsample=add_downsample, + ) + elif mid_block_type == "ValueFunctionMidBlock1D": + return ValueFunctionMidBlock1D(in_channels=in_channels, out_channels=out_channels, embed_dim=embed_dim) + elif mid_block_type == "UNetMidBlock1D": + return UNetMidBlock1D(in_channels=in_channels, mid_channels=mid_channels, out_channels=out_channels) + raise ValueError(f"{mid_block_type} does not exist.") + + +def get_out_block(*, out_block_type, num_groups_out, embed_dim, out_channels, act_fn, fc_dim): + if out_block_type == "OutConv1DBlock": + return OutConv1DBlock(num_groups_out, out_channels, embed_dim, act_fn) + elif out_block_type == "ValueFunction": + return OutValueFunctionBlock(fc_dim, embed_dim) + return None diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index a19d91879c..c3e373d2bd 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -204,6 +204,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): # for rl-diffuser https://arxiv.org/abs/2205.09991 elif variance_type == "fixed_small_log": variance = torch.log(torch.clamp(variance, min=1e-20)) + variance = torch.exp(0.5 * variance) elif variance_type == "fixed_large": variance = self.betas[t] elif variance_type == "fixed_large_log": @@ -301,7 +302,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): variance_noise = torch.randn( model_output.shape, generator=generator, device=device, dtype=model_output.dtype ) - variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise + if self.variance_type == "fixed_small_log": + variance = self._get_variance(t, predicted_variance=predicted_variance) * variance_noise + else: + variance = (self._get_variance(t, predicted_variance=predicted_variance) ** 0.5) * variance_noise pred_prev_sample = pred_prev_sample + variance diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index c274ce4192..41c4fdecfa 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -18,13 +18,120 @@ import unittest import torch from diffusers import UNet1DModel -from diffusers.utils import slow, torch_device +from diffusers.utils import floats_tensor, slow, torch_device + +from ..test_modeling_common import ModelTesterMixin torch.backends.cuda.matmul.allow_tf32 = False -class UnetModel1DTests(unittest.TestCase): +class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 16) + + def test_ema_training(self): + pass + + def test_training(self): + pass + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_determinism(self): + super().test_determinism() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_save_pretrained(self): + super().test_from_pretrained_save_pretrained() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_model_from_config(self): + super().test_model_from_config() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output(self): + super().test_output() + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64, 128, 256), + "in_channels": 14, + "out_channels": 14, + "time_embedding_type": "positional", + "use_timestep_embedding": True, + "flip_sin_to_cos": False, + "freq_shift": 1.0, + "out_block_type": "OutConv1DBlock", + "mid_block_type": "MidResTemporalBlock1D", + "down_block_types": ("DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"), + "up_block_types": ("UpResnetBlock1D", "UpResnetBlock1D", "UpResnetBlock1D"), + "act_fn": "mish", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_hub(self): + model, loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="unet" + ) + self.assertIsNotNone(model) + self.assertEqual(len(loading_info["missing_keys"]), 0) + + model.to(torch_device) + image = model(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output_pretrained(self): + model = UNet1DModel.from_pretrained("bglick13/hopper-medium-v2-value-function-hor32", subfolder="unet") + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = model.in_channels + seq_len = 16 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = model(noise, time_step).sample.permute(0, 2, 1) + + output_slice = output[0, -3:, -3:].flatten() + # fmt: off + expected_output_slice = torch.tensor([-2.137172, 1.1426016, 0.3688687, -0.766922, 0.7303146, 0.11038864, -0.4760633, 0.13270172, 0.02591348]) + # fmt: on + self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass + @slow def test_unet_1d_maestro(self): model_id = "harmonai/maestro-150k" @@ -43,3 +150,127 @@ class UnetModel1DTests(unittest.TestCase): assert (output_sum - 224.0896).abs() < 4e-2 assert (output_max - 0.0607).abs() < 4e-4 + + +class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet1DModel + + @property + def dummy_input(self): + batch_size = 4 + num_features = 14 + seq_len = 16 + + noise = floats_tensor((batch_size, num_features, seq_len)).to(torch_device) + time_step = torch.tensor([10] * batch_size).to(torch_device) + + return {"sample": noise, "timestep": time_step} + + @property + def input_shape(self): + return (4, 14, 16) + + @property + def output_shape(self): + return (4, 14, 1) + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_determinism(self): + super().test_determinism() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_save_pretrained(self): + super().test_from_pretrained_save_pretrained() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_model_from_config(self): + super().test_model_from_config() + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output(self): + # UNetRL is a value-function is different output shape + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + model.eval() + + with torch.no_grad(): + output = model(**inputs_dict) + + if isinstance(output, dict): + output = output.sample + + self.assertIsNotNone(output) + expected_shape = torch.Size((inputs_dict["sample"].shape[0], 1)) + self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") + + def test_ema_training(self): + pass + + def test_training(self): + pass + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "in_channels": 14, + "out_channels": 14, + "down_block_types": ["DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D", "DownResnetBlock1D"], + "up_block_types": [], + "out_block_type": "ValueFunction", + "mid_block_type": "ValueFunctionMidBlock1D", + "block_out_channels": [32, 64, 128, 256], + "layers_per_block": 1, + "downsample_each_block": True, + "use_timestep_embedding": True, + "freq_shift": 1.0, + "flip_sin_to_cos": False, + "time_embedding_type": "positional", + "act_fn": "mish", + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_from_pretrained_hub(self): + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" + ) + self.assertIsNotNone(value_function) + self.assertEqual(len(vf_loading_info["missing_keys"]), 0) + + value_function.to(torch_device) + image = value_function(**self.dummy_input) + + assert image is not None, "Make sure output is not None" + + @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") + def test_output_pretrained(self): + value_function, vf_loading_info = UNet1DModel.from_pretrained( + "bglick13/hopper-medium-v2-value-function-hor32", output_loading_info=True, subfolder="value_function" + ) + torch.manual_seed(0) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(0) + + num_features = value_function.in_channels + seq_len = 14 + noise = torch.randn((1, seq_len, num_features)).permute( + 0, 2, 1 + ) # match original, we can update values and remove + time_step = torch.full((num_features,), 0) + + with torch.no_grad(): + output = value_function(noise, time_step).sample + + # fmt: off + expected_output_slice = torch.tensor([165.25] * seq_len) + # fmt: on + self.assertTrue(torch.allclose(output, expected_output_slice, rtol=1e-3)) + + def test_forward_with_norm_groups(self): + # Not implemented yet for this UNet + pass diff --git a/tests/pipelines/dance_diffusion/test_dance_diffusion.py b/tests/pipelines/dance_diffusion/test_dance_diffusion.py index 72e67e4479..a63ef84c63 100644 --- a/tests/pipelines/dance_diffusion/test_dance_diffusion.py +++ b/tests/pipelines/dance_diffusion/test_dance_diffusion.py @@ -44,6 +44,10 @@ class PipelineFastTests(unittest.TestCase): sample_rate=16_000, in_channels=2, out_channels=2, + flip_sin_to_cos=True, + use_timestep_embedding=False, + time_embedding_type="fourier", + mid_block_type="UNetMidBlock1D", down_block_types=["DownBlock1DNoSkip"] + ["DownBlock1D"] + ["AttnDownBlock1D"], up_block_types=["AttnUpBlock1D"] + ["UpBlock1D"] + ["UpBlock1DNoSkip"], ) From 57525bb41879d42fa8a0379b0c800011e5802277 Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Mon, 14 Nov 2022 23:54:09 +0200 Subject: [PATCH 2/6] Fix documentation typo for `UNet2DModel` and `UNet2DConditionModel` (#1275) * Fix documentation typo * Fix other typo --- src/diffusers/models/unet_2d.py | 2 +- src/diffusers/models/unet_2d_condition.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 641c253c86..0432405760 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -51,7 +51,7 @@ class UNet2DModel(ModelMixin, ConfigMixin): time_embedding_type (`str`, *optional*, defaults to `"positional"`): Type of time embedding to use. freq_shift (`int`, *optional*, defaults to 0): Frequency shift for fourier time embedding. flip_sin_to_cos (`bool`, *optional*, defaults to : - obj:`False`): Whether to flip sin to cos for fourier time embedding. + obj:`True`): Whether to flip sin to cos for fourier time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to : obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block types. diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index 7f7f3ecd44..becae75683 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -60,7 +60,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample. out_channels (`int`, *optional*, defaults to 4): The number of channels in the output. center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + flip_sin_to_cos (`bool`, *optional*, defaults to `True`): Whether to flip the sin to cos in the time embedding. freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): From 07f9e56d51e994d14e1e0d3330f1ee193df44f0d Mon Sep 17 00:00:00 2001 From: Nan Liu <33531451+nanliu1@users.noreply.github.com> Date: Tue, 15 Nov 2022 03:19:06 -0600 Subject: [PATCH 3/6] add source link to composable diffusion model (#1293) --- examples/community/README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/community/README.md b/examples/community/README.md index fd6fff79c5..5535937dca 100644 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -15,7 +15,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | Long Prompt Weighting Stable Diffusion | **One** Stable Diffusion Pipeline without tokens length limit, and support parsing weighting in prompt. | [Long Prompt Weighting Stable Diffusion](#long-prompt-weighting-stable-diffusion) | - | [SkyTNT](https://github.com/SkyTNT) | | Speech to Image | Using automatic-speech-recognition to transcribe text and Stable Diffusion to generate images | [Speech to Image](#speech-to-image) | - | [Mikail Duzenli](https://github.com/MikailINTech) | Wild Card Stable Diffusion | Stable Diffusion Pipeline that supports prompts that contain wildcard terms (indicated by surrounding double underscores), with values instantiated randomly from a corresponding txt file or a dictionary of possible values | [Wildcard Stable Diffusion](#wildcard-stable-diffusion) | - | [Shyam Sudhakaran](https://github.com/shyamsn97) | -| Composable Stable Diffusion| Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | +| [Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) | Stable Diffusion Pipeline that supports prompts that contain "|" in prompts (as an AND condition) and weights (separated by "|" as well) to positively / negatively weight prompts. | [Composable Stable Diffusion](#composable-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | | Seed Resizing Stable Diffusion| Stable Diffusion Pipeline that supports resizing an image and retaining the concepts of the 512 by 512 generation. | [Seed Resizing](#seed-resizing) | - | [Mark Rich](https://github.com/MarkRich) | | Imagic Stable Diffusion | Stable Diffusion Pipeline that enables writing a text prompt to edit an existing image| [Imagic Stable Diffusion](#imagic-stable-diffusion) | - | [Mark Rich](https://github.com/MarkRich) | | Multilingual Stable Diffusion| Stable Diffusion Pipeline that supports prompts in 50 different languages. | [Multilingual Stable Diffusion](#multilingual-stable-diffusion-pipeline) | - | [Juan Carlos Piรฑeros](https://github.com/juancopi81) | @@ -345,6 +345,8 @@ out = pipe( ### Composable Stable diffusion +[Composable Stable Diffusion](https://energy-based-model.github.io/Compositional-Visual-Generation-with-Composable-Diffusion-Models/) proposes conjunction and negation (negative prompts) operators for compositional generation with conditional diffusion models. + ```python import torch as th import numpy as np From 610e2a6fd98126b5a2fc9bf223eae2cd7de5b032 Mon Sep 17 00:00:00 2001 From: Dhruv Naik Date: Tue, 15 Nov 2022 04:19:35 -0500 Subject: [PATCH 4/6] Fix incorrect link to Stable Diffusion notebook (#1291) Update README.md --- src/diffusers/pipelines/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/README.md b/src/diffusers/pipelines/README.md index 2941660fa2..6ff40d3549 100644 --- a/src/diffusers/pipelines/README.md +++ b/src/diffusers/pipelines/README.md @@ -40,7 +40,7 @@ available a colab notebook to directly try them out. | [pndm](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | *Unconditional Image Generation* | | [score_sde_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | | [score_sde_vp](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | *Unconditional Image Generation* | -| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb) +| [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-to-Image Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb) | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Image-to-Image Text-Guided Generation* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb) | [stable_diffusion](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | *Text-Guided Image Inpainting* | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb) | [stochastic_karras_ve](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | *Unconditional Image Generation* | From db1cb0b1a233cc6f4029261a67e503b322f31cd0 Mon Sep 17 00:00:00 2001 From: Glenn 'devalias' Grant Date: Tue, 15 Nov 2022 22:53:54 +1100 Subject: [PATCH 5/6] [dreambooth] link to bitsandbytes readme for installation (#1229) * add 'conda install cudatoolkit' to dreambooth 'training on 16GB' example fixes https://github.com/huggingface/diffusers/issues/1207 * Apply suggestions from code review Co-authored-by: Suraj Patil --- examples/dreambooth/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 3c9d04abc2..2339e2979d 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -92,7 +92,7 @@ accelerate launch train_dreambooth.py \ With the help of gradient checkpointing and the 8-bit optimizer from bitsandbytes it's possible to run train dreambooth on a 16GB GPU. -Install `bitsandbytes` with `pip install bitsandbytes` +To install `bitandbytes` please refer to this [readme](https://github.com/TimDettmers/bitsandbytes#requirements--installation). ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" From a0520193e15951655ee2c08c24bfdca716f6f64c Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 15 Nov 2022 18:15:13 +0100 Subject: [PATCH 6/6] Add Scheduler.from_pretrained and better scheduler changing (#1286) * add conversion script for vae * uP * uP * more changes * push * up * finish again * up * up * up * up * finish * up * uP * up * Apply suggestions from code review Co-authored-by: Pedro Cuenca Co-authored-by: Anton Lozhkov Co-authored-by: Suraj Patil * up * up Co-authored-by: Pedro Cuenca Co-authored-by: Anton Lozhkov Co-authored-by: Suraj Patil --- README.md | 10 +- docs/source/_toctree.yml | 2 + docs/source/api/configuration.mdx | 4 +- docs/source/api/pipelines/cycle_diffusion.mdx | 2 +- docs/source/api/pipelines/repaint.mdx | 2 +- .../source/api/pipelines/stable_diffusion.mdx | 12 +- docs/source/quicktour.mdx | 25 +- docs/source/using-diffusers/loading.mdx | 26 +- docs/source/using-diffusers/schedulers.mdx | 262 ++++++++++++++++++ src/diffusers/configuration_utils.py | 222 ++++++++++----- src/diffusers/pipeline_flax_utils.py | 10 +- src/diffusers/pipeline_utils.py | 18 +- src/diffusers/pipelines/ddpm/pipeline_ddpm.py | 2 +- .../pipelines/stable_diffusion/README.md | 6 +- src/diffusers/schedulers/scheduling_ddim.py | 15 +- .../schedulers/scheduling_ddim_flax.py | 13 +- src/diffusers/schedulers/scheduling_ddpm.py | 17 +- .../schedulers/scheduling_ddpm_flax.py | 15 +- .../scheduling_dpmsolver_multistep.py | 14 +- .../scheduling_dpmsolver_multistep_flax.py | 13 +- .../scheduling_euler_ancestral_discrete.py | 15 +- .../schedulers/scheduling_euler_discrete.py | 15 +- src/diffusers/schedulers/scheduling_ipndm.py | 4 +- .../schedulers/scheduling_karras_ve.py | 4 +- .../schedulers/scheduling_karras_ve_flax.py | 4 +- .../schedulers/scheduling_lms_discrete.py | 15 +- .../scheduling_lms_discrete_flax.py | 13 +- src/diffusers/schedulers/scheduling_pndm.py | 14 +- .../schedulers/scheduling_pndm_flax.py | 13 +- .../schedulers/scheduling_repaint.py | 4 +- src/diffusers/schedulers/scheduling_sde_ve.py | 4 +- .../schedulers/scheduling_sde_ve_flax.py | 4 +- src/diffusers/schedulers/scheduling_sde_vp.py | 4 +- src/diffusers/schedulers/scheduling_utils.py | 111 ++++++++ .../schedulers/scheduling_utils_flax.py | 121 +++++++- .../schedulers/scheduling_vq_diffusion.py | 4 +- src/diffusers/utils/__init__.py | 10 + tests/models/test_models_unet_1d.py | 8 +- tests/pipelines/ddim/test_ddim.py | 2 +- tests/pipelines/ddpm/test_ddpm.py | 2 +- tests/pipelines/repaint/test_repaint.py | 2 +- .../score_sde_ve/test_score_sde_ve.py | 2 +- .../stable_diffusion/test_cycle_diffusion.py | 4 +- .../test_onnx_stable_diffusion.py | 4 +- .../test_onnx_stable_diffusion_img2img.py | 2 +- .../test_onnx_stable_diffusion_inpaint.py | 2 +- .../stable_diffusion/test_stable_diffusion.py | 4 +- .../test_stable_diffusion_img2img.py | 6 +- .../test_stable_diffusion_inpaint.py | 4 +- .../test_stable_diffusion_inpaint_legacy.py | 2 +- tests/test_config.py | 131 +-------- tests/test_modeling_common.py | 6 +- tests/test_pipelines.py | 82 +++++- tests/test_scheduler.py | 228 ++++++++++++++- tests/test_scheduler_flax.py | 16 +- 55 files changed, 1149 insertions(+), 407 deletions(-) create mode 100644 docs/source/using-diffusers/schedulers.mdx diff --git a/README.md b/README.md index 64cbd15aab..4a944d0459 100644 --- a/README.md +++ b/README.md @@ -152,15 +152,7 @@ it before the pipeline and pass it to `from_pretrained`. ```python from diffusers import LMSDiscreteScheduler -lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") - -pipe = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", - revision="fp16", - torch_dtype=torch.float16, - scheduler=lms, -) -pipe = pipe.to("cuda") +pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) prompt = "a photo of an astronaut riding a horse on mars" image = pipe(prompt).images[0] diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index d8efb5eee3..0e8dec8167 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -10,6 +10,8 @@ - sections: - local: using-diffusers/loading title: "Loading Pipelines, Models, and Schedulers" + - local: using-diffusers/schedulers + title: "Using different Schedulers" - local: using-diffusers/configuration title: "Configuring Pipelines, Models, and Schedulers" - local: using-diffusers/custom_pipeline_overview diff --git a/docs/source/api/configuration.mdx b/docs/source/api/configuration.mdx index 45176f55b0..423c31f462 100644 --- a/docs/source/api/configuration.mdx +++ b/docs/source/api/configuration.mdx @@ -15,9 +15,9 @@ specific language governing permissions and limitations under the License. In Diffusers, schedulers of type [`schedulers.scheduling_utils.SchedulerMixin`], and models of type [`ModelMixin`] inherit from [`ConfigMixin`] which conveniently takes care of storing all parameters that are passed to the respective `__init__` methods in a JSON-configuration file. -TODO(PVP) - add example and better info here - ## ConfigMixin + [[autodoc]] ConfigMixin + - load_config - from_config - save_config diff --git a/docs/source/api/pipelines/cycle_diffusion.mdx b/docs/source/api/pipelines/cycle_diffusion.mdx index 50d2a5c87e..8eecd3d624 100644 --- a/docs/source/api/pipelines/cycle_diffusion.mdx +++ b/docs/source/api/pipelines/cycle_diffusion.mdx @@ -39,7 +39,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler # load the pipeline # make sure you're logged in with `huggingface-cli login` model_id_or_path = "CompVis/stable-diffusion-v1-4" -scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") +scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") # let's download an initial image diff --git a/docs/source/api/pipelines/repaint.mdx b/docs/source/api/pipelines/repaint.mdx index 0b7de8a457..ce262daffa 100644 --- a/docs/source/api/pipelines/repaint.mdx +++ b/docs/source/api/pipelines/repaint.mdx @@ -54,7 +54,7 @@ original_image = download_image(img_url).resize((256, 256)) mask_image = download_image(mask_url).resize((256, 256)) # Load the RePaint scheduler and pipeline based on a pretrained DDPM model -scheduler = RePaintScheduler.from_config("google/ddpm-ema-celebahq-256") +scheduler = RePaintScheduler.from_pretrained("google/ddpm-ema-celebahq-256") pipe = RePaintPipeline.from_pretrained("google/ddpm-ema-celebahq-256", scheduler=scheduler) pipe = pipe.to("cuda") diff --git a/docs/source/api/pipelines/stable_diffusion.mdx b/docs/source/api/pipelines/stable_diffusion.mdx index 26d6a210ad..1d22024a53 100644 --- a/docs/source/api/pipelines/stable_diffusion.mdx +++ b/docs/source/api/pipelines/stable_diffusion.mdx @@ -34,13 +34,17 @@ For more details about how Stable Diffusion works and how it differs from the ba ### How to load and use different schedulers. The stable diffusion pipeline uses [`PNDMScheduler`] scheduler by default. But `diffusers` provides many other schedulers that can be used with the stable diffusion pipeline such as [`DDIMScheduler`], [`LMSDiscreteScheduler`], [`EulerDiscreteScheduler`], [`EulerAncestralDiscreteScheduler`] etc. -To use a different scheduler, you can pass the `scheduler` argument to `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: +To use a different scheduler, you can either change it via the [`ConfigMixin.from_config`] method or pass the `scheduler` argument to the `from_pretrained` method of the pipeline. For example, to use the [`EulerDiscreteScheduler`], you can do the following: ```python -from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler +>>> from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler -euler_scheduler = EulerDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") -pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler) +>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4") +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +>>> # or +>>> euler_scheduler = EulerDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") +>>> pipeline = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", scheduler=euler_scheduler) ``` diff --git a/docs/source/quicktour.mdx b/docs/source/quicktour.mdx index 463780a072..a50b476c3d 100644 --- a/docs/source/quicktour.mdx +++ b/docs/source/quicktour.mdx @@ -41,7 +41,7 @@ In this guide though, you'll use [`DiffusionPipeline`] for text-to-image generat ```python >>> from diffusers import DiffusionPipeline ->>> generator = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") +>>> pipeline = DiffusionPipeline.from_pretrained("CompVis/ldm-text2im-large-256") ``` The [`DiffusionPipeline`] downloads and caches all modeling, tokenization, and scheduling components. @@ -49,13 +49,13 @@ Because the model consists of roughly 1.4 billion parameters, we strongly recomm You can move the generator object to GPU, just like you would in PyTorch. ```python ->>> generator.to("cuda") +>>> pipeline.to("cuda") ``` -Now you can use the `generator` on your text prompt: +Now you can use the `pipeline` on your text prompt: ```python ->>> image = generator("An image of a squirrel in Picasso style").images[0] +>>> image = pipeline("An image of a squirrel in Picasso style").images[0] ``` The output is by default wrapped into a [PIL Image object](https://pillow.readthedocs.io/en/stable/reference/Image.html?highlight=image#the-image-class). @@ -82,7 +82,7 @@ just like we did before only that now you need to pass your `AUTH_TOKEN`: ```python >>> from diffusers import DiffusionPipeline ->>> generator = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) +>>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) ``` If you do not pass your authentication token you will see that the diffusion system will not be correctly @@ -102,7 +102,7 @@ token. Assuming that `"./stable-diffusion-v1-5"` is the local path to the cloned you can also load the pipeline as follows: ```python ->>> generator = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") +>>> pipeline = DiffusionPipeline.from_pretrained("./stable-diffusion-v1-5") ``` Running the pipeline is then identical to the code above as it's the same model architecture. @@ -115,19 +115,20 @@ Running the pipeline is then identical to the code above as it's the same model Diffusion systems can be used with multiple different [schedulers](./api/schedulers) each with their pros and cons. By default, Stable Diffusion runs with [`PNDMScheduler`], but it's very simple to -use a different scheduler. *E.g.* if you would instead like to use the [`LMSDiscreteScheduler`] scheduler, +use a different scheduler. *E.g.* if you would instead like to use the [`EulerDiscreteScheduler`] scheduler, you could use it as follows: ```python ->>> from diffusers import LMSDiscreteScheduler +>>> from diffusers import EulerDiscreteScheduler ->>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") +>>> pipeline = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", use_auth_token=AUTH_TOKEN) ->>> generator = StableDiffusionPipeline.from_pretrained( -... "runwayml/stable-diffusion-v1-5", scheduler=scheduler, use_auth_token=AUTH_TOKEN -... ) +>>> # change scheduler to Euler +>>> pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) ``` +For more in-detail information on how to change between schedulers, please refer to the [Using Schedulers](./using-diffusers/schedulers) guide. + [Stability AI's](https://stability.ai/) Stable Diffusion model is an impressive image generation model and can do much more than just generating images from text. We have dedicated a whole documentation page, just for Stable Diffusion [here](./conceptual/stable_diffusion). diff --git a/docs/source/using-diffusers/loading.mdx b/docs/source/using-diffusers/loading.mdx index 2cb980ea61..c97ad5c5d0 100644 --- a/docs/source/using-diffusers/loading.mdx +++ b/docs/source/using-diffusers/loading.mdx @@ -19,7 +19,7 @@ In the following we explain in-detail how to easily load: - *Complete Diffusion Pipelines* via the [`DiffusionPipeline.from_pretrained`] - *Diffusion Models* via [`ModelMixin.from_pretrained`] -- *Schedulers* via [`ConfigMixin.from_config`] +- *Schedulers* via [`SchedulerMixin.from_pretrained`] ## Loading pipelines @@ -137,15 +137,15 @@ from diffusers import DiffusionPipeline, EulerDiscreteScheduler, DPMSolverMultis repo_id = "runwayml/stable-diffusion-v1-5" -scheduler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler") +scheduler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") # or -# scheduler = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler") +# scheduler = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler") stable_diffusion = DiffusionPipeline.from_pretrained(repo_id, scheduler=scheduler) ``` Three things are worth paying attention to here. -- First, the scheduler is loaded with [`ConfigMixin.from_config`] since it only depends on a configuration file and not any parameterized weights +- First, the scheduler is loaded with [`SchedulerMixin.from_pretrained`] - Second, the scheduler is loaded with a function argument, called `subfolder="scheduler"` as the configuration of stable diffusion's scheduling is defined in a [subfolder of the official pipeline repository](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main/scheduler) - Third, the scheduler instance can simply be passed with the `scheduler` keyword argument to [`DiffusionPipeline.from_pretrained`]. This works because the [`StableDiffusionPipeline`] defines its scheduler with the `scheduler` attribute. It's not possible to use a different name, such as `sampler=scheduler` since `sampler` is not a defined keyword for [`StableDiffusionPipeline.__init__`] @@ -337,8 +337,8 @@ model = UNet2DModel.from_pretrained(repo_id) ## Loading schedulers -Schedulers cannot be loaded via a `from_pretrained` method, but instead rely on [`ConfigMixin.from_config`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file. -Therefore the loading method was given a different name here. +Schedulers rely on [`SchedulerMixin.from_pretrained`]. Schedulers are **not parameterized** or **trained**, but instead purely defined by a configuration file. +For consistency, we use the same method name as we do for models or pipelines, but no weights are loaded in this case. In constrast to pipelines or models, loading schedulers does not consume any significant amount of memory and the same configuration file can often be used for a variety of different schedulers. For example, all of: @@ -367,13 +367,13 @@ from diffusers import ( repo_id = "runwayml/stable-diffusion-v1-5" -ddpm = DDPMScheduler.from_config(repo_id, subfolder="scheduler") -ddim = DDIMScheduler.from_config(repo_id, subfolder="scheduler") -pndm = PNDMScheduler.from_config(repo_id, subfolder="scheduler") -lms = LMSDiscreteScheduler.from_config(repo_id, subfolder="scheduler") -euler_anc = EulerAncestralDiscreteScheduler.from_config(repo_id, subfolder="scheduler") -euler = EulerDiscreteScheduler.from_config(repo_id, subfolder="scheduler") -dpm = DPMSolverMultistepScheduler.from_config(repo_id, subfolder="scheduler") +ddpm = DDPMScheduler.from_pretrained(repo_id, subfolder="scheduler") +ddim = DDIMScheduler.from_pretrained(repo_id, subfolder="scheduler") +pndm = PNDMScheduler.from_pretrained(repo_id, subfolder="scheduler") +lms = LMSDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +euler_anc = EulerAncestralDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +euler = EulerDiscreteScheduler.from_pretrained(repo_id, subfolder="scheduler") +dpm = DPMSolverMultistepScheduler.from_pretrained(repo_id, subfolder="scheduler") # replace `dpm` with any of `ddpm`, `ddim`, `pndm`, `lms`, `euler`, `euler_anc` pipeline = StableDiffusionPipeline.from_pretrained(repo_id, scheduler=dpm) diff --git a/docs/source/using-diffusers/schedulers.mdx b/docs/source/using-diffusers/schedulers.mdx new file mode 100644 index 0000000000..87ff789747 --- /dev/null +++ b/docs/source/using-diffusers/schedulers.mdx @@ -0,0 +1,262 @@ + + +# Schedulers + +Diffusion pipelines are inherently a collection of diffusion models and schedulers that are partly independent from each other. This means that one is able to switch out parts of the pipeline to better customize +a pipeline to one's use case. The best example of this are the [Schedulers](../api/schedulers.mdx). + +Whereas diffusion models usually simply define the forward pass from noise to a less noisy sample, +schedulers define the whole denoising process, *i.e.*: +- How many denoising steps? +- Stochastic or deterministic? +- What algorithm to use to find the denoised sample + +They can be quite complex and often define a trade-off between **denoising speed** and **denoising quality**. +It is extremely difficult to measure quantitatively which scheduler works best for a given diffusion pipeline, so it is often recommended to simply try out which works best. + +The following paragraphs shows how to do so with the ๐Ÿงจ Diffusers library. + +## Load pipeline + +Let's start by loading the stable diffusion pipeline. +Remember that you have to be a registered user on the ๐Ÿค— Hugging Face Hub, and have "click-accepted" the [license](https://huggingface.co/runwayml/stable-diffusion-v1-5) in order to use stable diffusion. + +```python +from huggingface_hub import login +from diffusers import DiffusionPipeline +import torch + +# first we need to login with our access token +login() + +# Now we can download the pipeline +pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) +``` + +Next, we move it to GPU: + +```python +pipeline.to("cuda") +``` + +## Access the scheduler + +The scheduler is always one of the components of the pipeline and is usually called `"scheduler"`. +So it can be accessed via the `"scheduler"` property. + +```python +pipeline.scheduler +``` + +**Output**: +``` +PNDMScheduler { + "_class_name": "PNDMScheduler", + "_diffusers_version": "0.8.0.dev0", + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": false, + "num_train_timesteps": 1000, + "set_alpha_to_one": false, + "skip_prk_steps": true, + "steps_offset": 1, + "trained_betas": null +} +``` + +We can see that the scheduler is of type [`PNDMScheduler`]. +Cool, now let's compare the scheduler in its performance to other schedulers. +First we define a prompt on which we will test all the different schedulers: + +```python +prompt = "A photograph of an astronaut riding a horse on Mars, high resolution, high definition." +``` + +Next, we create a generator from a random seed that will ensure that we can generate similar images as well as run the pipeline: + +```python +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator).images[0] +image +``` + +

+
+ +
+

+ + +## Changing the scheduler + +Now we show how easy it is to change the scheduler of a pipeline. Every scheduler has a property [`SchedulerMixin.compatibles`] +which defines all compatible schedulers. You can take a look at all available, compatible schedulers for the Stable Diffusion pipeline as follows. + +```python +pipeline.scheduler.compatibles +``` + +**Output**: +``` +[diffusers.schedulers.scheduling_lms_discrete.LMSDiscreteScheduler, + diffusers.schedulers.scheduling_ddim.DDIMScheduler, + diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler, + diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler, + diffusers.schedulers.scheduling_pndm.PNDMScheduler, + diffusers.schedulers.scheduling_ddpm.DDPMScheduler, + diffusers.schedulers.scheduling_euler_ancestral_discrete.EulerAncestralDiscreteScheduler] +``` + +Cool, lots of schedulers to look at. Feel free to have a look at their respective class definitions: + +- [`LMSDiscreteScheduler`], +- [`DDIMScheduler`], +- [`DPMSolverMultistepScheduler`], +- [`EulerDiscreteScheduler`], +- [`PNDMScheduler`], +- [`DDPMScheduler`], +- [`EulerAncestralDiscreteScheduler`]. + +We will now compare the input prompt with all other schedulers. To change the scheduler of the pipeline you can make use of the +convenient [`ConfigMixin.config`] property in combination with the [`ConfigMixin.from_config`] function. + +```python +pipeline.scheduler.config +``` + +returns a dictionary of the configuration of the scheduler: + +**Output**: +``` +FrozenDict([('num_train_timesteps', 1000), + ('beta_start', 0.00085), + ('beta_end', 0.012), + ('beta_schedule', 'scaled_linear'), + ('trained_betas', None), + ('skip_prk_steps', True), + ('set_alpha_to_one', False), + ('steps_offset', 1), + ('_class_name', 'PNDMScheduler'), + ('_diffusers_version', '0.8.0.dev0'), + ('clip_sample', False)]) +``` + +This configuration can then be used to instantiate a scheduler +of a different class that is compatible with the pipeline. Here, +we change the scheduler to the [`DDIMScheduler`]. + +```python +from diffusers import DDIMScheduler + +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +``` + +Cool, now we can run the pipeline again to compare the generation quality. + +```python +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator).images[0] +image +``` + +

+
+ +
+

+ + +## Compare schedulers + +So far we have tried running the stable diffusion pipeline with two schedulers: [`PNDMScheduler`] and [`DDIMScheduler`]. +A number of better schedulers have been released that can be run with much fewer steps, let's compare them here: + +[`LMSDiscreteScheduler`] usually leads to better results: + +```python +from diffusers import LMSDiscreteScheduler + +pipeline.scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator).images[0] +image +``` + +

+
+ +
+

+ + +[`EulerDiscreteScheduler`] and [`EulerAncestralDiscreteScheduler`] can generate high quality results with as little as 30 steps. + +```python +from diffusers import EulerDiscreteScheduler + +pipeline.scheduler = EulerDiscreteScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0] +image +``` + +

+
+ +
+

+ + +and: + +```python +from diffusers import EulerAncestralDiscreteScheduler + +pipeline.scheduler = EulerAncestralDiscreteScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator, num_inference_steps=30).images[0] +image +``` + +

+
+ +
+

+ + +At the time of writing this doc [`DPMSolverMultistepScheduler`] gives arguably the best speed/quality trade-off and can be run with as little +as 20 steps. + +```python +from diffusers import DPMSolverMultistepScheduler + +pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + +generator = torch.Generator(device="cuda").manual_seed(8) +image = pipeline(prompt, generator=generator, num_inference_steps=20).images[0] +image +``` + +

+
+ +
+

+ +As you can see most images look very similar and are arguably of very similar quality. It often really depends on the specific use case which scheduler to choose. A good approach is always to run multiple different +schedulers to compare results. diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index fc6ac9b5b9..c4819ddc2e 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -29,7 +29,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R from requests import HTTPError from . import __version__ -from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, logging +from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging logger = logging.get_logger(__name__) @@ -37,6 +37,38 @@ logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") +class FrozenDict(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + for key, value in self.items(): + setattr(self, key, value) + + self.__frozen = True + + def __delitem__(self, *args, **kwargs): + raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __setattr__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setattr__(name, value) + + def __setitem__(self, name, value): + if hasattr(self, "__frozen") and self.__frozen: + raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") + super().__setitem__(name, value) + + class ConfigMixin: r""" Base class for all configuration classes. Stores all configuration parameters under `self.config` Also handles all @@ -49,13 +81,12 @@ class ConfigMixin: [`~ConfigMixin.save_config`] (should be overridden by parent class). - **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be overridden by parent class). - - **_compatible_classes** (`List[str]`) -- A list of classes that are compatible with the parent class, so that - `from_config` can be used from a class different than the one used to save the config (should be overridden - by parent class). + - **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by parent + class). """ config_name = None ignore_for_config = [] - _compatible_classes = [] + has_compatibles = False def register_to_config(self, **kwargs): if self.config_name is None: @@ -104,9 +135,98 @@ class ConfigMixin: logger.info(f"Configuration saved in {output_config_file}") @classmethod - def from_config(cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs): + def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs): r""" - Instantiate a Python class from a pre-defined JSON-file. + Instantiate a Python class from a config dictionary + + Parameters: + config (`Dict[str, Any]`): + A config dictionary from which the Python class will be instantiated. Make sure to only load + configuration files of compatible classes. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the Python class. + `**kwargs` will be directly passed to the underlying scheduler/model's `__init__` method and eventually + overwrite same named arguments of `config`. + + Examples: + + ```python + >>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler + + >>> # Download scheduler from huggingface.co and cache. + >>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32") + + >>> # Instantiate DDIM scheduler class with same config as DDPM + >>> scheduler = DDIMScheduler.from_config(scheduler.config) + + >>> # Instantiate PNDM scheduler class with same config as DDPM + >>> scheduler = PNDMScheduler.from_config(scheduler.config) + ``` + """ + # <===== TO BE REMOVED WITH DEPRECATION + # TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated + if "pretrained_model_name_or_path" in kwargs: + config = kwargs.pop("pretrained_model_name_or_path") + + if config is None: + raise ValueError("Please make sure to provide a config as the first positional argument.") + # ======> + + if not isinstance(config, dict): + deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`." + if "Scheduler" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead." + " Otherwise, please make sure to pass a configuration dictionary instead. This functionality will" + " be removed in v1.0.0." + ) + elif "Model" in cls.__name__: + deprecation_message += ( + f"If you were trying to load a model, please use {cls}.load_config(...) followed by" + f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary" + " instead. This functionality will be removed in v1.0.0." + ) + deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False) + config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs) + + init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs) + + # Allow dtype to be specified on initialization + if "dtype" in unused_kwargs: + init_dict["dtype"] = unused_kwargs.pop("dtype") + + # Return model and optionally state and/or unused_kwargs + model = cls(**init_dict) + + # make sure to also save config parameters that might be used for compatible classes + model.register_to_config(**hidden_dict) + + # add hidden kwargs of compatible classes to unused_kwargs + unused_kwargs = {**unused_kwargs, **hidden_dict} + + if return_unused_kwargs: + return (model, unused_kwargs) + else: + return model + + @classmethod + def get_config_dict(cls, *args, **kwargs): + deprecation_message = ( + f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be" + " removed in version v1.0.0" + ) + deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False) + return cls.load_config(*args, **kwargs) + + @classmethod + def load_config( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + r""" + Instantiate a Python class from a config dictionary Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): @@ -120,10 +240,6 @@ class ConfigMixin: cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory in which a downloaded pretrained model configuration should be cached if the standard cache should not be used. - ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): - Whether or not to raise an error if some of the weights from the checkpoint do not have the same size - as the weights of the model (if for instance, you are instantiating a model with 10 labels from a - checkpoint with 3 labels). force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. @@ -161,33 +277,7 @@ class ConfigMixin: use this method in a firewalled environment. - """ - config_dict = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) - init_dict, unused_kwargs = cls.extract_init_dict(config_dict, **kwargs) - - # Allow dtype to be specified on initialization - if "dtype" in unused_kwargs: - init_dict["dtype"] = unused_kwargs.pop("dtype") - - # Return model and optionally state and/or unused_kwargs - model = cls(**init_dict) - return_tuple = (model,) - - # Flax schedulers have a state, so return it. - if cls.__name__.startswith("Flax") and hasattr(model, "create_state") and getattr(model, "has_state", False): - state = model.create_state() - return_tuple += (state,) - - if return_unused_kwargs: - return return_tuple + (unused_kwargs,) - else: - return return_tuple if len(return_tuple) > 1 else model - - @classmethod - def get_config_dict( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) @@ -283,6 +373,9 @@ class ConfigMixin: except (json.JSONDecodeError, UnicodeDecodeError): raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.") + if return_unused_kwargs: + return config_dict, kwargs + return config_dict @staticmethod @@ -291,6 +384,9 @@ class ConfigMixin: @classmethod def extract_init_dict(cls, config_dict, **kwargs): + # 0. Copy origin config dict + original_dict = {k: v for k, v in config_dict.items()} + # 1. Retrieve expected config attributes from __init__ signature expected_keys = cls._get_init_keys(cls) expected_keys.remove("self") @@ -310,10 +406,11 @@ class ConfigMixin: # load diffusers library to import compatible and original scheduler diffusers_library = importlib.import_module(__name__.split(".")[0]) - # remove attributes from compatible classes that orig cannot expect - compatible_classes = [getattr(diffusers_library, c, None) for c in cls._compatible_classes] - # filter out None potentially undefined dummy classes - compatible_classes = [c for c in compatible_classes if c is not None] + if cls.has_compatibles: + compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)] + else: + compatible_classes = [] + expected_keys_comp_cls = set() for c in compatible_classes: expected_keys_c = cls._get_init_keys(c) @@ -364,7 +461,10 @@ class ConfigMixin: # 6. Define unused keyword arguments unused_kwargs = {**config_dict, **kwargs} - return init_dict, unused_kwargs + # 7. Define "hidden" config parameters that were saved for compatible classes + hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict and not k.startswith("_")} + + return init_dict, unused_kwargs, hidden_config_dict @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): @@ -377,6 +477,12 @@ class ConfigMixin: @property def config(self) -> Dict[str, Any]: + """ + Returns the config of the class as a frozen dictionary + + Returns: + `Dict[str, Any]`: Config of the class. + """ return self._internal_dict def to_json_string(self) -> str: @@ -401,38 +507,6 @@ class ConfigMixin: writer.write(self.to_json_string()) -class FrozenDict(OrderedDict): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - for key, value in self.items(): - setattr(self, key, value) - - self.__frozen = True - - def __delitem__(self, *args, **kwargs): - raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") - - def setdefault(self, *args, **kwargs): - raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") - - def pop(self, *args, **kwargs): - raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") - - def update(self, *args, **kwargs): - raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") - - def __setattr__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setattr__(name, value) - - def __setitem__(self, name, value): - if hasattr(self, "__frozen") and self.__frozen: - raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.") - super().__setitem__(name, value) - - def register_to_config(init): r""" Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are diff --git a/src/diffusers/pipeline_flax_utils.py b/src/diffusers/pipeline_flax_utils.py index 4c34e64f78..54bb028139 100644 --- a/src/diffusers/pipeline_flax_utils.py +++ b/src/diffusers/pipeline_flax_utils.py @@ -47,7 +47,7 @@ logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "FlaxModelMixin": ["save_pretrained", "from_pretrained"], - "FlaxSchedulerMixin": ["save_config", "from_config"], + "FlaxSchedulerMixin": ["save_pretrained", "from_pretrained"], "FlaxDiffusionPipeline": ["save_pretrained", "from_pretrained"], }, "transformers": { @@ -280,7 +280,7 @@ class FlaxDiffusionPipeline(ConfigMixin): >>> from diffusers import FlaxDPMSolverMultistepScheduler >>> model_id = "runwayml/stable-diffusion-v1-5" - >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_config( + >>> sched, sched_state = FlaxDPMSolverMultistepScheduler.from_pretrained( ... model_id, ... subfolder="scheduler", ... ) @@ -303,7 +303,7 @@ class FlaxDiffusionPipeline(ConfigMixin): # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.get_config_dict( + config_dict = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, @@ -349,7 +349,7 @@ class FlaxDiffusionPipeline(ConfigMixin): else: cached_folder = pretrained_model_name_or_path - config_dict = cls.get_config_dict(cached_folder) + config_dict = cls.load_config(cached_folder) # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it @@ -370,7 +370,7 @@ class FlaxDiffusionPipeline(ConfigMixin): expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_dict, _, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index a194f3eb34..b4b1b9dd0c 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -65,7 +65,7 @@ logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], - "SchedulerMixin": ["save_config", "from_config"], + "SchedulerMixin": ["save_pretrained", "from_pretrained"], "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "OnnxRuntimeModel": ["save_pretrained", "from_pretrained"], }, @@ -207,7 +207,7 @@ class DiffusionPipeline(ConfigMixin): if torch_device is None: return self - module_names, _ = self.extract_init_dict(dict(self.config)) + module_names, _, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): @@ -228,7 +228,7 @@ class DiffusionPipeline(ConfigMixin): Returns: `torch.device`: The torch device on which the pipeline is located. """ - module_names, _ = self.extract_init_dict(dict(self.config)) + module_names, _, _ = self.extract_init_dict(dict(self.config)) for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): @@ -377,11 +377,11 @@ class DiffusionPipeline(ConfigMixin): >>> # of the documentation](https://huggingface.co/docs/hub/security-tokens) >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") - >>> # Download pipeline, but overwrite scheduler + >>> # Use a different scheduler >>> from diffusers import LMSDiscreteScheduler - >>> scheduler = LMSDiscreteScheduler.from_config("runwayml/stable-diffusion-v1-5", subfolder="scheduler") - >>> pipeline = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", scheduler=scheduler) + >>> scheduler = LMSDiscreteScheduler.from_config(pipeline.scheduler.config) + >>> pipeline.scheduler = scheduler ``` """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) @@ -428,7 +428,7 @@ class DiffusionPipeline(ConfigMixin): # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): - config_dict = cls.get_config_dict( + config_dict = cls.load_config( pretrained_model_name_or_path, cache_dir=cache_dir, resume_download=resume_download, @@ -474,7 +474,7 @@ class DiffusionPipeline(ConfigMixin): else: cached_folder = pretrained_model_name_or_path - config_dict = cls.get_config_dict(cached_folder) + config_dict = cls.load_config(cached_folder) # 2. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it @@ -513,7 +513,7 @@ class DiffusionPipeline(ConfigMixin): expected_modules = set(inspect.signature(pipeline_class.__init__).parameters.keys()) - set(["self"]) passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs} - init_dict, unused_kwargs = pipeline_class.extract_init_dict(config_dict, **kwargs) + init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) if len(unused_kwargs) > 0: logger.warning(f"Keyword arguments {unused_kwargs} not recognized.") diff --git a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py index b7194664f4..c937a23003 100644 --- a/src/diffusers/pipelines/ddpm/pipeline_ddpm.py +++ b/src/diffusers/pipelines/ddpm/pipeline_ddpm.py @@ -71,7 +71,7 @@ class DDPMPipeline(DiffusionPipeline): """ message = ( "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" - " DDPMScheduler.from_config(, predict_epsilon=True)`." + " DDPMScheduler.from_pretrained(, predict_epsilon=True)`." ) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) diff --git a/src/diffusers/pipelines/stable_diffusion/README.md b/src/diffusers/pipelines/stable_diffusion/README.md index a76e4c6682..bc30be4a7b 100644 --- a/src/diffusers/pipelines/stable_diffusion/README.md +++ b/src/diffusers/pipelines/stable_diffusion/README.md @@ -72,7 +72,7 @@ image.save("astronaut_rides_horse.png") # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, DDIMScheduler -scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") +scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", @@ -91,7 +91,7 @@ image.save("astronaut_rides_horse.png") # make sure you're logged in with `huggingface-cli login` from diffusers import StableDiffusionPipeline, LMSDiscreteScheduler -lms = LMSDiscreteScheduler.from_config("CompVis/stable-diffusion-v1-4", subfolder="scheduler") +lms = LMSDiscreteScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler") pipe = StableDiffusionPipeline.from_pretrained( "runwayml/stable-diffusion-v1-5", @@ -120,7 +120,7 @@ from diffusers import CycleDiffusionPipeline, DDIMScheduler # load the pipeline # make sure you're logged in with `huggingface-cli login` model_id_or_path = "CompVis/stable-diffusion-v1-4" -scheduler = DDIMScheduler.from_config(model_id_or_path, subfolder="scheduler") +scheduler = DDIMScheduler.from_pretrained(model_id_or_path, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained(model_id_or_path, scheduler=scheduler).to("cuda") # let's download an initial image diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index 75cef635d0..1326b503ed 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -23,7 +23,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput from .scheduling_utils import SchedulerMixin @@ -82,8 +82,8 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 @@ -109,14 +109,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "PNDMScheduler", - "DDPMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_ddim_flax.py b/src/diffusers/schedulers/scheduling_ddim_flax.py index 590e3aac2e..ceef96a4a9 100644 --- a/src/diffusers/schedulers/scheduling_ddim_flax.py +++ b/src/diffusers/schedulers/scheduling_ddim_flax.py @@ -23,7 +23,12 @@ import flax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -79,8 +84,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2010.02502 @@ -105,6 +110,8 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin): stable diffusion. """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @property def has_state(self): return True diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index c3e373d2bd..299a06f4eb 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -22,7 +22,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config -from ..utils import BaseOutput, deprecate +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate from .scheduling_utils import SchedulerMixin @@ -80,8 +80,8 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 @@ -104,14 +104,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "PNDMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() @register_to_config def __init__( @@ -249,7 +242,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin): """ message = ( "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" - " DDPMScheduler.from_config(, predict_epsilon=True)`." + " DDPMScheduler.from_pretrained(, predict_epsilon=True)`." ) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index f1b04a0417..480cbda73c 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -24,7 +24,12 @@ from jax import random from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config from ..utils import deprecate -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> jnp.ndarray: @@ -79,8 +84,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2006.11239 @@ -103,6 +108,8 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @property def has_state(self): return True @@ -221,7 +228,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): """ message = ( "Please make sure to instantiate your scheduler with `predict_epsilon` instead. E.g. `scheduler =" - " DDPMScheduler.from_config(, predict_epsilon=True)`." + " DDPMScheduler.from_pretrained(, predict_epsilon=True)`." ) predict_epsilon = deprecate("predict_epsilon", "0.10.0", message, take_from=kwargs) if predict_epsilon is not None and predict_epsilon != self.config.predict_epsilon: diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index d166354809..472b24637d 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -21,6 +21,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -71,8 +72,8 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -116,14 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "PNDMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py index c9a6d1cd5c..d6fa383534 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py @@ -23,7 +23,12 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -96,8 +101,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2206.00927 and https://arxiv.org/abs/2211.01095 @@ -143,6 +148,8 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin): """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @property def has_state(self): return True diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 621b5c17c0..f3abf017d9 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -19,7 +19,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging from .scheduling_utils import SchedulerMixin @@ -52,8 +52,8 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -67,14 +67,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "LMSDiscreteScheduler", - "PNDMScheduler", - "EulerDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 2f9e938474..d9991bc3a0 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -19,7 +19,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput, logging +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging from .scheduling_utils import SchedulerMixin @@ -53,8 +53,8 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -68,14 +68,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "LMSDiscreteScheduler", - "PNDMScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py index fb413a2805..e5495713a8 100644 --- a/src/diffusers/schedulers/scheduling_ipndm.py +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -28,8 +28,8 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 diff --git a/src/diffusers/schedulers/scheduling_karras_ve.py b/src/diffusers/schedulers/scheduling_karras_ve.py index 743f2e061c..b2eb332aed 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve.py +++ b/src/diffusers/schedulers/scheduling_karras_ve.py @@ -56,8 +56,8 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index 78ab007954..c4e612c3cc 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -67,8 +67,8 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details on the parameters, see the original paper's Appendix E.: "Elucidating the Design Space of Diffusion-Based Generative Models." https://arxiv.org/abs/2206.00364. The grid search values used to find the diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 373c373ee0..8a9aedb41b 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -21,7 +21,7 @@ import torch from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from ..utils import BaseOutput +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput from .scheduling_utils import SchedulerMixin @@ -52,8 +52,8 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -67,14 +67,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "PNDMScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py index 20982d38aa..21f25f72fa 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete_flax.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete_flax.py @@ -20,7 +20,12 @@ import jax.numpy as jnp from scipy import integrate from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) @flax.struct.dataclass @@ -49,8 +54,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. @@ -63,6 +68,8 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin): option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @property def has_state(self): return True diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index eec18af8d3..8bf0a59582 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -21,6 +21,7 @@ import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS from .scheduling_utils import SchedulerMixin, SchedulerOutput @@ -60,8 +61,8 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 @@ -88,14 +89,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin): """ - _compatible_classes = [ - "DDIMScheduler", - "DDPMScheduler", - "LMSDiscreteScheduler", - "EulerDiscreteScheduler", - "EulerAncestralDiscreteScheduler", - "DPMSolverMultistepScheduler", - ] + _compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() @register_to_config def __init__( diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index 357ecfe046..298e62de20 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -23,7 +23,12 @@ import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config -from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left +from .scheduling_utils_flax import ( + _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, + FlaxSchedulerMixin, + FlaxSchedulerOutput, + broadcast_to_shape_from_left, +) def betas_for_alpha_bar(num_diffusion_timesteps: int, max_beta=0.999) -> jnp.ndarray: @@ -87,8 +92,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2202.09778 @@ -114,6 +119,8 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin): stable diffusion. """ + _compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy() + @property def has_state(self): return True diff --git a/src/diffusers/schedulers/scheduling_repaint.py b/src/diffusers/schedulers/scheduling_repaint.py index 1751f41c66..55625c1bfa 100644 --- a/src/diffusers/schedulers/scheduling_repaint.py +++ b/src/diffusers/schedulers/scheduling_repaint.py @@ -77,8 +77,8 @@ class RePaintScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/pdf/2201.09865.pdf diff --git a/src/diffusers/schedulers/scheduling_sde_ve.py b/src/diffusers/schedulers/scheduling_sde_ve.py index d31adbc3c6..1d436ab0cb 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve.py +++ b/src/diffusers/schedulers/scheduling_sde_ve.py @@ -50,8 +50,8 @@ class ScoreSdeVeScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index d3eadede61..d1f762bc90 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -64,8 +64,8 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. Args: num_train_timesteps (`int`): number of diffusion steps used to train the model. diff --git a/src/diffusers/schedulers/scheduling_sde_vp.py b/src/diffusers/schedulers/scheduling_sde_vp.py index a37a159a87..537d6f7e2a 100644 --- a/src/diffusers/schedulers/scheduling_sde_vp.py +++ b/src/diffusers/schedulers/scheduling_sde_vp.py @@ -29,8 +29,8 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more information, see the original paper: https://arxiv.org/abs/2011.13456 diff --git a/src/diffusers/schedulers/scheduling_utils.py b/src/diffusers/schedulers/scheduling_utils.py index 29bf982f6a..90ab674e38 100644 --- a/src/diffusers/schedulers/scheduling_utils.py +++ b/src/diffusers/schedulers/scheduling_utils.py @@ -11,7 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import os from dataclasses import dataclass +from typing import Any, Dict, Optional, Union import torch @@ -38,6 +41,114 @@ class SchedulerOutput(BaseOutput): class SchedulerMixin: """ Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). """ config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Dict[str, Any] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON configuration file inside a directory or Hub repo. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing the schedluer configurations saved using + [`~SchedulerMixin.save_pretrained`], e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, + subfolder=subfolder, + return_unused_kwargs=True, + **kwargs, + ) + return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~SchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes diff --git a/src/diffusers/schedulers/scheduling_utils_flax.py b/src/diffusers/schedulers/scheduling_utils_flax.py index e545cfe247..b3024ca450 100644 --- a/src/diffusers/schedulers/scheduling_utils_flax.py +++ b/src/diffusers/schedulers/scheduling_utils_flax.py @@ -11,15 +11,18 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib +import os from dataclasses import dataclass -from typing import Tuple +from typing import Any, Dict, Optional, Tuple, Union import jax.numpy as jnp -from ..utils import BaseOutput +from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput SCHEDULER_CONFIG_NAME = "scheduler_config.json" +_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS] @dataclass @@ -39,9 +42,123 @@ class FlaxSchedulerOutput(BaseOutput): class FlaxSchedulerMixin: """ Mixin containing common functions for the schedulers. + + Class attributes: + - **_compatibles** (`List[str]`) -- A list of classes that are compatible with the parent class, so that + `from_config` can be used from a class different than the one used to save the config (should be overridden + by parent class). """ config_name = SCHEDULER_CONFIG_NAME + _compatibles = [] + has_compatibles = True + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Dict[str, Any] = None, + subfolder: Optional[str] = None, + return_unused_kwargs=False, + **kwargs, + ): + r""" + Instantiate a Scheduler class from a pre-defined JSON-file. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a model repo on huggingface.co. Valid model ids should have an + organization name, like `google/ddpm-celebahq-256`. + - A path to a *directory* containing model weights saved using [`~SchedulerMixin.save_pretrained`], + e.g., `./my_model_directory/`. + subfolder (`str`, *optional*): + In case the relevant files are located inside a subfolder of the model repo (either remote in + huggingface.co or downloaded locally), you can specify the folder name here. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + Whether kwargs that are not consumed by the Python class should be returned or not. + + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated + models](https://huggingface.co/docs/hub/models-gated#gated-models). + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + config, kwargs = cls.load_config( + pretrained_model_name_or_path=pretrained_model_name_or_path, return_unused_kwargs=True, **kwargs + ) + scheduler, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **kwargs) + + if hasattr(scheduler, "create_state") and getattr(scheduler, "has_state", False): + state = scheduler.create_state() + + if return_unused_kwargs: + return scheduler, state, unused_kwargs + + return scheduler, state + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a scheduler configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~FlaxSchedulerMixin.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + """ + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def compatibles(self): + """ + Returns all schedulers that are compatible with this scheduler + + Returns: + `List[SchedulerMixin]`: List of compatible schedulers + """ + return self._get_compatibles() + + @classmethod + def _get_compatibles(cls): + compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) + diffusers_library = importlib.import_module(__name__.split(".")[0]) + compatible_classes = [ + getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) + ] + return compatible_classes def broadcast_to_shape_from_left(x: jnp.ndarray, shape: Tuple[int]) -> jnp.ndarray: diff --git a/src/diffusers/schedulers/scheduling_vq_diffusion.py b/src/diffusers/schedulers/scheduling_vq_diffusion.py index dbe320d998..91c46e6554 100644 --- a/src/diffusers/schedulers/scheduling_vq_diffusion.py +++ b/src/diffusers/schedulers/scheduling_vq_diffusion.py @@ -112,8 +112,8 @@ class VQDiffusionScheduler(SchedulerMixin, ConfigMixin): [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. For more details, see the original paper: https://arxiv.org/abs/2111.14822 diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index a00e1f4dcd..a80d249498 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -72,3 +72,13 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co" DIFFUSERS_CACHE = default_cache_path DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules" HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) + +_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [ + "DDIMScheduler", + "DDPMScheduler", + "PNDMScheduler", + "LMSDiscreteScheduler", + "EulerDiscreteScheduler", + "EulerAncestralDiscreteScheduler", + "DPMSolverMultistepScheduler", +] diff --git a/tests/models/test_models_unet_1d.py b/tests/models/test_models_unet_1d.py index 41c4fdecfa..089d935651 100644 --- a/tests/models/test_models_unet_1d.py +++ b/tests/models/test_models_unet_1d.py @@ -67,8 +67,8 @@ class UNet1DModelTests(ModelTesterMixin, unittest.TestCase): super().test_from_pretrained_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_model_from_config(self): - super().test_model_from_config() + def test_model_from_pretrained(self): + super().test_model_from_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output(self): @@ -187,8 +187,8 @@ class UNetRLModelTests(ModelTesterMixin, unittest.TestCase): super().test_from_pretrained_save_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") - def test_model_from_config(self): - super().test_model_from_config() + def test_model_from_pretrained(self): + super().test_model_from_pretrained() @unittest.skipIf(torch_device == "mps", "mish op not supported in MPS") def test_output(self): diff --git a/tests/pipelines/ddim/test_ddim.py b/tests/pipelines/ddim/test_ddim.py index 81c49912be..2d03383599 100644 --- a/tests/pipelines/ddim/test_ddim.py +++ b/tests/pipelines/ddim/test_ddim.py @@ -75,7 +75,7 @@ class DDIMPipelineIntegrationTests(unittest.TestCase): model_id = "google/ddpm-ema-bedroom-256" unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDIMScheduler.from_config(model_id) + scheduler = DDIMScheduler.from_pretrained(model_id) ddpm = DDIMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) diff --git a/tests/pipelines/ddpm/test_ddpm.py b/tests/pipelines/ddpm/test_ddpm.py index 14bc094697..ef293109bf 100644 --- a/tests/pipelines/ddpm/test_ddpm.py +++ b/tests/pipelines/ddpm/test_ddpm.py @@ -106,7 +106,7 @@ class DDPMPipelineIntegrationTests(unittest.TestCase): model_id = "google/ddpm-cifar10-32" unet = UNet2DModel.from_pretrained(model_id) - scheduler = DDPMScheduler.from_config(model_id) + scheduler = DDPMScheduler.from_pretrained(model_id) ddpm = DDPMPipeline(unet=unet, scheduler=scheduler) ddpm.to(torch_device) diff --git a/tests/pipelines/repaint/test_repaint.py b/tests/pipelines/repaint/test_repaint.py index 23544dfd24..3ab0efc875 100644 --- a/tests/pipelines/repaint/test_repaint.py +++ b/tests/pipelines/repaint/test_repaint.py @@ -44,7 +44,7 @@ class RepaintPipelineIntegrationTests(unittest.TestCase): model_id = "google/ddpm-ema-celebahq-256" unet = UNet2DModel.from_pretrained(model_id) - scheduler = RePaintScheduler.from_config(model_id) + scheduler = RePaintScheduler.from_pretrained(model_id) repaint = RePaintPipeline(unet=unet, scheduler=scheduler).to(torch_device) diff --git a/tests/pipelines/score_sde_ve/test_score_sde_ve.py b/tests/pipelines/score_sde_ve/test_score_sde_ve.py index 55dcc1cea1..9cdf3f0191 100644 --- a/tests/pipelines/score_sde_ve/test_score_sde_ve.py +++ b/tests/pipelines/score_sde_ve/test_score_sde_ve.py @@ -74,7 +74,7 @@ class ScoreSdeVePipelineIntegrationTests(unittest.TestCase): model_id = "google/ncsnpp-church-256" model = UNet2DModel.from_pretrained(model_id) - scheduler = ScoreSdeVeScheduler.from_config(model_id) + scheduler = ScoreSdeVeScheduler.from_pretrained(model_id) sde_ve = ScoreSdeVePipeline(unet=model, scheduler=scheduler) sde_ve.to(torch_device) diff --git a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py index de918c7e5c..7a32b74096 100644 --- a/tests/pipelines/stable_diffusion/test_cycle_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_cycle_diffusion.py @@ -281,7 +281,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase): init_image = init_image.resize((512, 512)) model_id = "CompVis/stable-diffusion-v1-4" - scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler") + scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained( model_id, scheduler=scheduler, safety_checker=None, torch_dtype=torch.float16, revision="fp16" ) @@ -322,7 +322,7 @@ class CycleDiffusionPipelineIntegrationTests(unittest.TestCase): init_image = init_image.resize((512, 512)) model_id = "CompVis/stable-diffusion-v1-4" - scheduler = DDIMScheduler.from_config(model_id, subfolder="scheduler") + scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = CycleDiffusionPipeline.from_pretrained(model_id, scheduler=scheduler, safety_checker=None) pipe.to(torch_device) diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py index a1946e39f9..a2b48d27e6 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion.py @@ -75,7 +75,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_inference_ddim(self): - ddim_scheduler = DDIMScheduler.from_config( + ddim_scheduler = DDIMScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" ) sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( @@ -98,7 +98,7 @@ class OnnxStableDiffusionPipelineIntegrationTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 def test_inference_k_lms(self): - lms_scheduler = LMSDiscreteScheduler.from_config( + lms_scheduler = LMSDiscreteScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" ) sd_pipe = OnnxStableDiffusionPipeline.from_pretrained( diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py index 61831c64c0..91e4412425 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_img2img.py @@ -93,7 +93,7 @@ class OnnxStableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): "/img2img/sketch-mountains-input.jpg" ) init_image = init_image.resize((768, 512)) - lms_scheduler = LMSDiscreteScheduler.from_config( + lms_scheduler = LMSDiscreteScheduler.from_pretrained( "runwayml/stable-diffusion-v1-5", subfolder="scheduler", revision="onnx" ) pipe = OnnxStableDiffusionImg2ImgPipeline.from_pretrained( diff --git a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py index 4ba8e273b4..507375bddb 100644 --- a/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_onnx_stable_diffusion_inpaint.py @@ -97,7 +97,7 @@ class OnnxStableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main" "/in_paint/overture-creations-5sI6fQgYIuo_mask.png" ) - lms_scheduler = LMSDiscreteScheduler.from_config( + lms_scheduler = LMSDiscreteScheduler.from_pretrained( "runwayml/stable-diffusion-inpainting", subfolder="scheduler", revision="onnx" ) pipe = OnnxStableDiffusionInpaintPipeline.from_pretrained( diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index 87d238c869..17a293e605 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -703,7 +703,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 def test_stable_diffusion_fast_ddim(self): - scheduler = DDIMScheduler.from_config("CompVis/stable-diffusion-v1-1", subfolder="scheduler") + scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-1", subfolder="scheduler") sd_pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-1", scheduler=scheduler) sd_pipe = sd_pipe.to(torch_device) @@ -726,7 +726,7 @@ class StableDiffusionPipelineIntegrationTests(unittest.TestCase): model_id = "CompVis/stable-diffusion-v1-1" pipe = StableDiffusionPipeline.from_pretrained(model_id).to(torch_device) pipe.set_progress_bar_config(disable=None) - scheduler = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + scheduler = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe.scheduler = scheduler prompt = "a photograph of an astronaut riding a horse" diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 3c0fa8aa81..d86b259eae 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -520,7 +520,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ) model_id = "CompVis/stable-diffusion-v1-4" - lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, scheduler=lms, @@ -557,7 +557,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): ) model_id = "CompVis/stable-diffusion-v1-4" - ddim = DDIMScheduler.from_config(model_id, subfolder="scheduler") + ddim = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, scheduler=ddim, @@ -649,7 +649,7 @@ class StableDiffusionImg2ImgPipelineIntegrationTests(unittest.TestCase): init_image = init_image.resize((768, 512)) model_id = "CompVis/stable-diffusion-v1-4" - lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionImg2ImgPipeline.from_pretrained( model_id, scheduler=lms, safety_checker=None, device_map="auto", revision="fp16", torch_dtype=torch.float16 ) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 8d269c38f9..ce231a1a46 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -400,7 +400,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ) model_id = "runwayml/stable-diffusion-inpainting" - pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") + pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionInpaintPipeline.from_pretrained(model_id, safety_checker=None, scheduler=pndm) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) @@ -437,7 +437,7 @@ class StableDiffusionInpaintPipelineIntegrationTests(unittest.TestCase): ) model_id = "runwayml/stable-diffusion-inpainting" - pndm = PNDMScheduler.from_config(model_id, subfolder="scheduler") + pndm = PNDMScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionInpaintPipeline.from_pretrained( model_id, safety_checker=None, diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py index 4b535dc9df..94106b6ba8 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint_legacy.py @@ -401,7 +401,7 @@ class StableDiffusionInpaintLegacyPipelineIntegrationTests(unittest.TestCase): ) model_id = "CompVis/stable-diffusion-v1-4" - lms = LMSDiscreteScheduler.from_config(model_id, subfolder="scheduler") + lms = LMSDiscreteScheduler.from_pretrained(model_id, subfolder="scheduler") pipe = StableDiffusionInpaintPipeline.from_pretrained( model_id, scheduler=lms, diff --git a/tests/test_config.py b/tests/test_config.py index 8ae8e1d9e1..0875930e37 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -13,12 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json -import os import tempfile import unittest -import diffusers from diffusers import ( DDIMScheduler, DDPMScheduler, @@ -81,7 +78,7 @@ class SampleObject3(ConfigMixin): class ConfigTester(unittest.TestCase): def test_load_not_from_mixin(self): with self.assertRaises(ValueError): - ConfigMixin.from_config("dummy_path") + ConfigMixin.load_config("dummy_path") def test_register_to_config(self): obj = SampleObject() @@ -131,7 +128,7 @@ class ConfigTester(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: obj.save_config(tmpdirname) - new_obj = SampleObject.from_config(tmpdirname) + new_obj = SampleObject.from_config(SampleObject.load_config(tmpdirname)) new_config = new_obj.config # unfreeze configs @@ -142,117 +139,13 @@ class ConfigTester(unittest.TestCase): assert new_config.pop("c") == [2, 5] # saved & loaded as list because of json assert config == new_config - def test_save_load_from_different_config(self): - obj = SampleObject() - - # mock add obj class to `diffusers` - setattr(diffusers, "SampleObject", SampleObject) - logger = logging.get_logger("diffusers.configuration_utils") - - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - with CaptureLogger(logger) as cap_logger_1: - new_obj_1 = SampleObject2.from_config(tmpdirname) - - # now save a config parameter that is not expected - with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f: - data = json.load(f) - data["unexpected"] = True - - with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f: - json.dump(data, f) - - with CaptureLogger(logger) as cap_logger_2: - new_obj_2 = SampleObject.from_config(tmpdirname) - - with CaptureLogger(logger) as cap_logger_3: - new_obj_3 = SampleObject2.from_config(tmpdirname) - - assert new_obj_1.__class__ == SampleObject2 - assert new_obj_2.__class__ == SampleObject - assert new_obj_3.__class__ == SampleObject2 - - assert cap_logger_1.out == "" - assert ( - cap_logger_2.out - == "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will" - " be ignored. Please verify your config.json configuration file.\n" - ) - assert cap_logger_2.out.replace("SampleObject", "SampleObject2") == cap_logger_3.out - - def test_save_load_compatible_schedulers(self): - SampleObject2._compatible_classes = ["SampleObject"] - SampleObject._compatible_classes = ["SampleObject2"] - - obj = SampleObject() - - # mock add obj class to `diffusers` - setattr(diffusers, "SampleObject", SampleObject) - setattr(diffusers, "SampleObject2", SampleObject2) - logger = logging.get_logger("diffusers.configuration_utils") - - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - - # now save a config parameter that is expected by another class, but not origin class - with open(os.path.join(tmpdirname, SampleObject.config_name), "r") as f: - data = json.load(f) - data["f"] = [0, 0] - data["unexpected"] = True - - with open(os.path.join(tmpdirname, SampleObject.config_name), "w") as f: - json.dump(data, f) - - with CaptureLogger(logger) as cap_logger: - new_obj = SampleObject.from_config(tmpdirname) - - assert new_obj.__class__ == SampleObject - - assert ( - cap_logger.out - == "The config attributes {'unexpected': True} were passed to SampleObject, but are not expected and will" - " be ignored. Please verify your config.json configuration file.\n" - ) - - def test_save_load_from_different_config_comp_schedulers(self): - SampleObject3._compatible_classes = ["SampleObject", "SampleObject2"] - SampleObject2._compatible_classes = ["SampleObject", "SampleObject3"] - SampleObject._compatible_classes = ["SampleObject2", "SampleObject3"] - - obj = SampleObject() - - # mock add obj class to `diffusers` - setattr(diffusers, "SampleObject", SampleObject) - setattr(diffusers, "SampleObject2", SampleObject2) - setattr(diffusers, "SampleObject3", SampleObject3) - logger = logging.get_logger("diffusers.configuration_utils") - logger.setLevel(diffusers.logging.INFO) - - with tempfile.TemporaryDirectory() as tmpdirname: - obj.save_config(tmpdirname) - - with CaptureLogger(logger) as cap_logger_1: - new_obj_1 = SampleObject.from_config(tmpdirname) - - with CaptureLogger(logger) as cap_logger_2: - new_obj_2 = SampleObject2.from_config(tmpdirname) - - with CaptureLogger(logger) as cap_logger_3: - new_obj_3 = SampleObject3.from_config(tmpdirname) - - assert new_obj_1.__class__ == SampleObject - assert new_obj_2.__class__ == SampleObject2 - assert new_obj_3.__class__ == SampleObject3 - - assert cap_logger_1.out == "" - assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n" - assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n" - def test_load_ddim_from_pndm(self): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + ddim = DDIMScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) assert ddim.__class__ == DDIMScheduler # no warning should be thrown @@ -262,7 +155,7 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - euler = EulerDiscreteScheduler.from_config( + euler = EulerDiscreteScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) @@ -274,7 +167,7 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - euler = EulerAncestralDiscreteScheduler.from_config( + euler = EulerAncestralDiscreteScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) @@ -286,7 +179,9 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + pndm = PNDMScheduler.from_pretrained( + "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" + ) assert pndm.__class__ == PNDMScheduler # no warning should be thrown @@ -296,7 +191,7 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - ddpm = DDPMScheduler.from_config( + ddpm = DDPMScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler", predict_epsilon=False, @@ -304,7 +199,7 @@ class ConfigTester(unittest.TestCase): ) with CaptureLogger(logger) as cap_logger_2: - ddpm_2 = DDPMScheduler.from_config("google/ddpm-celebahq-256", beta_start=88) + ddpm_2 = DDPMScheduler.from_pretrained("google/ddpm-celebahq-256", beta_start=88) assert ddpm.__class__ == DDPMScheduler assert ddpm.config.predict_epsilon is False @@ -319,7 +214,7 @@ class ConfigTester(unittest.TestCase): logger = logging.get_logger("diffusers.configuration_utils") with CaptureLogger(logger) as cap_logger: - dpm = DPMSolverMultistepScheduler.from_config( + dpm = DPMSolverMultistepScheduler.from_pretrained( "hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler" ) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index eabe6ada9f..49bb4f6deb 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -130,7 +130,7 @@ class ModelTesterMixin: expected_arg_names = ["sample", "timestep"] self.assertListEqual(arg_names[:2], expected_arg_names) - def test_model_from_config(self): + def test_model_from_pretrained(self): init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() model = self.model_class(**init_dict) @@ -140,8 +140,8 @@ class ModelTesterMixin: # test if the model can be loaded from the config # and has all the expected shape with tempfile.TemporaryDirectory() as tmpdirname: - model.save_config(tmpdirname) - new_model = self.model_class.from_config(tmpdirname) + model.save_pretrained(tmpdirname) + new_model = self.model_class.from_pretrained(tmpdirname) new_model.to(torch_device) new_model.eval() diff --git a/tests/test_pipelines.py b/tests/test_pipelines.py index 4559d713ed..c77b000292 100644 --- a/tests/test_pipelines.py +++ b/tests/test_pipelines.py @@ -29,6 +29,10 @@ from diffusers import ( DDIMScheduler, DDPMPipeline, DDPMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, PNDMScheduler, StableDiffusionImg2ImgPipeline, StableDiffusionInpaintPipelineLegacy, @@ -398,6 +402,82 @@ class PipelineFastTests(unittest.TestCase): assert image_img2img.shape == (1, 32, 32, 3) assert image_text2img.shape == (1, 128, 128, 3) + def test_set_scheduler(self): + unet = self.dummy_cond_unet + scheduler = PNDMScheduler(skip_prk_steps=True) + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=scheduler, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, DDIMScheduler) + sd.scheduler = DDPMScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, DDPMScheduler) + sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, PNDMScheduler) + sd.scheduler = LMSDiscreteScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, LMSDiscreteScheduler) + sd.scheduler = EulerDiscreteScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, EulerDiscreteScheduler) + sd.scheduler = EulerAncestralDiscreteScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, EulerAncestralDiscreteScheduler) + sd.scheduler = DPMSolverMultistepScheduler.from_config(sd.scheduler.config) + assert isinstance(sd.scheduler, DPMSolverMultistepScheduler) + + def test_set_scheduler_consistency(self): + unet = self.dummy_cond_unet + pndm = PNDMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + ddim = DDIMScheduler.from_config("hf-internal-testing/tiny-stable-diffusion-torch", subfolder="scheduler") + vae = self.dummy_vae + bert = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=pndm, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + pndm_config = sd.scheduler.config + sd.scheduler = DDPMScheduler.from_config(pndm_config) + sd.scheduler = PNDMScheduler.from_config(sd.scheduler.config) + pndm_config_2 = sd.scheduler.config + pndm_config_2 = {k: v for k, v in pndm_config_2.items() if k in pndm_config} + + assert dict(pndm_config) == dict(pndm_config_2) + + sd = StableDiffusionPipeline( + unet=unet, + scheduler=ddim, + vae=vae, + text_encoder=bert, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=self.dummy_extractor, + ) + + ddim_config = sd.scheduler.config + sd.scheduler = LMSDiscreteScheduler.from_config(ddim_config) + sd.scheduler = DDIMScheduler.from_config(sd.scheduler.config) + ddim_config_2 = sd.scheduler.config + ddim_config_2 = {k: v for k, v in ddim_config_2.items() if k in ddim_config} + + assert dict(ddim_config) == dict(ddim_config_2) + @slow class PipelineSlowTests(unittest.TestCase): @@ -519,7 +599,7 @@ class PipelineSlowTests(unittest.TestCase): def test_output_format(self): model_path = "google/ddpm-cifar10-32" - scheduler = DDIMScheduler.from_config(model_path) + scheduler = DDIMScheduler.from_pretrained(model_path) pipe = DDIMPipeline.from_pretrained(model_path, scheduler=scheduler) pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index a9770f0a54..9c9abd0973 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -13,6 +13,8 @@ # See the License for the specific language governing permissions and # limitations under the License. import inspect +import json +import os import tempfile import unittest from typing import Dict, List, Tuple @@ -21,6 +23,7 @@ import numpy as np import torch import torch.nn.functional as F +import diffusers from diffusers import ( DDIMScheduler, DDPMScheduler, @@ -32,13 +35,180 @@ from diffusers import ( PNDMScheduler, ScoreSdeVeScheduler, VQDiffusionScheduler, + logging, ) +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin from diffusers.utils import deprecate, torch_device +from diffusers.utils.testing_utils import CaptureLogger torch.backends.cuda.matmul.allow_tf32 = False +class SchedulerObject(SchedulerMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + ): + pass + + +class SchedulerObject2(SchedulerMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + f=[1, 3], + ): + pass + + +class SchedulerObject3(SchedulerMixin, ConfigMixin): + config_name = "config.json" + + @register_to_config + def __init__( + self, + a=2, + b=5, + c=(2, 5), + d="for diffusion", + e=[1, 3], + f=[1, 3], + ): + pass + + +class SchedulerBaseTests(unittest.TestCase): + def test_save_load_from_different_config(self): + obj = SchedulerObject() + + # mock add obj class to `diffusers` + setattr(diffusers, "SchedulerObject", SchedulerObject) + logger = logging.get_logger("diffusers.configuration_utils") + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + with CaptureLogger(logger) as cap_logger_1: + config = SchedulerObject2.load_config(tmpdirname) + new_obj_1 = SchedulerObject2.from_config(config) + + # now save a config parameter that is not expected + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f: + data = json.load(f) + data["unexpected"] = True + + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f: + json.dump(data, f) + + with CaptureLogger(logger) as cap_logger_2: + config = SchedulerObject.load_config(tmpdirname) + new_obj_2 = SchedulerObject.from_config(config) + + with CaptureLogger(logger) as cap_logger_3: + config = SchedulerObject2.load_config(tmpdirname) + new_obj_3 = SchedulerObject2.from_config(config) + + assert new_obj_1.__class__ == SchedulerObject2 + assert new_obj_2.__class__ == SchedulerObject + assert new_obj_3.__class__ == SchedulerObject2 + + assert cap_logger_1.out == "" + assert ( + cap_logger_2.out + == "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and" + " will" + " be ignored. Please verify your config.json configuration file.\n" + ) + assert cap_logger_2.out.replace("SchedulerObject", "SchedulerObject2") == cap_logger_3.out + + def test_save_load_compatible_schedulers(self): + SchedulerObject2._compatibles = ["SchedulerObject"] + SchedulerObject._compatibles = ["SchedulerObject2"] + + obj = SchedulerObject() + + # mock add obj class to `diffusers` + setattr(diffusers, "SchedulerObject", SchedulerObject) + setattr(diffusers, "SchedulerObject2", SchedulerObject2) + logger = logging.get_logger("diffusers.configuration_utils") + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + + # now save a config parameter that is expected by another class, but not origin class + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "r") as f: + data = json.load(f) + data["f"] = [0, 0] + data["unexpected"] = True + + with open(os.path.join(tmpdirname, SchedulerObject.config_name), "w") as f: + json.dump(data, f) + + with CaptureLogger(logger) as cap_logger: + config = SchedulerObject.load_config(tmpdirname) + new_obj = SchedulerObject.from_config(config) + + assert new_obj.__class__ == SchedulerObject + + assert ( + cap_logger.out + == "The config attributes {'unexpected': True} were passed to SchedulerObject, but are not expected and" + " will" + " be ignored. Please verify your config.json configuration file.\n" + ) + + def test_save_load_from_different_config_comp_schedulers(self): + SchedulerObject3._compatibles = ["SchedulerObject", "SchedulerObject2"] + SchedulerObject2._compatibles = ["SchedulerObject", "SchedulerObject3"] + SchedulerObject._compatibles = ["SchedulerObject2", "SchedulerObject3"] + + obj = SchedulerObject() + + # mock add obj class to `diffusers` + setattr(diffusers, "SchedulerObject", SchedulerObject) + setattr(diffusers, "SchedulerObject2", SchedulerObject2) + setattr(diffusers, "SchedulerObject3", SchedulerObject3) + logger = logging.get_logger("diffusers.configuration_utils") + logger.setLevel(diffusers.logging.INFO) + + with tempfile.TemporaryDirectory() as tmpdirname: + obj.save_config(tmpdirname) + + with CaptureLogger(logger) as cap_logger_1: + config = SchedulerObject.load_config(tmpdirname) + new_obj_1 = SchedulerObject.from_config(config) + + with CaptureLogger(logger) as cap_logger_2: + config = SchedulerObject2.load_config(tmpdirname) + new_obj_2 = SchedulerObject2.from_config(config) + + with CaptureLogger(logger) as cap_logger_3: + config = SchedulerObject3.load_config(tmpdirname) + new_obj_3 = SchedulerObject3.from_config(config) + + assert new_obj_1.__class__ == SchedulerObject + assert new_obj_2.__class__ == SchedulerObject2 + assert new_obj_3.__class__ == SchedulerObject3 + + assert cap_logger_1.out == "" + assert cap_logger_2.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + assert cap_logger_3.out == "{'f'} was not found in config. Values will be initialized to default values.\n" + + class SchedulerCommonTest(unittest.TestCase): scheduler_classes = () forward_default_kwargs = () @@ -102,7 +272,7 @@ class SchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -145,7 +315,7 @@ class SchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -187,7 +357,7 @@ class SchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): scheduler.set_timesteps(num_inference_steps) @@ -205,6 +375,42 @@ class SchedulerCommonTest(unittest.TestCase): assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical" + def test_compatibles(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + + scheduler = scheduler_class(**scheduler_config) + + assert all(c is not None for c in scheduler.compatibles) + + for comp_scheduler_cls in scheduler.compatibles: + comp_scheduler = comp_scheduler_cls.from_config(scheduler.config) + assert comp_scheduler is not None + + new_scheduler = scheduler_class.from_config(comp_scheduler.config) + + new_scheduler_config = {k: v for k, v in new_scheduler.config.items() if k in scheduler.config} + scheduler_diff = {k: v for k, v in new_scheduler.config.items() if k not in scheduler.config} + + # make sure that configs are essentially identical + assert new_scheduler_config == dict(scheduler.config) + + # make sure that only differences are for configs that are not in init + init_keys = inspect.signature(scheduler_class.__init__).parameters.keys() + assert set(scheduler_diff.keys()).intersection(set(init_keys)) == set() + + def test_from_pretrained(self): + for scheduler_class in self.scheduler_classes: + scheduler_config = self.get_scheduler_config() + + scheduler = scheduler_class(**scheduler_config) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_pretrained(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + + assert scheduler.config == new_scheduler.config + def test_step_shape(self): kwargs = dict(self.forward_default_kwargs) @@ -616,7 +822,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals new_scheduler.model_outputs = dummy_past_residuals[: new_scheduler.config.solver_order] @@ -648,7 +854,7 @@ class DPMSolverMultistepSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) # copy over dummy past residuals new_scheduler.set_timesteps(num_inference_steps) @@ -790,7 +996,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] @@ -825,7 +1031,7 @@ class PNDMSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) # copy over dummy past residuals new_scheduler.set_timesteps(num_inference_steps) @@ -1043,7 +1249,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) output = scheduler.step_pred( residual, time_step, sample, generator=torch.manual_seed(0), **kwargs @@ -1074,7 +1280,7 @@ class ScoreSdeVeSchedulerTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) output = scheduler.step_pred( residual, time_step, sample, generator=torch.manual_seed(0), **kwargs @@ -1470,7 +1676,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) new_scheduler.set_timesteps(num_inference_steps) # copy over dummy past residuals new_scheduler.ets = dummy_past_residuals[:] @@ -1508,7 +1714,7 @@ class IPNDMSchedulerTest(SchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler = scheduler_class.from_config(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) # copy over dummy past residuals new_scheduler.set_timesteps(num_inference_steps) diff --git a/tests/test_scheduler_flax.py b/tests/test_scheduler_flax.py index 7928939f2d..0fa0e1b495 100644 --- a/tests/test_scheduler_flax.py +++ b/tests/test_scheduler_flax.py @@ -83,7 +83,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -112,7 +112,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -140,7 +140,7 @@ class FlaxSchedulerCommonTest(unittest.TestCase): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -373,7 +373,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -401,7 +401,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -430,7 +430,7 @@ class FlaxDDIMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"): state = scheduler.set_timesteps(state, num_inference_steps) @@ -633,7 +633,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape) # copy over dummy past residuals new_state = new_state.replace(ets=dummy_past_residuals[:]) @@ -720,7 +720,7 @@ class FlaxPNDMSchedulerTest(FlaxSchedulerCommonTest): with tempfile.TemporaryDirectory() as tmpdirname: scheduler.save_config(tmpdirname) - new_scheduler, new_state = scheduler_class.from_config(tmpdirname) + new_scheduler, new_state = scheduler_class.from_pretrained(tmpdirname) # copy over dummy past residuals new_state = new_scheduler.set_timesteps(new_state, num_inference_steps, shape=sample.shape)