1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Refactor FIBO classes to BriaFibo naming convention

- Updated class names from FIBO to BriaFibo for consistency across the module.
- Modified instances of FIBOEmbedND, FIBOTimesteps, TextProjection, and TimestepProjEmbeddings to reflect the new naming.
- Ensured all references in the BriaFiboTransformer2DModel are updated accordingly.
This commit is contained in:
galbria
2025-10-28 08:17:05 +00:00
parent d0a6cb6ed1
commit 69e49599a1

View File

@@ -214,7 +214,7 @@ class BriaFiboAttention(torch.nn.Module, AttentionModuleMixin):
return self.processor(self, hidden_states, encoder_hidden_states, attention_mask, image_rotary_emb, **kwargs)
class FIBOEmbedND(torch.nn.Module):
class BriaFiboEmbedND(torch.nn.Module):
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
def __init__(self, theta: int, axes_dim: List[int]):
super().__init__()
@@ -297,7 +297,7 @@ class BriaFiboSingleTransformerBlock(nn.Module):
return hidden_states
class TextProjection(nn.Module):
class BriaFiboTextProjection(nn.Module):
def __init__(self, in_features, hidden_size):
super().__init__()
self.linear = nn.Linear(in_features=in_features, out_features=hidden_size, bias=False)
@@ -393,7 +393,7 @@ class BriaFiboTransformerBlock(nn.Module):
return encoder_hidden_states, hidden_states
class FIBOTimesteps(nn.Module):
class BriaFiboTimesteps(nn.Module):
def __init__(
self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1, time_theta=10000
):
@@ -416,11 +416,11 @@ class FIBOTimesteps(nn.Module):
return t_emb
class TimestepProjEmbeddings(nn.Module):
class BriaFiboTimestepProjEmbeddings(nn.Module):
def __init__(self, embedding_dim, time_theta):
super().__init__()
self.time_proj = FIBOTimesteps(
self.time_proj = BriaFiboTimesteps(
num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, time_theta=time_theta
)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
@@ -469,12 +469,12 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
self.out_channels = in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = FIBOEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.pos_embed = BriaFiboEmbedND(theta=rope_theta, axes_dim=axes_dims_rope)
self.time_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
self.time_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim, time_theta=time_theta)
if guidance_embeds:
self.guidance_embed = TimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.guidance_embed = BriaFiboTimestepProjEmbeddings(embedding_dim=self.inner_dim)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = torch.nn.Linear(self.config.in_channels, self.inner_dim)
@@ -507,7 +507,7 @@ class BriaFiboTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
self.gradient_checkpointing = False
caption_projection = [
TextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
BriaFiboTextProjection(in_features=text_encoder_dim, hidden_size=self.inner_dim // 2)
for i in range(self.config.num_layers + self.config.num_single_layers)
]
self.caption_projection = nn.ModuleList(caption_projection)