1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

make inputs truly immutable, remove the output logic in sequential pipeline, and update so that intermediates_outputs are only new variables

This commit is contained in:
yiyixuxu
2025-05-13 01:52:51 +02:00
parent 522e827625
commit 5cde77f915

View File

@@ -17,6 +17,7 @@ import warnings
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Any, Dict, List, Tuple, Union, Optional, Type
from copy import deepcopy
import torch
@@ -109,7 +110,9 @@ class PipelineState:
self.intermediate_kwargs[kwargs_type].append(key)
def get_input(self, key: str, default: Any = None) -> Any:
return self.inputs.get(key, default)
value = self.inputs.get(key, default)
if value is not None:
return deepcopy(value)
def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
return {key: self.inputs.get(key, default) for key in keys}
@@ -483,6 +486,7 @@ class PipelineBlock(ModularPipelineMixin):
)
# YiYi TODO: input and inteermediate inputs with same name? should warn?
def get_block_state(self, state: PipelineState) -> dict:
"""Get all inputs and intermediates in one dictionary"""
data = {}
@@ -1032,14 +1036,21 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
@property
def intermediates_outputs(self) -> List[str]:
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
named_outputs = []
for name, block in self.blocks.items():
inp_names = set([inp.name for inp in block.intermediates_inputs])
# so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
# filter out them here so they do not end up as intermediates_outputs
if name not in inp_names:
named_outputs.append((name, block.intermediates_outputs))
combined_outputs = combine_outputs(*named_outputs)
return combined_outputs
# YiYi TODO: I think we can remove the outputs property
@property
def outputs(self) -> List[str]:
return next(reversed(self.blocks.values())).intermediates_outputs
# return next(reversed(self.blocks.values())).intermediates_outputs
return self.intermediates_outputs
@torch.no_grad()
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
for block_name, block in self.blocks.items():