mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow saving None pipeline components (#1118)
* Allow saving `None` pipeline components * support flax as well * style
This commit is contained in:
@@ -161,6 +161,10 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
if sub_model is None:
|
||||
# edge case for saving a pipeline with safety_checker=None
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
@@ -367,6 +371,11 @@ class FlaxDiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name is None:
|
||||
# edge case for when the pipeline was saved with safety_checker=None
|
||||
init_kwargs[name] = None
|
||||
continue
|
||||
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
loaded_sub_model = None
|
||||
sub_model_should_be_defined = True
|
||||
|
||||
@@ -176,6 +176,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
for pipeline_component_name in model_index_dict.keys():
|
||||
sub_model = getattr(self, pipeline_component_name)
|
||||
if sub_model is None:
|
||||
# edge case for saving a pipeline with safety_checker=None
|
||||
continue
|
||||
|
||||
model_cls = sub_model.__class__
|
||||
|
||||
save_method_name = None
|
||||
@@ -477,6 +481,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 3. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
if class_name is None:
|
||||
# edge case for when the pipeline was saved with safety_checker=None
|
||||
init_kwargs[name] = None
|
||||
continue
|
||||
|
||||
# 3.1 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
||||
if class_name.startswith("Flax"):
|
||||
class_name = class_name[4:]
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import gc
|
||||
import random
|
||||
import tempfile
|
||||
import time
|
||||
import unittest
|
||||
|
||||
@@ -318,6 +319,16 @@ class StableDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
image = pipe("example prompt", num_inference_steps=2).images[0]
|
||||
assert image is not None
|
||||
|
||||
# check that there's no error when saving a pipeline with one of the models being None
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
pipe.save_pretrained(tmpdirname)
|
||||
pipe = StableDiffusionPipeline.from_pretrained(tmpdirname)
|
||||
|
||||
# sanity check that the pipeline still works
|
||||
assert pipe.safety_checker is None
|
||||
image = pipe("example prompt", num_inference_steps=2).images[0]
|
||||
assert image is not None
|
||||
|
||||
def test_stable_diffusion_k_lms(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
unet = self.dummy_cond_unet
|
||||
|
||||
Reference in New Issue
Block a user