diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 6ebed5e1df..8f33b50682 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -512,8 +512,13 @@ class DiffusionPipeline(ConfigMixin): save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs) - def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False): - if torch_device is None: + def to( + self, + torch_device: Optional[Union[str, torch.device]] = None, + torch_dtype: Optional[torch.dtype] = None, + silence_dtype_warnings: bool = False, + ): + if torch_device is None and torch_dtype is None: return self # throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU. @@ -550,6 +555,7 @@ class DiffusionPipeline(ConfigMixin): for name in module_names.keys(): module = getattr(self, name) if isinstance(module, torch.nn.Module): + module.to(torch_device, torch_dtype) if ( module.dtype == torch.float16 and str(torch_device) in ["cpu"] @@ -563,7 +569,6 @@ class DiffusionPipeline(ConfigMixin): " support for`float16` operations on this device in PyTorch. Please, remove the" " `torch_dtype=torch.float16` argument, or use another device for inference." ) - module.to(torch_device) return self @property diff --git a/tests/test_pipelines_common.py b/tests/test_pipelines_common.py index 504acf1b97..1ab6baeb81 100644 --- a/tests/test_pipelines_common.py +++ b/tests/test_pipelines_common.py @@ -344,11 +344,8 @@ class PipelineTesterMixin: pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - for name, module in components.items(): - if hasattr(module, "half"): - components[name] = module.half() pipe_fp16 = self.pipeline_class(**components) - pipe_fp16.to(torch_device) + pipe_fp16.to(torch_device, torch.float16) pipe_fp16.set_progress_bar_config(disable=None) output = pipe(**self.get_dummy_inputs(torch_device))[0] @@ -447,6 +444,18 @@ class PipelineTesterMixin: output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0] self.assertTrue(np.isnan(output_cuda).sum() == 0) + def test_to_dtype(self): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.set_progress_bar_config(disable=None) + + model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes)) + + pipe.to(torch_dtype=torch.float16) + model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")] + self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes)) + def test_attention_slicing_forward_pass(self): self._test_attention_slicing_forward_pass()