diff --git a/src/diffusers/guiders/adaptive_projected_guidance.py b/src/diffusers/guiders/adaptive_projected_guidance.py index 7da1cc59a3..83e93c15ff 100644 --- a/src/diffusers/guiders/adaptive_projected_guidance.py +++ b/src/diffusers/guiders/adaptive_projected_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -73,14 +73,18 @@ class AdaptiveProjectedGuidance(BaseGuidance): self.use_original_formulation = use_original_formulation self.momentum_buffer = None - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self._step == 0: if self.adaptive_projected_guidance_momentum is not None: self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum) tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/auto_guidance.py b/src/diffusers/guiders/auto_guidance.py index bfffb9f39c..8bb6083781 100644 --- a/src/diffusers/guiders/auto_guidance.py +++ b/src/diffusers/guiders/auto_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -120,11 +120,15 @@ class AutoGuidance(BaseGuidance): registry = HookRegistry.check_if_exists_or_initialize(denoiser) registry.remove_hook(name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/classifier_free_guidance.py b/src/diffusers/guiders/classifier_free_guidance.py index 429f845041..429392e3f9 100644 --- a/src/diffusers/guiders/classifier_free_guidance.py +++ b/src/diffusers/guiders/classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -75,11 +75,15 @@ class ClassifierFreeGuidance(BaseGuidance): self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/classifier_free_zero_star_guidance.py b/src/diffusers/guiders/classifier_free_zero_star_guidance.py index 4c9839ee78..220a95e54a 100644 --- a/src/diffusers/guiders/classifier_free_zero_star_guidance.py +++ b/src/diffusers/guiders/classifier_free_zero_star_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -73,11 +73,15 @@ class ClassifierFreeZeroStarGuidance(BaseGuidance): self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/guider_utils.py b/src/diffusers/guiders/guider_utils.py index 7d005442e8..18c85f5794 100644 --- a/src/diffusers/guiders/guider_utils.py +++ b/src/diffusers/guiders/guider_utils.py @@ -174,7 +174,7 @@ class BaseGuidance: from ..pipelines.modular_pipeline import BlockState if input_fields is None: - raise ValueError("Input fields have not been set. Please call `set_input_fields` before preparing inputs.") + raise ValueError("Input fields cannot be None. Please pass `input_fields` to `prepare_inputs` or call `set_input_fields` before preparing inputs.") data_batch = {} for key, value in input_fields.items(): try: @@ -186,7 +186,7 @@ class BaseGuidance: # We've already checked that value is a string or a tuple of strings with length 2 pass except AttributeError: - raise ValueError(f"Expected `data` to have attribute(s) {value}, but it does not. Please check the input data.") + logger.warning(f"`data` does not have attribute(s) {value}, skipping.") data_batch[cls._identifier_key] = identifier return BlockState(**data_batch) diff --git a/src/diffusers/guiders/skip_layer_guidance.py b/src/diffusers/guiders/skip_layer_guidance.py index bdd9e4af81..56dae19036 100644 --- a/src/diffusers/guiders/skip_layer_guidance.py +++ b/src/diffusers/guiders/skip_layer_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -156,7 +156,11 @@ class SkipLayerGuidance(BaseGuidance): for hook_name in self._skip_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -168,7 +172,7 @@ class SkipLayerGuidance(BaseGuidance): input_predictions = ["pred_cond", "pred_uncond", "pred_cond_skip"] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/smoothed_energy_guidance.py b/src/diffusers/guiders/smoothed_energy_guidance.py index 1c7ee45dc3..c215cb0afd 100644 --- a/src/diffusers/guiders/smoothed_energy_guidance.py +++ b/src/diffusers/guiders/smoothed_energy_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import List, Optional, Union, TYPE_CHECKING +from typing import List, Optional, Union, TYPE_CHECKING, Dict, Tuple import torch @@ -149,7 +149,11 @@ class SmoothedEnergyGuidance(BaseGuidance): for hook_name in self._seg_layer_hook_names: registry.remove_hook(hook_name, recurse=True) - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + if self.num_conditions == 1: tuple_indices = [0] input_predictions = ["pred_cond"] @@ -161,7 +165,7 @@ class SmoothedEnergyGuidance(BaseGuidance): input_predictions = ["pred_cond", "pred_uncond", "pred_cond_seg"] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], input_predictions[i]) data_batches.append(data_batch) return data_batches diff --git a/src/diffusers/guiders/tangential_classifier_free_guidance.py b/src/diffusers/guiders/tangential_classifier_free_guidance.py index 631f9a5f33..9fa8f94541 100644 --- a/src/diffusers/guiders/tangential_classifier_free_guidance.py +++ b/src/diffusers/guiders/tangential_classifier_free_guidance.py @@ -13,7 +13,7 @@ # limitations under the License. import math -from typing import Optional, List, TYPE_CHECKING +from typing import Optional, List, TYPE_CHECKING, Dict, Union, Tuple import torch @@ -62,11 +62,15 @@ class TangentialClassifierFreeGuidance(BaseGuidance): self.guidance_rescale = guidance_rescale self.use_original_formulation = use_original_formulation - def prepare_inputs(self, data: "BlockState") -> List["BlockState"]: + def prepare_inputs(self, data: "BlockState", input_fields: Optional[Dict[str, Union[str, Tuple[str, str]]]] = None) -> List["BlockState"]: + + if input_fields is None: + input_fields = self._input_fields + tuple_indices = [0] if self.num_conditions == 1 else [0, 1] data_batches = [] for i in range(self.num_conditions): - data_batch = self._prepare_batch(self._input_fields, data, tuple_indices[i], self._input_predictions[i]) + data_batch = self._prepare_batch(input_fields, data, tuple_indices[i], self._input_predictions[i]) data_batches.append(data_batch) return data_batches