diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 17b13db22a..532e134a61 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -996,14 +996,11 @@ def main(args): ) if unwrap_model(unet).dtype != torch.float32: - raise ValueError( - f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}" - ) + raise ValueError(f"Unet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}") if args.train_text_encoder and unwrap_model(text_encoder).dtype != torch.float32: raise ValueError( - f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." - f" {low_precision_error_string}" + f"Text encoder loaded as datatype {unwrap_model(text_encoder).dtype}." f" {low_precision_error_string}" ) # Enable TF32 for faster training on Ampere GPUs,