import gc import tempfile from io import BytesIO import requests import torch from huggingface_hub import hf_hub_download, snapshot_download from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name from diffusers.models.attention_processor import AttnProcessor from ..testing_utils import ( backend_empty_cache, nightly, numpy_cosine_similarity_distance, require_torch_accelerator, torch_device, ) def download_single_file_checkpoint(repo_id, filename, tmpdir): path = hf_hub_download(repo_id, filename=filename, local_dir=tmpdir) return path def download_original_config(config_url, tmpdir): original_config_file = BytesIO(requests.get(config_url).content) path = f"{tmpdir}/config.yaml" with open(path, "wb") as f: f.write(original_config_file.read()) return path def download_diffusers_config(repo_id, tmpdir): path = snapshot_download( repo_id, ignore_patterns=[ "**/*.ckpt", "*.ckpt", "**/*.bin", "*.bin", "**/*.pt", "*.pt", "**/*.safetensors", "*.safetensors", ], allow_patterns=["**/*.json", "*.json", "*.txt", "**/*.txt"], local_dir=tmpdir, ) return path @nightly @require_torch_accelerator class SingleFileModelTesterMixin: def setup_method(self): gc.collect() backend_empty_cache(torch_device) def teardown_method(self): gc.collect() backend_empty_cache(torch_device) def test_single_file_model_config(self): pretrained_kwargs = {} single_file_kwargs = {} if hasattr(self, "subfolder") and self.subfolder: pretrained_kwargs["subfolder"] = self.subfolder if hasattr(self, "torch_dtype") and self.torch_dtype: pretrained_kwargs["torch_dtype"] = self.torch_dtype single_file_kwargs["torch_dtype"] = self.torch_dtype model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs) model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"] for param_name, param_value in model_single_file.config.items(): if param_name in PARAMS_TO_IGNORE: continue assert model.config[param_name] == param_value, ( f"{param_name} differs between pretrained loading and single file loading" ) def test_single_file_model_parameters(self): pretrained_kwargs = {} single_file_kwargs = {} if hasattr(self, "subfolder") and self.subfolder: pretrained_kwargs["subfolder"] = self.subfolder if hasattr(self, "torch_dtype") and self.torch_dtype: pretrained_kwargs["torch_dtype"] = self.torch_dtype single_file_kwargs["torch_dtype"] = self.torch_dtype model = self.model_class.from_pretrained(self.repo_id, **pretrained_kwargs) model_single_file = self.model_class.from_single_file(self.ckpt_path, **single_file_kwargs) state_dict = model.state_dict() state_dict_single_file = model_single_file.state_dict() assert set(state_dict.keys()) == set(state_dict_single_file.keys()), ( "Model parameters keys differ between pretrained and single file loading" ) for key in state_dict.keys(): param = state_dict[key] param_single_file = state_dict_single_file[key] assert param.shape == param_single_file.shape, ( f"Parameter shape mismatch for {key}: " f"pretrained {param.shape} vs single file {param_single_file.shape}" ) assert torch.allclose(param, param_single_file, rtol=1e-5, atol=1e-5), ( f"Parameter values differ for {key}: " f"max difference {torch.max(torch.abs(param - param_single_file)).item()}" ) def test_checkpoint_altered_keys_loading(self): # Test loading with checkpoints that have altered keys if not hasattr(self, "alternate_keys_ckpt_paths") or not self.alternate_keys_ckpt_paths: return for ckpt_path in self.alternate_keys_ckpt_paths: backend_empty_cache(torch_device) single_file_kwargs = {} if hasattr(self, "torch_dtype") and self.torch_dtype: single_file_kwargs["torch_dtype"] = self.torch_dtype model = self.model_class.from_single_file(ckpt_path, **single_file_kwargs) del model gc.collect() backend_empty_cache(torch_device) class SDSingleFileTesterMixin: single_file_kwargs = {} def _compare_component_configs(self, pipe, single_file_pipe): for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items(): if param_name in ["torch_dtype", "architectures", "_name_or_path"]: continue assert pipe.text_encoder.config.to_dict()[param_name] == param_value PARAMS_TO_IGNORE = [ "torch_dtype", "_name_or_path", "architectures", "_use_default_values", "_diffusers_version", ] for component_name, component in single_file_pipe.components.items(): if component_name in single_file_pipe._optional_components: continue # skip testing transformer based components here # skip text encoders / safety checkers since they have already been tested if component_name in ["text_encoder", "tokenizer", "safety_checker", "feature_extractor"]: continue assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline" assert isinstance(component, pipe.components[component_name].__class__), ( f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" ) for param_name, param_value in component.config.items(): if param_name in PARAMS_TO_IGNORE: continue # Some pretrained configs will set upcast attention to None # In single file loading it defaults to the value in the class __init__ which is False if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None: pipe.components[component_name].config[param_name] = param_value assert pipe.components[component_name].config[param_name] == param_value, ( f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" ) def test_single_file_components(self, pipe=None, single_file_pipe=None): single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( self.ckpt_path, safety_checker=None ) pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) self._compare_component_configs(pipe, single_file_pipe) def test_single_file_components_local_files_only(self, pipe=None, single_file_pipe=None): pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, safety_checker=None, local_files_only=True ) self._compare_component_configs(pipe, single_file_pipe) def test_single_file_components_with_original_config( self, pipe=None, single_file_pipe=None, ): pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) # Not possible to infer this value when original config is provided # we just pass it in here otherwise this test will fail upcast_attention = pipe.unet.config.upcast_attention single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( self.ckpt_path, original_config=self.original_config, safety_checker=None, upcast_attention=upcast_attention, ) self._compare_component_configs(pipe, single_file_pipe) def test_single_file_components_with_original_config_local_files_only( self, pipe=None, single_file_pipe=None, ): pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) # Not possible to infer this value when original config is provided # we just pass it in here otherwise this test will fail upcast_attention = pipe.unet.config.upcast_attention with tempfile.TemporaryDirectory() as tmpdir: repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, original_config=local_original_config, safety_checker=None, upcast_attention=upcast_attention, local_files_only=True, ) self._compare_component_configs(pipe, single_file_pipe) def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4): sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, safety_checker=None, **self.single_file_kwargs) sf_pipe.unet.set_attn_processor(AttnProcessor()) sf_pipe.enable_model_cpu_offload(device=torch_device) inputs = self.get_inputs(torch_device) image_single_file = sf_pipe(**inputs).images[0] pipe = self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) pipe.unet.set_attn_processor(AttnProcessor()) pipe.enable_model_cpu_offload(device=torch_device) inputs = self.get_inputs(torch_device) image = pipe(**inputs).images[0] max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) assert max_diff < expected_max_diff, f"{image.flatten()} != {image_single_file.flatten()}" def test_single_file_components_with_diffusers_config( self, pipe=None, single_file_pipe=None, ): single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( self.ckpt_path, config=self.repo_id, safety_checker=None ) pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) self._compare_component_configs(pipe, single_file_pipe) def test_single_file_components_with_diffusers_config_local_files_only( self, pipe=None, single_file_pipe=None, ): pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True ) self._compare_component_configs(pipe, single_file_pipe) def test_single_file_setting_pipeline_dtype_to_fp16( self, single_file_pipe=None, ): single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( self.ckpt_path, torch_dtype=torch.float16 ) for component_name, component in single_file_pipe.components.items(): if not isinstance(component, torch.nn.Module): continue assert component.dtype == torch.float16 class SDXLSingleFileTesterMixin: def _compare_component_configs(self, pipe, single_file_pipe): # Skip testing the text_encoder for Refiner Pipelines if pipe.text_encoder: for param_name, param_value in single_file_pipe.text_encoder.config.to_dict().items(): if param_name in ["torch_dtype", "architectures", "_name_or_path"]: continue assert pipe.text_encoder.config.to_dict()[param_name] == param_value for param_name, param_value in single_file_pipe.text_encoder_2.config.to_dict().items(): if param_name in ["torch_dtype", "architectures", "_name_or_path"]: continue assert pipe.text_encoder_2.config.to_dict()[param_name] == param_value PARAMS_TO_IGNORE = [ "torch_dtype", "_name_or_path", "architectures", "_use_default_values", "_diffusers_version", ] for component_name, component in single_file_pipe.components.items(): if component_name in single_file_pipe._optional_components: continue # skip text encoders since they have already been tested if component_name in ["text_encoder", "text_encoder_2", "tokenizer", "tokenizer_2"]: continue # skip safety checker if it is not present in the pipeline if component_name in ["safety_checker", "feature_extractor"]: continue assert component_name in pipe.components, f"single file {component_name} not found in pretrained pipeline" assert isinstance(component, pipe.components[component_name].__class__), ( f"single file {component.__class__.__name__} and pretrained {pipe.components[component_name].__class__.__name__} are not the same" ) for param_name, param_value in component.config.items(): if param_name in PARAMS_TO_IGNORE: continue # Some pretrained configs will set upcast attention to None # In single file loading it defaults to the value in the class __init__ which is False if param_name == "upcast_attention" and pipe.components[component_name].config[param_name] is None: pipe.components[component_name].config[param_name] = param_value assert pipe.components[component_name].config[param_name] == param_value, ( f"single file {param_name}: {param_value} differs from pretrained {pipe.components[component_name].config[param_name]}" ) def test_single_file_components(self, pipe=None, single_file_pipe=None): single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( self.ckpt_path, safety_checker=None ) pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) self._compare_component_configs( pipe, single_file_pipe, ) def test_single_file_components_local_files_only( self, pipe=None, single_file_pipe=None, ): pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, safety_checker=None, local_files_only=True ) self._compare_component_configs(pipe, single_file_pipe) def test_single_file_components_with_original_config( self, pipe=None, single_file_pipe=None, ): pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) # Not possible to infer this value when original config is provided # we just pass it in here otherwise this test will fail upcast_attention = pipe.unet.config.upcast_attention single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( self.ckpt_path, original_config=self.original_config, safety_checker=None, upcast_attention=upcast_attention, ) self._compare_component_configs( pipe, single_file_pipe, ) def test_single_file_components_with_original_config_local_files_only( self, pipe=None, single_file_pipe=None, ): pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) # Not possible to infer this value when original config is provided # we just pass it in here otherwise this test will fail upcast_attention = pipe.unet.config.upcast_attention with tempfile.TemporaryDirectory() as tmpdir: repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_original_config = download_original_config(self.original_config, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, original_config=local_original_config, upcast_attention=upcast_attention, safety_checker=None, local_files_only=True, ) self._compare_component_configs( pipe, single_file_pipe, ) def test_single_file_components_with_diffusers_config( self, pipe=None, single_file_pipe=None, ): single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( self.ckpt_path, config=self.repo_id, safety_checker=None ) pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) self._compare_component_configs(pipe, single_file_pipe) def test_single_file_components_with_diffusers_config_local_files_only( self, pipe=None, single_file_pipe=None, ): pipe = pipe or self.pipeline_class.from_pretrained(self.repo_id, safety_checker=None) with tempfile.TemporaryDirectory() as tmpdir: repo_id, weight_name = _extract_repo_id_and_weights_name(self.ckpt_path) local_ckpt_path = download_single_file_checkpoint(repo_id, weight_name, tmpdir) local_diffusers_config = download_diffusers_config(self.repo_id, tmpdir) single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( local_ckpt_path, config=local_diffusers_config, safety_checker=None, local_files_only=True ) self._compare_component_configs(pipe, single_file_pipe) def test_single_file_format_inference_is_same_as_pretrained(self, expected_max_diff=1e-4): sf_pipe = self.pipeline_class.from_single_file(self.ckpt_path, torch_dtype=torch.float16, safety_checker=None) sf_pipe.unet.set_default_attn_processor() sf_pipe.enable_model_cpu_offload(device=torch_device) inputs = self.get_inputs(torch_device) image_single_file = sf_pipe(**inputs).images[0] pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16, safety_checker=None) pipe.unet.set_default_attn_processor() pipe.enable_model_cpu_offload(device=torch_device) inputs = self.get_inputs(torch_device) image = pipe(**inputs).images[0] max_diff = numpy_cosine_similarity_distance(image.flatten(), image_single_file.flatten()) assert max_diff < expected_max_diff def test_single_file_setting_pipeline_dtype_to_fp16( self, single_file_pipe=None, ): single_file_pipe = single_file_pipe or self.pipeline_class.from_single_file( self.ckpt_path, torch_dtype=torch.float16 ) for component_name, component in single_file_pipe.components.items(): if not isinstance(component, torch.nn.Module): continue assert component.dtype == torch.float16