mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Dtype] Align dtype casting behavior with Transformers and Accelerate (#1725)
* [Dtype] Align automatic dtype * up * up * fix * re-add accelerate
This commit is contained in:
committed by
GitHub
parent
debc74f442
commit
f2e521c499
4
.github/workflows/nightly_tests.yml
vendored
4
.github/workflows/nightly_tests.yml
vendored
@@ -62,6 +62,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -134,6 +135,7 @@ jobs:
|
||||
${CONDA_RUN} python -m pip install --upgrade pip
|
||||
${CONDA_RUN} python -m pip install -e .[quality,test]
|
||||
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
- name: Environment
|
||||
shell: arch -arch arm64 bash {0}
|
||||
@@ -157,4 +159,4 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: torch_mps_test_reports
|
||||
path: reports
|
||||
path: reports
|
||||
|
||||
2
.github/workflows/pr_tests.yml
vendored
2
.github/workflows/pr_tests.yml
vendored
@@ -60,6 +60,7 @@ jobs:
|
||||
apt-get update && apt-get install libsndfile1-dev -y
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -126,6 +127,7 @@ jobs:
|
||||
${CONDA_RUN} python -m pip install --upgrade pip
|
||||
${CONDA_RUN} python -m pip install -e .[quality,test]
|
||||
${CONDA_RUN} python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cpu
|
||||
${CONDA_RUN} python -m pip install git+https://github.com/huggingface/accelerate
|
||||
${CONDA_RUN} python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
|
||||
4
.github/workflows/push_tests.yml
vendored
4
.github/workflows/push_tests.yml
vendored
@@ -62,6 +62,7 @@ jobs:
|
||||
run: |
|
||||
python -m pip install -e .[quality,test]
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
|
||||
- name: Environment
|
||||
run: |
|
||||
@@ -130,6 +131,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install -e .[quality,test,training]
|
||||
python -m pip install git+https://github.com/huggingface/accelerate
|
||||
python -m pip install -U git+https://github.com/huggingface/transformers
|
||||
|
||||
- name: Environment
|
||||
@@ -151,4 +153,4 @@ jobs:
|
||||
uses: actions/upload-artifact@v2
|
||||
with:
|
||||
name: examples_test_reports
|
||||
path: reports
|
||||
path: reports
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
@@ -489,11 +490,15 @@ class ModelMixin(torch.nn.Module):
|
||||
state_dict = load_state_dict(model_file)
|
||||
# move the parms from meta device to cpu
|
||||
for param_name, param in state_dict.items():
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
||||
if accepts_dtype:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param, dtype=torch_dtype)
|
||||
else:
|
||||
set_module_tensor_to_device(model, param_name, param_device, value=param)
|
||||
else: # else let accelerate handle loading and dispatching.
|
||||
# Load weights and dispatch according to the device_map
|
||||
# by deafult the device_map is None and the weights are loaded on the CPU
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map)
|
||||
accelerate.load_checkpoint_and_dispatch(model, model_file, device_map, dtype=torch_dtype)
|
||||
|
||||
loading_info = {
|
||||
"missing_keys": [],
|
||||
@@ -519,20 +524,6 @@ class ModelMixin(torch.nn.Module):
|
||||
model = cls.from_config(config, **unused_kwargs)
|
||||
|
||||
state_dict = load_state_dict(model_file)
|
||||
dtype = set(v.dtype for v in state_dict.values())
|
||||
|
||||
if len(dtype) > 1 and torch.float32 not in dtype:
|
||||
raise ValueError(
|
||||
f"The weights of the model file {model_file} have a mixture of incompatible dtypes {dtype}. Please"
|
||||
f" make sure that {model_file} weights have only one dtype."
|
||||
)
|
||||
elif len(dtype) > 1 and torch.float32 in dtype:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
dtype = dtype.pop()
|
||||
|
||||
# move model to correct dtype
|
||||
model = model.to(dtype)
|
||||
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
|
||||
@@ -70,9 +70,9 @@ class ModelTesterMixin:
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.to(dtype)
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=True, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False)
|
||||
new_model = self.model_class.from_pretrained(tmpdirname, low_cpu_mem_usage=False, torch_dtype=dtype)
|
||||
assert new_model.dtype == dtype
|
||||
|
||||
def test_determinism(self):
|
||||
|
||||
Reference in New Issue
Block a user