mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add post init for safty checker (#12794)
* add post init for safty checker Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * check transformers version before post init Signed-off-by: jiqing-feng <jiqing.feng@intel.com> * Apply style fixes --------- Signed-off-by: jiqing-feng <jiqing.feng@intel.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -17,7 +17,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
|
||||
|
||||
from ...utils import logging
|
||||
from ...utils import is_transformers_version, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -46,6 +46,9 @@ class StableDiffusionSafetyChecker(PreTrainedModel):
|
||||
|
||||
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
|
||||
self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False)
|
||||
# Model requires post_init after transformers v4.57.3
|
||||
if is_transformers_version(">", "4.57.3"):
|
||||
self.post_init()
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, clip_input, images):
|
||||
|
||||
Reference in New Issue
Block a user