1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow saving None pipeline components (#1118)

* Allow saving `None` pipeline components

* support flax as well

* style
This commit is contained in:
Anton Lozhkov
2022-11-03 15:41:33 +01:00
committed by GitHub
parent 0edf9ca082
commit 4a38166afe
3 changed files with 29 additions and 0 deletions

View File

@@ -161,6 +161,10 @@ class FlaxDiffusionPipeline(ConfigMixin):
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__
save_method_name = None
@@ -367,6 +371,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
is_pipeline_module = hasattr(pipelines, library_name)
loaded_sub_model = None
sub_model_should_be_defined = True

View File

@@ -176,6 +176,10 @@ class DiffusionPipeline(ConfigMixin):
for pipeline_component_name in model_index_dict.keys():
sub_model = getattr(self, pipeline_component_name)
if sub_model is None:
# edge case for saving a pipeline with safety_checker=None
continue
model_cls = sub_model.__class__
save_method_name = None
@@ -477,6 +481,11 @@ class DiffusionPipeline(ConfigMixin):
# 3. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
if class_name is None:
# edge case for when the pipeline was saved with safety_checker=None
init_kwargs[name] = None
continue
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
if class_name.startswith("Flax"):
class_name = class_name[4:]

View File

@@ -15,6 +15,7 @@
import gc
import random
import tempfile
import time
import unittest
@@ -318,6 +319,16 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
# check that there's no error when saving a pipeline with one of the models being None
with tempfile.TemporaryDirectory() as tmpdirname:
pipe.save_pretrained(tmpdirname)
pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
# sanity check that the pipeline still works
assert pipe.safety_checker is None
image = pipe("example prompt", num_inference_steps=2).images[0]
assert image is not None
def test_stable_diffusion_k_lms(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
unet = self.dummy_cond_unet