1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add option to set dtype in pipeline.to() method (#2317)

add test_to_dtype to check pipe.to(fp16)
This commit is contained in:
1lint
2023-03-21 09:21:23 -05:00
committed by GitHub
parent 1fcf279d74
commit b33bd91fae
2 changed files with 21 additions and 7 deletions

View File

@@ -512,8 +512,13 @@ class DiffusionPipeline(ConfigMixin):
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
def to(self, torch_device: Optional[Union[str, torch.device]] = None, silence_dtype_warnings: bool = False):
if torch_device is None:
def to(
self,
torch_device: Optional[Union[str, torch.device]] = None,
torch_dtype: Optional[torch.dtype] = None,
silence_dtype_warnings: bool = False,
):
if torch_device is None and torch_dtype is None:
return self
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
@@ -550,6 +555,7 @@ class DiffusionPipeline(ConfigMixin):
for name in module_names.keys():
module = getattr(self, name)
if isinstance(module, torch.nn.Module):
module.to(torch_device, torch_dtype)
if (
module.dtype == torch.float16
and str(torch_device) in ["cpu"]
@@ -563,7 +569,6 @@ class DiffusionPipeline(ConfigMixin):
" support for`float16` operations on this device in PyTorch. Please, remove the"
" `torch_dtype=torch.float16` argument, or use another device for inference."
)
module.to(torch_device)
return self
@property

View File

@@ -344,11 +344,8 @@ class PipelineTesterMixin:
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
for name, module in components.items():
if hasattr(module, "half"):
components[name] = module.half()
pipe_fp16 = self.pipeline_class(**components)
pipe_fp16.to(torch_device)
pipe_fp16.to(torch_device, torch.float16)
pipe_fp16.set_progress_bar_config(disable=None)
output = pipe(**self.get_dummy_inputs(torch_device))[0]
@@ -447,6 +444,18 @@ class PipelineTesterMixin:
output_cuda = pipe(**self.get_dummy_inputs("cuda"))[0]
self.assertTrue(np.isnan(output_cuda).sum() == 0)
def test_to_dtype(self):
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.set_progress_bar_config(disable=None)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float32 for dtype in model_dtypes))
pipe.to(torch_dtype=torch.float16)
model_dtypes = [component.dtype for component in components.values() if hasattr(component, "dtype")]
self.assertTrue(all(dtype == torch.float16 for dtype in model_dtypes))
def test_attention_slicing_forward_pass(self):
self._test_attention_slicing_forward_pass()