mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -1238,37 +1238,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
|
||||
return x
|
||||
|
||||
|
||||
class FluxPosEmbed(nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
is_npu = ids.device.type == "npu"
|
||||
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -2619,3 +2588,13 @@ class MultiIPAdapterImageProjection(nn.Module):
|
||||
projected_image_embeds.append(image_embed)
|
||||
|
||||
return projected_image_embeds
|
||||
|
||||
|
||||
class FluxPosEmbed(nn.Module):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
|
||||
deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
|
||||
|
||||
from .transformers.transformer_flux import FluxPosEmbed
|
||||
|
||||
return FluxPosEmbed(*args, **kwargs)
|
||||
|
||||
@@ -30,8 +30,8 @@ from ..cache_utils import CacheMixin
|
||||
from ..embeddings import (
|
||||
CombinedTimestepGuidanceTextProjEmbeddings,
|
||||
CombinedTimestepTextProjEmbeddings,
|
||||
FluxPosEmbed,
|
||||
apply_rotary_emb,
|
||||
get_1d_rotary_pos_embed,
|
||||
)
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
@@ -510,6 +510,37 @@ class FluxTransformerBlock(nn.Module):
|
||||
return encoder_hidden_states, hidden_states
|
||||
|
||||
|
||||
class FluxPosEmbed(nn.Module):
|
||||
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
||||
def __init__(self, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
is_npu = ids.device.type == "npu"
|
||||
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
||||
for i in range(n_axes):
|
||||
cos, sin = get_1d_rotary_pos_embed(
|
||||
self.axes_dim[i],
|
||||
pos[:, i],
|
||||
theta=self.theta,
|
||||
repeat_interleave_real=True,
|
||||
use_real=True,
|
||||
freqs_dtype=freqs_dtype,
|
||||
)
|
||||
cos_out.append(cos)
|
||||
sin_out.append(sin)
|
||||
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
||||
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
class FluxTransformer2DModel(
|
||||
ModelMixin,
|
||||
ConfigMixin,
|
||||
|
||||
Reference in New Issue
Block a user