1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Apply style fixes

This commit is contained in:
github-actions[bot]
2025-04-08 04:46:02 +00:00
parent ed33194780
commit 65a3bf5f81
3 changed files with 33 additions and 10 deletions

View File

@@ -160,7 +160,12 @@ class AuraFlowSingleTransformerBlock(nn.Module):
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
self.ff = AuraFlowFeedForward(dim, dim * 4)
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None):
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
residual = hidden_states
attention_kwargs = attention_kwargs or {}
@@ -224,7 +229,11 @@ class AuraFlowJointTransformerBlock(nn.Module):
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
def forward(
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None,
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
attention_kwargs: Optional[Dict[str, Any]] = None,
):
residual = hidden_states
residual_context = encoder_hidden_states
@@ -238,7 +247,9 @@ class AuraFlowJointTransformerBlock(nn.Module):
# Attention.
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **attention_kwargs,
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
**attention_kwargs,
)
# Process attention outputs for the `hidden_states`.
@@ -492,7 +503,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, attention_kwargs=attention_kwargs,
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
attention_kwargs=attention_kwargs,
)
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
@@ -509,7 +523,9 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
)
else:
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs)
combined_hidden_states = block(
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
)
hidden_states = combined_hidden_states[:, encoder_seq_len:]

View File

@@ -564,9 +564,7 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
batch_size = prompt_embeds.shape[0]
device = self._execution_device
lora_scale = (
self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
)
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`

View File

@@ -2190,7 +2190,14 @@ class PeftLoraLoaderMixinTests:
@property
def supports_text_encoder_lora(self):
return len({"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(self.pipeline_class._lora_loadable_modules)) != 0
return (
len(
{"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(
self.pipeline_class._lora_loadable_modules
)
)
!= 0
)
def test_layerwise_casting_inference_denoiser(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
@@ -2249,7 +2256,9 @@ class PeftLoraLoaderMixinTests:
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
pipe_float8_e4m3_bf16 = initialize_pipeline(
storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16
)
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
@require_peft_version_greater("0.14.0")