mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Check correct model type is passed to from_pretrained (#10189)
* Check correct model type is passed to `from_pretrained` * Flax, skip scheduler * test_wrong_model * Fix for scheduler * Update tests/pipelines/test_pipelines.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * EnumMeta * Flax * scheduler in expected types * make * type object 'CLIPTokenizer' has no attribute '_PipelineFastTests__name' * support union * fix typing in kandinsky * make * add LCMScheduler * 'LCMScheduler' object has no attribute 'sigmas' * tests for wrong scheduler * make * update * warning * tests * Update src/diffusers/pipelines/pipeline_utils.py Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> * import FlaxSchedulerMixin * skip scheduler --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
import fnmatch
|
||||
import importlib
|
||||
import inspect
|
||||
@@ -811,6 +812,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
# in this case they are already instantiated in `kwargs`
|
||||
# extract them here
|
||||
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
||||
expected_types = pipeline_class._get_signature_types()
|
||||
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
||||
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
||||
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
||||
@@ -833,6 +835,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
||||
|
||||
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
||||
|
||||
for key in init_dict.keys():
|
||||
if key not in passed_class_obj:
|
||||
continue
|
||||
if "scheduler" in key:
|
||||
continue
|
||||
|
||||
class_obj = passed_class_obj[key]
|
||||
_expected_class_types = []
|
||||
for expected_type in expected_types[key]:
|
||||
if isinstance(expected_type, enum.EnumMeta):
|
||||
_expected_class_types.extend(expected_type.__members__.keys())
|
||||
else:
|
||||
_expected_class_types.append(expected_type.__name__)
|
||||
|
||||
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
|
||||
if not _is_valid_type:
|
||||
logger.warning(
|
||||
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
|
||||
)
|
||||
|
||||
# Special case: safety_checker must be loaded separately when using `from_flax`
|
||||
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
||||
raise NotImplementedError(
|
||||
|
||||
@@ -1802,6 +1802,16 @@ class PipelineFastTests(unittest.TestCase):
|
||||
sd.maybe_free_model_hooks()
|
||||
assert sd._offload_gpu_id == 5
|
||||
|
||||
def test_wrong_model(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
with self.assertRaises(ValueError) as error_context:
|
||||
_ = StableDiffusionPipeline.from_pretrained(
|
||||
"hf-internal-testing/diffusers-stable-diffusion-tiny-all", text_encoder=tokenizer
|
||||
)
|
||||
|
||||
assert "is of type" in str(error_context.exception)
|
||||
assert "but should be" in str(error_context.exception)
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
|
||||
Reference in New Issue
Block a user