mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix torch_dtype in Kolors text encoder with transformers v4.49 (#10816)
* Fix `torch_dtype` in Kolors text encoder with `transformers` v4.49 * Default torch_dtype and warning
This commit is contained in:
@@ -92,9 +92,13 @@ class CheckpointMergerPipeline(DiffusionPipeline):
|
||||
token = kwargs.pop("token", None)
|
||||
variant = kwargs.pop("variant", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
print(f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`.")
|
||||
|
||||
alpha = kwargs.pop("alpha", 0.5)
|
||||
interp = kwargs.pop("interp", None)
|
||||
|
||||
|
||||
@@ -360,11 +360,17 @@ class FromSingleFileMixin:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
is_legacy_loading = False
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
# We shouldn't allow configuring individual models components through a Pipeline creation method
|
||||
# These model kwargs should be deprecated
|
||||
scaling_factor = kwargs.get("scaling_factor", None)
|
||||
|
||||
@@ -240,11 +240,17 @@ class FromOriginalModelMixin:
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
config_revision = kwargs.pop("config_revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
quantization_config = kwargs.pop("quantization_config", None)
|
||||
device = kwargs.pop("device", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
if isinstance(pretrained_model_link_or_path_or_dict, dict):
|
||||
checkpoint = pretrained_model_link_or_path_or_dict
|
||||
else:
|
||||
|
||||
@@ -866,7 +866,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
local_files_only = kwargs.pop("local_files_only", None)
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
device_map = kwargs.pop("device_map", None)
|
||||
max_memory = kwargs.pop("max_memory", None)
|
||||
@@ -879,6 +879,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
||||
disable_mmap = kwargs.pop("disable_mmap", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
allow_pickle = False
|
||||
if use_safetensors is None:
|
||||
use_safetensors = True
|
||||
|
||||
@@ -685,7 +685,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
token = kwargs.pop("token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
from_flax = kwargs.pop("from_flax", False)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
custom_revision = kwargs.pop("custom_revision", None)
|
||||
provider = kwargs.pop("provider", None)
|
||||
@@ -702,6 +702,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
use_onnx = kwargs.pop("use_onnx", None)
|
||||
load_connected_pipeline = kwargs.pop("load_connected_pipeline", False)
|
||||
|
||||
if not isinstance(torch_dtype, torch.dtype):
|
||||
torch_dtype = torch.float32
|
||||
logger.warning(
|
||||
f"Passed `torch_dtype` {torch_dtype} is not a `torch.dtype`. Defaulting to `torch.float32`."
|
||||
)
|
||||
|
||||
if low_cpu_mem_usage and not is_accelerate_available():
|
||||
low_cpu_mem_usage = False
|
||||
logger.warning(
|
||||
@@ -1826,7 +1832,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
"""
|
||||
|
||||
original_config = dict(pipeline.config)
|
||||
torch_dtype = kwargs.pop("torch_dtype", None)
|
||||
torch_dtype = kwargs.pop("torch_dtype", torch.float32)
|
||||
|
||||
# derive the pipeline class to instantiate
|
||||
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
||||
|
||||
@@ -89,7 +89,9 @@ class KolorsPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
text_encoder = ChatGLMModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -93,7 +93,9 @@ class KolorsPipelineImg2ImgFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
text_encoder = ChatGLMModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
|
||||
components = {
|
||||
|
||||
@@ -98,7 +98,9 @@ class KolorsPAGPipelineFastTests(
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder = ChatGLMModel.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
text_encoder = ChatGLMModel.from_pretrained(
|
||||
"hf-internal-testing/tiny-random-chatglm3-6b", torch_dtype=torch.bfloat16
|
||||
)
|
||||
tokenizer = ChatGLMTokenizer.from_pretrained("hf-internal-testing/tiny-random-chatglm3-6b")
|
||||
|
||||
components = {
|
||||
|
||||
Reference in New Issue
Block a user