diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 3eeff41dd1..5dcb903db4 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -17,6 +17,7 @@ import warnings from collections import OrderedDict from dataclasses import dataclass, field from typing import Any, Dict, List, Tuple, Union, Optional, Type +from copy import deepcopy import torch @@ -109,7 +110,9 @@ class PipelineState: self.intermediate_kwargs[kwargs_type].append(key) def get_input(self, key: str, default: Any = None) -> Any: - return self.inputs.get(key, default) + value = self.inputs.get(key, default) + if value is not None: + return deepcopy(value) def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]: return {key: self.inputs.get(key, default) for key in keys} @@ -483,6 +486,7 @@ class PipelineBlock(ModularPipelineMixin): ) + # YiYi TODO: input and inteermediate inputs with same name? should warn? def get_block_state(self, state: PipelineState) -> dict: """Get all inputs and intermediates in one dictionary""" data = {} @@ -1032,14 +1036,21 @@ class SequentialPipelineBlocks(ModularPipelineMixin): @property def intermediates_outputs(self) -> List[str]: - named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()] + named_outputs = [] + for name, block in self.blocks.items(): + inp_names = set([inp.name for inp in block.intermediates_inputs]) + # so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce) + # filter out them here so they do not end up as intermediates_outputs + if name not in inp_names: + named_outputs.append((name, block.intermediates_outputs)) combined_outputs = combine_outputs(*named_outputs) return combined_outputs + # YiYi TODO: I think we can remove the outputs property @property def outputs(self) -> List[str]: - return next(reversed(self.blocks.values())).intermediates_outputs - + # return next(reversed(self.blocks.values())).intermediates_outputs + return self.intermediates_outputs @torch.no_grad() def __call__(self, pipeline, state: PipelineState) -> PipelineState: for block_name, block in self.blocks.items():