1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
sayakpaul
2026-01-15 16:59:41 +05:30
parent c5e023fbe6
commit d0f279ce76
3 changed files with 10 additions and 1 deletions

View File

@@ -17,6 +17,9 @@ import logging
import os
import sys
import tempfile
import unittest
from diffusers.utils import is_transformers_version
sys.path.append("..")
@@ -30,6 +33,7 @@ stream_handler = logging.StreamHandler(sys.stdout)
logger.addHandler(stream_handler)
@unittest.skipIf(is_transformers_version(">=", "4.57.5"), "Size mismatch")
class CustomDiffusion(ExamplesTestsAccelerate):
def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir:

View File

@@ -20,6 +20,8 @@ class MultilingualCLIP(PreTrainedModel):
self.LinearTransformation = torch.nn.Linear(
in_features=config.transformerDimensions, out_features=config.numDims
)
if hasattr(self, "post_init"):
self.post_init()
def forward(self, input_ids, attention_mask):
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0]

View File

@@ -782,6 +782,9 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
self.prefix_encoder = PrefixEncoder(config)
self.dropout = torch.nn.Dropout(0.1)
if hasattr(self, "post_init"):
self.post_init()
def get_input_embeddings(self):
return self.embedding.word_embeddings
@@ -811,7 +814,7 @@ class ChatGLMModel(ChatGLMPreTrainedModel):
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
use_cache = use_cache if use_cache is not None else getattr(self.config, "use_cache", None)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
batch_size, seq_length = input_ids.shape