mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
@@ -141,6 +141,22 @@ def get_tests_dir(append_path=None):
|
||||
return tests_dir
|
||||
|
||||
|
||||
# Taken from the following PR:
|
||||
# https://github.com/huggingface/accelerate/pull/1964
|
||||
def str_to_bool(value) -> int:
|
||||
"""
|
||||
Converts a string representation of truth to `True` (1) or `False` (0).
|
||||
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
|
||||
"""
|
||||
value = value.lower()
|
||||
if value in ("y", "yes", "t", "true", "on", "1"):
|
||||
return 1
|
||||
elif value in ("n", "no", "f", "false", "off", "0"):
|
||||
return 0
|
||||
else:
|
||||
raise ValueError(f"invalid truth value {value}")
|
||||
|
||||
|
||||
def parse_flag_from_env(key, default=False):
|
||||
try:
|
||||
value = os.environ[key]
|
||||
@@ -920,22 +936,6 @@ def backend_supports_training(device: str):
|
||||
return BACKEND_SUPPORTS_TRAINING[device]
|
||||
|
||||
|
||||
# Taken from the following PR:
|
||||
# https://github.com/huggingface/accelerate/pull/1964
|
||||
def str_to_bool(value) -> int:
|
||||
"""
|
||||
Converts a string representation of truth to `True` (1) or `False` (0).
|
||||
True values are `y`, `yes`, `t`, `true`, `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
|
||||
"""
|
||||
value = value.lower()
|
||||
if value in ("y", "yes", "t", "true", "on", "1"):
|
||||
return 1
|
||||
elif value in ("n", "no", "f", "false", "off", "0"):
|
||||
return 0
|
||||
else:
|
||||
raise ValueError(f"invalid truth value {value}")
|
||||
|
||||
|
||||
# Guard for when Torch is not available
|
||||
if is_torch_available():
|
||||
# Update device function dict mapping
|
||||
|
||||
Reference in New Issue
Block a user