1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Merge branch 'main' into hidream-torch-compile

This commit is contained in:
Sayak Paul
2025-05-14 20:26:44 +05:30
committed by GitHub
3 changed files with 42 additions and 4 deletions

View File

@@ -607,6 +607,39 @@ class FluxImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFile
return latents
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_slicing
def enable_vae_slicing(self):
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.vae.enable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_slicing
def disable_vae_slicing(self):
r"""
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_slicing()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.enable_vae_tiling
def enable_vae_tiling(self):
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
"""
self.vae.enable_tiling()
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.disable_vae_tiling
def disable_vae_tiling(self):
r"""
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
computing decoding in one step.
"""
self.vae.disable_tiling()
def prepare_latents(
self,
image,

View File

@@ -53,7 +53,12 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, UNetTesterMixin
from ..test_modeling_common import (
LoraHotSwappingForModelTesterMixin,
ModelTesterMixin,
TorchCompileTesterMixin,
UNetTesterMixin,
)
if is_peft_available():
@@ -351,7 +356,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True):
class UNet2DConditionModelTests(
ModelTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
ModelTesterMixin, TorchCompileTesterMixin, LoraHotSwappingForModelTesterMixin, UNetTesterMixin, unittest.TestCase
):
model_class = UNet2DConditionModel
main_input_name = "sample"

View File

@@ -109,7 +109,7 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
else:
generator = torch.Generator(device=device).manual_seed(seed)
image = torch.randn((1, 3, 32, 32), generator=generator, device=device)
image = torch.rand((1, 3, 32, 32), generator=generator, device=device)
inputs = {
"image": image,
@@ -142,7 +142,7 @@ class LTXImageToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
expected_video = torch.randn(9, 3, 32, 32)
max_diff = np.abs(generated_video - expected_video).max()
max_diff = torch.amax(torch.abs(generated_video - expected_video))
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):