diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 69cce72468..0d384b1647 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -56,6 +56,8 @@ def build_sub_model_components( if component_name == "unet": num_in_channels = kwargs.pop("num_in_channels", None) + upcast_attention = kwargs.pop("upcast_attention", None) + unet_components = create_diffusers_unet_model_from_ldm( pipeline_class_name, original_config, @@ -64,6 +66,7 @@ def build_sub_model_components( image_size=image_size, torch_dtype=torch_dtype, model_type=model_type, + upcast_attention=upcast_attention, ) return unet_components @@ -300,7 +303,9 @@ class FromSingleFileMixin: continue init_kwargs.update(components) - additional_components = set_additional_components(class_name, original_config, model_type=model_type) + additional_components = set_additional_components( + class_name, original_config, checkpoint=checkpoint, model_type=model_type + ) if additional_components: init_kwargs.update(additional_components) diff --git a/src/diffusers/loaders/single_file_utils.py b/src/diffusers/loaders/single_file_utils.py index ff8c6b64cb..cdaa0802a2 100644 --- a/src/diffusers/loaders/single_file_utils.py +++ b/src/diffusers/loaders/single_file_utils.py @@ -410,7 +410,7 @@ def fetch_original_config(pipeline_class_name, checkpoint, original_config_file= return original_config -def infer_model_type(original_config, checkpoint=None, model_type=None): +def infer_model_type(original_config, checkpoint, model_type=None): if model_type is not None: return model_type @@ -1279,7 +1279,7 @@ def create_diffusers_unet_model_from_ldm( original_config, checkpoint, num_in_channels=None, - upcast_attention=False, + upcast_attention=None, extract_ema=False, image_size=None, torch_dtype=None, @@ -1307,7 +1307,8 @@ def create_diffusers_unet_model_from_ldm( ) unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config["in_channels"] = num_in_channels - unet_config["upcast_attention"] = upcast_attention + if upcast_attention is not None: + unet_config["upcast_attention"] = upcast_attention diffusers_format_unet_checkpoint = convert_ldm_unet_checkpoint(checkpoint, unet_config, extract_ema=extract_ema) ctx = init_empty_weights if is_accelerate_available() else nullcontext diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 9718aede35..2c4c01c746 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -838,9 +838,11 @@ class StableDiffusionXLImg2ImgIntegrationTests(unittest.TestCase): for param_name, param_value in single_file_pipe.unet.config.items(): if param_name in PARAMS_TO_IGNORE: continue + if param_name == "upcast_attention" and pipe.unet.config[param_name] is None: + pipe.unet.config[param_name] = False assert ( pipe.unet.config[param_name] == param_value - ), f"{param_name} differs between single file loading and pretrained loading" + ), f"{param_name} is differs between single file loading and pretrained loading" for param_name, param_value in single_file_pipe.vae.config.items(): if param_name in PARAMS_TO_IGNORE: