mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
make style & quality
This commit is contained in:
@@ -374,7 +374,9 @@ def convert_ltx2_connectors(original_state_dict: Dict[str, Any], version: str) -
|
||||
return connectors
|
||||
|
||||
|
||||
def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
def get_ltx2_video_vae_config(
|
||||
version: str, timestep_conditioning: bool = False
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
|
||||
if version == "test":
|
||||
config = {
|
||||
"model_id": "diffusers-internal-dev/dummy-ltx2",
|
||||
@@ -452,7 +454,9 @@ def get_ltx2_video_vae_config(version: str, timestep_conditioning: bool = False)
|
||||
return config, rename_dict, special_keys_remap
|
||||
|
||||
|
||||
def convert_ltx2_video_vae(original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool) -> Dict[str, Any]:
|
||||
def convert_ltx2_video_vae(
|
||||
original_state_dict: Dict[str, Any], version: str, timestep_conditioning: bool
|
||||
) -> Dict[str, Any]:
|
||||
config, rename_dict, special_keys_remap = get_ltx2_video_vae_config(version, timestep_conditioning)
|
||||
diffusers_config = config["diffusers_config"]
|
||||
|
||||
@@ -719,7 +723,9 @@ def get_args():
|
||||
help="Latent upsampler filename",
|
||||
)
|
||||
|
||||
parser.add_argument("--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model")
|
||||
parser.add_argument(
|
||||
"--timestep_conditioning", action="store_true", help="Whether to add timestep condition to the video VAE model"
|
||||
)
|
||||
parser.add_argument("--vae", action="store_true", help="Whether to convert the video VAE model")
|
||||
parser.add_argument("--audio_vae", action="store_true", help="Whether to convert the audio VAE model")
|
||||
parser.add_argument("--dit", action="store_true", help="Whether to convert the DiT model")
|
||||
@@ -789,7 +795,9 @@ def main(args):
|
||||
original_vae_ckpt = load_hub_or_local_checkpoint(filename=args.vae_filename)
|
||||
elif combined_ckpt is not None:
|
||||
original_vae_ckpt = get_model_state_dict_from_combined_ckpt(combined_ckpt, args.vae_prefix)
|
||||
vae = convert_ltx2_video_vae(original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning)
|
||||
vae = convert_ltx2_video_vae(
|
||||
original_vae_ckpt, version=args.version, timestep_conditioning=args.timestep_conditioning
|
||||
)
|
||||
if not args.full_pipeline and not args.upsample_pipeline:
|
||||
vae.to(vae_dtype).save_pretrained(os.path.join(args.output_path, "vae"))
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
DISTILLED_SIGMA_VALUES = [1.0, 0.99375, 0.9875, 0.98125, 0.975, 0.909375, 0.725, 0.421875, 0.0]
|
||||
|
||||
# Reduced schedule for super-resolution stage 2 (subset of distilled values)
|
||||
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
|
||||
STAGE_2_DISTILLED_SIGMA_VALUES = [0.909375, 0.725, 0.421875, 0.0]
|
||||
|
||||
Reference in New Issue
Block a user