diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 93b90f50b9..55ea27d2f8 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -1541,6 +1541,11 @@ def main(args): ) pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + if args.enable_slicing: + pipe.vae.enable_slicing() + if args.enable_tiling: + pipe.vae.enable_tiling() + # Load LoRA weights lora_scaling = args.lora_alpha / args.rank pipe.load_lora_weights(args.output_dir, adapter_name="cogvideox-lora")