From f67639b0bb54d3ccf7fc17157ba0b1e2e959ac5e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 8 Dec 2025 14:01:03 +0800 Subject: [PATCH] add post init for safty checker (#12794) * add post init for safty checker Signed-off-by: jiqing-feng * check transformers version before post init Signed-off-by: jiqing-feng * Apply style fixes --------- Signed-off-by: jiqing-feng Co-authored-by: github-actions[bot] --- src/diffusers/pipelines/stable_diffusion/safety_checker.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker.py b/src/diffusers/pipelines/stable_diffusion/safety_checker.py index 16aff10259..65daafe012 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker.py @@ -17,7 +17,7 @@ import torch import torch.nn as nn from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel -from ...utils import logging +from ...utils import is_transformers_version, logging logger = logging.get_logger(__name__) @@ -46,6 +46,9 @@ class StableDiffusionSafetyChecker(PreTrainedModel): self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False) self.special_care_embeds_weights = nn.Parameter(torch.ones(3), requires_grad=False) + # Model requires post_init after transformers v4.57.3 + if is_transformers_version(">", "4.57.3"): + self.post_init() @torch.no_grad() def forward(self, clip_input, images):