1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

remove torch_dtype from usage scripts.

This commit is contained in:
sayakpaul
2024-02-07 12:50:26 +05:30
parent b209a10a3c
commit 9a41549417
12 changed files with 14 additions and 14 deletions

View File

@@ -576,6 +576,6 @@ if __name__ == "__main__":
)
if args.half:
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
pipe.save_pretrained(args.dump_path)

View File

@@ -179,7 +179,7 @@ if __name__ == "__main__":
)
if args.half:
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
if args.controlnet:
# only save the controlnet model

View File

@@ -801,6 +801,6 @@ if __name__ == "__main__":
)
if args.half:
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

View File

@@ -776,7 +776,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
[`DiffusionPipeline`]: The pipeline converted to specified `dtype` and/or `dtype`.
"""
dtype = kwargs.pop("dtype", None)
device= kwargs.pop("device", None)
device = kwargs.pop("device", None)
silence_dtype_warnings = kwargs.pop("silence_dtype_warnings", False)
dtype_arg = None
@@ -851,12 +851,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
if is_loaded_in_8bit and dtype is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {torch_dtype} is not yet supported. Module is still in 8bit precision."
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not yet supported. Module is still in 8bit precision."
)
if is_loaded_in_8bit and device is not None:
logger.warning(
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {torch_dtype} via `.to()` is not yet supported. Module is still on {module.device}."
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {dtype} via `.to()` is not yet supported. Module is still on {module.device}."
)
else:
module.to(device, dtype)

View File

@@ -218,7 +218,7 @@ class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))

View File

@@ -224,7 +224,7 @@ class AnimateDiffVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.Tes
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))

View File

@@ -483,7 +483,7 @@ class AudioLDM2PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
# Once we send to fp16, all params are in half-precision, including the logit scale
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

View File

@@ -400,7 +400,7 @@ class MusicLDMPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes.values()))
# Once we send to fp16, all params are in half-precision, including the logit scale
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = {key: component.dtype for key, component in components.items() if hasattr(component, "dtype")}
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes.values()))

View File

@@ -231,7 +231,7 @@ class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))

View File

@@ -396,7 +396,7 @@ class StableVideoDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCa
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in pipe.components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))

View File

@@ -1623,7 +1623,7 @@ class PipelineFastTests(unittest.TestCase):
sd1 = sd.to(torch.float16)
sd2 = sd.to(None, torch.float16)
sd3 = sd.to(dtype=torch.float16)
sd4 = sd.to(torch_dtype=torch.float16)
sd4 = sd.to(dtype=torch.float16)
sd5 = sd.to(None, dtype=torch.float16)
sd6 = sd.to(None, torch_dtype=torch.float16)

View File

@@ -716,7 +716,7 @@ class PipelineTesterMixin:
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
pipe.to(dtype=torch.float16)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))