mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Make mid block optional for flax UNet (#7083)
* make mid block optional for flax UNet * make style
This commit is contained in:
@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
The tuple of downsample blocks to use.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
|
||||
The tuple of upsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
The tuple of output channels for each block.
|
||||
layers_per_block (`int`, *optional*, defaults to 2):
|
||||
@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
"DownBlock2D",
|
||||
)
|
||||
up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
|
||||
mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False
|
||||
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
||||
layers_per_block: int = 2
|
||||
@@ -252,16 +255,21 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
self.down_blocks = down_blocks
|
||||
|
||||
# mid
|
||||
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=self.dropout,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
if self.config.mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = FlaxUNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
dropout=self.dropout,
|
||||
num_attention_heads=num_attention_heads[-1],
|
||||
transformer_layers_per_block=transformer_layers_per_block[-1],
|
||||
use_linear_projection=self.use_linear_projection,
|
||||
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
||||
split_head_dim=self.split_head_dim,
|
||||
dtype=self.dtype,
|
||||
)
|
||||
elif self.config.mid_block_type is None:
|
||||
self.mid_block = None
|
||||
else:
|
||||
raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}")
|
||||
|
||||
# up
|
||||
up_blocks = []
|
||||
@@ -412,7 +420,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
down_block_res_samples = new_down_block_res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
if self.mid_block is not None:
|
||||
sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
|
||||
|
||||
if mid_block_additional_residual is not None:
|
||||
sample += mid_block_additional_residual
|
||||
|
||||
Reference in New Issue
Block a user