diff --git a/scripts/convert_gligen_to_diffusers.py b/scripts/convert_gligen_to_diffusers.py index 30d789b606..83c1f928e4 100644 --- a/scripts/convert_gligen_to_diffusers.py +++ b/scripts/convert_gligen_to_diffusers.py @@ -576,6 +576,6 @@ if __name__ == "__main__": ) if args.half: - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) pipe.save_pretrained(args.dump_path) diff --git a/scripts/convert_original_stable_diffusion_to_diffusers.py b/scripts/convert_original_stable_diffusion_to_diffusers.py index 2ca70963d1..980446179c 100644 --- a/scripts/convert_original_stable_diffusion_to_diffusers.py +++ b/scripts/convert_original_stable_diffusion_to_diffusers.py @@ -179,7 +179,7 @@ if __name__ == "__main__": ) if args.half: - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) if args.controlnet: # only save the controlnet model diff --git a/scripts/convert_zero123_to_diffusers.py b/scripts/convert_zero123_to_diffusers.py index f016312b8b..3bb6f6c041 100644 --- a/scripts/convert_zero123_to_diffusers.py +++ b/scripts/convert_zero123_to_diffusers.py @@ -801,6 +801,6 @@ if __name__ == "__main__": ) if args.half: - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index f06633a690..769fcd2e83 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -776,7 +776,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): [`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`. """ dtype = kwargs.pop("dtype", None) - device= kwargs.pop("device", None) + device = kwargs.pop("device", None) silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False) dtype_arg = None @@ -851,12 +851,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): if is_loaded_in_8bit and dtype is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision." + f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision." ) if is_loaded_in_8bit and device is not None: logger.warning( - f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}." + f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}." ) else: module.to(device, dtype) diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 80a8fd19f5..525ca24bbd 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -218,7 +218,7 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index 3226bdb3ca..767fc30b4e 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -224,7 +224,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.Tes model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/audioldm2/test_audioldm2.py b/tests/pipelines/audioldm2/test_audioldm2.py index 60ef86518e..e2655515bc 100644 --- a/tests/pipelines/audioldm2/test_audioldm2.py +++ b/tests/pipelines/audioldm2/test_audioldm2.py @@ -483,7 +483,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase): self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values())) # Once we send to fp16, all params are in half-precision, including the logit scale - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) diff --git a/tests/pipelines/musicldm/test_musicldm.py b/tests/pipelines/musicldm/test_musicldm.py index 4bf03569bb..fe78ab6acb 100644 --- a/tests/pipelines/musicldm/test_musicldm.py +++ b/tests/pipelines/musicldm/test_musicldm.py @@ -400,7 +400,7 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase): self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values())) # Once we send to fp16, all params are in half-precision, including the logit scale - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")} self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values())) diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index eb76457abc..edd129560c 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -231,7 +231,7 @@ class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase): model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py index 871266fb9c..60c4112838 100644 --- a/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py +++ b/tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py @@ -396,7 +396,7 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) diff --git a/tests/pipelines/test_pipelines.py b/tests/pipelines/test_pipelines.py index 32ae81ddc2..bd9f42f185 100644 --- a/tests/pipelines/test_pipelines.py +++ b/tests/pipelines/test_pipelines.py @@ -1623,7 +1623,7 @@ class PipelineFastTests(unittest.TestCase): sd1 = sd.to(torch.float16) sd2 = sd.to(None, torch.float16) sd3 = sd.to(dtype=torch.float16) - sd4 = sd.to(torch_dtype=torch.float16) + sd4 = sd.to(dtype=torch.float16) sd5 = sd.to(None, dtype=torch.float16) sd6 = sd.to(None, torch_dtype=torch.float16) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index e3c8a4ef50..7f51847caf 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -716,7 +716,7 @@ class PipelineTesterMixin: model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) - pipe.to(torch_dtype=torch.float16) + pipe.to(dtype=torch.float16) model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))