1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Check correct model type is passed to from_pretrained (#10189)

* Check correct model type is passed to `from_pretrained`

* Flax, skip scheduler

* test_wrong_model

* Fix for scheduler

* Update tests/pipelines/test_pipelines.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* EnumMeta

* Flax

* scheduler in expected types

* make

* type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name'

* support union

* fix typing in kandinsky

* make

* add LCMScheduler

* 'LCMScheduler' object has no attribute 'sigmas'

* tests for wrong scheduler

* make

* update

* warning

* tests

* Update src/diffusers/pipelines/pipeline_utils.py

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>

* import FlaxSchedulerMixin

* skip scheduler

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
hlky
2024-12-19 09:24:52 +00:00
committed by GitHub
parent 2f7a417d1f
commit 0ed09a17bb
2 changed files with 32 additions and 0 deletions

View File

@@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import fnmatch
import importlib
import inspect
@@ -811,6 +812,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
# in this case they are already instantiated in `kwargs`
# extract them here
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
expected_types = pipeline_class._get_signature_types()
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
@@ -833,6 +835,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
for key in init_dict.keys():
if key not in passed_class_obj:
continue
if "scheduler" in key:
continue
class_obj = passed_class_obj[key]
_expected_class_types = []
for expected_type in expected_types[key]:
if isinstance(expected_type, enum.EnumMeta):
_expected_class_types.extend(expected_type.__members__.keys())
else:
_expected_class_types.append(expected_type.__name__)
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
if not _is_valid_type:
logger.warning(
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
)
# Special case: safety_checker must be loaded separately when using `from_flax`
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
raise NotImplementedError(

View File

@@ -1802,6 +1802,16 @@ class PipelineFastTests(unittest.TestCase):
sd.maybe_free_model_hooks()
assert sd._offload_gpu_id == 5
def test_wrong_model(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
with self.assertRaises(ValueError) as error_context:
_ = StableDiffusionPipeline.from_pretrained(
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
)
assert "is of type" in str(error_context.exception)
assert "but should be" in str(error_context.exception)
@slow
@require_torch_gpu