mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Apply style fixes
This commit is contained in:
@@ -160,7 +160,12 @@ class AuraFlowSingleTransformerBlock(nn.Module):
|
||||
self.norm2 = FP32LayerNorm(dim, elementwise_affine=False, bias=False)
|
||||
self.ff = AuraFlowFeedForward(dim, dim * 4)
|
||||
|
||||
def forward(self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None):
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
attention_kwargs = attention_kwargs or {}
|
||||
|
||||
@@ -224,7 +229,11 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
self.ff_context = AuraFlowFeedForward(dim, dim * 4)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
temb: torch.FloatTensor,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
residual = hidden_states
|
||||
residual_context = encoder_hidden_states
|
||||
@@ -238,7 +247,9 @@ class AuraFlowJointTransformerBlock(nn.Module):
|
||||
|
||||
# Attention.
|
||||
attn_output, context_attn_output = self.attn(
|
||||
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, **attention_kwargs,
|
||||
hidden_states=norm_hidden_states,
|
||||
encoder_hidden_states=norm_encoder_hidden_states,
|
||||
**attention_kwargs,
|
||||
)
|
||||
|
||||
# Process attention outputs for the `hidden_states`.
|
||||
@@ -492,7 +503,10 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
|
||||
else:
|
||||
encoder_hidden_states, hidden_states = block(
|
||||
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, attention_kwargs=attention_kwargs,
|
||||
hidden_states=hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
temb=temb,
|
||||
attention_kwargs=attention_kwargs,
|
||||
)
|
||||
|
||||
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
|
||||
@@ -509,7 +523,9 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
|
||||
)
|
||||
|
||||
else:
|
||||
combined_hidden_states = block(hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs)
|
||||
combined_hidden_states = block(
|
||||
hidden_states=combined_hidden_states, temb=temb, attention_kwargs=attention_kwargs
|
||||
)
|
||||
|
||||
hidden_states = combined_hidden_states[:, encoder_seq_len:]
|
||||
|
||||
|
||||
@@ -564,9 +564,7 @@ class AuraFlowPipeline(DiffusionPipeline, AuraFlowLoraLoaderMixin):
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
lora_scale = (
|
||||
self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
||||
)
|
||||
lora_scale = self.attention_kwargs.get("scale", None) if self.attention_kwargs is not None else None
|
||||
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
|
||||
@@ -2190,7 +2190,14 @@ class PeftLoraLoaderMixinTests:
|
||||
|
||||
@property
|
||||
def supports_text_encoder_lora(self):
|
||||
return len({"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(self.pipeline_class._lora_loadable_modules)) != 0
|
||||
return (
|
||||
len(
|
||||
{"text_encoder", "text_encoder_2", "text_encoder_3"}.intersection(
|
||||
self.pipeline_class._lora_loadable_modules
|
||||
)
|
||||
)
|
||||
!= 0
|
||||
)
|
||||
|
||||
def test_layerwise_casting_inference_denoiser(self):
|
||||
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
|
||||
@@ -2249,7 +2256,9 @@ class PeftLoraLoaderMixinTests:
|
||||
pipe_float8_e4m3_fp32 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.float32)
|
||||
pipe_float8_e4m3_fp32(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
pipe_float8_e4m3_bf16 = initialize_pipeline(storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
|
||||
pipe_float8_e4m3_bf16 = initialize_pipeline(
|
||||
storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16
|
||||
)
|
||||
pipe_float8_e4m3_bf16(**inputs, generator=torch.manual_seed(0))[0]
|
||||
|
||||
@require_peft_version_greater("0.14.0")
|
||||
|
||||
Reference in New Issue
Block a user