mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix get_execusion blocks with loopsequential
This commit is contained in:
@@ -1033,16 +1033,17 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
def _traverse_trigger_blocks(self, trigger_inputs):
|
||||
# Convert trigger_inputs to a set for easier manipulation
|
||||
active_triggers = set(trigger_inputs)
|
||||
|
||||
def fn_recursive_traverse(block, block_name, active_triggers):
|
||||
result_blocks = OrderedDict()
|
||||
|
||||
# sequential or PipelineBlock
|
||||
# sequential(include loopsequential) or PipelineBlock
|
||||
if not hasattr(block, 'block_trigger_inputs'):
|
||||
if hasattr(block, 'blocks'):
|
||||
# sequential
|
||||
for block_name, block in block.blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(block, block_name, active_triggers)
|
||||
# sequential or LoopSequentialPipelineBlocks (keep traversing)
|
||||
for sub_block_name, sub_block in block.blocks.items():
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
|
||||
blocks_to_update = fn_recursive_traverse(sub_block, sub_block_name, active_triggers)
|
||||
blocks_to_update = {f"{block_name}.{k}": v for k,v in blocks_to_update.items()}
|
||||
result_blocks.update(blocks_to_update)
|
||||
else:
|
||||
# PipelineBlock
|
||||
@@ -1069,13 +1070,14 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
matching_trigger = None
|
||||
|
||||
if this_block is not None:
|
||||
# sequential/auto
|
||||
# sequential/auto (keep traversing)
|
||||
if hasattr(this_block, 'blocks'):
|
||||
result_blocks.update(fn_recursive_traverse(this_block, block_name, active_triggers))
|
||||
else:
|
||||
# PipelineBlock
|
||||
result_blocks[block_name] = this_block
|
||||
# Add this block's output names to active triggers if defined
|
||||
# YiYi TODO: do we need outputs here? can it just be intermediate_outputs? can we get rid of outputs attribute?
|
||||
if hasattr(this_block, 'outputs'):
|
||||
active_triggers.update(out.name for out in this_block.outputs)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user