mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix custom diffusion tests (#4996)
This commit is contained in:
@@ -785,8 +785,8 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
new_model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict).sample
|
||||
|
||||
Reference in New Issue
Block a user