mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix a bug in from_pretrained when load optional components (#4745)
* fix --------- Co-authored-by: yiyixuxu <yixu310@gmail,com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -1012,8 +1012,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
|
||||
# define init kwargs
|
||||
init_kwargs = {k: init_dict.pop(k) for k in optional_kwargs if k in init_dict}
|
||||
# define init kwargs and make sure that optional component modules are filtered out
|
||||
init_kwargs = {
|
||||
k: init_dict.pop(k)
|
||||
for k in optional_kwargs
|
||||
if k in init_dict and k not in pipeline_class._optional_components
|
||||
}
|
||||
init_kwargs = {**init_kwargs, **passed_pipe_kwargs}
|
||||
|
||||
# remove `null` components
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import gc
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -315,6 +316,52 @@ class StableDiffusionUpscalePipelineFastTests(unittest.TestCase):
|
||||
expected_height_width = low_res_image.size[0] * 4
|
||||
assert image.shape == (1, expected_height_width, expected_height_width, 3)
|
||||
|
||||
def test_stable_diffusion_upscale_from_save_pretrained(self):
|
||||
pipes = []
|
||||
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
low_res_scheduler = DDPMScheduler()
|
||||
scheduler = DDIMScheduler(prediction_type="v_prediction")
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
|
||||
# make sure here that pndm scheduler skips prk
|
||||
sd_pipe = StableDiffusionUpscalePipeline(
|
||||
unet=self.dummy_cond_unet_upscale,
|
||||
low_res_scheduler=low_res_scheduler,
|
||||
scheduler=scheduler,
|
||||
vae=self.dummy_vae,
|
||||
text_encoder=self.dummy_text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
max_noise_level=350,
|
||||
)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd_pipe.save_pretrained(tmpdirname)
|
||||
sd_pipe = StableDiffusionUpscalePipeline.from_pretrained(tmpdirname).to(device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
prompt = "A painting of a squirrel eating a burger"
|
||||
image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0]
|
||||
low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64))
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image = pipe(
|
||||
[prompt],
|
||||
image=low_res_image,
|
||||
generator=generator,
|
||||
guidance_scale=6.0,
|
||||
noise_level=20,
|
||||
num_inference_steps=2,
|
||||
output_type="np",
|
||||
).images
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
@@ -689,3 +690,25 @@ class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTest
|
||||
|
||||
# ensure the results are not equal
|
||||
assert np.abs(image_slice_1.flatten() - image_slice_3.flatten()).max() > 1e-4
|
||||
|
||||
def test_stable_diffusion_xl_save_from_pretrained(self):
|
||||
pipes = []
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**components).to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sd_pipe.save_pretrained(tmpdirname)
|
||||
sd_pipe = StableDiffusionXLPipeline.from_pretrained(tmpdirname).to(torch_device)
|
||||
pipes.append(sd_pipe)
|
||||
|
||||
image_slices = []
|
||||
for pipe in pipes:
|
||||
pipe.unet.set_default_attn_processor()
|
||||
|
||||
inputs = self.get_dummy_inputs(torch_device)
|
||||
image = pipe(**inputs).images
|
||||
|
||||
image_slices.append(image[0, -3:, -3:, -1].flatten())
|
||||
|
||||
assert np.abs(image_slices[0] - image_slices[1]).max() < 1e-3
|
||||
|
||||
Reference in New Issue
Block a user