1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Merge branch 'main' into requirements-custom-blocks

This commit is contained in:
Sayak Paul
2025-10-22 06:23:40 +05:30
committed by GitHub
34 changed files with 2805 additions and 290 deletions

View File

@@ -7,7 +7,7 @@ on:
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8

View File

@@ -42,18 +42,39 @@ jobs:
CHANGED_FILES: ${{ steps.file_changes.outputs.all }}
run: |
echo "$CHANGED_FILES"
for FILE in $CHANGED_FILES; do
ALLOWED_IMAGES=(
diffusers-pytorch-cpu
diffusers-pytorch-cuda
diffusers-pytorch-xformers-cuda
diffusers-pytorch-minimum-cuda
diffusers-doc-builder
)
declare -A IMAGES_TO_BUILD=()
for FILE in $CHANGED_FILES; do
# skip anything that isn't still on disk
if [[ ! -f "$FILE" ]]; then
if [[ ! -e "$FILE" ]]; then
echo "Skipping removed file $FILE"
continue
fi
if [[ "$FILE" == docker/*Dockerfile ]]; then
DOCKER_PATH="${FILE%/Dockerfile}"
DOCKER_TAG=$(basename "$DOCKER_PATH")
echo "Building Docker image for $DOCKER_TAG"
docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
fi
for IMAGE in "${ALLOWED_IMAGES[@]}"; do
if [[ "$FILE" == docker/${IMAGE}/* ]]; then
IMAGES_TO_BUILD["$IMAGE"]=1
fi
done
done
if [[ ${#IMAGES_TO_BUILD[@]} -eq 0 ]]; then
echo "No relevant Docker changes detected."
exit 0
fi
for IMAGE in "${!IMAGES_TO_BUILD[@]}"; do
DOCKER_PATH="docker/${IMAGE}"
echo "Building Docker image for $IMAGE"
docker build -t "$IMAGE" "$DOCKER_PATH"
done
if: steps.file_changes.outputs.all != ''

View File

@@ -7,7 +7,7 @@ on:
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
PYTEST_TIMEOUT: 600

View File

@@ -26,7 +26,7 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60

View File

@@ -22,7 +22,7 @@ concurrency:
env:
DIFFUSERS_IS_CI: yes
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
OMP_NUM_THREADS: 4
MKL_NUM_THREADS: 4
PYTEST_TIMEOUT: 60

View File

@@ -24,7 +24,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run

View File

@@ -14,7 +14,7 @@ env:
DIFFUSERS_IS_CI: yes
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
PIPELINE_USAGE_CUTOFF: 50000

View File

@@ -18,7 +18,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no

View File

@@ -8,7 +8,7 @@ env:
HF_HOME: /mnt/cache
OMP_NUM_THREADS: 8
MKL_NUM_THREADS: 8
HF_HUB_ENABLE_HF_TRANSFER: 1
HF_XET_HIGH_PERFORMANCE: 1
PYTEST_TIMEOUT: 600
RUN_SLOW: no

View File

@@ -33,7 +33,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
hf_transfer \
hf_xet \
setuptools==69.5.1 \
bitsandbytes \
torchao \

View File

@@ -44,6 +44,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
scipy \
tensorboard \
transformers \
hf_transfer
hf_xet
CMD ["/bin/bash"]

View File

@@ -38,13 +38,12 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
datasets \
hf-doc-builder \
huggingface-hub \
hf_transfer \
hf_xet \
Jinja2 \
librosa \
numpy==1.26.4 \
scipy \
tensorboard \
transformers \
hf_transfer
transformers
CMD ["/bin/bash"]

View File

@@ -31,7 +31,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
hf_transfer
hf_xet
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean

View File

@@ -44,6 +44,6 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_transfer
hf_xet
CMD ["/bin/bash"]

View File

@@ -47,6 +47,6 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_transfer
hf_xet
CMD ["/bin/bash"]

View File

@@ -44,7 +44,7 @@ RUN uv pip install --no-cache-dir \
accelerate \
numpy==1.26.4 \
pytorch-lightning \
hf_transfer \
hf_xet \
xformers
CMD ["/bin/bash"]

View File

@@ -1,5 +1,4 @@
- title: Get started
sections:
- sections:
- local: index
title: Diffusers
- local: installation
@@ -8,9 +7,8 @@
title: Quickstart
- local: stable_diffusion
title: Basic performance
- title: Pipelines
isExpanded: false
title: Get started
- isExpanded: false
sections:
- local: using-diffusers/loading
title: DiffusionPipeline
@@ -28,9 +26,8 @@
title: Model formats
- local: using-diffusers/push_to_hub
title: Sharing pipelines and models
- title: Adapters
isExpanded: false
title: Pipelines
- isExpanded: false
sections:
- local: tutorials/using_peft_for_inference
title: LoRA
@@ -44,9 +41,8 @@
title: DreamBooth
- local: using-diffusers/textual_inversion_inference
title: Textual inversion
- title: Inference
isExpanded: false
title: Adapters
- isExpanded: false
sections:
- local: using-diffusers/weighted_prompts
title: Prompting
@@ -56,9 +52,8 @@
title: Batch inference
- local: training/distributed_inference
title: Distributed inference
- title: Inference optimization
isExpanded: false
title: Inference
- isExpanded: false
sections:
- local: optimization/fp16
title: Accelerate inference
@@ -70,8 +65,7 @@
title: Reduce memory usage
- local: optimization/speed-memory-optims
title: Compiling and offloading quantized models
- title: Community optimizations
sections:
- sections:
- local: optimization/pruna
title: Pruna
- local: optimization/xformers
@@ -90,9 +84,9 @@
title: ParaAttention
- local: using-diffusers/image_quality
title: FreeU
- title: Hybrid Inference
isExpanded: false
title: Community optimizations
title: Inference optimization
- isExpanded: false
sections:
- local: hybrid_inference/overview
title: Overview
@@ -102,9 +96,8 @@
title: VAE Encode
- local: hybrid_inference/api_reference
title: API Reference
- title: Modular Diffusers
isExpanded: false
title: Hybrid Inference
- isExpanded: false
sections:
- local: modular_diffusers/overview
title: Overview
@@ -126,9 +119,8 @@
title: ComponentsManager
- local: modular_diffusers/guiders
title: Guiders
- title: Training
isExpanded: false
title: Modular Diffusers
- isExpanded: false
sections:
- local: training/overview
title: Overview
@@ -138,8 +130,7 @@
title: Adapt a model to a new task
- local: tutorials/basic_training
title: Train a diffusion model
- title: Models
sections:
- sections:
- local: training/unconditional_training
title: Unconditional image generation
- local: training/text2image
@@ -158,8 +149,8 @@
title: InstructPix2Pix
- local: training/cogvideox
title: CogVideoX
- title: Methods
sections:
title: Models
- sections:
- local: training/text_inversion
title: Textual Inversion
- local: training/dreambooth
@@ -172,9 +163,9 @@
title: Latent Consistency Distillation
- local: training/ddpo
title: Reinforcement learning training with DDPO
- title: Quantization
isExpanded: false
title: Methods
title: Training
- isExpanded: false
sections:
- local: quantization/overview
title: Getting started
@@ -188,9 +179,8 @@
title: quanto
- local: quantization/modelopt
title: NVIDIA ModelOpt
- title: Model accelerators and hardware
isExpanded: false
title: Quantization
- isExpanded: false
sections:
- local: optimization/onnx
title: ONNX
@@ -204,9 +194,8 @@
title: Intel Gaudi
- local: optimization/neuron
title: AWS Neuron
- title: Specific pipeline examples
isExpanded: false
title: Model accelerators and hardware
- isExpanded: false
sections:
- local: using-diffusers/consisid
title: ConsisID
@@ -232,12 +221,10 @@
title: Stable Video Diffusion
- local: using-diffusers/marigold_usage
title: Marigold Computer Vision
- title: Resources
isExpanded: false
title: Specific pipeline examples
- isExpanded: false
sections:
- title: Task recipes
sections:
- sections:
- local: using-diffusers/unconditional_image_generation
title: Unconditional image generation
- local: using-diffusers/conditional_image_generation
@@ -252,6 +239,7 @@
title: Video generation
- local: using-diffusers/depth2img
title: Depth-to-image
title: Task recipes
- local: using-diffusers/write_own_pipeline
title: Understanding pipelines, models and schedulers
- local: community_projects
@@ -266,12 +254,10 @@
title: Diffusers' Ethical Guidelines
- local: conceptual/evaluation
title: Evaluating Diffusion Models
- title: API
isExpanded: false
title: Resources
- isExpanded: false
sections:
- title: Main Classes
sections:
- sections:
- local: api/configuration
title: Configuration
- local: api/logging
@@ -282,8 +268,8 @@
title: Quantization
- local: api/parallel
title: Parallel inference
- title: Modular
sections:
title: Main Classes
- sections:
- local: api/modular_diffusers/pipeline
title: Pipeline
- local: api/modular_diffusers/pipeline_blocks
@@ -294,8 +280,8 @@
title: Components and configs
- local: api/modular_diffusers/guiders
title: Guiders
- title: Loaders
sections:
title: Modular
- sections:
- local: api/loaders/ip_adapter
title: IP-Adapter
- local: api/loaders/lora
@@ -310,14 +296,13 @@
title: SD3Transformer2D
- local: api/loaders/peft
title: PEFT
- title: Models
sections:
title: Loaders
- sections:
- local: api/models/overview
title: Overview
- local: api/models/auto_model
title: AutoModel
- title: ControlNets
sections:
- sections:
- local: api/models/controlnet
title: ControlNetModel
- local: api/models/controlnet_union
@@ -332,8 +317,8 @@
title: SD3ControlNetModel
- local: api/models/controlnet_sparsectrl
title: SparseControlNetModel
- title: Transformers
sections:
title: ControlNets
- sections:
- local: api/models/allegro_transformer3d
title: AllegroTransformer3DModel
- local: api/models/aura_flow_transformer2d
@@ -396,8 +381,8 @@
title: TransformerTemporalModel
- local: api/models/wan_transformer_3d
title: WanTransformer3DModel
- title: UNets
sections:
title: Transformers
- sections:
- local: api/models/stable_cascade_unet
title: StableCascadeUNet
- local: api/models/unet
@@ -412,8 +397,8 @@
title: UNetMotionModel
- local: api/models/uvit2d
title: UViT2DModel
- title: VAEs
sections:
title: UNets
- sections:
- local: api/models/asymmetricautoencoderkl
title: AsymmetricAutoencoderKL
- local: api/models/autoencoder_dc
@@ -446,210 +431,220 @@
title: Tiny AutoEncoder
- local: api/models/vq
title: VQModel
- title: Pipelines
sections:
title: VAEs
title: Models
- sections:
- local: api/pipelines/overview
title: Overview
- local: api/pipelines/allegro
title: Allegro
- local: api/pipelines/amused
title: aMUSEd
- local: api/pipelines/animatediff
title: AnimateDiff
- local: api/pipelines/attend_and_excite
title: Attend-and-Excite
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/audioldm2
title: AudioLDM 2
- local: api/pipelines/aura_flow
title: AuraFlow
- sections:
- local: api/pipelines/audioldm
title: AudioLDM
- local: api/pipelines/audioldm2
title: AudioLDM 2
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/musicldm
title: MusicLDM
- local: api/pipelines/stable_audio
title: Stable Audio
title: Audio
- local: api/pipelines/auto_pipeline
title: AutoPipeline
- local: api/pipelines/blip_diffusion
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/cogview3
title: CogView3
- local: api/pipelines/cogview4
title: CogView4
- local: api/pipelines/consisid
title: ConsisID
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
title: ControlNet
- local: api/pipelines/controlnet_flux
title: ControlNet with Flux.1
- local: api/pipelines/controlnet_hunyuandit
title: ControlNet with Hunyuan-DiT
- local: api/pipelines/controlnet_sd3
title: ControlNet with Stable Diffusion 3
- local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL
- local: api/pipelines/controlnet_sana
title: ControlNet-Sana
- local: api/pipelines/controlnetxs
title: ControlNet-XS
- local: api/pipelines/controlnetxs_sdxl
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/controlnet_union
title: ControlNetUnion
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/dance_diffusion
title: Dance Diffusion
- local: api/pipelines/ddim
title: DDIM
- local: api/pipelines/ddpm
title: DDPM
- local: api/pipelines/deepfloyd_if
title: DeepFloyd IF
- local: api/pipelines/diffedit
title: DiffEdit
- local: api/pipelines/dit
title: DiT
- local: api/pipelines/easyanimate
title: EasyAnimate
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
title: I2VGen-XL
- local: api/pipelines/pix2pix
title: InstructPix2Pix
- local: api/pipelines/kandinsky
title: Kandinsky 2.1
- local: api/pipelines/kandinsky_v22
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/kolors
title: Kolors
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion
title: Latent Diffusion
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/ltx_video
title: LTXVideo
- local: api/pipelines/lumina2
title: Lumina 2.0
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
title: Marigold
- local: api/pipelines/mochi
title: Mochi
- local: api/pipelines/panorama
title: MultiDiffusion
- local: api/pipelines/musicldm
title: MusicLDM
- local: api/pipelines/omnigen
title: OmniGen
- local: api/pipelines/pag
title: PAG
- local: api/pipelines/paint_by_example
title: Paint by Example
- local: api/pipelines/pia
title: Personalized Image Animator (PIA)
- local: api/pipelines/pixart
title: PixArt-α
- local: api/pipelines/pixart_sigma
title: PixArt-Σ
- local: api/pipelines/qwenimage
title: QwenImage
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/sana_sprint
title: Sana Sprint
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/skyreels_v2
title: SkyReels-V2
- local: api/pipelines/stable_audio
title: Stable Audio
- local: api/pipelines/stable_cascade
title: Stable Cascade
- title: Stable Diffusion
sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
- local: api/pipelines/stable_diffusion/depth2img
title: Depth-to-image
- local: api/pipelines/stable_diffusion/gligen
title: GLIGEN (Grounded Language-to-Image Generation)
- local: api/pipelines/stable_diffusion/image_variation
title: Image variation
- local: api/pipelines/stable_diffusion/img2img
title: Image-to-image
- sections:
- local: api/pipelines/amused
title: aMUSEd
- local: api/pipelines/animatediff
title: AnimateDiff
- local: api/pipelines/attend_and_excite
title: Attend-and-Excite
- local: api/pipelines/aura_flow
title: AuraFlow
- local: api/pipelines/blip_diffusion
title: BLIP-Diffusion
- local: api/pipelines/bria_3_2
title: Bria 3.2
- local: api/pipelines/chroma
title: Chroma
- local: api/pipelines/cogview3
title: CogView3
- local: api/pipelines/cogview4
title: CogView4
- local: api/pipelines/consistency_models
title: Consistency Models
- local: api/pipelines/controlnet
title: ControlNet
- local: api/pipelines/controlnet_flux
title: ControlNet with Flux.1
- local: api/pipelines/controlnet_hunyuandit
title: ControlNet with Hunyuan-DiT
- local: api/pipelines/controlnet_sd3
title: ControlNet with Stable Diffusion 3
- local: api/pipelines/controlnet_sdxl
title: ControlNet with Stable Diffusion XL
- local: api/pipelines/controlnet_sana
title: ControlNet-Sana
- local: api/pipelines/controlnetxs
title: ControlNet-XS
- local: api/pipelines/controlnetxs_sdxl
title: ControlNet-XS with Stable Diffusion XL
- local: api/pipelines/controlnet_union
title: ControlNetUnion
- local: api/pipelines/cosmos
title: Cosmos
- local: api/pipelines/ddim
title: DDIM
- local: api/pipelines/ddpm
title: DDPM
- local: api/pipelines/deepfloyd_if
title: DeepFloyd IF
- local: api/pipelines/diffedit
title: DiffEdit
- local: api/pipelines/dit
title: DiT
- local: api/pipelines/easyanimate
title: EasyAnimate
- local: api/pipelines/flux
title: Flux
- local: api/pipelines/control_flux_inpaint
title: FluxControlInpaint
- local: api/pipelines/hidream
title: HiDream-I1
- local: api/pipelines/hunyuandit
title: Hunyuan-DiT
- local: api/pipelines/pix2pix
title: InstructPix2Pix
- local: api/pipelines/kandinsky
title: Kandinsky 2.1
- local: api/pipelines/kandinsky_v22
title: Kandinsky 2.2
- local: api/pipelines/kandinsky3
title: Kandinsky 3
- local: api/pipelines/kolors
title: Kolors
- local: api/pipelines/latent_consistency_models
title: Latent Consistency Models
- local: api/pipelines/latent_diffusion
title: Latent Diffusion
- local: api/pipelines/ledits_pp
title: LEDITS++
- local: api/pipelines/lumina2
title: Lumina 2.0
- local: api/pipelines/lumina
title: Lumina-T2X
- local: api/pipelines/marigold
title: Marigold
- local: api/pipelines/panorama
title: MultiDiffusion
- local: api/pipelines/omnigen
title: OmniGen
- local: api/pipelines/pag
title: PAG
- local: api/pipelines/paint_by_example
title: Paint by Example
- local: api/pipelines/pixart
title: PixArt-α
- local: api/pipelines/pixart_sigma
title: PixArt-Σ
- local: api/pipelines/prx
title: PRX
- local: api/pipelines/qwenimage
title: QwenImage
- local: api/pipelines/sana
title: Sana
- local: api/pipelines/sana_sprint
title: Sana Sprint
- local: api/pipelines/self_attention_guidance
title: Self-Attention Guidance
- local: api/pipelines/semantic_stable_diffusion
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/stable_cascade
title: Stable Cascade
- sections:
- local: api/pipelines/stable_diffusion/overview
title: Overview
- local: api/pipelines/stable_diffusion/depth2img
title: Depth-to-image
- local: api/pipelines/stable_diffusion/gligen
title: GLIGEN (Grounded Language-to-Image Generation)
- local: api/pipelines/stable_diffusion/image_variation
title: Image variation
- local: api/pipelines/stable_diffusion/img2img
title: Image-to-image
- local: api/pipelines/stable_diffusion/inpaint
title: Inpainting
- local: api/pipelines/stable_diffusion/k_diffusion
title: K-Diffusion
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D
Upscaler
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
title: Safe Stable Diffusion
- local: api/pipelines/stable_diffusion/sdxl_turbo
title: SDXL Turbo
- local: api/pipelines/stable_diffusion/stable_diffusion_2
title: Stable Diffusion 2
- local: api/pipelines/stable_diffusion/stable_diffusion_3
title: Stable Diffusion 3
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
title: Stable Diffusion XL
- local: api/pipelines/stable_diffusion/upscale
title: Super-resolution
- local: api/pipelines/stable_diffusion/adapter
title: T2I-Adapter
- local: api/pipelines/stable_diffusion/text2img
title: Text-to-image
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
- local: api/pipelines/unclip
title: unCLIP
- local: api/pipelines/unidiffuser
title: UniDiffuser
- local: api/pipelines/value_guided_sampling
title: Value-guided sampling
- local: api/pipelines/visualcloze
title: VisualCloze
- local: api/pipelines/wuerstchen
title: Wuerstchen
title: Image
- sections:
- local: api/pipelines/allegro
title: Allegro
- local: api/pipelines/cogvideox
title: CogVideoX
- local: api/pipelines/consisid
title: ConsisID
- local: api/pipelines/framepack
title: Framepack
- local: api/pipelines/hunyuan_video
title: HunyuanVideo
- local: api/pipelines/i2vgenxl
title: I2VGen-XL
- local: api/pipelines/latte
title: Latte
- local: api/pipelines/ltx_video
title: LTXVideo
- local: api/pipelines/mochi
title: Mochi
- local: api/pipelines/pia
title: Personalized Image Animator (PIA)
- local: api/pipelines/skyreels_v2
title: SkyReels-V2
- local: api/pipelines/stable_diffusion/svd
title: Image-to-video
- local: api/pipelines/stable_diffusion/inpaint
title: Inpainting
- local: api/pipelines/stable_diffusion/k_diffusion
title: K-Diffusion
- local: api/pipelines/stable_diffusion/latent_upscale
title: Latent upscaler
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
title: Safe Stable Diffusion
- local: api/pipelines/stable_diffusion/sdxl_turbo
title: SDXL Turbo
- local: api/pipelines/stable_diffusion/stable_diffusion_2
title: Stable Diffusion 2
- local: api/pipelines/stable_diffusion/stable_diffusion_3
title: Stable Diffusion 3
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
title: Stable Diffusion XL
- local: api/pipelines/stable_diffusion/upscale
title: Super-resolution
- local: api/pipelines/stable_diffusion/adapter
title: T2I-Adapter
- local: api/pipelines/stable_diffusion/text2img
title: Text-to-image
- local: api/pipelines/stable_unclip
title: Stable unCLIP
- local: api/pipelines/text_to_video
title: Text-to-video
- local: api/pipelines/text_to_video_zero
title: Text2Video-Zero
- local: api/pipelines/unclip
title: unCLIP
- local: api/pipelines/unidiffuser
title: UniDiffuser
- local: api/pipelines/value_guided_sampling
title: Value-guided sampling
- local: api/pipelines/visualcloze
title: VisualCloze
- local: api/pipelines/wan
title: Wan
- local: api/pipelines/wuerstchen
title: Wuerstchen
- title: Schedulers
sections:
title: Stable Video Diffusion
- local: api/pipelines/text_to_video
title: Text-to-video
- local: api/pipelines/text_to_video_zero
title: Text2Video-Zero
- local: api/pipelines/wan
title: Wan
title: Video
title: Pipelines
- sections:
- local: api/schedulers/overview
title: Overview
- local: api/schedulers/cm_stochastic_iterative
@@ -718,8 +713,8 @@
title: UniPCMultistepScheduler
- local: api/schedulers/vq_diffusion
title: VQDiffusionScheduler
- title: Internal classes
sections:
title: Schedulers
- sections:
- local: api/internal_classes_overview
title: Overview
- local: api/attnprocessor
@@ -736,3 +731,5 @@
title: VAE Image Processor
- local: api/video_processor
title: Video Processor
title: Internal classes
title: API

View File

@@ -0,0 +1,131 @@
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License. -->
# PRX
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
## Available models
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
## Loading the pipeline
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
```py
from diffusers.pipelines.prx import PRXPipeline
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("prx_output.png")
```
### Manual Component Loading
Load components individually to customize the pipeline for instance to use quantized models.
```py
import torch
from diffusers.pipelines.prx import PRXPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer
transformer = PRXTransformer2DModel.from_pretrained(
"checkpoints/prx-512-t2i-sft",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
)
# Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
)
# Load T5Gemma text encoder
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
# Load VAE - choose either Flux VAE or DC-AE
# Flux VAE
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
subfolder="vae",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
pipe = PRXPipeline(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae
)
pipe.to("cuda")
```
## Memory Optimization
For memory-constrained environments:
```py
import torch
from diffusers.pipelines.prx import PRXPipeline
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
# Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload()
```
## PRXPipeline
[[autodoc]] PRXPipeline
- all
- __call__
## PRXPipelineOutput
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput

View File

@@ -0,0 +1,345 @@
#!/usr/bin/env python3
"""
Script to convert PRX checkpoint from original codebase to diffusers format.
"""
import argparse
import json
import os
import sys
from dataclasses import asdict, dataclass
from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx import PRXPipeline
DEFAULT_RESOLUTION = 512
@dataclass(frozen=True)
class PRXBase:
context_in_dim: int = 2304
hidden_size: int = 1792
mlp_ratio: float = 3.5
num_heads: int = 28
depth: int = 16
axes_dim: Tuple[int, int] = (32, 32)
theta: int = 10_000
time_factor: float = 1000.0
time_max_period: int = 10_000
@dataclass(frozen=True)
class PRXFlux(PRXBase):
in_channels: int = 16
patch_size: int = 2
@dataclass(frozen=True)
class PRXDCAE(PRXBase):
in_channels: int = 32
patch_size: int = 1
def build_config(vae_type: str) -> Tuple[dict, int]:
if vae_type == "flux":
cfg = PRXFlux()
elif vae_type == "dc-ae":
cfg = PRXDCAE()
else:
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
config_dict = asdict(cfg)
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
return config_dict
def create_parameter_mapping(depth: int) -> dict:
"""Create mapping from old parameter names to new diffusers names."""
# Key mappings for structural changes
mapping = {}
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
for i in range(depth):
# QKV projections moved to attention module
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
# K norm for text tokens moved to attention module
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
# Attention output projection
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
return mapping
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
"""Convert old checkpoint parameters to new diffusers format."""
print("Converting checkpoint parameters...")
mapping = create_parameter_mapping(depth)
converted_state_dict = {}
for key, value in old_state_dict.items():
new_key = key
# Apply specific mappings if needed
if key in mapping:
new_key = mapping[key]
print(f" Mapped: {key} -> {new_key}")
converted_state_dict[new_key] = value
print(f"✓ Converted {len(converted_state_dict)} parameters")
return converted_state_dict
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
"""Create and load PRXTransformer2DModel from old checkpoint."""
print(f"Loading checkpoint from: {checkpoint_path}")
# Load old checkpoint
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
# Handle different checkpoint formats
if isinstance(old_checkpoint, dict):
if "model" in old_checkpoint:
state_dict = old_checkpoint["model"]
elif "state_dict" in old_checkpoint:
state_dict = old_checkpoint["state_dict"]
else:
state_dict = old_checkpoint
else:
state_dict = old_checkpoint
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
# Convert parameter names if needed
model_depth = int(config.get("depth", 16))
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
# Create transformer with config
print("Creating PRXTransformer2DModel...")
transformer = PRXTransformer2DModel(**config)
# Load state dict
print("Loading converted parameters...")
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
if missing_keys:
print(f"⚠ Missing keys: {missing_keys}")
if unexpected_keys:
print(f"⚠ Unexpected keys: {unexpected_keys}")
if not missing_keys and not unexpected_keys:
print("✓ All parameters loaded successfully!")
return transformer
def create_scheduler_config(output_path: str, shift: float):
"""Create FlowMatchEulerDiscreteScheduler config."""
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
scheduler_path = os.path.join(output_path, "scheduler")
os.makedirs(scheduler_path, exist_ok=True)
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
json.dump(scheduler_config, f, indent=2)
print("✓ Created scheduler config")
def download_and_save_vae(vae_type: str, output_path: str):
"""Download and save VAE to local directory."""
from diffusers import AutoencoderDC, AutoencoderKL
vae_path = os.path.join(output_path, "vae")
os.makedirs(vae_path, exist_ok=True)
if vae_type == "flux":
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
else: # dc-ae
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
vae.save_pretrained(vae_path)
print(f"✓ Saved VAE to {vae_path}")
def download_and_save_text_encoder(output_path: str):
"""Download and save T5Gemma text encoder and tokenizer."""
from transformers import GemmaTokenizerFast
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
text_encoder_path = os.path.join(output_path, "text_encoder")
tokenizer_path = os.path.join(output_path, "tokenizer")
os.makedirs(text_encoder_path, exist_ok=True)
os.makedirs(tokenizer_path, exist_ok=True)
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
# Extract and save only the encoder
t5gemma_encoder = t5gemma_model.encoder
t5gemma_encoder.save_pretrained(text_encoder_path)
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
tokenizer.model_max_length = 256
tokenizer.save_pretrained(tokenizer_path)
print(f"✓ Saved tokenizer to {tokenizer_path}")
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
"""Create model_index.json for the pipeline."""
if vae_type == "flux":
vae_class = "AutoencoderKL"
else: # dc-ae
vae_class = "AutoencoderDC"
model_index = {
"_class_name": "PRXPipeline",
"_diffusers_version": "0.31.0.dev0",
"_name_or_path": os.path.basename(output_path),
"default_sample_size": default_image_size,
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_encoder": ["prx", "T5GemmaEncoder"],
"tokenizer": ["transformers", "GemmaTokenizerFast"],
"transformer": ["diffusers", "PRXTransformer2DModel"],
"vae": ["diffusers", vae_class],
}
model_index_path = os.path.join(output_path, "model_index.json")
with open(model_index_path, "w") as f:
json.dump(model_index, f, indent=2)
def main(args):
# Validate inputs
if not os.path.exists(args.checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
config = build_config(args.vae_type)
# Create output directory
os.makedirs(args.output_path, exist_ok=True)
print(f"✓ Output directory: {args.output_path}")
# Create transformer from checkpoint
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
# Save transformer
transformer_path = os.path.join(args.output_path, "transformer")
os.makedirs(transformer_path, exist_ok=True)
# Save config
with open(os.path.join(transformer_path, "config.json"), "w") as f:
json.dump(config, f, indent=2)
# Save model weights as safetensors
state_dict = transformer.state_dict()
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
print(f"✓ Saved transformer to {transformer_path}")
# Create scheduler config
create_scheduler_config(args.output_path, args.shift)
download_and_save_vae(args.vae_type, args.output_path)
download_and_save_text_encoder(args.output_path)
# Create model_index.json
create_model_index(args.vae_type, args.resolution, args.output_path)
# Verify the pipeline can be loaded
try:
pipeline = PRXPipeline.from_pretrained(args.output_path)
print("Pipeline loaded successfully!")
print(f"Transformer: {type(pipeline.transformer).__name__}")
print(f"VAE: {type(pipeline.vae).__name__}")
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
# Display model info
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
print(f"✓ Transformer parameters: {num_params:,}")
except Exception as e:
print(f"Pipeline verification failed: {e}")
return False
print("Conversion completed successfully!")
print(f"Converted pipeline saved to: {args.output_path}")
print(f"VAE type: {args.vae_type}")
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
)
parser.add_argument(
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
)
parser.add_argument(
"--vae_type",
type=str,
choices=["flux", "dc-ae"],
required=True,
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
)
parser.add_argument(
"--resolution",
type=int,
choices=[256, 512, 1024],
default=DEFAULT_RESOLUTION,
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
)
parser.add_argument(
"--shift",
type=float,
default=3.0,
help="Shift for the scheduler",
)
args = parser.parse_args()
try:
success = main(args)
if not success:
sys.exit(1)
except Exception as e:
print(f"Conversion failed: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -234,6 +234,7 @@ else:
"ParallelConfig",
"PixArtTransformer2DModel",
"PriorTransformer",
"PRXTransformer2DModel",
"QwenImageControlNetModel",
"QwenImageMultiControlNetModel",
"QwenImageTransformer2DModel",
@@ -519,6 +520,7 @@ else:
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
"PixArtSigmaPipeline",
"PRXPipeline",
"QwenImageControlNetInpaintPipeline",
"QwenImageControlNetPipeline",
"QwenImageEditInpaintPipeline",
@@ -928,6 +930,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ParallelConfig,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageControlNetModel,
QwenImageMultiControlNetModel,
QwenImageTransformer2DModel,
@@ -1183,6 +1186,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,
PixArtSigmaPipeline,
PRXPipeline,
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,
QwenImageEditInpaintPipeline,

View File

@@ -293,7 +293,7 @@ class PeftAdapterMixin:
# For hotswapping, we need the adapter name to be present in the state dict keys
new_sd = {}
for k, v in sd.items():
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
if k.endswith("lora_A.weight") or k.endswith("lora_B.weight"):
k = k[: -len(".weight")] + f".{adapter_name}.weight"
elif k.endswith("lora_B.bias"): # lora_bias=True option
k = k[: -len(".bias")] + f".{adapter_name}.bias"

View File

@@ -96,6 +96,7 @@ if is_torch_available():
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
@@ -192,6 +193,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
OmniGenTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
PRXTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
SD3Transformer2DModel,

View File

@@ -32,6 +32,7 @@ if is_torch_available():
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_prx import PRXTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel

View File

@@ -0,0 +1,770 @@
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from torch.nn.functional import fold, unfold
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ..attention import AttentionMixin, AttentionModuleMixin
from ..attention_dispatch import dispatch_attention_fn
from ..embeddings import get_timestep_embedding
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import RMSNorm
logger = logging.get_logger(__name__)
def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
r"""
Generates 2D patch coordinate indices for a batch of images.
Args:
batch_size (`int`):
Number of images in the batch.
height (`int`):
Height of the input images (in pixels).
width (`int`):
Width of the input images (in pixels).
patch_size (`int`):
Size of the square patches that the image is divided into.
device (`torch.device`):
The device on which to create the tensor.
Returns:
`torch.Tensor`:
Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the
image grid.
"""
img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device)
img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None]
img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :]
return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
r"""
Applies rotary positional embeddings (RoPE) to a query tensor.
Args:
xq (`torch.Tensor`):
Input tensor of shape `(..., dim)` representing the queries.
freqs_cis (`torch.Tensor`):
Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs.
Returns:
`torch.Tensor`:
Tensor of the same shape as `xq` with rotary embeddings applied.
"""
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
# Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading
freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq)
class PRXAttnProcessor2_0:
r"""
Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
"""
_attention_backend = None
_parallel_config = None
def __init__(self):
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn: "PRXAttention",
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Apply PRX attention using PRXAttention module.
Args:
attn: PRXAttention module containing projection layers
hidden_states: Image tokens [B, L_img, D]
encoder_hidden_states: Text tokens [B, L_txt, D]
attention_mask: Boolean mask for text tokens [B, L_txt]
image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
"""
if encoder_hidden_states is None:
raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
# Project image tokens to Q, K, V
img_qkv = attn.img_qkv_proj(hidden_states)
B, L_img, _ = img_qkv.shape
img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim)
img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D]
img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
# Apply QK normalization to image tokens
img_q = attn.norm_q(img_q)
img_k = attn.norm_k(img_k)
# Project text tokens to K, V
txt_kv = attn.txt_kv_proj(encoder_hidden_states)
B, L_txt, _ = txt_kv.shape
txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim)
txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D]
txt_k, txt_v = txt_kv[0], txt_kv[1]
# Apply K normalization to text tokens
txt_k = attn.norm_added_k(txt_k)
# Apply RoPE to image queries and keys
if image_rotary_emb is not None:
img_q = apply_rope(img_q, image_rotary_emb)
img_k = apply_rope(img_k, image_rotary_emb)
# Concatenate text and image keys/values
k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D]
v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D]
# Build attention mask if provided
attn_mask_tensor = None
if attention_mask is not None:
bs, _, l_img, _ = img_q.shape
l_txt = txt_k.shape[2]
if attention_mask.dim() != 2:
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
if attention_mask.shape[-1] != l_txt:
raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}")
device = img_q.device
ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
attention_mask = attention_mask.to(device=device, dtype=torch.bool)
joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)
# Apply attention using dispatch_attention_fn for backend support
# Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
query = img_q.transpose(1, 2) # [B, L_img, H, D]
key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]
attn_output = dispatch_attention_fn(
query,
key,
value,
attn_mask=attn_mask_tensor,
backend=self._attention_backend,
parallel_config=self._parallel_config,
)
# Reshape from [B, L_img, H, D] to [B, L_img, H*D]
batch_size, seq_len, num_heads, head_dim = attn_output.shape
attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)
# Apply output projection
attn_output = attn.to_out[0](attn_output)
if len(attn.to_out) > 1:
attn_output = attn.to_out[1](attn_output) # dropout if present
return attn_output
class PRXAttention(nn.Module, AttentionModuleMixin):
r"""
PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
PRX's architecture.
"""
_default_processor_cls = PRXAttnProcessor2_0
_available_processors = [PRXAttnProcessor2_0]
def __init__(
self,
query_dim: int,
heads: int = 8,
dim_head: int = 64,
bias: bool = False,
out_bias: bool = False,
eps: float = 1e-6,
processor=None,
):
super().__init__()
self.heads = heads
self.head_dim = dim_head
self.inner_dim = dim_head * heads
self.query_dim = query_dim
self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias)
self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias)
self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias))
self.to_out.append(nn.Dropout(0.0))
if processor is None:
processor = self._default_processor_cls()
self.set_processor(processor)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
return self.processor(
self,
hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
**kwargs,
)
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
class PRXEmbedND(nn.Module):
r"""
N-dimensional rotary positional embedding.
This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
dimension. The embeddings are combined and returned as a single tensor
Args:
dim (int):
Base embedding dimension (must be even).
theta (int):
Scaling factor that controls the frequency spectrum of the rotary embeddings.
axes_dim (list[int]):
List of embedding dimensions for each axis (each must be even).
"""
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
assert dim % 2 == 0
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
omega = 1.0 / (theta**scale)
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
# Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
# out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
out = out.reshape(*out.shape[:-1], 2, 2)
return out.float()
def forward(self, ids: torch.Tensor) -> torch.Tensor:
n_axes = ids.shape[-1]
emb = torch.cat(
[self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
dim=-3,
)
return emb.unsqueeze(1)
class MLPEmbedder(nn.Module):
r"""
A simple 2-layer MLP used for embedding inputs.
Args:
in_dim (`int`):
Dimensionality of the input features.
hidden_dim (`int`):
Dimensionality of the hidden and output embedding space.
Returns:
`torch.Tensor`:
Tensor of shape `(..., hidden_dim)` containing the embedded representations.
"""
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.silu = nn.SiLU()
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.out_layer(self.silu(self.in_layer(x)))
class Modulation(nn.Module):
r"""
Modulation network that generates scale, shift, and gating parameters.
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
two tuples `(shift, scale, gate)`.
Args:
dim (`int`):
Dimensionality of the input vector. The output will have `6 * dim` features internally.
Returns:
((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
Two tuples `(shift, scale, gate)`.
"""
def __init__(self, dim: int):
super().__init__()
self.lin = nn.Linear(dim, 6 * dim, bias=True)
nn.init.constant_(self.lin.weight, 0)
nn.init.constant_(self.lin.bias, 0)
def forward(
self, vec: torch.Tensor
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
return tuple(out[:3]), tuple(out[3:])
class PRXBlock(nn.Module):
r"""
Multimodal transformer block with textimage cross-attention, modulation, and MLP.
Args:
hidden_size (`int`):
Dimension of the hidden representations.
num_heads (`int`):
Number of attention heads.
mlp_ratio (`float`, *optional*, defaults to 4.0):
Expansion ratio for the hidden dimension inside the MLP.
qk_scale (`float`, *optional*):
Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``.
Attributes:
img_pre_norm (`nn.LayerNorm`):
Pre-normalization applied to image tokens before attention.
attention (`PRXAttention`):
Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
image and text tokens.
post_attention_layernorm (`nn.LayerNorm`):
Normalization applied after attention.
gate_proj / up_proj / down_proj (`nn.Linear`):
Feedforward layers forming the gated MLP.
mlp_act (`nn.GELU`):
Nonlinear activation used in the MLP.
modulation (`Modulation`):
Produces scale/shift/gating parameters for modulated layers.
Methods:
The forward method performs cross-attention and the MLP with modulation.
"""
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
self.head_dim = hidden_size // num_heads
self.scale = qk_scale or self.head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.hidden_size = hidden_size
# Pre-attention normalization for image tokens
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
# PRXAttention module with built-in projections and norms
self.attention = PRXAttention(
query_dim=hidden_size,
heads=num_heads,
dim_head=self.head_dim,
bias=False,
out_bias=False,
eps=1e-6,
processor=PRXAttnProcessor2_0(),
)
# mlp
self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False)
self.mlp_act = nn.GELU(approximate="tanh")
self.modulation = Modulation(hidden_size)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
temb: torch.Tensor,
image_rotary_emb: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
**kwargs: Dict[str, Any],
) -> torch.Tensor:
r"""
Runs modulation-gated cross-attention and MLP, with residual connections.
Args:
hidden_states (`torch.Tensor`):
Image tokens of shape `(B, L_img, hidden_size)`.
encoder_hidden_states (`torch.Tensor`):
Text tokens of shape `(B, L_txt, hidden_size)`.
temb (`torch.Tensor`):
Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
broadcastable).
image_rotary_emb (`torch.Tensor`):
Rotary positional embeddings applied inside attention.
attention_mask (`torch.Tensor`, *optional*):
Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
**kwargs:
Additional keyword arguments for API compatibility.
Returns:
`torch.Tensor`:
Updated image tokens of shape `(B, L_img, hidden_size)`.
"""
mod_attn, mod_mlp = self.modulation(temb)
attn_shift, attn_scale, attn_gate = mod_attn
mlp_shift, mlp_scale, mlp_gate = mod_mlp
hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift
attn_out = self.attention(
hidden_states=hidden_states_mod,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
hidden_states = hidden_states + attn_gate * attn_out
x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
return hidden_states
class FinalLayer(nn.Module):
r"""
Final projection layer with adaptive LayerNorm modulation.
This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
outputs.
Args:
hidden_size (`int`):
Dimensionality of the input tokens.
patch_size (`int`):
Size of the square image patches.
out_channels (`int`):
Number of output channels per pixel (e.g. RGB = 3).
Forward Inputs:
x (`torch.Tensor`):
Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
vec (`torch.Tensor`):
Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive
LayerNorm.
Returns:
`torch.Tensor`:
Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
"""
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
r"""
Flattens an image tensor into a sequence of non-overlapping patches.
Args:
img (`torch.Tensor`):
Input image tensor of shape `(B, C, H, W)`.
patch_size (`int`):
Size of each square patch. Must evenly divide both `H` and `W`.
Returns:
`torch.Tensor`:
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
// patch_size)` is the number of patches.
"""
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
r"""
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
Args:
seq (`torch.Tensor`):
Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
patch_size)`.
patch_size (`int`):
Size of each square patch.
shape (`tuple` or `torch.Tensor`):
The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as
height and width.
Returns:
`torch.Tensor`:
Reconstructed image tensor of shape `(B, C, H, W)`.
"""
if isinstance(shape, tuple):
shape = shape[-2:]
elif isinstance(shape, torch.Tensor):
shape = (int(shape[0]), int(shape[1]))
else:
raise NotImplementedError(f"shape type {type(shape)} not supported")
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
r"""
Transformer-based 2D model for text to image generation.
Args:
in_channels (`int`, *optional*, defaults to 16):
Number of input channels in the latent image.
patch_size (`int`, *optional*, defaults to 2):
Size of the square patches used to flatten the input image.
context_in_dim (`int`, *optional*, defaults to 2304):
Dimensionality of the text conditioning input.
hidden_size (`int`, *optional*, defaults to 1792):
Dimension of the hidden representation.
mlp_ratio (`float`, *optional*, defaults to 3.5):
Expansion ratio for the hidden dimension inside MLP blocks.
num_heads (`int`, *optional*, defaults to 28):
Number of attention heads.
depth (`int`, *optional*, defaults to 16):
Number of transformer blocks.
axes_dim (`list[int]`, *optional*):
List of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
theta (`int`, *optional*, defaults to 10000):
Frequency scaling factor for rotary embeddings.
time_factor (`float`, *optional*, defaults to 1000.0):
Scaling factor applied in timestep embeddings.
time_max_period (`int`, *optional*, defaults to 10000):
Maximum frequency period for timestep embeddings.
Attributes:
pe_embedder (`EmbedND`):
Multi-axis rotary embedding generator for positional encodings.
img_in (`nn.Linear`):
Projection layer for image patch tokens.
time_in (`MLPEmbedder`):
Embedding layer for timestep embeddings.
txt_in (`nn.Linear`):
Projection layer for text conditioning.
blocks (`nn.ModuleList`):
Stack of transformer blocks (`PRXBlock`).
final_layer (`LastLayer`):
Projection layer mapping hidden tokens back to patch outputs.
Methods:
attn_processors:
Returns a dictionary of all attention processors in the model.
set_attn_processor(processor):
Replaces attention processors across all attention layers.
process_inputs(image_latent, txt):
Converts inputs into patch tokens, encodes text, and produces positional encodings.
compute_timestep_embedding(timestep, dtype):
Creates a timestep embedding of dimension 256, scaled and projected.
forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask,
**block_kwargs):
Runs the sequence of transformer blocks over image and text tokens.
forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None,
attention_kwargs=None, return_dict=True):
Full forward pass from latent input to reconstructed output image.
Returns:
`Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
- `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
"""
config_name = "config.json"
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 16,
patch_size: int = 2,
context_in_dim: int = 2304,
hidden_size: int = 1792,
mlp_ratio: float = 3.5,
num_heads: int = 28,
depth: int = 16,
axes_dim: list = None,
theta: int = 10000,
time_factor: float = 1000.0,
time_max_period: int = 10000,
):
super().__init__()
if axes_dim is None:
axes_dim = [32, 32]
# Store parameters directly
self.in_channels = in_channels
self.patch_size = patch_size
self.out_channels = self.in_channels * self.patch_size**2
self.time_factor = time_factor
self.time_max_period = time_max_period
if hidden_size % num_heads != 0:
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
pe_dim = hidden_size // num_heads
if sum(axes_dim) != pe_dim:
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
self.hidden_size = hidden_size
self.num_heads = num_heads
self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
self.blocks = nn.ModuleList(
[
PRXBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=mlp_ratio,
)
for i in range(depth)
]
)
self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels)
self.gradient_checkpointing = False
def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
return self.time_in(
get_timestep_embedding(
timesteps=timestep,
embedding_dim=256,
max_period=self.time_max_period,
scale=self.time_factor,
flip_sin_to_cos=True, # Match original cos, sin order
).to(dtype)
)
def forward(
self,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
r"""
Forward pass of the PRXTransformer2DModel.
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
Args:
hidden_states (`torch.Tensor`):
Input latent image tensor of shape `(B, C, H, W)`.
timestep (`torch.Tensor`):
Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
encoder_hidden_states (`torch.Tensor`):
Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
attention_mask (`torch.Tensor`, *optional*):
Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
attention_kwargs (`dict`, *optional*):
Additional arguments passed to attention layers.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a `Transformer2DModelOutput` or a tuple.
Returns:
`Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:
- `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
"""
# Process text conditioning
txt = self.txt_in(encoder_hidden_states)
# Convert image to sequence and embed
img = img2seq(hidden_states, self.patch_size)
img = self.img_in(img)
# Generate positional embeddings
bs, _, h, w = hidden_states.shape
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
pe = self.pe_embedder(img_ids)
# Compute time embedding
vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
# Apply transformer blocks
for block in self.blocks:
if torch.is_grad_enabled() and self.gradient_checkpointing:
img = self._gradient_checkpointing_func(
block.__call__,
img,
txt,
vec,
pe,
attention_mask,
)
else:
img = block(
hidden_states=img,
encoder_hidden_states=txt,
temb=vec,
image_rotary_emb=pe,
attention_mask=attention_mask,
)
# Final layer and convert back to image
img = self.final_layer(img, vec)
output = seq2img(img, self.patch_size, hidden_states.shape)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -144,6 +144,7 @@ else:
"FluxKontextPipeline",
"FluxKontextInpaintPipeline",
]
_import_structure["prx"] = ["PRXPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -719,6 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .paint_by_example import PaintByExamplePipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .prx import PRXPipeline
from .qwenimage import (
QwenImageControlNetInpaintPipeline,
QwenImageControlNetPipeline,

View File

@@ -0,0 +1,63 @@
from typing import TYPE_CHECKING
from ...utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
_LazyModule,
get_objects_from_module,
is_torch_available,
is_transformers_available,
)
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils import dummy_torch_and_transformers_objects # noqa F403
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_prx"] = ["PRXPipeline"]
# Import T5GemmaEncoder for pipeline loading compatibility
try:
if is_transformers_available():
import transformers
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
_additional_imports["T5GemmaEncoder"] = T5GemmaEncoder
# Patch transformers module directly for serialization
if not hasattr(transformers, "T5GemmaEncoder"):
transformers.T5GemmaEncoder = T5GemmaEncoder
except ImportError:
pass
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
if not (is_transformers_available() and is_torch_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_output import PRXPipelineOutput
from .pipeline_prx import PRXPipeline
else:
import sys
sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
module_spec=__spec__,
)
for name, value in _dummy_objects.items():
setattr(sys.modules[__name__], name, value)
for name, value in _additional_imports.items():
setattr(sys.modules[__name__], name, value)

View File

@@ -0,0 +1,35 @@
# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import List, Union
import numpy as np
import PIL.Image
from ...utils import BaseOutput
@dataclass
class PRXPipelineOutput(BaseOutput):
"""
Output class for PRX pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
"""
images: Union[List[PIL.Image.Image], np.ndarray]

View File

@@ -0,0 +1,767 @@
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import html
import inspect
import re
import urllib.parse as ul
from typing import Callable, Dict, List, Optional, Union
import ftfy
import torch
from transformers import (
AutoTokenizer,
GemmaTokenizerFast,
T5TokenizerFast,
)
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from diffusers.image_processor import PixArtImageProcessor
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import (
logging,
replace_example_docstring,
)
from diffusers.utils.torch_utils import randn_tensor
DEFAULT_RESOLUTION = 512
ASPECT_RATIO_256_BIN = {
"0.46": [160, 352],
"0.6": [192, 320],
"0.78": [224, 288],
"1.0": [256, 256],
"1.29": [288, 224],
"1.67": [320, 192],
"2.2": [352, 160],
}
ASPECT_RATIO_512_BIN = {
"0.5": [352, 704],
"0.57": [384, 672],
"0.6": [384, 640],
"0.68": [416, 608],
"0.78": [448, 576],
"0.88": [480, 544],
"1.0": [512, 512],
"1.13": [544, 480],
"1.29": [576, 448],
"1.46": [608, 416],
"1.67": [640, 384],
"1.75": [672, 384],
"2.0": [704, 352],
}
logger = logging.get_logger(__name__)
class TextPreprocessor:
"""Text preprocessing utility for PRXPipeline."""
def __init__(self):
"""Initialize text preprocessor."""
self.bad_punct_regex = re.compile(
r"["
+ "#®•©™&@·º½¾¿¡§~"
+ r"\)"
+ r"\("
+ r"\]"
+ r"\["
+ r"\}"
+ r"\{"
+ r"\|"
+ r"\\"
+ r"\/"
+ r"\*"
+ r"]{1,}"
)
def clean_text(self, text: str) -> str:
"""Clean text using comprehensive text processing logic."""
# See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py
text = str(text)
text = ul.unquote_plus(text)
text = text.strip().lower()
text = re.sub("<person>", "person", text)
# Remove all urls:
text = re.sub(
r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))",
"",
text,
) # regex for urls
# @<nickname>
text = re.sub(r"@[\w\d]+\b", "", text)
# 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs
text = re.sub(r"[\u31c0-\u31ef]+", "", text)
text = re.sub(r"[\u31f0-\u31ff]+", "", text)
text = re.sub(r"[\u3200-\u32ff]+", "", text)
text = re.sub(r"[\u3300-\u33ff]+", "", text)
text = re.sub(r"[\u3400-\u4dbf]+", "", text)
text = re.sub(r"[\u4dc0-\u4dff]+", "", text)
text = re.sub(r"[\u4e00-\u9fff]+", "", text)
# все виды тире / all types of dash --> "-"
text = re.sub(
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
"-",
text,
)
# кавычки к одному стандарту
text = re.sub(r"[`´«»" "¨]", '"', text)
text = re.sub(r"['']", "'", text)
# &quot; and &amp
text = re.sub(r"&quot;?", "", text)
text = re.sub(r"&amp", "", text)
# ip addresses:
text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text)
# article ids:
text = re.sub(r"\d:\d\d\s+$", "", text)
# \n
text = re.sub(r"\\n", " ", text)
# "#123", "#12345..", "123456.."
text = re.sub(r"#\d{1,3}\b", "", text)
text = re.sub(r"#\d{5,}\b", "", text)
text = re.sub(r"\b\d{6,}\b", "", text)
# filenames:
text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text)
# Clean punctuation
text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT"""
text = re.sub(r"[\.]{2,}", r" ", text)
text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT
text = re.sub(r"\s+\.\s+", r" ", text) # " . "
# this-is-my-cute-cat / this_is_my_cute_cat
regex2 = re.compile(r"(?:\-|\_)")
if len(re.findall(regex2, text)) > 3:
text = re.sub(regex2, " ", text)
# Basic cleaning
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
text = text.strip()
# Clean alphanumeric patterns
text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640
text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc
text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231
# Common spam patterns
text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text)
text = re.sub(r"(free\s)?download(\sfree)?", "", text)
text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text)
text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text)
text = re.sub(r"\bpage\s+\d+\b", "", text)
text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a...
text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text)
# Final cleanup
text = re.sub(r"\b\s+\:\s+", r": ", text)
text = re.sub(r"(\D[,\./])\b", r"\1 ", text)
text = re.sub(r"\s+", " ", text)
text.strip()
text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text)
text = re.sub(r"^[\'\_,\-\:;]", r"", text)
text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text)
text = re.sub(r"^\.\S+$", "", text)
return text.strip()
EXAMPLE_DOC_STRING = """
Examples:
```py
>>> import torch
>>> from diffusers import PRXPipeline
>>> # Load pipeline with from_pretrained
>>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
>>> pipe.to("cuda")
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
>>> image.save("prx_output.png")
```
"""
class PRXPipeline(
DiffusionPipeline,
LoraLoaderMixin,
FromSingleFileMixin,
TextualInversionLoaderMixin,
):
r"""
Pipeline for text-to-image generation using PRX Transformer.
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
transformer ([`PRXTransformer2DModel`]):
The PRX transformer model to denoise the encoded image latents.
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
text_encoder ([`T5GemmaEncoder`]):
Text encoder model for encoding prompts.
tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]):
Tokenizer for the text encoder.
vae ([`AutoencoderKL`] or [`AutoencoderDC`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression).
"""
model_cpu_offload_seq = "text_encoder->transformer->vae"
_callback_tensor_inputs = ["latents", "prompt_embeds"]
_optional_components = ["vae"]
def __init__(
self,
transformer: PRXTransformer2DModel,
scheduler: FlowMatchEulerDiscreteScheduler,
text_encoder: T5GemmaEncoder,
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None,
default_sample_size: Optional[int] = DEFAULT_RESOLUTION,
):
super().__init__()
if PRXTransformer2DModel is None:
raise ImportError(
"PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed."
)
self.text_preprocessor = TextPreprocessor()
self.default_sample_size = default_sample_size
self._guidance_scale = 1.0
self.register_modules(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
)
self.register_to_config(default_sample_size=self.default_sample_size)
if vae is not None:
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
else:
self.image_processor = None
@property
def vae_scale_factor(self):
if self.vae is None:
return 8
if hasattr(self.vae, "spatial_compression_ratio"):
return self.vae.spatial_compression_ratio
else: # Flux VAE
return 2 ** (len(self.vae.config.block_out_channels) - 1)
@property
def do_classifier_free_guidance(self):
"""Check if classifier-free guidance is enabled based on guidance scale."""
return self._guidance_scale > 1.0
@property
def guidance_scale(self):
return self._guidance_scale
def get_default_resolution(self):
"""Determine the default resolution based on the loaded VAE and config.
Returns:
int: The default sample size (height/width) to use for generation.
"""
default_from_config = getattr(self.config, "default_sample_size", None)
if default_from_config is not None:
return default_from_config
return DEFAULT_RESOLUTION
def prepare_latents(
self,
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.Tensor] = None,
):
"""Prepare initial latents for the diffusion process."""
if latents is None:
spatial_compression = self.vae_scale_factor
latent_height, latent_width = (
height // spatial_compression,
width // spatial_compression,
)
shape = (batch_size, num_channels_latents, latent_height, latent_width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
latents = latents.to(device)
return latents
def encode_prompt(
self,
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
num_images_per_prompt: int = 1,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.BoolTensor] = None,
negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
):
"""Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings."""
if device is None:
device = self._execution_device
if prompt_embeds is None:
if isinstance(prompt, str):
prompt = [prompt]
# Encode the prompts
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
)
# Duplicate embeddings for each generation per prompt
if num_images_per_prompt > 1:
# Repeat prompt embeddings
bs_embed, seq_len, _ = prompt_embeds.shape
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if prompt_attention_mask is not None:
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
# Repeat negative embeddings if using CFG
if do_classifier_free_guidance and negative_prompt_embeds is not None:
bs_embed, seq_len, _ = negative_prompt_embeds.shape
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
if negative_prompt_attention_mask is not None:
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
return (
prompt_embeds,
prompt_attention_mask,
negative_prompt_embeds if do_classifier_free_guidance else None,
negative_prompt_attention_mask if do_classifier_free_guidance else None,
)
def _tokenize_prompts(self, prompts: List[str], device: torch.device):
"""Tokenize and clean prompts."""
cleaned = [self.text_preprocessor.clean_text(text) for text in prompts]
tokens = self.tokenizer(
cleaned,
padding="max_length",
max_length=self.tokenizer.model_max_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device)
def _encode_prompt_standard(
self,
prompt: List[str],
device: torch.device,
do_classifier_free_guidance: bool = True,
negative_prompt: str = "",
):
"""Encode prompt using standard text encoder and tokenizer with batch processing."""
batch_size = len(prompt)
if do_classifier_free_guidance:
if isinstance(negative_prompt, str):
negative_prompt = [negative_prompt] * batch_size
prompts_to_encode = negative_prompt + prompt
else:
prompts_to_encode = prompt
input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device)
with torch.no_grad():
embeddings = self.text_encoder(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
)["last_hidden_state"]
if do_classifier_free_guidance:
uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0)
uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0)
else:
text_embeddings = embeddings
cross_attn_mask = attention_mask
uncond_text_embeddings = None
uncond_cross_attn_mask = None
return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask
def check_inputs(
self,
prompt: Union[str, List[str]],
height: int,
width: int,
guidance_scale: float,
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
):
"""Check that all inputs are in correct format."""
if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
" only forward one of the two."
)
if prompt is None and prompt_embeds is None:
raise ValueError(
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
)
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None:
raise ValueError(
"When `prompt_embeds` is provided and `guidance_scale > 1.0`, "
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
)
spatial_compression = self.vae_scale_factor
if height % spatial_compression != 0 or width % spatial_compression != 0:
raise ValueError(
f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}."
)
if guidance_scale < 1.0:
raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}")
if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}"
)
if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)
@torch.no_grad()
@replace_example_docstring(EXAMPLE_DOC_STRING)
def __call__(
self,
prompt: Union[str, List[str]] = None,
negative_prompt: str = "",
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 28,
timesteps: List[int] = None,
guidance_scale: float = 4.0,
num_images_per_prompt: Optional[int] = 1,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.Tensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
prompt_attention_mask: Optional[torch.BoolTensor] = None,
negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
use_resolution_binning: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
):
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
instead.
negative_prompt (`str`, *optional*, defaults to `""`):
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
if `guidance_scale` is less than `1`).
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 28):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
timesteps (`List[int]`, *optional*):
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
passed will be used. Must be in descending order.
guidance_scale (`float`, *optional*, defaults to 4.0):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.Tensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will be generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an
empty string.
prompt_attention_mask (`torch.BoolTensor`, *optional*):
Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated
from `prompt` input argument.
negative_prompt_attention_mask (`torch.BoolTensor`, *optional*):
Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`,
attention mask will be generated from an empty string.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple.
use_resolution_binning (`bool`, *optional*, defaults to `True`):
If set to `True`, the requested height and width are first mapped to the closest resolutions using
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
to the requested resolution. Useful for generating non-square images at optimal resolutions.
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`.
`callback_kwargs` will include a list of all tensors as specified by
`callback_on_step_end_tensor_inputs`.
callback_on_step_end_tensor_inputs (`List`, *optional*):
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed
in the `._callback_tensor_inputs` attribute.
Examples:
Returns:
[`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
"""
# 0. Set height and width
default_resolution = self.get_default_resolution()
height = height or default_resolution
width = width or default_resolution
if use_resolution_binning:
if self.image_processor is None:
raise ValueError(
"Resolution binning requires a VAE with image_processor, but VAE is not available. "
"Set use_resolution_binning=False or provide a VAE."
)
if self.default_sample_size <= 256:
aspect_ratio_bin = ASPECT_RATIO_256_BIN
else:
aspect_ratio_bin = ASPECT_RATIO_512_BIN
# Store original dimensions
orig_height, orig_width = height, width
# Map to closest resolution in the bin
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
# 1. Check inputs
self.check_inputs(
prompt,
height,
width,
guidance_scale,
callback_on_step_end_tensor_inputs,
prompt_embeds,
negative_prompt_embeds,
)
if self.vae is None and output_type not in ["latent", "pt"]:
raise ValueError(
f"VAE is required for output_type='{output_type}' but it is not available. "
"Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs."
)
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
# Use execution device (handles offloading scenarios including group offloading)
device = self._execution_device
self._guidance_scale = guidance_scale
# 2. Encode input prompt
text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
prompt,
device,
do_classifier_free_guidance=self.do_classifier_free_guidance,
negative_prompt=negative_prompt,
num_images_per_prompt=num_images_per_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
prompt_attention_mask=prompt_attention_mask,
negative_prompt_attention_mask=negative_prompt_attention_mask,
)
# Expose standard names for callbacks parity
prompt_embeds = text_embeddings
negative_prompt_embeds = uncond_text_embeddings
# 3. Prepare timesteps
if timesteps is not None:
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
timesteps = self.scheduler.timesteps
num_inference_steps = len(timesteps)
else:
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
self.num_timesteps = len(timesteps)
# 4. Prepare latent variables
if self.vae is not None:
num_channels_latents = self.vae.config.latent_channels
else:
# When vae is None, get latent channels from transformer
num_channels_latents = self.transformer.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
text_embeddings.dtype,
device,
generator,
latents,
)
# 5. Prepare extra step kwargs
extra_step_kwargs = {}
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_eta:
extra_step_kwargs["eta"] = 0.0
# 6. Prepare cross-attention embeddings and masks
if self.do_classifier_free_guidance:
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
ca_mask = None
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
else:
ca_embed = text_embeddings
ca_mask = cross_attn_mask
# 7. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# Duplicate latents if using classifier-free guidance
if self.do_classifier_free_guidance:
latents_in = torch.cat([latents, latents], dim=0)
# Normalize timestep for the transformer
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
else:
latents_in = latents
# Normalize timestep for the transformer
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
# Forward through transformer
noise_pred = self.transformer(
hidden_states=latents_in,
timestep=t_cont,
encoder_hidden_states=ca_embed,
attention_mask=ca_mask,
return_dict=False,
)[0]
# Apply CFG
if self.do_classifier_free_guidance:
noise_uncond, noise_text = noise_pred.chunk(2, dim=0)
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
# Compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_on_step_end(self, i, t, callback_kwargs)
# Call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
# 8. Post-processing
if output_type == "latent" or (output_type == "pt" and self.vae is None):
image = latents
else:
# Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC)
scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
latents = (latents / scaling_factor) + shift_factor
# Decode using VAE (AutoencoderKL or AutoencoderDC)
image = self.vae.decode(latents, return_dict=False)[0]
# Resize back to original resolution if using binning
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
# Use standard image processor for post-processing
image = self.image_processor.postprocess(image, output_type=output_type)
# Offload all models
self.maybe_free_model_hooks()
if not return_dict:
return (image,)
return PRXPipelineOutput(images=image)

View File

@@ -1128,6 +1128,21 @@ class PriorTransformer(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class PRXTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class QwenImageControlNetModel(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -1907,6 +1907,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class PRXPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class QwenImageControlNetInpaintPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

@@ -0,0 +1,83 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import torch
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from ...testing_utils import enable_full_determinism, torch_device
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()
class PRXTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = PRXTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
@property
def dummy_input(self):
return self.prepare_dummy_input()
@property
def input_shape(self):
return (16, 16, 16)
@property
def output_shape(self):
return (16, 16, 16)
def prepare_dummy_input(self, height=16, width=16):
batch_size = 1
num_latent_channels = 16
sequence_length = 16
embedding_dim = 1792
hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device)
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
return {
"hidden_states": hidden_states,
"timestep": timestep,
"encoder_hidden_states": encoder_hidden_states,
}
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 16,
"patch_size": 2,
"context_in_dim": 1792,
"hidden_size": 1792,
"mlp_ratio": 3.5,
"num_heads": 28,
"depth": 4, # Smaller depth for testing
"axes_dim": [32, 32],
"theta": 10_000,
}
inputs_dict = self.prepare_dummy_input()
return init_dict, inputs_dict
def test_gradient_checkpointing_is_applied(self):
expected_set = {"PRXTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
if __name__ == "__main__":
unittest.main()

View File

View File

@@ -0,0 +1,265 @@
import unittest
import numpy as np
import pytest
import torch
from transformers import AutoTokenizer
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
from diffusers.models import AutoencoderDC, AutoencoderKL
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from diffusers.utils import is_transformers_version
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
from ..test_pipelines_common import PipelineTesterMixin
@pytest.mark.xfail(
condition=is_transformers_version(">", "4.57.1"),
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
strict=False,
)
class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = PRXPipeline
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
test_xformers_attention = False
test_layerwise_casting = True
test_group_offloading = True
@classmethod
def setUpClass(cls):
# Ensure PRXPipeline has an _execution_device property expected by __call__
if not isinstance(getattr(PRXPipeline, "_execution_device", None), property):
try:
setattr(PRXPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
except Exception:
pass
def get_dummy_components(self):
torch.manual_seed(0)
transformer = PRXTransformer2DModel(
patch_size=1,
in_channels=4,
context_in_dim=8,
hidden_size=8,
mlp_ratio=2.0,
num_heads=2,
depth=1,
axes_dim=[2, 2],
)
torch.manual_seed(0)
vae = AutoencoderKL(
sample_size=32,
in_channels=3,
out_channels=3,
block_out_channels=(4,),
layers_per_block=1,
latent_channels=4,
norm_num_groups=1,
use_quant_conv=False,
use_post_quant_conv=False,
shift_factor=0.0,
scaling_factor=1.0,
).eval()
torch.manual_seed(0)
scheduler = FlowMatchEulerDiscreteScheduler()
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
tokenizer.model_max_length = 64
torch.manual_seed(0)
encoder_params = {
"vocab_size": tokenizer.vocab_size,
"hidden_size": 8,
"intermediate_size": 16,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 1,
"head_dim": 4,
"max_position_embeddings": 64,
"layer_types": ["full_attention"],
"attention_bias": False,
"attention_dropout": 0.0,
"dropout_rate": 0.0,
"hidden_activation": "gelu_pytorch_tanh",
"rms_norm_eps": 1e-06,
"attn_logit_softcapping": 50.0,
"final_logit_softcapping": 30.0,
"query_pre_attn_scalar": 4,
"rope_theta": 10000.0,
"sliding_window": 4096,
}
encoder_config = T5GemmaModuleConfig(**encoder_params)
text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params)
text_encoder = T5GemmaEncoder(text_encoder_config)
return {
"transformer": transformer,
"vae": vae,
"scheduler": scheduler,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
}
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
return {
"prompt": "",
"negative_prompt": "",
"generator": generator,
"num_inference_steps": 2,
"guidance_scale": 1.0,
"height": 32,
"width": 32,
"output_type": "pt",
"use_resolution_binning": False,
}
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = PRXPipeline(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
try:
pipe.register_to_config(_execution_device="cpu")
except Exception:
pass
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
expected_image = torch.zeros(3, 32, 32)
max_diff = np.abs(generated_image - expected_image).max()
self.assertLessEqual(max_diff, 1e10)
def test_callback_inputs(self):
components = self.get_dummy_components()
pipe = PRXPipeline(**components)
pipe = pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
try:
pipe.register_to_config(_execution_device="cpu")
except Exception:
pass
self.assertTrue(
hasattr(pipe, "_callback_tensor_inputs"),
f" {PRXPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
)
def callback_inputs_subset(pipe, i, t, callback_kwargs):
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
def callback_inputs_all(pipe, i, t, callback_kwargs):
for tensor_name in pipe._callback_tensor_inputs:
assert tensor_name in callback_kwargs
for tensor_name in callback_kwargs.keys():
assert tensor_name in pipe._callback_tensor_inputs
return callback_kwargs
inputs = self.get_dummy_inputs("cpu")
inputs["callback_on_step_end"] = callback_inputs_subset
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
_ = pipe(**inputs)[0]
inputs["callback_on_step_end"] = callback_inputs_all
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
_ = pipe(**inputs)[0]
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
if not self.test_attention_slicing:
return
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
for component in pipe.components.values():
if hasattr(component, "set_default_attn_processor"):
component.set_default_attn_processor()
pipe.to("cpu")
pipe.set_progress_bar_config(disable=None)
def to_np_local(tensor):
if isinstance(tensor, torch.Tensor):
return tensor.detach().cpu().numpy()
return tensor
generator_device = "cpu"
inputs = self.get_dummy_inputs(generator_device)
output_without_slicing = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=1)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing1 = pipe(**inputs)[0]
pipe.enable_attention_slicing(slice_size=2)
inputs = self.get_dummy_inputs(generator_device)
output_with_slicing2 = pipe(**inputs)[0]
max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max()
max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max()
self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
def test_inference_with_autoencoder_dc(self):
"""Test PRXPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
device = "cpu"
components = self.get_dummy_components()
torch.manual_seed(0)
vae_dc = AutoencoderDC(
in_channels=3,
latent_channels=4,
attention_head_dim=2,
encoder_block_types=(
"ResBlock",
"EfficientViTBlock",
),
decoder_block_types=(
"ResBlock",
"EfficientViTBlock",
),
encoder_block_out_channels=(8, 8),
decoder_block_out_channels=(8, 8),
encoder_qkv_multiscales=((), (5,)),
decoder_qkv_multiscales=((), (5,)),
encoder_layers_per_block=(1, 1),
decoder_layers_per_block=(1, 1),
upsample_block_type="interpolate",
downsample_block_type="stride_conv",
decoder_norm_types="rms_norm",
decoder_act_fns="silu",
).eval()
components["vae"] = vae_dc
pipe = PRXPipeline(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
expected_scale_factor = vae_dc.spatial_compression_ratio
self.assertEqual(pipe.vae_scale_factor, expected_scale_factor)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs)[0]
generated_image = image[0]
self.assertEqual(generated_image.shape, (3, 32, 32))
expected_image = torch.zeros(3, 32, 32)
max_diff = np.abs(generated_image - expected_image).max()
self.assertLessEqual(max_diff, 1e10)

View File

@@ -29,7 +29,7 @@ The benchmark results for Flux and CogVideoX can be found in [this](https://gith
The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent:
```bash
HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
HF_XET_HIGH_PERFORMANCE=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
```
`diffusers-cli`: