mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix norm not training in train_control_lora_flux.py (#11832)
This commit is contained in:
@@ -837,11 +837,6 @@ def main(args):
|
||||
assert torch.all(flux_transformer.x_embedder.weight[:, initial_input_channels:].data == 0)
|
||||
flux_transformer.register_to_config(in_channels=initial_input_channels * 2, out_channels=initial_input_channels)
|
||||
|
||||
if args.train_norm_layers:
|
||||
for name, param in flux_transformer.named_parameters():
|
||||
if any(k in name for k in NORM_LAYER_PREFIXES):
|
||||
param.requires_grad = True
|
||||
|
||||
if args.lora_layers is not None:
|
||||
if args.lora_layers != "all-linear":
|
||||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
|
||||
@@ -879,6 +874,11 @@ def main(args):
|
||||
)
|
||||
flux_transformer.add_adapter(transformer_lora_config)
|
||||
|
||||
if args.train_norm_layers:
|
||||
for name, param in flux_transformer.named_parameters():
|
||||
if any(k in name for k in NORM_LAYER_PREFIXES):
|
||||
param.requires_grad = True
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
|
||||
Reference in New Issue
Block a user