mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -204,7 +204,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_skip_time_act: bool = False,
|
||||
resnet_out_scale_factor: int = 1.0,
|
||||
resnet_out_scale_factor: float = 1.0,
|
||||
time_embedding_type: str = "positional",
|
||||
time_embedding_dim: Optional[int] = None,
|
||||
time_embedding_act_fn: Optional[str] = None,
|
||||
@@ -217,7 +217,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
class_embeddings_concat: bool = False,
|
||||
mid_block_only_cross_attention: Optional[bool] = None,
|
||||
cross_attention_norm: Optional[str] = None,
|
||||
addition_embed_type_num_heads=64,
|
||||
addition_embed_type_num_heads: int = 64,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -485,9 +485,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
up_block_types: Tuple[str],
|
||||
only_cross_attention: Union[bool, Tuple[bool]],
|
||||
block_out_channels: Tuple[int],
|
||||
layers_per_block: [int, Tuple[int]],
|
||||
layers_per_block: Union[int, Tuple[int]],
|
||||
cross_attention_dim: Union[int, Tuple[int]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]],
|
||||
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
|
||||
reverse_transformer_layers_per_block: bool,
|
||||
attention_head_dim: int,
|
||||
num_attention_heads: Optional[Union[int, Tuple[int]]],
|
||||
@@ -762,7 +762,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
|
||||
self.set_attn_processor(processor)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
@@ -831,7 +831,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
if hasattr(module, "gradient_checkpointing"):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_freeu(self, s1, s2, b1, b2):
|
||||
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
||||
r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
|
||||
|
||||
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
||||
@@ -953,7 +953,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
return class_emb
|
||||
|
||||
def get_aug_embed(
|
||||
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict
|
||||
self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
||||
) -> Optional[torch.Tensor]:
|
||||
aug_emb = None
|
||||
if self.config.addition_embed_type == "text":
|
||||
@@ -1004,7 +1004,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
|
||||
aug_emb = self.add_embedding(image_embs, hint)
|
||||
return aug_emb
|
||||
|
||||
def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor:
|
||||
def process_encoder_hidden_states(
|
||||
self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
|
||||
) -> torch.Tensor:
|
||||
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
||||
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
||||
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
||||
|
||||
Reference in New Issue
Block a user