1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix torch_dtype in Kolors text encoder with transformers v4.49 (#10816)

* Fix `torch_dtype` in Kolors text encoder with `transformers` v4.49

* Default torch_dtype and warning
This commit is contained in:
hlky
2025-02-24 08:07:54 +00:00
committed by GitHub
parent 9c7e205176
commit 6f74ef550d
8 changed files with 43 additions and 9 deletions

View File

@@ -92,9 +92,13 @@ class CheckpointMergerPipeline(DiffusionPipeline):
token = kwargs.pop("token", None)
variant = kwargs.pop("variant", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
device_map = kwargs.pop("device_map", None)
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.")
alpha = kwargs.pop("alpha", 0.5)
interp = kwargs.pop("interp", None)

View File

@@ -360,11 +360,17 @@ class FromSingleFileMixin:
cache_dir = kwargs.pop("cache_dir", None)
local_files_only = kwargs.pop("local_files_only", False)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
disable_mmap = kwargs.pop("disable_mmap", False)
is_legacy_loading = False
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
# We shouldn't allow configuring individual models components through a Pipeline creation method
# These model kwargs should be deprecated
scaling_factor = kwargs.get("scaling_factor", None)

View File

@@ -240,11 +240,17 @@ class FromOriginalModelMixin:
subfolder = kwargs.pop("subfolder", None)
revision = kwargs.pop("revision", None)
config_revision = kwargs.pop("config_revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
quantization_config = kwargs.pop("quantization_config", None)
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
if isinstance(pretrained_model_link_or_path_or_dict, dict):
checkpoint = pretrained_model_link_or_path_or_dict
else:

View File

@@ -866,7 +866,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
local_files_only = kwargs.pop("local_files_only", None)
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
subfolder = kwargs.pop("subfolder", None)
device_map = kwargs.pop("device_map", None)
max_memory = kwargs.pop("max_memory", None)
@@ -879,6 +879,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
disable_mmap = kwargs.pop("disable_mmap", False)
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
allow_pickle = False
if use_safetensors is None:
use_safetensors = True

View File

@@ -685,7 +685,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
token = kwargs.pop("token", None)
revision = kwargs.pop("revision", None)
from_flax = kwargs.pop("from_flax", False)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
custom_pipeline = kwargs.pop("custom_pipeline", None)
custom_revision = kwargs.pop("custom_revision", None)
provider = kwargs.pop("provider", None)
@@ -702,6 +702,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
use_onnx = kwargs.pop("use_onnx", None)
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
if not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
)
if low_cpu_mem_usage and not is_accelerate_available():
low_cpu_mem_usage = False
logger.warning(
@@ -1826,7 +1832,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
"""
original_config = dict(pipeline.config)
torch_dtype = kwargs.pop("torch_dtype", None)
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
# derive the pipeline class to instantiate
custom_pipeline = kwargs.pop("custom_pipeline", None)

View File

@@ -89,7 +89,9 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
sample_size=128,
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
components = {

View File

@@ -93,7 +93,9 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
sample_size=128,
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
components = {

View File

@@ -98,7 +98,9 @@ class KolorsPAGPipelineFastTests(
sample_size=128,
)
torch.manual_seed(0)
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
text_encoder = ChatGLMModel.from_pretrained(
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
)
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
components = {