mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
make inputs truly immutable, remove the output logic in sequential pipeline, and update so that intermediates_outputs are only new variables
This commit is contained in:
@@ -17,6 +17,7 @@ import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional, Type
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
import torch
|
||||
@@ -109,7 +110,9 @@ class PipelineState:
|
||||
self.intermediate_kwargs[kwargs_type].append(key)
|
||||
|
||||
def get_input(self, key: str, default: Any = None) -> Any:
|
||||
return self.inputs.get(key, default)
|
||||
value = self.inputs.get(key, default)
|
||||
if value is not None:
|
||||
return deepcopy(value)
|
||||
|
||||
def get_inputs(self, keys: List[str], default: Any = None) -> Dict[str, Any]:
|
||||
return {key: self.inputs.get(key, default) for key in keys}
|
||||
@@ -483,6 +486,7 @@ class PipelineBlock(ModularPipelineMixin):
|
||||
)
|
||||
|
||||
|
||||
# YiYi TODO: input and inteermediate inputs with same name? should warn?
|
||||
def get_block_state(self, state: PipelineState) -> dict:
|
||||
"""Get all inputs and intermediates in one dictionary"""
|
||||
data = {}
|
||||
@@ -1032,14 +1036,21 @@ class SequentialPipelineBlocks(ModularPipelineMixin):
|
||||
|
||||
@property
|
||||
def intermediates_outputs(self) -> List[str]:
|
||||
named_outputs = [(name, block.intermediates_outputs) for name, block in self.blocks.items()]
|
||||
named_outputs = []
|
||||
for name, block in self.blocks.items():
|
||||
inp_names = set([inp.name for inp in block.intermediates_inputs])
|
||||
# so we only need to list new variables as intermediates_outputs, but if user wants to list these they modified it's still fine (a.k.a we don't enforce)
|
||||
# filter out them here so they do not end up as intermediates_outputs
|
||||
if name not in inp_names:
|
||||
named_outputs.append((name, block.intermediates_outputs))
|
||||
combined_outputs = combine_outputs(*named_outputs)
|
||||
return combined_outputs
|
||||
|
||||
# YiYi TODO: I think we can remove the outputs property
|
||||
@property
|
||||
def outputs(self) -> List[str]:
|
||||
return next(reversed(self.blocks.values())).intermediates_outputs
|
||||
|
||||
# return next(reversed(self.blocks.values())).intermediates_outputs
|
||||
return self.intermediates_outputs
|
||||
@torch.no_grad()
|
||||
def __call__(self, pipeline, state: PipelineState) -> PipelineState:
|
||||
for block_name, block in self.blocks.items():
|
||||
|
||||
Reference in New Issue
Block a user