diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 8103c01643..5af95cba74 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -529,8 +529,6 @@ title: Kandinsky 2.2 - local: api/pipelines/kandinsky3 title: Kandinsky 3 - - local: api/pipelines/kandinsky5 - title: Kandinsky 5 - local: api/pipelines/kolors title: Kolors - local: api/pipelines/latent_consistency_models @@ -638,6 +636,8 @@ title: HunyuanVideo - local: api/pipelines/i2vgenxl title: I2VGen-XL + - local: api/pipelines/kandinsky5_video + title: Kandinsky 5.0 Video - local: api/pipelines/latte title: Latte - local: api/pipelines/ltx_video diff --git a/docs/source/en/api/pipelines/kandinsky5.md b/docs/source/en/api/pipelines/kandinsky5_video.md similarity index 90% rename from docs/source/en/api/pipelines/kandinsky5.md rename to docs/source/en/api/pipelines/kandinsky5_video.md index a98a0826b7..533db23e1c 100644 --- a/docs/source/en/api/pipelines/kandinsky5.md +++ b/docs/source/en/api/pipelines/kandinsky5_video.md @@ -7,9 +7,9 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o specific language governing permissions and limitations under the License. --> -# Kandinsky 5.0 +# Kandinsky 5.0 Video -Kandinsky 5.0 is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov +Kandinsky 5.0 Video is created by the Kandinsky team: Alexey Letunovskiy, Maria Kovaleva, Ivan Kirillov, Lev Novitskiy, Denis Koposov, Dmitrii Mikhailov, Anna Averchenkova, Andrey Shutkin, Julia Agafonova, Olga Kim, Anastasiia Kargapoltseva, Nikita Kiselev, Anna Dmitrienko, Anastasia Maltseva, Kirill Chernyshev, Ilia Vasiliev, Viacheslav Vasilev, Vladimir Polovnikov, Yury Kolabushin, Alexander Belykh, Mikhail Mamaev, Anastasia Aliaskina, Tatiana Nikulina, Polina Gavrilova, Vladimir Arkhipkin, Vladimir Korviakov, Nikolai Gerasimenko, Denis Parkhomenko, Denis Dimitrov Kandinsky 5.0 is a family of diffusion models for Video & Image generation. Kandinsky 5.0 T2V Lite is a lightweight video generation model (2B parameters) that ranks #1 among open-source models in its class. It outperforms larger models and offers the best understanding of Russian concepts in the open-source ecosystem. @@ -92,7 +92,7 @@ pipe = pipe.to("cuda") pipe.transformer.set_attention_backend( "flex" -) # <--- Set attention backend to Flex +) # <--- Sett attention bakend to Flex pipe.transformer.compile( mode="max-autotune-no-cudagraphs", dynamic=True @@ -115,7 +115,7 @@ export_to_video(output, "output.mp4", fps=24, quality=9) ``` ### Diffusion Distilled model -**⚠️ Warning!** all nocfg and diffusion distilled models should be inferred without CFG (```guidance_scale=1.0```): +**⚠️ Warning!** all nocfg and diffusion distilled models should be infered wothout CFG (```guidance_scale=1.0```): ```python model_id = "ai-forever/Kandinsky-5.0-T2V-Lite-distilled16steps-5s-Diffusers" diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index 22f982509b..6f8276a632 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways: - you can either provide your own folder as `--train_data_dir` - or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument. +If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time. + Below, we explain both in more detail. #### Provide the dataset as a folder diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3ffeef1364..0cc96220b9 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): return res.expand(broadcast_shape) +def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor: + """ + Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed. + """ + if tensor.ndim == 2: + tensor = tensor.unsqueeze(0) + channels = tensor.shape[0] + if channels == 3: + return tensor + if channels == 1: + return tensor.repeat(3, 1, 1) + if channels == 2: + return torch.cat([tensor, tensor[:1]], dim=0) + if channels > 3: + return tensor[:3] + raise ValueError(f"Unsupported number of channels: {channels}") + + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -260,6 +278,11 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--preserve_input_precision", + action="store_true", + help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -453,19 +476,41 @@ def main(args): # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # Preprocessing the datasets and DataLoaders creation. + spatial_augmentations = [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + ] + augmentations = transforms.Compose( - [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), - transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + spatial_augmentations + + [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) + precision_augmentations = transforms.Compose( + [ + transforms.PILToTensor(), + transforms.Lambda(_ensure_three_channels), + transforms.ConvertImageDtype(torch.float32), + ] + + spatial_augmentations + + [transforms.Normalize([0.5], [0.5])] + ) + def transform_images(examples): - images = [augmentations(image.convert("RGB")) for image in examples["image"]] - return {"input": images} + processed = [] + for image in examples["image"]: + if not args.preserve_input_precision: + processed.append(augmentations(image.convert("RGB"))) + else: + precise_image = image + if precise_image.mode == "P": + precise_image = precise_image.convert("RGB") + processed.append(precision_augmentations(precise_image)) + return {"input": processed} logger.info(f"Dataset size: {len(dataset)}") diff --git a/src/diffusers/loaders/lora_conversion_utils.py b/src/diffusers/loaders/lora_conversion_utils.py index 099dbfc1d2..2807416f97 100644 --- a/src/diffusers/loaders/lora_conversion_utils.py +++ b/src/diffusers/loaders/lora_conversion_utils.py @@ -2213,6 +2213,10 @@ def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict): state_dict = {convert_key(k): v for k, v in state_dict.items()} + has_default = any("default." in k for k in state_dict) + if has_default: + state_dict = {k.replace("default.", ""): v for k, v in state_dict.items()} + converted_state_dict = {} all_keys = list(state_dict.keys()) down_key = ".lora_down.weight" diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index 2bb6c0ea02..25919a896a 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -4940,7 +4940,8 @@ class QwenImageLoraLoaderMixin(LoraBaseMixin): has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict) has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict) has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict) - if has_alphas_in_sd or has_lora_unet or has_diffusion_model: + has_default = any("default." in k for k in state_dict) + if has_alphas_in_sd or has_lora_unet or has_diffusion_model or has_default: state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict) out = (state_dict, metadata) if return_lora_metadata else state_dict diff --git a/src/diffusers/models/attention_dispatch.py b/src/diffusers/models/attention_dispatch.py index ab0d7102ee..c17a3d0ed6 100644 --- a/src/diffusers/models/attention_dispatch.py +++ b/src/diffusers/models/attention_dispatch.py @@ -649,6 +649,86 @@ def _( # ===== Helper functions to use attention backends with templated CP autograd functions ===== +def _native_attention_forward_op( + ctx: torch.autograd.function.FunctionCtx, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_mask: Optional[torch.Tensor] = None, + dropout_p: float = 0.0, + is_causal: bool = False, + scale: Optional[float] = None, + enable_gqa: bool = False, + return_lse: bool = False, + _save_ctx: bool = True, + _parallel_config: Optional["ParallelConfig"] = None, +): + # Native attention does not return_lse + if return_lse: + raise ValueError("Native attention does not support return_lse=True") + + # used for backward pass + if _save_ctx: + ctx.save_for_backward(query, key, value) + ctx.attn_mask = attn_mask + ctx.dropout_p = dropout_p + ctx.is_causal = is_causal + ctx.scale = scale + ctx.enable_gqa = enable_gqa + + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + return out + + +def _native_attention_backward_op( + ctx: torch.autograd.function.FunctionCtx, + grad_out: torch.Tensor, + *args, + **kwargs, +): + query, key, value = ctx.saved_tensors + + query.requires_grad_(True) + key.requires_grad_(True) + value.requires_grad_(True) + + query_t, key_t, value_t = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query_t, + key=key_t, + value=value_t, + attn_mask=ctx.attn_mask, + dropout_p=ctx.dropout_p, + is_causal=ctx.is_causal, + scale=ctx.scale, + enable_gqa=ctx.enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + + grad_out_t = grad_out.permute(0, 2, 1, 3) + grad_query_t, grad_key_t, grad_value_t = torch.autograd.grad( + outputs=out, inputs=[query_t, key_t, value_t], grad_outputs=grad_out_t, retain_graph=False + ) + + grad_query = grad_query_t.permute(0, 2, 1, 3) + grad_key = grad_key_t.permute(0, 2, 1, 3) + grad_value = grad_value_t.permute(0, 2, 1, 3) + + return grad_query, grad_key, grad_value + + # https://github.com/pytorch/pytorch/blob/8904ba638726f8c9a5aff5977c4aa76c9d2edfa6/aten/src/ATen/native/native_functions.yaml#L14958 # forward declaration: # aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0., bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) @@ -1523,6 +1603,7 @@ def _native_flex_attention( @_AttentionBackendRegistry.register( AttentionBackendName.NATIVE, constraints=[_check_device, _check_shape], + supports_context_parallel=True, ) def _native_attention( query: torch.Tensor, @@ -1538,18 +1619,35 @@ def _native_attention( ) -> torch.Tensor: if return_lse: raise ValueError("Native attention backend does not support setting `return_lse=True`.") - query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) - out = torch.nn.functional.scaled_dot_product_attention( - query=query, - key=key, - value=value, - attn_mask=attn_mask, - dropout_p=dropout_p, - is_causal=is_causal, - scale=scale, - enable_gqa=enable_gqa, - ) - out = out.permute(0, 2, 1, 3) + if _parallel_config is None: + query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value)) + out = torch.nn.functional.scaled_dot_product_attention( + query=query, + key=key, + value=value, + attn_mask=attn_mask, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + enable_gqa=enable_gqa, + ) + out = out.permute(0, 2, 1, 3) + else: + out = _templated_context_parallel_attention( + query, + key, + value, + attn_mask, + dropout_p, + is_causal, + scale, + enable_gqa, + return_lse, + forward_op=_native_attention_forward_op, + backward_op=_native_attention_backward_op, + _parallel_config=_parallel_config, + ) + return out diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index b2ef5a29e0..1c0734cf35 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -358,6 +358,7 @@ def get_cached_module_file( proxies=proxies, local_files_only=local_files_only, local_dir=local_dir, + revision=revision, token=token, ) submodule = os.path.join("local", "--".join(pretrained_model_name_or_path.split("/")))