mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add option to set dtype in pipeline.to() method (#2317)
add test_to_dtype to check pipe.to(fp16)
This commit is contained in:
@@ -512,8 +512,13 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
||||
|
||||
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
|
||||
if torch_device is None:
|
||||
def to(
|
||||
self,
|
||||
torch_device: Optional[Union[str, torch.device]] = None,
|
||||
torch_dtype: Optional[torch.dtype] = None,
|
||||
silence_dtype_warnings: bool = False,
|
||||
):
|
||||
if torch_device is None and torch_dtype is None:
|
||||
return self
|
||||
|
||||
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
|
||||
@@ -550,6 +555,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
for name in module_names.keys():
|
||||
module = getattr(self, name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
module.to(torch_device, torch_dtype)
|
||||
if (
|
||||
module.dtype == torch.float16
|
||||
and str(torch_device) in ["cpu"]
|
||||
@@ -563,7 +569,6 @@ class DiffusionPipeline(ConfigMixin):
|
||||
" support for`float16` operations on this device in PyTorch. Please, remove the"
|
||||
" `torch_dtype=torch.float16` argument, or use another device for inference."
|
||||
)
|
||||
module.to(torch_device)
|
||||
return self
|
||||
|
||||
@property
|
||||
|
||||
@@ -344,11 +344,8 @@ class PipelineTesterMixin:
|
||||
pipe.to(torch_device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
for name, module in components.items():
|
||||
if hasattr(module, "half"):
|
||||
components[name] = module.half()
|
||||
pipe_fp16 = self.pipeline_class(**components)
|
||||
pipe_fp16.to(torch_device)
|
||||
pipe_fp16.to(torch_device, torch.float16)
|
||||
pipe_fp16.set_progress_bar_config(disable=None)
|
||||
|
||||
output = pipe(**self.get_dummy_inputs(torch_device))[0]
|
||||
@@ -447,6 +444,18 @@ class PipelineTesterMixin:
|
||||
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
|
||||
self.assertTrue(np.isnan(output_cuda).sum() == 0)
|
||||
|
||||
def test_to_dtype(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
|
||||
|
||||
pipe.to(torch_dtype=torch.float16)
|
||||
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
|
||||
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
|
||||
|
||||
def test_attention_slicing_forward_pass(self):
|
||||
self._test_attention_slicing_forward_pass()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user