1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-10-23 15:36:27 +05:30
parent 8f1b207ffd
commit 0229976ab5

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import numpy as np
@@ -250,3 +251,43 @@ class WanVACEPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
inputs = self.get_dummy_inputs(torch_device)
video = pipe(**inputs).frames[0]
assert video.shape == (17, 3, 16, 16)
def test_save_load_optional_components(self, expected_max_difference=1e-4):
optional_component = ["transformer", "image_encoder", "image_processor"]
components = self.get_dummy_components()
for component in optional_component:
components[component] = None
components["boundary_ratio"] = 1.0
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output = pipe(**inputs)[0]
with tempfile.TemporaryDirectory() as tmpdir:
pipe.save_pretrained(tmpdir, safe_serialization=False)
pipe_loaded = self.pipeline_class.from_pretrained(tmpdir)
for component in pipe_loaded.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe_loaded.to(torch_device)
pipe_loaded.set_progress_bar_config(disable=None)
for component in optional_component:
assert getattr(pipe_loaded, component) is None, f"`{component}` did not stay set to None after loading."
inputs = self.get_dummy_inputs(generator_device)
torch.manual_seed(0)
output_loaded = pipe_loaded(**inputs)[0]
max_diff = np.abs(output.detach().cpu().numpy() - output_loaded.detach().cpu().numpy()).max()
assert max_diff < expected_max_difference, "Outputs exceed expecpted maximum difference"