1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

style etc

This commit is contained in:
yiyixuxu
2026-01-22 03:14:15 +01:00
parent ea63cccb8c
commit fb6ec06a39
5 changed files with 19 additions and 16 deletions

View File

@@ -44,7 +44,7 @@ class TestFluxModularPipelineFast(ModularPipelineTesterMixin):
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)
@@ -74,7 +74,12 @@ class TestFluxImg2ImgModularPipelineFast(ModularPipelineTesterMixin):
batch_params = frozenset(["prompt", "image"])
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", ])
text_encoder_block_params = frozenset(
[
"prompt",
"max_sequence_length",
]
)
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = frozenset(["image", "height", "width"])

View File

@@ -39,7 +39,7 @@ class TestFlux2ModularPipelineFast(ModularPipelineTesterMixin):
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "max_sequence_length", "text_encoder_out_layers"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None
vae_encoder_block_params = None
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)

View File

@@ -42,7 +42,7 @@ class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, ModularGuider
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt", "negative_prompt", "max_sequence_length"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self):
generator = self.get_generator()

View File

@@ -296,7 +296,7 @@ class TestSDXLModularPipelineFast(
# should choose from the dict returned by `get_dummy_inputs`
text_encoder_block_params = frozenset(["prompt"])
decode_block_params = frozenset(["output_type"])
vae_encoder_block_params = None # None if vae_encoder is not supported
vae_encoder_block_params = None # None if vae_encoder is not supported
def get_dummy_inputs(self, seed=0):
generator = self.get_generator(seed)

View File

@@ -103,7 +103,7 @@ class ModularPipelineTesterMixin:
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
def decode_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `decode_block_params` in the child test class. "
@@ -111,7 +111,7 @@ class ModularPipelineTesterMixin:
" if should be a subset of the parameters returned by `get_dummy_inputs`"
"See existing pipeline tests for reference."
)
def vae_encoder_block_params(self) -> frozenset:
raise NotImplementedError(
"You need to set the attribute `vae_encoder_block_params` in the child test class. "
@@ -164,10 +164,7 @@ class ModularPipelineTesterMixin:
except Exception as e:
assert False, f"Failed to load pipeline from default repo: {e}"
def test_modular_inference(self):
# run the pipeline to get the base output for comparison
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
@@ -185,7 +182,7 @@ class ModularPipelineTesterMixin:
assert "vae_encoder" in blocks.sub_blocks, "`vae_encoder` block is not present in the pipeline"
# manually set the components in the sub_pipe
# a hack to workaround the fact the default pipeline properties are often incorrect for testing cases,
# a hack to workaround the fact the default pipeline properties are often incorrect for testing cases,
# #e.g. vae_scale_factor is ususally not 8 because vae is configured to be smaller for testing
def manually_set_all_components(pipe: ModularPipeline, sub_pipe: ModularPipeline):
for n, comp in pipe.components.items():
@@ -201,7 +198,7 @@ class ModularPipelineTesterMixin:
denoise_node.load_components(torch_dtype=torch.float32)
denoise_node.to(torch_device)
manually_set_all_components(pipe, denoise_node)
decoder_node = blocks.sub_blocks["decode"].init_pipeline(self.pretrained_model_name_or_path)
decoder_node.load_components(torch_dtype=torch.float32)
decoder_node.to(torch_device)
@@ -214,7 +211,7 @@ class ModularPipelineTesterMixin:
manually_set_all_components(pipe, vae_encoder_node)
else:
vae_encoder_node = None
# prepare inputs for each node
inputs = self.get_dummy_inputs()
@@ -243,9 +240,10 @@ class ModularPipelineTesterMixin:
# denoder node input should be "latents" and output should be "images"
modular_output = decoder_node(**decoder_inputs, latents=latents).images
assert modular_output.shape == standard_output.shape, f"Modular output should have same shape as standard output {standard_output.shape}, but got {modular_output.shape}"
assert modular_output.shape == standard_output.shape, (
f"Modular output should have same shape as standard output {standard_output.shape}, but got {modular_output.shape}"
)
def test_inference_batch_consistent(self, batch_sizes=[2], batch_generator=True):
pipe = self.get_pipeline().to(torch_device)