diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 142aac94b9..4151e070c5 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -183,7 +183,7 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin): self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index f699e23310..b2fa059205 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -191,7 +191,7 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin): self.to("cpu", silence_dtype_warnings=True) torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + for cpu_offloaded_model in [self.unet, self.text_encoder, self.text_encoder_2, self.vae]: cpu_offload(cpu_offloaded_model, device) def enable_model_cpu_offload(self, gpu_id=0): diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index e97bdb352b..a9b8003efc 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -699,12 +699,16 @@ class PipelineTesterMixin: inputs = self.get_dummy_inputs(torch_device) output_without_offload = pipe(**inputs)[0] - output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload + output_without_offload = ( + output_without_offload.cpu() if torch.is_tensor(output_without_offload) else output_without_offload + ) pipe.enable_xformers_memory_efficient_attention() inputs = self.get_dummy_inputs(torch_device) output_with_offload = pipe(**inputs)[0] - output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload + output_with_offload = ( + output_with_offload.cpu() if torch.is_tensor(output_with_offload) else output_without_offload + ) if test_max_difference: max_diff = np.abs(output_with_offload - output_without_offload).max() diff --git a/tests/pipelines/text_to_video/test_text_to_video.py b/tests/pipelines/text_to_video/test_text_to_video.py index 8b4bae2275..f391568d10 100644 --- a/tests/pipelines/text_to_video/test_text_to_video.py +++ b/tests/pipelines/text_to_video/test_text_to_video.py @@ -26,7 +26,7 @@ from diffusers import ( TextToVideoSDPipeline, UNet3DConditionModel, ) -from diffusers.utils import load_numpy, skip_mps, slow +from diffusers.utils import is_xformers_available, load_numpy, skip_mps, slow, torch_device from diffusers.utils.testing_utils import enable_full_determinism from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS @@ -143,6 +143,13 @@ class TextToVideoSDPipelineFastTests(PipelineTesterMixin, unittest.TestCase): def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False, expected_max_diff=3e-3) + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False, expected_max_diff=1e-2) + # (todo): sayakpaul @unittest.skip(reason="Batching needs to be properly figured out first for this pipeline.") def test_inference_batch_consistent(self):