From 6f74ef550d04248b3ff3cbcbb5f5a2add6c56aa0 Mon Sep 17 00:00:00 2001 From: hlky Date: Mon, 24 Feb 2025 08:07:54 +0000 Subject: [PATCH] Fix `torch_dtype` in Kolors text encoder with `transformers` v4.49 (#10816) * Fix `torch_dtype` in Kolors text encoder with `transformers` v4.49 * Default torch_dtype and warning --- examples/community/checkpoint_merger.py | 6 +++++- src/diffusers/loaders/single_file.py | 8 +++++++- src/diffusers/loaders/single_file_model.py | 8 +++++++- src/diffusers/models/modeling_utils.py | 8 +++++++- src/diffusers/pipelines/pipeline_utils.py | 10 ++++++++-- tests/pipelines/kolors/test_kolors.py | 4 +++- tests/pipelines/kolors/test_kolors_img2img.py | 4 +++- tests/pipelines/pag/test_pag_kolors.py | 4 +++- 8 files changed, 43 insertions(+), 9 deletions(-) diff --git a/examples/community/checkpoint_merger.py b/examples/community/checkpoint_merger.py index 6ba4b8c6e8..f23e8a207e 100644 --- a/examples/community/checkpoint_merger.py +++ b/examples/community/checkpoint_merger.py @@ -92,9 +92,13 @@ class CheckpointMergerPipeline(DiffusionPipeline): token = kwargs.pop("token", None) variant = kwargs.pop("variant", None) revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) device_map = kwargs.pop("device_map", None) + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.") + alpha = kwargs.pop("alpha", 0.5) interp = kwargs.pop("interp", None) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index c87d2a7cf8..fdfbb923ba 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -360,11 +360,17 @@ class FromSingleFileMixin: cache_dir = kwargs.pop("cache_dir", None) local_files_only = kwargs.pop("local_files_only", False) revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) disable_mmap = kwargs.pop("disable_mmap", False) is_legacy_loading = False + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + # We shouldn't allow configuring individual models components through a Pipeline creation method # These model kwargs should be deprecated scaling_factor = kwargs.get("scaling_factor", None) diff --git a/src/diffusers/loaders/single_file_model.py b/src/diffusers/loaders/single_file_model.py index b6eaffbc8c..e6b0508334 100644 --- a/src/diffusers/loaders/single_file_model.py +++ b/src/diffusers/loaders/single_file_model.py @@ -240,11 +240,17 @@ class FromOriginalModelMixin: subfolder = kwargs.pop("subfolder", None) revision = kwargs.pop("revision", None) config_revision = kwargs.pop("config_revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) quantization_config = kwargs.pop("quantization_config", None) device = kwargs.pop("device", None) disable_mmap = kwargs.pop("disable_mmap", False) + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + if isinstance(pretrained_model_link_or_path_or_dict, dict): checkpoint = pretrained_model_link_or_path_or_dict else: diff --git a/src/diffusers/models/modeling_utils.py b/src/diffusers/models/modeling_utils.py index e7f306da6b..4fbbd78667 100644 --- a/src/diffusers/models/modeling_utils.py +++ b/src/diffusers/models/modeling_utils.py @@ -866,7 +866,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): local_files_only = kwargs.pop("local_files_only", None) token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) subfolder = kwargs.pop("subfolder", None) device_map = kwargs.pop("device_map", None) max_memory = kwargs.pop("max_memory", None) @@ -879,6 +879,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin): dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None) disable_mmap = kwargs.pop("disable_mmap", False) + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + allow_pickle = False if use_safetensors is None: use_safetensors = True diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 90a05e97f6..e112947c8d 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -685,7 +685,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): token = kwargs.pop("token", None) revision = kwargs.pop("revision", None) from_flax = kwargs.pop("from_flax", False) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) custom_pipeline = kwargs.pop("custom_pipeline", None) custom_revision = kwargs.pop("custom_revision", None) provider = kwargs.pop("provider", None) @@ -702,6 +702,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): use_onnx = kwargs.pop("use_onnx", None) load_connected_pipeline = kwargs.pop("load_connected_pipeline", False) + if not isinstance(torch_dtype, torch.dtype): + torch_dtype = torch.float32 + logger.warning( + f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`." + ) + if low_cpu_mem_usage and not is_accelerate_available(): low_cpu_mem_usage = False logger.warning( @@ -1826,7 +1832,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin): """ original_config = dict(pipeline.config) - torch_dtype = kwargs.pop("torch_dtype", None) + torch_dtype = kwargs.pop("torch_dtype", torch.float32) # derive the pipeline class to instantiate custom_pipeline = kwargs.pop("custom_pipeline", None) diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py index cf0b392ddc..edeb588414 100644 --- a/tests/pipelines/kolors/test_kolors.py +++ b/tests/pipelines/kolors/test_kolors.py @@ -89,7 +89,9 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase): sample_size=128, ) torch.manual_seed(0) - text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") + text_encoder = ChatGLMModel.from_pretrained( + "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16 + ) tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") components = { diff --git a/tests/pipelines/kolors/test_kolors_img2img.py b/tests/pipelines/kolors/test_kolors_img2img.py index 025bcf2fac..9c43e0920e 100644 --- a/tests/pipelines/kolors/test_kolors_img2img.py +++ b/tests/pipelines/kolors/test_kolors_img2img.py @@ -93,7 +93,9 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase): sample_size=128, ) torch.manual_seed(0) - text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") + text_encoder = ChatGLMModel.from_pretrained( + "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16 + ) tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") components = { diff --git a/tests/pipelines/pag/test_pag_kolors.py b/tests/pipelines/pag/test_pag_kolors.py index 9a5764e24f..f6d7331b1a 100644 --- a/tests/pipelines/pag/test_pag_kolors.py +++ b/tests/pipelines/pag/test_pag_kolors.py @@ -98,7 +98,9 @@ class KolorsPAGPipelineFastTests( sample_size=128, ) torch.manual_seed(0) - text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") + text_encoder = ChatGLMModel.from_pretrained( + "hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16 + ) tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b") components = {