mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix schedule_shifted_power usage in 🪆Matryoshka Diffusion Models (#9723)
* [matryoshka.py] Add schedule_shifted_power attribute and update get_schedule_shifted method
This commit is contained in:
@@ -4336,19 +4336,19 @@ The Abstract of the paper:
|
||||
|
||||
**64x64**
|
||||
:-------------------------:
|
||||
| <img src="https://github.com/user-attachments/assets/9e7bb2cd-45a0-4bd1-adb8-23e283baed39" width="222" height="222" alt="bird_64"> |
|
||||
| <img src="https://github.com/user-attachments/assets/032738eb-c6cd-4fd9-b4d7-a7317b4b6528" width="222" height="222" alt="bird_64_64"> |
|
||||
|
||||
- `256×256, nesting_level=1`: 1.776 GiB. With `150` DDIM inference steps:
|
||||
|
||||
**64x64** | **256x256**
|
||||
:-------------------------:|:-------------------------:
|
||||
| <img src="https://github.com/user-attachments/assets/6b724c2e-5e6a-4b63-9b65-c1182cbb67e0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/7dbab2ad-bf40-4a73-ab04-f178347cb7d5" width="222" height="222" alt="256x256"> |
|
||||
| <img src="https://github.com/user-attachments/assets/21b9ad8b-eea6-4603-80a2-31180f391589" width="222" height="222" alt="bird_256_64"> | <img src="https://github.com/user-attachments/assets/fc411682-8a36-422c-9488-395b77d4406e" width="222" height="222" alt="bird_256_256"> |
|
||||
|
||||
- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible. With `250` DDIM inference steps:
|
||||
- `1024×1024, nesting_level=2`: 1.792 GiB. As one can realize the cost of adding another layer is really negligible in this context! With `250` DDIM inference steps:
|
||||
|
||||
**64x64** | **256x256** | **1024x1024**
|
||||
:-------------------------:|:-------------------------:|:-------------------------:
|
||||
| <img src="https://github.com/user-attachments/assets/4a9454e4-e20a-4736-a196-270e2ae796c0" width="222" height="222" alt="64x64"> | <img src="https://github.com/user-attachments/assets/4a96555d-0fda-4303-82b1-a4d886f770b9" width="222" height="222" alt="256x256"> | <img src="https://github.com/user-attachments/assets/e0239b7a-ab73-4d45-8f3e-b4e6b4b50abe" width="222" height="222" alt="1024x1024"> |
|
||||
| <img src="https://github.com/user-attachments/assets/febf4b98-3dee-4a8e-9946-fd42e1f232e6" width="222" height="222" alt="bird_1024_64"> | <img src="https://github.com/user-attachments/assets/c5f85b40-5d6d-4267-a92a-c89dff015b9b" width="222" height="222" alt="bird_1024_256"> | <img src="https://github.com/user-attachments/assets/ad66b913-4367-4cb9-889e-bc06f4d96148" width="222" height="222" alt="bird_1024_1024"> |
|
||||
|
||||
```py
|
||||
from diffusers import DiffusionPipeline
|
||||
@@ -4362,8 +4362,7 @@ pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-model
|
||||
|
||||
prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
|
||||
prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
|
||||
negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
|
||||
image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
|
||||
image = pipe(prompt, num_inference_steps=50).images
|
||||
make_image_grid(image, rows=1, cols=len(image))
|
||||
|
||||
# pipe.change_nesting_level(<int>) # 0, 1, or 2
|
||||
|
||||
@@ -107,15 +107,16 @@ EXAMPLE_DOC_STRING = """
|
||||
|
||||
>>> # nesting_level=0 -> 64x64; nesting_level=1 -> 256x256 - 64x64; nesting_level=2 -> 1024x1024 - 256x256 - 64x64
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("tolgacangoz/matryoshka-diffusion-models",
|
||||
>>> custom_pipeline="matryoshka").to("cuda")
|
||||
... nesting_level=0,
|
||||
... trust_remote_code=False, # One needs to give permission for this code to run
|
||||
... ).to("cuda")
|
||||
|
||||
>>> prompt0 = "a blue jay stops on the top of a helmet of Japanese samurai, background with sakura tree"
|
||||
>>> prompt = f"breathtaking {prompt0}. award-winning, professional, highly detailed"
|
||||
>>> negative_prompt = "deformed, mutated, ugly, disfigured, blur, blurry, noise, noisy"
|
||||
>>> image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=50).images
|
||||
>>> image = pipe(prompt, num_inference_steps=50).images
|
||||
>>> make_image_grid(image, rows=1, cols=len(image))
|
||||
|
||||
>>> pipe.change_nesting_level(<int>) # 0, 1, or 2
|
||||
>>> # pipe.change_nesting_level(<int>) # 0, 1, or 2
|
||||
>>> # 50+, 100+, and 250+ num_inference_steps are recommended for nesting levels 0, 1, and 2 respectively.
|
||||
```
|
||||
"""
|
||||
@@ -420,6 +421,7 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
||||
|
||||
self.scales = None
|
||||
self.schedule_shifted_power = 1.0
|
||||
|
||||
def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
|
||||
"""
|
||||
@@ -532,6 +534,7 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
def get_schedule_shifted(self, alpha_prod, scale_factor=None):
|
||||
if (scale_factor is not None) and (scale_factor > 1): # rescale noise schedule
|
||||
scale_factor = scale_factor**self.schedule_shifted_power
|
||||
snr = alpha_prod / (1 - alpha_prod)
|
||||
scaled_snr = snr / scale_factor
|
||||
alpha_prod = 1 / (1 + 1 / scaled_snr)
|
||||
@@ -639,17 +642,14 @@ class MatryoshkaDDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# 4. Clip or threshold "predicted x_0"
|
||||
if self.config.thresholding:
|
||||
if len(model_output) > 1:
|
||||
pred_original_sample = [
|
||||
self._threshold_sample(p_o_s * scale) / scale
|
||||
for p_o_s, scale in zip(pred_original_sample, self.scales)
|
||||
]
|
||||
pred_original_sample = [self._threshold_sample(p_o_s) for p_o_s in pred_original_sample]
|
||||
else:
|
||||
pred_original_sample = self._threshold_sample(pred_original_sample)
|
||||
elif self.config.clip_sample:
|
||||
if len(model_output) > 1:
|
||||
pred_original_sample = [
|
||||
(p_o_s * scale).clamp(-self.config.clip_sample_range, self.config.clip_sample_range) / scale
|
||||
for p_o_s, scale in zip(pred_original_sample, self.scales)
|
||||
p_o_s.clamp(-self.config.clip_sample_range, self.config.clip_sample_range)
|
||||
for p_o_s in pred_original_sample
|
||||
]
|
||||
else:
|
||||
pred_original_sample = pred_original_sample.clamp(
|
||||
@@ -3816,6 +3816,8 @@ class MatryoshkaPipeline(
|
||||
|
||||
if hasattr(unet, "nest_ratio"):
|
||||
scheduler.scales = unet.nest_ratio + [1]
|
||||
if nesting_level == 2:
|
||||
scheduler.schedule_shifted_power = 2.0
|
||||
|
||||
self.register_modules(
|
||||
text_encoder=text_encoder,
|
||||
@@ -3842,12 +3844,14 @@ class MatryoshkaPipeline(
|
||||
).to(self.device)
|
||||
self.config.nesting_level = 1
|
||||
self.scheduler.scales = self.unet.nest_ratio + [1]
|
||||
self.scheduler.schedule_shifted_power = 1.0
|
||||
elif nesting_level == 2:
|
||||
self.unet = NestedUNet2DConditionModel.from_pretrained(
|
||||
"tolgacangoz/matryoshka-diffusion-models", subfolder="unet/nesting_level_2"
|
||||
).to(self.device)
|
||||
self.config.nesting_level = 2
|
||||
self.scheduler.scales = self.unet.nest_ratio + [1]
|
||||
self.scheduler.schedule_shifted_power = 2.0
|
||||
else:
|
||||
raise ValueError("Currently, nesting levels 0, 1, and 2 are supported.")
|
||||
|
||||
@@ -4627,8 +4631,8 @@ class MatryoshkaPipeline(
|
||||
image = latents
|
||||
|
||||
if self.scheduler.scales is not None:
|
||||
for i, (img, scale) in enumerate(zip(image, self.scheduler.scales)):
|
||||
image[i] = self.image_processor.postprocess(img * scale, output_type=output_type)[0]
|
||||
for i, img in enumerate(image):
|
||||
image[i] = self.image_processor.postprocess(img, output_type=output_type)[0]
|
||||
else:
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user