mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix Tiling in ConsistencyDecoderVAE (#7290)
* Fix typos
* Add docstring to `decode` method in `ConsistencyDecoderVAE`
* Fix tiling
* Enable tiled VAE decoding with customizable tile sample size and overlap factor
* Revert "Enable tiled VAE decoding with customizable tile sample size and overlap factor"
This reverts commit 181049675e.
* Add VAE tiling test for `ConsistencyDecoderVAE`
---------
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -63,7 +63,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
|
||||
... ).to("cuda")
|
||||
|
||||
>>> pipe("horse", generator=torch.manual_seed(0)).images
|
||||
>>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
|
||||
>>> image
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -72,6 +73,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
self,
|
||||
scaling_factor: float = 0.18215,
|
||||
latent_channels: int = 4,
|
||||
sample_size: int = 32,
|
||||
encoder_act_fn: str = "silu",
|
||||
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
|
||||
encoder_double_z: bool = True,
|
||||
@@ -153,6 +155,16 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
self.use_slicing = False
|
||||
self.use_tiling = False
|
||||
|
||||
# only relevant if vae tiling is enabled
|
||||
self.tile_sample_min_size = self.config.sample_size
|
||||
sample_size = (
|
||||
self.config.sample_size[0]
|
||||
if isinstance(self.config.sample_size, (list, tuple))
|
||||
else self.config.sample_size
|
||||
)
|
||||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
@@ -272,7 +284,7 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
Args:
|
||||
x (`torch.FloatTensor`): Input batch of images.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
|
||||
Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
@@ -305,6 +317,19 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
return_dict: bool = True,
|
||||
num_inference_steps: int = 2,
|
||||
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
|
||||
"""
|
||||
Decodes the input latent vector `z` using the consistency decoder VAE model.
|
||||
|
||||
Args:
|
||||
z (torch.FloatTensor): The input latent vector.
|
||||
generator (Optional[torch.Generator]): The random number generator. Default is None.
|
||||
return_dict (bool): Whether to return the output as a dictionary. Default is True.
|
||||
num_inference_steps (int): The number of inference steps. Default is 2.
|
||||
|
||||
Returns:
|
||||
Union[DecoderOutput, Tuple[torch.FloatTensor]]: The decoded output.
|
||||
|
||||
"""
|
||||
z = (z * self.config.scaling_factor - self.means) / self.stds
|
||||
|
||||
scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
|
||||
@@ -345,7 +370,9 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
|
||||
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
||||
return b
|
||||
|
||||
def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
|
||||
def tiled_encode(
|
||||
self, x: torch.FloatTensor, return_dict: bool = True
|
||||
) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
|
||||
r"""Encode a batch of images using a tiled encoder.
|
||||
|
||||
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
|
||||
|
||||
@@ -1116,3 +1116,33 @@ class ConsistencyDecoderVAEIntegrationTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
assert torch_all_close(actual_output, expected_output, atol=5e-3)
|
||||
|
||||
def test_vae_tiling(self):
|
||||
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")
|
||||
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
out_1 = pipe(
|
||||
"horse",
|
||||
num_inference_steps=2,
|
||||
output_type="pt",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
# make sure tiled vae decode yields the same result
|
||||
pipe.enable_vae_tiling()
|
||||
out_2 = pipe(
|
||||
"horse",
|
||||
num_inference_steps=2,
|
||||
output_type="pt",
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
).images[0]
|
||||
|
||||
assert torch_all_close(out_1, out_2, atol=5e-3)
|
||||
|
||||
# test that tiled decode works with various shapes
|
||||
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
|
||||
for shape in shapes:
|
||||
image = torch.zeros(shape, device=torch_device)
|
||||
pipe.vae.decode(image)
|
||||
|
||||
Reference in New Issue
Block a user