1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

allow input_fields as input & update message

This commit is contained in:
yiyixuxu
2025-05-08 11:25:31 +02:00
parent f552773572
commit 16b6583fa8
8 changed files with 51 additions and 23 deletions

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
@@ -73,14 +73,18 @@ class AdaptiveProjectedGuidance(BaseGuidance):
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
@@ -120,11 +120,15 @@ class AutoGuidance(BaseGuidance):
registry = HookRegistry.check_if_exists_or_initialize(denoiser)
registry.remove_hook(name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
@@ -75,11 +75,15 @@ class ClassifierFreeGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
@@ -73,11 +73,15 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -174,7 +174,7 @@ class BaseGuidance:
from ..pipelines.modular_pipeline import BlockState
if input_fields is None:
raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.")
raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.")
data_batch = {}
for key, value in input_fields.items():
try:
@@ -186,7 +186,7 @@ class BaseGuidance:
# We've already checked that value is a string or a tuple of strings with length 2
pass
except AttributeError:
raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.")
logger.warning(f"`data` does not have attribute(s) {value}, skipping.")
data_batch[cls._identifier_key] = identifier
return BlockState(**data_batch)

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
@@ -156,7 +156,11 @@ class SkipLayerGuidance(BaseGuidance):
for hook_name in self._skip_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -168,7 +172,7 @@ class SkipLayerGuidance(BaseGuidance):
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import List, Optional, Union, TYPE_CHECKING
from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple
import torch
@@ -149,7 +149,11 @@ class SmoothedEnergyGuidance(BaseGuidance):
for hook_name in self._seg_layer_hook_names:
registry.remove_hook(hook_name, recurse=True)
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
if self.num_conditions == 1:
tuple_indices = [0]
input_predictions = ["pred_cond"]
@@ -161,7 +165,7 @@ class SmoothedEnergyGuidance(BaseGuidance):
input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i])
data_batches.append(data_batch)
return data_batches

View File

@@ -13,7 +13,7 @@
# limitations under the License.
import math
from typing import Optional, List, TYPE_CHECKING
from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple
import torch
@@ -62,11 +62,15 @@ class TangentialClassifierFreeGuidance(BaseGuidance):
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]:
if input_fields is None:
input_fields = self._input_fields
tuple_indices = [0] if self.num_conditions == 1 else [0, 1]
data_batches = []
for i in range(self.num_conditions):
data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i])
data_batches.append(data_batch)
return data_batches