diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4f268bfa01..262e57a3a0 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -1238,37 +1238,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions): return x -class FluxPosEmbed(nn.Module): - # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 - def __init__(self, theta: int, axes_dim: List[int]): - super().__init__() - self.theta = theta - self.axes_dim = axes_dim - - def forward(self, ids: torch.Tensor) -> torch.Tensor: - n_axes = ids.shape[-1] - cos_out = [] - sin_out = [] - pos = ids.float() - is_mps = ids.device.type == "mps" - is_npu = ids.device.type == "npu" - freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 - for i in range(n_axes): - cos, sin = get_1d_rotary_pos_embed( - self.axes_dim[i], - pos[:, i], - theta=self.theta, - repeat_interleave_real=True, - use_real=True, - freqs_dtype=freqs_dtype, - ) - cos_out.append(cos) - sin_out.append(sin) - freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) - freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) - return freqs_cos, freqs_sin - - class TimestepEmbedding(nn.Module): def __init__( self, @@ -2619,3 +2588,13 @@ class MultiIPAdapterImageProjection(nn.Module): projected_image_embeds.append(image_embed) return projected_image_embeds + + +class FluxPosEmbed(nn.Module): + def __new__(cls, *args, **kwargs): + deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`." + deprecate("FluxPosEmbed", "1.0.0", deprecation_message) + + from .transformers.transformer_flux import FluxPosEmbed + + return FluxPosEmbed(*args, **kwargs) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index 8682bcbb60..8218b2ae6e 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -30,8 +30,8 @@ from ..cache_utils import CacheMixin from ..embeddings import ( CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, - FluxPosEmbed, apply_rotary_emb, + get_1d_rotary_pos_embed, ) from ..modeling_outputs import Transformer2DModelOutput from ..modeling_utils import ModelMixin @@ -510,6 +510,37 @@ class FluxTransformerBlock(nn.Module): return encoder_hidden_states, hidden_states +class FluxPosEmbed(nn.Module): + # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 + def __init__(self, theta: int, axes_dim: List[int]): + super().__init__() + self.theta = theta + self.axes_dim = axes_dim + + def forward(self, ids: torch.Tensor) -> torch.Tensor: + n_axes = ids.shape[-1] + cos_out = [] + sin_out = [] + pos = ids.float() + is_mps = ids.device.type == "mps" + is_npu = ids.device.type == "npu" + freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64 + for i in range(n_axes): + cos, sin = get_1d_rotary_pos_embed( + self.axes_dim[i], + pos[:, i], + theta=self.theta, + repeat_interleave_real=True, + use_real=True, + freqs_dtype=freqs_dtype, + ) + cos_out.append(cos) + sin_out.append(sin) + freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) + freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) + return freqs_cos, freqs_sin + + class FluxTransformer2DModel( ModelMixin, ConfigMixin,