mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Merge branch 'main' into requirements-custom-blocks
This commit is contained in:
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
|
||||
37
.github/workflows/build_docker_images.yml
vendored
37
.github/workflows/build_docker_images.yml
vendored
@@ -42,18 +42,39 @@ jobs:
|
||||
CHANGED_FILES: ${{ steps.file_changes.outputs.all }}
|
||||
run: |
|
||||
echo "$CHANGED_FILES"
|
||||
for FILE in $CHANGED_FILES; do
|
||||
ALLOWED_IMAGES=(
|
||||
diffusers-pytorch-cpu
|
||||
diffusers-pytorch-cuda
|
||||
diffusers-pytorch-xformers-cuda
|
||||
diffusers-pytorch-minimum-cuda
|
||||
diffusers-doc-builder
|
||||
)
|
||||
|
||||
declare -A IMAGES_TO_BUILD=()
|
||||
|
||||
for FILE in $CHANGED_FILES; do
|
||||
# skip anything that isn't still on disk
|
||||
if [[ ! -f "$FILE" ]]; then
|
||||
if [[ ! -e "$FILE" ]]; then
|
||||
echo "Skipping removed file $FILE"
|
||||
continue
|
||||
fi
|
||||
if [[ "$FILE" == docker/*Dockerfile ]]; then
|
||||
DOCKER_PATH="${FILE%/Dockerfile}"
|
||||
DOCKER_TAG=$(basename "$DOCKER_PATH")
|
||||
echo "Building Docker image for $DOCKER_TAG"
|
||||
docker build -t "$DOCKER_TAG" "$DOCKER_PATH"
|
||||
fi
|
||||
|
||||
for IMAGE in "${ALLOWED_IMAGES[@]}"; do
|
||||
if [[ "$FILE" == docker/${IMAGE}/* ]]; then
|
||||
IMAGES_TO_BUILD["$IMAGE"]=1
|
||||
fi
|
||||
done
|
||||
done
|
||||
|
||||
if [[ ${#IMAGES_TO_BUILD[@]} -eq 0 ]]; then
|
||||
echo "No relevant Docker changes detected."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
for IMAGE in "${!IMAGES_TO_BUILD[@]}"; do
|
||||
DOCKER_PATH="docker/${IMAGE}"
|
||||
echo "Building Docker image for $IMAGE"
|
||||
docker build -t "$IMAGE" "$DOCKER_PATH"
|
||||
done
|
||||
if: steps.file_changes.outputs.all != ''
|
||||
|
||||
|
||||
2
.github/workflows/nightly_tests.yml
vendored
2
.github/workflows/nightly_tests.yml
vendored
@@ -7,7 +7,7 @@ on:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
PYTEST_TIMEOUT: 600
|
||||
|
||||
2
.github/workflows/pr_modular_tests.yml
vendored
2
.github/workflows/pr_modular_tests.yml
vendored
@@ -26,7 +26,7 @@ concurrency:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
2
.github/workflows/pr_tests.yml
vendored
2
.github/workflows/pr_tests.yml
vendored
@@ -22,7 +22,7 @@ concurrency:
|
||||
|
||||
env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
OMP_NUM_THREADS: 4
|
||||
MKL_NUM_THREADS: 4
|
||||
PYTEST_TIMEOUT: 60
|
||||
|
||||
2
.github/workflows/pr_tests_gpu.yml
vendored
2
.github/workflows/pr_tests_gpu.yml
vendored
@@ -24,7 +24,7 @@ env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
PIPELINE_USAGE_CUTOFF: 1000000000 # set high cutoff so that only always-test pipelines run
|
||||
|
||||
|
||||
2
.github/workflows/push_tests.yml
vendored
2
.github/workflows/push_tests.yml
vendored
@@ -14,7 +14,7 @@ env:
|
||||
DIFFUSERS_IS_CI: yes
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
PIPELINE_USAGE_CUTOFF: 50000
|
||||
|
||||
|
||||
2
.github/workflows/push_tests_fast.yml
vendored
2
.github/workflows/push_tests_fast.yml
vendored
@@ -18,7 +18,7 @@ env:
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: no
|
||||
|
||||
|
||||
2
.github/workflows/push_tests_mps.yml
vendored
2
.github/workflows/push_tests_mps.yml
vendored
@@ -8,7 +8,7 @@ env:
|
||||
HF_HOME: /mnt/cache
|
||||
OMP_NUM_THREADS: 8
|
||||
MKL_NUM_THREADS: 8
|
||||
HF_HUB_ENABLE_HF_TRANSFER: 1
|
||||
HF_XET_HIGH_PERFORMANCE: 1
|
||||
PYTEST_TIMEOUT: 600
|
||||
RUN_SLOW: no
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
hf_transfer \
|
||||
hf_xet \
|
||||
setuptools==69.5.1 \
|
||||
bitsandbytes \
|
||||
torchao \
|
||||
|
||||
@@ -44,6 +44,6 @@ RUN python3 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -38,13 +38,12 @@ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
|
||||
datasets \
|
||||
hf-doc-builder \
|
||||
huggingface-hub \
|
||||
hf_transfer \
|
||||
hf_xet \
|
||||
Jinja2 \
|
||||
librosa \
|
||||
numpy==1.26.4 \
|
||||
scipy \
|
||||
tensorboard \
|
||||
transformers \
|
||||
hf_transfer
|
||||
transformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -31,7 +31,7 @@ RUN uv pip install --no-cache-dir "git+https://github.com/huggingface/diffusers.
|
||||
RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
RUN apt-get clean && rm -rf /var/lib/apt/lists/* && apt-get autoremove && apt-get autoclean
|
||||
|
||||
|
||||
@@ -44,6 +44,6 @@ RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -47,6 +47,6 @@ RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer
|
||||
hf_xet
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -44,7 +44,7 @@ RUN uv pip install --no-cache-dir \
|
||||
accelerate \
|
||||
numpy==1.26.4 \
|
||||
pytorch-lightning \
|
||||
hf_transfer \
|
||||
hf_xet \
|
||||
xformers
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
- title: Get started
|
||||
sections:
|
||||
- sections:
|
||||
- local: index
|
||||
title: Diffusers
|
||||
- local: installation
|
||||
@@ -8,9 +7,8 @@
|
||||
title: Quickstart
|
||||
- local: stable_diffusion
|
||||
title: Basic performance
|
||||
|
||||
- title: Pipelines
|
||||
isExpanded: false
|
||||
title: Get started
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: using-diffusers/loading
|
||||
title: DiffusionPipeline
|
||||
@@ -28,9 +26,8 @@
|
||||
title: Model formats
|
||||
- local: using-diffusers/push_to_hub
|
||||
title: Sharing pipelines and models
|
||||
|
||||
- title: Adapters
|
||||
isExpanded: false
|
||||
title: Pipelines
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: tutorials/using_peft_for_inference
|
||||
title: LoRA
|
||||
@@ -44,9 +41,8 @@
|
||||
title: DreamBooth
|
||||
- local: using-diffusers/textual_inversion_inference
|
||||
title: Textual inversion
|
||||
|
||||
- title: Inference
|
||||
isExpanded: false
|
||||
title: Adapters
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: using-diffusers/weighted_prompts
|
||||
title: Prompting
|
||||
@@ -56,9 +52,8 @@
|
||||
title: Batch inference
|
||||
- local: training/distributed_inference
|
||||
title: Distributed inference
|
||||
|
||||
- title: Inference optimization
|
||||
isExpanded: false
|
||||
title: Inference
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: optimization/fp16
|
||||
title: Accelerate inference
|
||||
@@ -70,8 +65,7 @@
|
||||
title: Reduce memory usage
|
||||
- local: optimization/speed-memory-optims
|
||||
title: Compiling and offloading quantized models
|
||||
- title: Community optimizations
|
||||
sections:
|
||||
- sections:
|
||||
- local: optimization/pruna
|
||||
title: Pruna
|
||||
- local: optimization/xformers
|
||||
@@ -90,9 +84,9 @@
|
||||
title: ParaAttention
|
||||
- local: using-diffusers/image_quality
|
||||
title: FreeU
|
||||
|
||||
- title: Hybrid Inference
|
||||
isExpanded: false
|
||||
title: Community optimizations
|
||||
title: Inference optimization
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: hybrid_inference/overview
|
||||
title: Overview
|
||||
@@ -102,9 +96,8 @@
|
||||
title: VAE Encode
|
||||
- local: hybrid_inference/api_reference
|
||||
title: API Reference
|
||||
|
||||
- title: Modular Diffusers
|
||||
isExpanded: false
|
||||
title: Hybrid Inference
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: modular_diffusers/overview
|
||||
title: Overview
|
||||
@@ -126,9 +119,8 @@
|
||||
title: ComponentsManager
|
||||
- local: modular_diffusers/guiders
|
||||
title: Guiders
|
||||
|
||||
- title: Training
|
||||
isExpanded: false
|
||||
title: Modular Diffusers
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: training/overview
|
||||
title: Overview
|
||||
@@ -138,8 +130,7 @@
|
||||
title: Adapt a model to a new task
|
||||
- local: tutorials/basic_training
|
||||
title: Train a diffusion model
|
||||
- title: Models
|
||||
sections:
|
||||
- sections:
|
||||
- local: training/unconditional_training
|
||||
title: Unconditional image generation
|
||||
- local: training/text2image
|
||||
@@ -158,8 +149,8 @@
|
||||
title: InstructPix2Pix
|
||||
- local: training/cogvideox
|
||||
title: CogVideoX
|
||||
- title: Methods
|
||||
sections:
|
||||
title: Models
|
||||
- sections:
|
||||
- local: training/text_inversion
|
||||
title: Textual Inversion
|
||||
- local: training/dreambooth
|
||||
@@ -172,9 +163,9 @@
|
||||
title: Latent Consistency Distillation
|
||||
- local: training/ddpo
|
||||
title: Reinforcement learning training with DDPO
|
||||
|
||||
- title: Quantization
|
||||
isExpanded: false
|
||||
title: Methods
|
||||
title: Training
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: quantization/overview
|
||||
title: Getting started
|
||||
@@ -188,9 +179,8 @@
|
||||
title: quanto
|
||||
- local: quantization/modelopt
|
||||
title: NVIDIA ModelOpt
|
||||
|
||||
- title: Model accelerators and hardware
|
||||
isExpanded: false
|
||||
title: Quantization
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: optimization/onnx
|
||||
title: ONNX
|
||||
@@ -204,9 +194,8 @@
|
||||
title: Intel Gaudi
|
||||
- local: optimization/neuron
|
||||
title: AWS Neuron
|
||||
|
||||
- title: Specific pipeline examples
|
||||
isExpanded: false
|
||||
title: Model accelerators and hardware
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- local: using-diffusers/consisid
|
||||
title: ConsisID
|
||||
@@ -232,12 +221,10 @@
|
||||
title: Stable Video Diffusion
|
||||
- local: using-diffusers/marigold_usage
|
||||
title: Marigold Computer Vision
|
||||
|
||||
- title: Resources
|
||||
isExpanded: false
|
||||
title: Specific pipeline examples
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- title: Task recipes
|
||||
sections:
|
||||
- sections:
|
||||
- local: using-diffusers/unconditional_image_generation
|
||||
title: Unconditional image generation
|
||||
- local: using-diffusers/conditional_image_generation
|
||||
@@ -252,6 +239,7 @@
|
||||
title: Video generation
|
||||
- local: using-diffusers/depth2img
|
||||
title: Depth-to-image
|
||||
title: Task recipes
|
||||
- local: using-diffusers/write_own_pipeline
|
||||
title: Understanding pipelines, models and schedulers
|
||||
- local: community_projects
|
||||
@@ -266,12 +254,10 @@
|
||||
title: Diffusers' Ethical Guidelines
|
||||
- local: conceptual/evaluation
|
||||
title: Evaluating Diffusion Models
|
||||
|
||||
- title: API
|
||||
isExpanded: false
|
||||
title: Resources
|
||||
- isExpanded: false
|
||||
sections:
|
||||
- title: Main Classes
|
||||
sections:
|
||||
- sections:
|
||||
- local: api/configuration
|
||||
title: Configuration
|
||||
- local: api/logging
|
||||
@@ -282,8 +268,8 @@
|
||||
title: Quantization
|
||||
- local: api/parallel
|
||||
title: Parallel inference
|
||||
- title: Modular
|
||||
sections:
|
||||
title: Main Classes
|
||||
- sections:
|
||||
- local: api/modular_diffusers/pipeline
|
||||
title: Pipeline
|
||||
- local: api/modular_diffusers/pipeline_blocks
|
||||
@@ -294,8 +280,8 @@
|
||||
title: Components and configs
|
||||
- local: api/modular_diffusers/guiders
|
||||
title: Guiders
|
||||
- title: Loaders
|
||||
sections:
|
||||
title: Modular
|
||||
- sections:
|
||||
- local: api/loaders/ip_adapter
|
||||
title: IP-Adapter
|
||||
- local: api/loaders/lora
|
||||
@@ -310,14 +296,13 @@
|
||||
title: SD3Transformer2D
|
||||
- local: api/loaders/peft
|
||||
title: PEFT
|
||||
- title: Models
|
||||
sections:
|
||||
title: Loaders
|
||||
- sections:
|
||||
- local: api/models/overview
|
||||
title: Overview
|
||||
- local: api/models/auto_model
|
||||
title: AutoModel
|
||||
- title: ControlNets
|
||||
sections:
|
||||
- sections:
|
||||
- local: api/models/controlnet
|
||||
title: ControlNetModel
|
||||
- local: api/models/controlnet_union
|
||||
@@ -332,8 +317,8 @@
|
||||
title: SD3ControlNetModel
|
||||
- local: api/models/controlnet_sparsectrl
|
||||
title: SparseControlNetModel
|
||||
- title: Transformers
|
||||
sections:
|
||||
title: ControlNets
|
||||
- sections:
|
||||
- local: api/models/allegro_transformer3d
|
||||
title: AllegroTransformer3DModel
|
||||
- local: api/models/aura_flow_transformer2d
|
||||
@@ -396,8 +381,8 @@
|
||||
title: TransformerTemporalModel
|
||||
- local: api/models/wan_transformer_3d
|
||||
title: WanTransformer3DModel
|
||||
- title: UNets
|
||||
sections:
|
||||
title: Transformers
|
||||
- sections:
|
||||
- local: api/models/stable_cascade_unet
|
||||
title: StableCascadeUNet
|
||||
- local: api/models/unet
|
||||
@@ -412,8 +397,8 @@
|
||||
title: UNetMotionModel
|
||||
- local: api/models/uvit2d
|
||||
title: UViT2DModel
|
||||
- title: VAEs
|
||||
sections:
|
||||
title: UNets
|
||||
- sections:
|
||||
- local: api/models/asymmetricautoencoderkl
|
||||
title: AsymmetricAutoencoderKL
|
||||
- local: api/models/autoencoder_dc
|
||||
@@ -446,210 +431,220 @@
|
||||
title: Tiny AutoEncoder
|
||||
- local: api/models/vq
|
||||
title: VQModel
|
||||
- title: Pipelines
|
||||
sections:
|
||||
title: VAEs
|
||||
title: Models
|
||||
- sections:
|
||||
- local: api/pipelines/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
- local: api/pipelines/amused
|
||||
title: aMUSEd
|
||||
- local: api/pipelines/animatediff
|
||||
title: AnimateDiff
|
||||
- local: api/pipelines/attend_and_excite
|
||||
title: Attend-and-Excite
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/aura_flow
|
||||
title: AuraFlow
|
||||
- sections:
|
||||
- local: api/pipelines/audioldm
|
||||
title: AudioLDM
|
||||
- local: api/pipelines/audioldm2
|
||||
title: AudioLDM 2
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/musicldm
|
||||
title: MusicLDM
|
||||
- local: api/pipelines/stable_audio
|
||||
title: Stable Audio
|
||||
title: Audio
|
||||
- local: api/pipelines/auto_pipeline
|
||||
title: AutoPipeline
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/bria_3_2
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/cogview3
|
||||
title: CogView3
|
||||
- local: api/pipelines/cogview4
|
||||
title: CogView4
|
||||
- local: api/pipelines/consisid
|
||||
title: ConsisID
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_flux
|
||||
title: ControlNet with Flux.1
|
||||
- local: api/pipelines/controlnet_hunyuandit
|
||||
title: ControlNet with Hunyuan-DiT
|
||||
- local: api/pipelines/controlnet_sd3
|
||||
title: ControlNet with Stable Diffusion 3
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_sana
|
||||
title: ControlNet-Sana
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: Dance Diffusion
|
||||
- local: api/pipelines/ddim
|
||||
title: DDIM
|
||||
- local: api/pipelines/ddpm
|
||||
title: DDPM
|
||||
- local: api/pipelines/deepfloyd_if
|
||||
title: DeepFloyd IF
|
||||
- local: api/pipelines/diffedit
|
||||
title: DiffEdit
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/easyanimate
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/pix2pix
|
||||
title: InstructPix2Pix
|
||||
- local: api/pipelines/kandinsky
|
||||
title: Kandinsky 2.1
|
||||
- local: api/pipelines/kandinsky_v22
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/lumina2
|
||||
title: Lumina 2.0
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
title: Marigold
|
||||
- local: api/pipelines/mochi
|
||||
title: Mochi
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/musicldm
|
||||
title: MusicLDM
|
||||
- local: api/pipelines/omnigen
|
||||
title: OmniGen
|
||||
- local: api/pipelines/pag
|
||||
title: PAG
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/pia
|
||||
title: Personalized Image Animator (PIA)
|
||||
- local: api/pipelines/pixart
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/qwenimage
|
||||
title: QwenImage
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/skyreels_v2
|
||||
title: SkyReels-V2
|
||||
- local: api/pipelines/stable_audio
|
||||
title: Stable Audio
|
||||
- local: api/pipelines/stable_cascade
|
||||
title: Stable Cascade
|
||||
- title: Stable Diffusion
|
||||
sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/stable_diffusion/depth2img
|
||||
title: Depth-to-image
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
title: GLIGEN (Grounded Language-to-Image Generation)
|
||||
- local: api/pipelines/stable_diffusion/image_variation
|
||||
title: Image variation
|
||||
- local: api/pipelines/stable_diffusion/img2img
|
||||
title: Image-to-image
|
||||
- sections:
|
||||
- local: api/pipelines/amused
|
||||
title: aMUSEd
|
||||
- local: api/pipelines/animatediff
|
||||
title: AnimateDiff
|
||||
- local: api/pipelines/attend_and_excite
|
||||
title: Attend-and-Excite
|
||||
- local: api/pipelines/aura_flow
|
||||
title: AuraFlow
|
||||
- local: api/pipelines/blip_diffusion
|
||||
title: BLIP-Diffusion
|
||||
- local: api/pipelines/bria_3_2
|
||||
title: Bria 3.2
|
||||
- local: api/pipelines/chroma
|
||||
title: Chroma
|
||||
- local: api/pipelines/cogview3
|
||||
title: CogView3
|
||||
- local: api/pipelines/cogview4
|
||||
title: CogView4
|
||||
- local: api/pipelines/consistency_models
|
||||
title: Consistency Models
|
||||
- local: api/pipelines/controlnet
|
||||
title: ControlNet
|
||||
- local: api/pipelines/controlnet_flux
|
||||
title: ControlNet with Flux.1
|
||||
- local: api/pipelines/controlnet_hunyuandit
|
||||
title: ControlNet with Hunyuan-DiT
|
||||
- local: api/pipelines/controlnet_sd3
|
||||
title: ControlNet with Stable Diffusion 3
|
||||
- local: api/pipelines/controlnet_sdxl
|
||||
title: ControlNet with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_sana
|
||||
title: ControlNet-Sana
|
||||
- local: api/pipelines/controlnetxs
|
||||
title: ControlNet-XS
|
||||
- local: api/pipelines/controlnetxs_sdxl
|
||||
title: ControlNet-XS with Stable Diffusion XL
|
||||
- local: api/pipelines/controlnet_union
|
||||
title: ControlNetUnion
|
||||
- local: api/pipelines/cosmos
|
||||
title: Cosmos
|
||||
- local: api/pipelines/ddim
|
||||
title: DDIM
|
||||
- local: api/pipelines/ddpm
|
||||
title: DDPM
|
||||
- local: api/pipelines/deepfloyd_if
|
||||
title: DeepFloyd IF
|
||||
- local: api/pipelines/diffedit
|
||||
title: DiffEdit
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/easyanimate
|
||||
title: EasyAnimate
|
||||
- local: api/pipelines/flux
|
||||
title: Flux
|
||||
- local: api/pipelines/control_flux_inpaint
|
||||
title: FluxControlInpaint
|
||||
- local: api/pipelines/hidream
|
||||
title: HiDream-I1
|
||||
- local: api/pipelines/hunyuandit
|
||||
title: Hunyuan-DiT
|
||||
- local: api/pipelines/pix2pix
|
||||
title: InstructPix2Pix
|
||||
- local: api/pipelines/kandinsky
|
||||
title: Kandinsky 2.1
|
||||
- local: api/pipelines/kandinsky_v22
|
||||
title: Kandinsky 2.2
|
||||
- local: api/pipelines/kandinsky3
|
||||
title: Kandinsky 3
|
||||
- local: api/pipelines/kolors
|
||||
title: Kolors
|
||||
- local: api/pipelines/latent_consistency_models
|
||||
title: Latent Consistency Models
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/ledits_pp
|
||||
title: LEDITS++
|
||||
- local: api/pipelines/lumina2
|
||||
title: Lumina 2.0
|
||||
- local: api/pipelines/lumina
|
||||
title: Lumina-T2X
|
||||
- local: api/pipelines/marigold
|
||||
title: Marigold
|
||||
- local: api/pipelines/panorama
|
||||
title: MultiDiffusion
|
||||
- local: api/pipelines/omnigen
|
||||
title: OmniGen
|
||||
- local: api/pipelines/pag
|
||||
title: PAG
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/pixart
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
title: PixArt-Σ
|
||||
- local: api/pipelines/prx
|
||||
title: PRX
|
||||
- local: api/pipelines/qwenimage
|
||||
title: QwenImage
|
||||
- local: api/pipelines/sana
|
||||
title: Sana
|
||||
- local: api/pipelines/sana_sprint
|
||||
title: Sana Sprint
|
||||
- local: api/pipelines/self_attention_guidance
|
||||
title: Self-Attention Guidance
|
||||
- local: api/pipelines/semantic_stable_diffusion
|
||||
title: Semantic Guidance
|
||||
- local: api/pipelines/shap_e
|
||||
title: Shap-E
|
||||
- local: api/pipelines/stable_cascade
|
||||
title: Stable Cascade
|
||||
- sections:
|
||||
- local: api/pipelines/stable_diffusion/overview
|
||||
title: Overview
|
||||
- local: api/pipelines/stable_diffusion/depth2img
|
||||
title: Depth-to-image
|
||||
- local: api/pipelines/stable_diffusion/gligen
|
||||
title: GLIGEN (Grounded Language-to-Image Generation)
|
||||
- local: api/pipelines/stable_diffusion/image_variation
|
||||
title: Image variation
|
||||
- local: api/pipelines/stable_diffusion/img2img
|
||||
title: Image-to-image
|
||||
- local: api/pipelines/stable_diffusion/inpaint
|
||||
title: Inpainting
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D
|
||||
Upscaler
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
|
||||
title: Safe Stable Diffusion
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_2
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_3
|
||||
title: Stable Diffusion 3
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
title: Super-resolution
|
||||
- local: api/pipelines/stable_diffusion/adapter
|
||||
title: T2I-Adapter
|
||||
- local: api/pipelines/stable_diffusion/text2img
|
||||
title: Text-to-image
|
||||
title: Stable Diffusion
|
||||
- local: api/pipelines/stable_unclip
|
||||
title: Stable unCLIP
|
||||
- local: api/pipelines/unclip
|
||||
title: unCLIP
|
||||
- local: api/pipelines/unidiffuser
|
||||
title: UniDiffuser
|
||||
- local: api/pipelines/value_guided_sampling
|
||||
title: Value-guided sampling
|
||||
- local: api/pipelines/visualcloze
|
||||
title: VisualCloze
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
title: Image
|
||||
- sections:
|
||||
- local: api/pipelines/allegro
|
||||
title: Allegro
|
||||
- local: api/pipelines/cogvideox
|
||||
title: CogVideoX
|
||||
- local: api/pipelines/consisid
|
||||
title: ConsisID
|
||||
- local: api/pipelines/framepack
|
||||
title: Framepack
|
||||
- local: api/pipelines/hunyuan_video
|
||||
title: HunyuanVideo
|
||||
- local: api/pipelines/i2vgenxl
|
||||
title: I2VGen-XL
|
||||
- local: api/pipelines/latte
|
||||
title: Latte
|
||||
- local: api/pipelines/ltx_video
|
||||
title: LTXVideo
|
||||
- local: api/pipelines/mochi
|
||||
title: Mochi
|
||||
- local: api/pipelines/pia
|
||||
title: Personalized Image Animator (PIA)
|
||||
- local: api/pipelines/skyreels_v2
|
||||
title: SkyReels-V2
|
||||
- local: api/pipelines/stable_diffusion/svd
|
||||
title: Image-to-video
|
||||
- local: api/pipelines/stable_diffusion/inpaint
|
||||
title: Inpainting
|
||||
- local: api/pipelines/stable_diffusion/k_diffusion
|
||||
title: K-Diffusion
|
||||
- local: api/pipelines/stable_diffusion/latent_upscale
|
||||
title: Latent upscaler
|
||||
- local: api/pipelines/stable_diffusion/ldm3d_diffusion
|
||||
title: LDM3D Text-to-(RGB, Depth), Text-to-(RGB-pano, Depth-pano), LDM3D Upscaler
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_safe
|
||||
title: Safe Stable Diffusion
|
||||
- local: api/pipelines/stable_diffusion/sdxl_turbo
|
||||
title: SDXL Turbo
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_2
|
||||
title: Stable Diffusion 2
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_3
|
||||
title: Stable Diffusion 3
|
||||
- local: api/pipelines/stable_diffusion/stable_diffusion_xl
|
||||
title: Stable Diffusion XL
|
||||
- local: api/pipelines/stable_diffusion/upscale
|
||||
title: Super-resolution
|
||||
- local: api/pipelines/stable_diffusion/adapter
|
||||
title: T2I-Adapter
|
||||
- local: api/pipelines/stable_diffusion/text2img
|
||||
title: Text-to-image
|
||||
- local: api/pipelines/stable_unclip
|
||||
title: Stable unCLIP
|
||||
- local: api/pipelines/text_to_video
|
||||
title: Text-to-video
|
||||
- local: api/pipelines/text_to_video_zero
|
||||
title: Text2Video-Zero
|
||||
- local: api/pipelines/unclip
|
||||
title: unCLIP
|
||||
- local: api/pipelines/unidiffuser
|
||||
title: UniDiffuser
|
||||
- local: api/pipelines/value_guided_sampling
|
||||
title: Value-guided sampling
|
||||
- local: api/pipelines/visualcloze
|
||||
title: VisualCloze
|
||||
- local: api/pipelines/wan
|
||||
title: Wan
|
||||
- local: api/pipelines/wuerstchen
|
||||
title: Wuerstchen
|
||||
- title: Schedulers
|
||||
sections:
|
||||
title: Stable Video Diffusion
|
||||
- local: api/pipelines/text_to_video
|
||||
title: Text-to-video
|
||||
- local: api/pipelines/text_to_video_zero
|
||||
title: Text2Video-Zero
|
||||
- local: api/pipelines/wan
|
||||
title: Wan
|
||||
title: Video
|
||||
title: Pipelines
|
||||
- sections:
|
||||
- local: api/schedulers/overview
|
||||
title: Overview
|
||||
- local: api/schedulers/cm_stochastic_iterative
|
||||
@@ -718,8 +713,8 @@
|
||||
title: UniPCMultistepScheduler
|
||||
- local: api/schedulers/vq_diffusion
|
||||
title: VQDiffusionScheduler
|
||||
- title: Internal classes
|
||||
sections:
|
||||
title: Schedulers
|
||||
- sections:
|
||||
- local: api/internal_classes_overview
|
||||
title: Overview
|
||||
- local: api/attnprocessor
|
||||
@@ -736,3 +731,5 @@
|
||||
title: VAE Image Processor
|
||||
- local: api/video_processor
|
||||
title: Video Processor
|
||||
title: Internal classes
|
||||
title: API
|
||||
|
||||
131
docs/source/en/api/pipelines/prx.md
Normal file
131
docs/source/en/api/pipelines/prx.md
Normal file
@@ -0,0 +1,131 @@
|
||||
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# PRX
|
||||
|
||||
|
||||
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
|
||||
|
||||
## Available models
|
||||
|
||||
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
|
||||
|
||||
|
||||
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|
||||
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
|
||||
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
|
||||
|
||||
## Loading the pipeline
|
||||
|
||||
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
image.save("prx_output.png")
|
||||
```
|
||||
|
||||
### Manual Component Loading
|
||||
|
||||
Load components individually to customize the pipeline for instance to use quantized models.
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
from diffusers.models import AutoencoderKL, AutoencoderDC
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from transformers import T5GemmaModel, GemmaTokenizerFast
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
from transformers import BitsAndBytesConfig as BitsAndBytesConfig
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
# Load transformer
|
||||
transformer = PRXTransformer2DModel.from_pretrained(
|
||||
"checkpoints/prx-512-t2i-sft",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
# Load scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
|
||||
)
|
||||
|
||||
# Load T5Gemma text encoder
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
text_encoder = t5gemma_model.encoder.to(dtype=torch.bfloat16)
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
|
||||
# Load VAE - choose either Flux VAE or DC-AE
|
||||
# Flux VAE
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
|
||||
subfolder="vae",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = PRXPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae
|
||||
)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
|
||||
## Memory Optimization
|
||||
|
||||
For memory-constrained environments:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
|
||||
|
||||
# Or use sequential CPU offload for even lower memory
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
## PRXPipeline
|
||||
|
||||
[[autodoc]] PRXPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## PRXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
|
||||
345
scripts/convert_prx_to_diffusers.py
Normal file
345
scripts/convert_prx_to_diffusers.py
Normal file
@@ -0,0 +1,345 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to convert PRX checkpoint from original codebase to diffusers format.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
|
||||
DEFAULT_RESOLUTION = 512
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PRXBase:
|
||||
context_in_dim: int = 2304
|
||||
hidden_size: int = 1792
|
||||
mlp_ratio: float = 3.5
|
||||
num_heads: int = 28
|
||||
depth: int = 16
|
||||
axes_dim: Tuple[int, int] = (32, 32)
|
||||
theta: int = 10_000
|
||||
time_factor: float = 1000.0
|
||||
time_max_period: int = 10_000
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PRXFlux(PRXBase):
|
||||
in_channels: int = 16
|
||||
patch_size: int = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PRXDCAE(PRXBase):
|
||||
in_channels: int = 32
|
||||
patch_size: int = 1
|
||||
|
||||
|
||||
def build_config(vae_type: str) -> Tuple[dict, int]:
|
||||
if vae_type == "flux":
|
||||
cfg = PRXFlux()
|
||||
elif vae_type == "dc-ae":
|
||||
cfg = PRXDCAE()
|
||||
else:
|
||||
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
|
||||
|
||||
config_dict = asdict(cfg)
|
||||
config_dict["axes_dim"] = list(config_dict["axes_dim"]) # type: ignore[index]
|
||||
return config_dict
|
||||
|
||||
|
||||
def create_parameter_mapping(depth: int) -> dict:
|
||||
"""Create mapping from old parameter names to new diffusers names."""
|
||||
|
||||
# Key mappings for structural changes
|
||||
mapping = {}
|
||||
|
||||
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
|
||||
for i in range(depth):
|
||||
# QKV projections moved to attention module
|
||||
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
|
||||
mapping[f"blocks.{i}.txt_kv_proj.weight"] = f"blocks.{i}.attention.txt_kv_proj.weight"
|
||||
|
||||
# QK norm moved to attention module and renamed to match Attention's qk_norm structure
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.scale"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.scale"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.query_norm.weight"] = f"blocks.{i}.attention.norm_q.weight"
|
||||
mapping[f"blocks.{i}.qk_norm.key_norm.weight"] = f"blocks.{i}.attention.norm_k.weight"
|
||||
|
||||
# K norm for text tokens moved to attention module
|
||||
mapping[f"blocks.{i}.k_norm.scale"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
mapping[f"blocks.{i}.k_norm.weight"] = f"blocks.{i}.attention.norm_added_k.weight"
|
||||
|
||||
# Attention output projection
|
||||
mapping[f"blocks.{i}.attn_out.weight"] = f"blocks.{i}.attention.to_out.0.weight"
|
||||
|
||||
return mapping
|
||||
|
||||
|
||||
def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth: int) -> Dict[str, torch.Tensor]:
|
||||
"""Convert old checkpoint parameters to new diffusers format."""
|
||||
|
||||
print("Converting checkpoint parameters...")
|
||||
|
||||
mapping = create_parameter_mapping(depth)
|
||||
converted_state_dict = {}
|
||||
|
||||
for key, value in old_state_dict.items():
|
||||
new_key = key
|
||||
|
||||
# Apply specific mappings if needed
|
||||
if key in mapping:
|
||||
new_key = mapping[key]
|
||||
print(f" Mapped: {key} -> {new_key}")
|
||||
|
||||
converted_state_dict[new_key] = value
|
||||
|
||||
print(f"✓ Converted {len(converted_state_dict)} parameters")
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
|
||||
"""Create and load PRXTransformer2DModel from old checkpoint."""
|
||||
|
||||
print(f"Loading checkpoint from: {checkpoint_path}")
|
||||
|
||||
# Load old checkpoint
|
||||
if not os.path.exists(checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
||||
|
||||
old_checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
||||
|
||||
# Handle different checkpoint formats
|
||||
if isinstance(old_checkpoint, dict):
|
||||
if "model" in old_checkpoint:
|
||||
state_dict = old_checkpoint["model"]
|
||||
elif "state_dict" in old_checkpoint:
|
||||
state_dict = old_checkpoint["state_dict"]
|
||||
else:
|
||||
state_dict = old_checkpoint
|
||||
else:
|
||||
state_dict = old_checkpoint
|
||||
|
||||
print(f"✓ Loaded checkpoint with {len(state_dict)} parameters")
|
||||
|
||||
# Convert parameter names if needed
|
||||
model_depth = int(config.get("depth", 16))
|
||||
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
|
||||
|
||||
# Create transformer with config
|
||||
print("Creating PRXTransformer2DModel...")
|
||||
transformer = PRXTransformer2DModel(**config)
|
||||
|
||||
# Load state dict
|
||||
print("Loading converted parameters...")
|
||||
missing_keys, unexpected_keys = transformer.load_state_dict(converted_state_dict, strict=False)
|
||||
|
||||
if missing_keys:
|
||||
print(f"⚠ Missing keys: {missing_keys}")
|
||||
if unexpected_keys:
|
||||
print(f"⚠ Unexpected keys: {unexpected_keys}")
|
||||
|
||||
if not missing_keys and not unexpected_keys:
|
||||
print("✓ All parameters loaded successfully!")
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
def create_scheduler_config(output_path: str, shift: float):
|
||||
"""Create FlowMatchEulerDiscreteScheduler config."""
|
||||
|
||||
scheduler_config = {"_class_name": "FlowMatchEulerDiscreteScheduler", "num_train_timesteps": 1000, "shift": shift}
|
||||
|
||||
scheduler_path = os.path.join(output_path, "scheduler")
|
||||
os.makedirs(scheduler_path, exist_ok=True)
|
||||
|
||||
with open(os.path.join(scheduler_path, "scheduler_config.json"), "w") as f:
|
||||
json.dump(scheduler_config, f, indent=2)
|
||||
|
||||
print("✓ Created scheduler config")
|
||||
|
||||
|
||||
def download_and_save_vae(vae_type: str, output_path: str):
|
||||
"""Download and save VAE to local directory."""
|
||||
from diffusers import AutoencoderDC, AutoencoderKL
|
||||
|
||||
vae_path = os.path.join(output_path, "vae")
|
||||
os.makedirs(vae_path, exist_ok=True)
|
||||
|
||||
if vae_type == "flux":
|
||||
print("Downloading FLUX VAE from black-forest-labs/FLUX.1-dev...")
|
||||
vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev", subfolder="vae")
|
||||
else: # dc-ae
|
||||
print("Downloading DC-AE VAE from mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers...")
|
||||
vae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers")
|
||||
|
||||
vae.save_pretrained(vae_path)
|
||||
print(f"✓ Saved VAE to {vae_path}")
|
||||
|
||||
|
||||
def download_and_save_text_encoder(output_path: str):
|
||||
"""Download and save T5Gemma text encoder and tokenizer."""
|
||||
from transformers import GemmaTokenizerFast
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaModel
|
||||
|
||||
text_encoder_path = os.path.join(output_path, "text_encoder")
|
||||
tokenizer_path = os.path.join(output_path, "tokenizer")
|
||||
os.makedirs(text_encoder_path, exist_ok=True)
|
||||
os.makedirs(tokenizer_path, exist_ok=True)
|
||||
|
||||
print("Downloading T5Gemma model from google/t5gemma-2b-2b-ul2...")
|
||||
t5gemma_model = T5GemmaModel.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
|
||||
# Extract and save only the encoder
|
||||
t5gemma_encoder = t5gemma_model.encoder
|
||||
t5gemma_encoder.save_pretrained(text_encoder_path)
|
||||
print(f"✓ Saved T5GemmaEncoder to {text_encoder_path}")
|
||||
|
||||
print("Downloading tokenizer from google/t5gemma-2b-2b-ul2...")
|
||||
tokenizer = GemmaTokenizerFast.from_pretrained("google/t5gemma-2b-2b-ul2")
|
||||
tokenizer.model_max_length = 256
|
||||
tokenizer.save_pretrained(tokenizer_path)
|
||||
print(f"✓ Saved tokenizer to {tokenizer_path}")
|
||||
|
||||
|
||||
def create_model_index(vae_type: str, default_image_size: int, output_path: str):
|
||||
"""Create model_index.json for the pipeline."""
|
||||
|
||||
if vae_type == "flux":
|
||||
vae_class = "AutoencoderKL"
|
||||
else: # dc-ae
|
||||
vae_class = "AutoencoderDC"
|
||||
|
||||
model_index = {
|
||||
"_class_name": "PRXPipeline",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"_name_or_path": os.path.basename(output_path),
|
||||
"default_sample_size": default_image_size,
|
||||
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
|
||||
"text_encoder": ["prx", "T5GemmaEncoder"],
|
||||
"tokenizer": ["transformers", "GemmaTokenizerFast"],
|
||||
"transformer": ["diffusers", "PRXTransformer2DModel"],
|
||||
"vae": ["diffusers", vae_class],
|
||||
}
|
||||
|
||||
model_index_path = os.path.join(output_path, "model_index.json")
|
||||
with open(model_index_path, "w") as f:
|
||||
json.dump(model_index, f, indent=2)
|
||||
|
||||
|
||||
def main(args):
|
||||
# Validate inputs
|
||||
if not os.path.exists(args.checkpoint_path):
|
||||
raise FileNotFoundError(f"Checkpoint not found: {args.checkpoint_path}")
|
||||
|
||||
config = build_config(args.vae_type)
|
||||
|
||||
# Create output directory
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
print(f"✓ Output directory: {args.output_path}")
|
||||
|
||||
# Create transformer from checkpoint
|
||||
transformer = create_transformer_from_checkpoint(args.checkpoint_path, config)
|
||||
|
||||
# Save transformer
|
||||
transformer_path = os.path.join(args.output_path, "transformer")
|
||||
os.makedirs(transformer_path, exist_ok=True)
|
||||
|
||||
# Save config
|
||||
with open(os.path.join(transformer_path, "config.json"), "w") as f:
|
||||
json.dump(config, f, indent=2)
|
||||
|
||||
# Save model weights as safetensors
|
||||
state_dict = transformer.state_dict()
|
||||
save_file(state_dict, os.path.join(transformer_path, "diffusion_pytorch_model.safetensors"))
|
||||
print(f"✓ Saved transformer to {transformer_path}")
|
||||
|
||||
# Create scheduler config
|
||||
create_scheduler_config(args.output_path, args.shift)
|
||||
|
||||
download_and_save_vae(args.vae_type, args.output_path)
|
||||
download_and_save_text_encoder(args.output_path)
|
||||
|
||||
# Create model_index.json
|
||||
create_model_index(args.vae_type, args.resolution, args.output_path)
|
||||
|
||||
# Verify the pipeline can be loaded
|
||||
try:
|
||||
pipeline = PRXPipeline.from_pretrained(args.output_path)
|
||||
print("Pipeline loaded successfully!")
|
||||
print(f"Transformer: {type(pipeline.transformer).__name__}")
|
||||
print(f"VAE: {type(pipeline.vae).__name__}")
|
||||
print(f"Text Encoder: {type(pipeline.text_encoder).__name__}")
|
||||
print(f"Scheduler: {type(pipeline.scheduler).__name__}")
|
||||
|
||||
# Display model info
|
||||
num_params = sum(p.numel() for p in pipeline.transformer.parameters())
|
||||
print(f"✓ Transformer parameters: {num_params:,}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Pipeline verification failed: {e}")
|
||||
return False
|
||||
|
||||
print("Conversion completed successfully!")
|
||||
print(f"Converted pipeline saved to: {args.output_path}")
|
||||
print(f"VAE type: {args.vae_type}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output_path", type=str, required=True, help="Output directory for the converted diffusers pipeline"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vae_type",
|
||||
type=str,
|
||||
choices=["flux", "dc-ae"],
|
||||
required=True,
|
||||
help="VAE type to use: 'flux' for AutoencoderKL (16 channels) or 'dc-ae' for AutoencoderDC (32 channels)",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
choices=[256, 512, 1024],
|
||||
default=DEFAULT_RESOLUTION,
|
||||
help="Target resolution for the model (256, 512, or 1024). Affects the transformer's sample_size.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shift",
|
||||
type=float,
|
||||
default=3.0,
|
||||
help="Shift for the scheduler",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
success = main(args)
|
||||
if not success:
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
print(f"Conversion failed: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
sys.exit(1)
|
||||
@@ -234,6 +234,7 @@ else:
|
||||
"ParallelConfig",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"PRXTransformer2DModel",
|
||||
"QwenImageControlNetModel",
|
||||
"QwenImageMultiControlNetModel",
|
||||
"QwenImageTransformer2DModel",
|
||||
@@ -519,6 +520,7 @@ else:
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
"PixArtSigmaPipeline",
|
||||
"PRXPipeline",
|
||||
"QwenImageControlNetInpaintPipeline",
|
||||
"QwenImageControlNetPipeline",
|
||||
"QwenImageEditInpaintPipeline",
|
||||
@@ -928,6 +930,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
ParallelConfig,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
PRXTransformer2DModel,
|
||||
QwenImageControlNetModel,
|
||||
QwenImageMultiControlNetModel,
|
||||
QwenImageTransformer2DModel,
|
||||
@@ -1183,6 +1186,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
PixArtSigmaPipeline,
|
||||
PRXPipeline,
|
||||
QwenImageControlNetInpaintPipeline,
|
||||
QwenImageControlNetPipeline,
|
||||
QwenImageEditInpaintPipeline,
|
||||
|
||||
@@ -293,7 +293,7 @@ class PeftAdapterMixin:
|
||||
# For hotswapping, we need the adapter name to be present in the state dict keys
|
||||
new_sd = {}
|
||||
for k, v in sd.items():
|
||||
if k.endswith("lora_A.weight") or key.endswith("lora_B.weight"):
|
||||
if k.endswith("lora_A.weight") or k.endswith("lora_B.weight"):
|
||||
k = k[: -len(".weight")] + f".{adapter_name}.weight"
|
||||
elif k.endswith("lora_B.bias"): # lora_bias=True option
|
||||
k = k[: -len(".bias")] + f".{adapter_name}.bias"
|
||||
|
||||
@@ -96,6 +96,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
@@ -192,6 +193,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
OmniGenTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
PRXTransformer2DModel,
|
||||
QwenImageTransformer2DModel,
|
||||
SanaTransformer2DModel,
|
||||
SD3Transformer2DModel,
|
||||
|
||||
@@ -32,6 +32,7 @@ if is_torch_available():
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_prx import PRXTransformer2DModel
|
||||
from .transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
|
||||
770
src/diffusers/models/transformers/transformer_prx.py
Normal file
770
src/diffusers/models/transformers/transformer_prx.py
Normal file
@@ -0,0 +1,770 @@
|
||||
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.functional import fold, unfold
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...utils import logging
|
||||
from ..attention import AttentionMixin, AttentionModuleMixin
|
||||
from ..attention_dispatch import dispatch_attention_fn
|
||||
from ..embeddings import get_timestep_embedding
|
||||
from ..modeling_outputs import Transformer2DModelOutput
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..normalization import RMSNorm
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_image_ids(batch_size: int, height: int, width: int, patch_size: int, device: torch.device) -> torch.Tensor:
|
||||
r"""
|
||||
Generates 2D patch coordinate indices for a batch of images.
|
||||
|
||||
Args:
|
||||
batch_size (`int`):
|
||||
Number of images in the batch.
|
||||
height (`int`):
|
||||
Height of the input images (in pixels).
|
||||
width (`int`):
|
||||
Width of the input images (in pixels).
|
||||
patch_size (`int`):
|
||||
Size of the square patches that the image is divided into.
|
||||
device (`torch.device`):
|
||||
The device on which to create the tensor.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Tensor of shape `(batch_size, num_patches, 2)` containing the (row, col) coordinates of each patch in the
|
||||
image grid.
|
||||
"""
|
||||
|
||||
img_ids = torch.zeros(height // patch_size, width // patch_size, 2, device=device)
|
||||
img_ids[..., 0] = torch.arange(height // patch_size, device=device)[:, None]
|
||||
img_ids[..., 1] = torch.arange(width // patch_size, device=device)[None, :]
|
||||
return img_ids.reshape((height // patch_size) * (width // patch_size), 2).unsqueeze(0).repeat(batch_size, 1, 1)
|
||||
|
||||
|
||||
def apply_rope(xq: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Applies rotary positional embeddings (RoPE) to a query tensor.
|
||||
|
||||
Args:
|
||||
xq (`torch.Tensor`):
|
||||
Input tensor of shape `(..., dim)` representing the queries.
|
||||
freqs_cis (`torch.Tensor`):
|
||||
Precomputed rotary frequency components of shape `(..., dim/2, 2)` containing cosine and sine pairs.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Tensor of the same shape as `xq` with rotary embeddings applied.
|
||||
"""
|
||||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
||||
# Ensure freqs_cis is on the same device as queries to avoid device mismatches with offloading
|
||||
freqs_cis = freqs_cis.to(device=xq.device, dtype=xq_.dtype)
|
||||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
||||
return xq_out.reshape(*xq.shape).type_as(xq)
|
||||
|
||||
|
||||
class PRXAttnProcessor2_0:
|
||||
r"""
|
||||
Processor for implementing PRX-style attention with multi-source tokens and RoPE. Supports multiple attention
|
||||
backends (Flash Attention, Sage Attention, etc.) via dispatch_attention_fn.
|
||||
"""
|
||||
|
||||
_attention_backend = None
|
||||
_parallel_config = None
|
||||
|
||||
def __init__(self):
|
||||
if not hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
||||
raise ImportError("PRXAttnProcessor2_0 requires PyTorch 2.0, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: "PRXAttention",
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply PRX attention using PRXAttention module.
|
||||
|
||||
Args:
|
||||
attn: PRXAttention module containing projection layers
|
||||
hidden_states: Image tokens [B, L_img, D]
|
||||
encoder_hidden_states: Text tokens [B, L_txt, D]
|
||||
attention_mask: Boolean mask for text tokens [B, L_txt]
|
||||
image_rotary_emb: Rotary positional embeddings [B, 1, L_img, head_dim//2, 2, 2]
|
||||
"""
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
raise ValueError("PRXAttnProcessor2_0 requires 'encoder_hidden_states' containing text tokens.")
|
||||
|
||||
# Project image tokens to Q, K, V
|
||||
img_qkv = attn.img_qkv_proj(hidden_states)
|
||||
B, L_img, _ = img_qkv.shape
|
||||
img_qkv = img_qkv.reshape(B, L_img, 3, attn.heads, attn.head_dim)
|
||||
img_qkv = img_qkv.permute(2, 0, 3, 1, 4) # [3, B, H, L_img, D]
|
||||
img_q, img_k, img_v = img_qkv[0], img_qkv[1], img_qkv[2]
|
||||
|
||||
# Apply QK normalization to image tokens
|
||||
img_q = attn.norm_q(img_q)
|
||||
img_k = attn.norm_k(img_k)
|
||||
|
||||
# Project text tokens to K, V
|
||||
txt_kv = attn.txt_kv_proj(encoder_hidden_states)
|
||||
B, L_txt, _ = txt_kv.shape
|
||||
txt_kv = txt_kv.reshape(B, L_txt, 2, attn.heads, attn.head_dim)
|
||||
txt_kv = txt_kv.permute(2, 0, 3, 1, 4) # [2, B, H, L_txt, D]
|
||||
txt_k, txt_v = txt_kv[0], txt_kv[1]
|
||||
|
||||
# Apply K normalization to text tokens
|
||||
txt_k = attn.norm_added_k(txt_k)
|
||||
|
||||
# Apply RoPE to image queries and keys
|
||||
if image_rotary_emb is not None:
|
||||
img_q = apply_rope(img_q, image_rotary_emb)
|
||||
img_k = apply_rope(img_k, image_rotary_emb)
|
||||
|
||||
# Concatenate text and image keys/values
|
||||
k = torch.cat((txt_k, img_k), dim=2) # [B, H, L_txt + L_img, D]
|
||||
v = torch.cat((txt_v, img_v), dim=2) # [B, H, L_txt + L_img, D]
|
||||
|
||||
# Build attention mask if provided
|
||||
attn_mask_tensor = None
|
||||
if attention_mask is not None:
|
||||
bs, _, l_img, _ = img_q.shape
|
||||
l_txt = txt_k.shape[2]
|
||||
|
||||
if attention_mask.dim() != 2:
|
||||
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
|
||||
if attention_mask.shape[-1] != l_txt:
|
||||
raise ValueError(f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}")
|
||||
|
||||
device = img_q.device
|
||||
ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device)
|
||||
attention_mask = attention_mask.to(device=device, dtype=torch.bool)
|
||||
joint_mask = torch.cat([attention_mask, ones_img], dim=-1)
|
||||
attn_mask_tensor = joint_mask[:, None, None, :].expand(-1, attn.heads, l_img, -1)
|
||||
|
||||
# Apply attention using dispatch_attention_fn for backend support
|
||||
# Reshape to match dispatch_attention_fn expectations: [B, L, H, D]
|
||||
query = img_q.transpose(1, 2) # [B, L_img, H, D]
|
||||
key = k.transpose(1, 2) # [B, L_txt + L_img, H, D]
|
||||
value = v.transpose(1, 2) # [B, L_txt + L_img, H, D]
|
||||
|
||||
attn_output = dispatch_attention_fn(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_mask=attn_mask_tensor,
|
||||
backend=self._attention_backend,
|
||||
parallel_config=self._parallel_config,
|
||||
)
|
||||
|
||||
# Reshape from [B, L_img, H, D] to [B, L_img, H*D]
|
||||
batch_size, seq_len, num_heads, head_dim = attn_output.shape
|
||||
attn_output = attn_output.reshape(batch_size, seq_len, num_heads * head_dim)
|
||||
|
||||
# Apply output projection
|
||||
attn_output = attn.to_out[0](attn_output)
|
||||
if len(attn.to_out) > 1:
|
||||
attn_output = attn.to_out[1](attn_output) # dropout if present
|
||||
|
||||
return attn_output
|
||||
|
||||
|
||||
class PRXAttention(nn.Module, AttentionModuleMixin):
|
||||
r"""
|
||||
PRX-style attention module that handles multi-source tokens and RoPE. Similar to FluxAttention but adapted for
|
||||
PRX's architecture.
|
||||
"""
|
||||
|
||||
_default_processor_cls = PRXAttnProcessor2_0
|
||||
_available_processors = [PRXAttnProcessor2_0]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
query_dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
bias: bool = False,
|
||||
out_bias: bool = False,
|
||||
eps: float = 1e-6,
|
||||
processor=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.heads = heads
|
||||
self.head_dim = dim_head
|
||||
self.inner_dim = dim_head * heads
|
||||
self.query_dim = query_dim
|
||||
|
||||
self.img_qkv_proj = nn.Linear(query_dim, query_dim * 3, bias=bias)
|
||||
|
||||
self.norm_q = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
|
||||
self.norm_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.txt_kv_proj = nn.Linear(query_dim, query_dim * 2, bias=bias)
|
||||
self.norm_added_k = RMSNorm(self.head_dim, eps=eps, elementwise_affine=True)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, query_dim, bias=out_bias))
|
||||
self.to_out.append(nn.Dropout(0.0))
|
||||
|
||||
if processor is None:
|
||||
processor = self._default_processor_cls()
|
||||
self.set_processor(processor)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
image_rotary_emb: Optional[torch.Tensor] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.processor(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
# inspired from https://github.com/black-forest-labs/flux/blob/main/src/flux/modules/layers.py
|
||||
class PRXEmbedND(nn.Module):
|
||||
r"""
|
||||
N-dimensional rotary positional embedding.
|
||||
|
||||
This module creates rotary embeddings (RoPE) across multiple axes, where each axis can have its own embedding
|
||||
dimension. The embeddings are combined and returned as a single tensor
|
||||
|
||||
Args:
|
||||
dim (int):
|
||||
Base embedding dimension (must be even).
|
||||
theta (int):
|
||||
Scaling factor that controls the frequency spectrum of the rotary embeddings.
|
||||
axes_dim (list[int]):
|
||||
List of embedding dimensions for each axis (each must be even).
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.theta = theta
|
||||
self.axes_dim = axes_dim
|
||||
|
||||
def rope(self, pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
||||
assert dim % 2 == 0
|
||||
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
||||
omega = 1.0 / (theta**scale)
|
||||
out = pos.unsqueeze(-1) * omega.unsqueeze(0)
|
||||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
||||
# Native PyTorch equivalent of: Rearrange("b n d (i j) -> b n d i j", i=2, j=2)
|
||||
# out shape: (b, n, d, 4) -> reshape to (b, n, d, 2, 2)
|
||||
out = out.reshape(*out.shape[:-1], 2, 2)
|
||||
return out.float()
|
||||
|
||||
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
||||
n_axes = ids.shape[-1]
|
||||
emb = torch.cat(
|
||||
[self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
||||
dim=-3,
|
||||
)
|
||||
return emb.unsqueeze(1)
|
||||
|
||||
|
||||
class MLPEmbedder(nn.Module):
|
||||
r"""
|
||||
A simple 2-layer MLP used for embedding inputs.
|
||||
|
||||
Args:
|
||||
in_dim (`int`):
|
||||
Dimensionality of the input features.
|
||||
hidden_dim (`int`):
|
||||
Dimensionality of the hidden and output embedding space.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Tensor of shape `(..., hidden_dim)` containing the embedded representations.
|
||||
"""
|
||||
|
||||
def __init__(self, in_dim: int, hidden_dim: int):
|
||||
super().__init__()
|
||||
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
||||
self.silu = nn.SiLU()
|
||||
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.out_layer(self.silu(self.in_layer(x)))
|
||||
|
||||
|
||||
class Modulation(nn.Module):
|
||||
r"""
|
||||
Modulation network that generates scale, shift, and gating parameters.
|
||||
|
||||
Given an input vector, the module projects it through a linear layer to produce six chunks, which are grouped into
|
||||
two tuples `(shift, scale, gate)`.
|
||||
|
||||
Args:
|
||||
dim (`int`):
|
||||
Dimensionality of the input vector. The output will have `6 * dim` features internally.
|
||||
|
||||
Returns:
|
||||
((`torch.Tensor`, `torch.Tensor`, `torch.Tensor`), (`torch.Tensor`, `torch.Tensor`, `torch.Tensor`)):
|
||||
Two tuples `(shift, scale, gate)`.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int):
|
||||
super().__init__()
|
||||
self.lin = nn.Linear(dim, 6 * dim, bias=True)
|
||||
nn.init.constant_(self.lin.weight, 0)
|
||||
nn.init.constant_(self.lin.bias, 0)
|
||||
|
||||
def forward(
|
||||
self, vec: torch.Tensor
|
||||
) -> Tuple[Tuple[torch.Tensor, torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
|
||||
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1)
|
||||
return tuple(out[:3]), tuple(out[3:])
|
||||
|
||||
|
||||
class PRXBlock(nn.Module):
|
||||
r"""
|
||||
Multimodal transformer block with text–image cross-attention, modulation, and MLP.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Dimension of the hidden representations.
|
||||
num_heads (`int`):
|
||||
Number of attention heads.
|
||||
mlp_ratio (`float`, *optional*, defaults to 4.0):
|
||||
Expansion ratio for the hidden dimension inside the MLP.
|
||||
qk_scale (`float`, *optional*):
|
||||
Scale factor for queries and keys. If not provided, defaults to ``head_dim**-0.5``.
|
||||
|
||||
Attributes:
|
||||
img_pre_norm (`nn.LayerNorm`):
|
||||
Pre-normalization applied to image tokens before attention.
|
||||
attention (`PRXAttention`):
|
||||
Multi-head attention module with built-in QKV projections and normalizations for cross-attention between
|
||||
image and text tokens.
|
||||
post_attention_layernorm (`nn.LayerNorm`):
|
||||
Normalization applied after attention.
|
||||
gate_proj / up_proj / down_proj (`nn.Linear`):
|
||||
Feedforward layers forming the gated MLP.
|
||||
mlp_act (`nn.GELU`):
|
||||
Nonlinear activation used in the MLP.
|
||||
modulation (`Modulation`):
|
||||
Produces scale/shift/gating parameters for modulated layers.
|
||||
|
||||
Methods:
|
||||
The forward method performs cross-attention and the MLP with modulation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qk_scale: Optional[float] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_dim = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = hidden_size // num_heads
|
||||
self.scale = qk_scale or self.head_dim**-0.5
|
||||
|
||||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
||||
self.hidden_size = hidden_size
|
||||
|
||||
# Pre-attention normalization for image tokens
|
||||
self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
# PRXAttention module with built-in projections and norms
|
||||
self.attention = PRXAttention(
|
||||
query_dim=hidden_size,
|
||||
heads=num_heads,
|
||||
dim_head=self.head_dim,
|
||||
bias=False,
|
||||
out_bias=False,
|
||||
eps=1e-6,
|
||||
processor=PRXAttnProcessor2_0(),
|
||||
)
|
||||
|
||||
# mlp
|
||||
self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
|
||||
self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False)
|
||||
self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False)
|
||||
self.mlp_act = nn.GELU(approximate="tanh")
|
||||
|
||||
self.modulation = Modulation(hidden_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
temb: torch.Tensor,
|
||||
image_rotary_emb: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> torch.Tensor:
|
||||
r"""
|
||||
Runs modulation-gated cross-attention and MLP, with residual connections.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Image tokens of shape `(B, L_img, hidden_size)`.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Text tokens of shape `(B, L_txt, hidden_size)`.
|
||||
temb (`torch.Tensor`):
|
||||
Conditioning vector used by `Modulation` to produce scale/shift/gates, shape `(B, hidden_size)` (or
|
||||
broadcastable).
|
||||
image_rotary_emb (`torch.Tensor`):
|
||||
Rotary positional embeddings applied inside attention.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Boolean mask for text tokens of shape `(B, L_txt)`, where `0` marks padding.
|
||||
**kwargs:
|
||||
Additional keyword arguments for API compatibility.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Updated image tokens of shape `(B, L_img, hidden_size)`.
|
||||
"""
|
||||
|
||||
mod_attn, mod_mlp = self.modulation(temb)
|
||||
attn_shift, attn_scale, attn_gate = mod_attn
|
||||
mlp_shift, mlp_scale, mlp_gate = mod_mlp
|
||||
|
||||
hidden_states_mod = (1 + attn_scale) * self.img_pre_norm(hidden_states) + attn_shift
|
||||
|
||||
attn_out = self.attention(
|
||||
hidden_states=hidden_states_mod,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + attn_gate * attn_out
|
||||
|
||||
x = (1 + mlp_scale) * self.post_attention_layernorm(hidden_states) + mlp_shift
|
||||
hidden_states = hidden_states + mlp_gate * (self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)))
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FinalLayer(nn.Module):
|
||||
r"""
|
||||
Final projection layer with adaptive LayerNorm modulation.
|
||||
|
||||
This layer applies a normalized and modulated transformation to input tokens and projects them into patch-level
|
||||
outputs.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`):
|
||||
Dimensionality of the input tokens.
|
||||
patch_size (`int`):
|
||||
Size of the square image patches.
|
||||
out_channels (`int`):
|
||||
Number of output channels per pixel (e.g. RGB = 3).
|
||||
|
||||
Forward Inputs:
|
||||
x (`torch.Tensor`):
|
||||
Input tokens of shape `(B, L, hidden_size)`, where `L` is the number of patches.
|
||||
vec (`torch.Tensor`):
|
||||
Conditioning vector of shape `(B, hidden_size)` used to generate shift and scale parameters for adaptive
|
||||
LayerNorm.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Projected patch outputs of shape `(B, L, patch_size * patch_size * out_channels)`.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
||||
super().__init__()
|
||||
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
||||
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
||||
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
||||
|
||||
def forward(self, x: torch.Tensor, vec: torch.Tensor) -> torch.Tensor:
|
||||
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
||||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
||||
x = self.linear(x)
|
||||
return x
|
||||
|
||||
|
||||
def img2seq(img: torch.Tensor, patch_size: int) -> torch.Tensor:
|
||||
r"""
|
||||
Flattens an image tensor into a sequence of non-overlapping patches.
|
||||
|
||||
Args:
|
||||
img (`torch.Tensor`):
|
||||
Input image tensor of shape `(B, C, H, W)`.
|
||||
patch_size (`int`):
|
||||
Size of each square patch. Must evenly divide both `H` and `W`.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Flattened patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W
|
||||
// patch_size)` is the number of patches.
|
||||
"""
|
||||
return unfold(img, kernel_size=patch_size, stride=patch_size).transpose(1, 2)
|
||||
|
||||
|
||||
def seq2img(seq: torch.Tensor, patch_size: int, shape: torch.Tensor) -> torch.Tensor:
|
||||
r"""
|
||||
Reconstructs an image tensor from a sequence of patches (inverse of `img2seq`).
|
||||
|
||||
Args:
|
||||
seq (`torch.Tensor`):
|
||||
Patch sequence of shape `(B, L, C * patch_size * patch_size)`, where `L = (H // patch_size) * (W //
|
||||
patch_size)`.
|
||||
patch_size (`int`):
|
||||
Size of each square patch.
|
||||
shape (`tuple` or `torch.Tensor`):
|
||||
The original image spatial shape `(H, W)`. If a tensor is provided, the first two values are interpreted as
|
||||
height and width.
|
||||
|
||||
Returns:
|
||||
`torch.Tensor`:
|
||||
Reconstructed image tensor of shape `(B, C, H, W)`.
|
||||
"""
|
||||
if isinstance(shape, tuple):
|
||||
shape = shape[-2:]
|
||||
elif isinstance(shape, torch.Tensor):
|
||||
shape = (int(shape[0]), int(shape[1]))
|
||||
else:
|
||||
raise NotImplementedError(f"shape type {type(shape)} not supported")
|
||||
return fold(seq.transpose(1, 2), shape, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
|
||||
class PRXTransformer2DModel(ModelMixin, ConfigMixin, AttentionMixin):
|
||||
r"""
|
||||
Transformer-based 2D model for text to image generation.
|
||||
|
||||
Args:
|
||||
in_channels (`int`, *optional*, defaults to 16):
|
||||
Number of input channels in the latent image.
|
||||
patch_size (`int`, *optional*, defaults to 2):
|
||||
Size of the square patches used to flatten the input image.
|
||||
context_in_dim (`int`, *optional*, defaults to 2304):
|
||||
Dimensionality of the text conditioning input.
|
||||
hidden_size (`int`, *optional*, defaults to 1792):
|
||||
Dimension of the hidden representation.
|
||||
mlp_ratio (`float`, *optional*, defaults to 3.5):
|
||||
Expansion ratio for the hidden dimension inside MLP blocks.
|
||||
num_heads (`int`, *optional*, defaults to 28):
|
||||
Number of attention heads.
|
||||
depth (`int`, *optional*, defaults to 16):
|
||||
Number of transformer blocks.
|
||||
axes_dim (`list[int]`, *optional*):
|
||||
List of dimensions for each positional embedding axis. Defaults to `[32, 32]`.
|
||||
theta (`int`, *optional*, defaults to 10000):
|
||||
Frequency scaling factor for rotary embeddings.
|
||||
time_factor (`float`, *optional*, defaults to 1000.0):
|
||||
Scaling factor applied in timestep embeddings.
|
||||
time_max_period (`int`, *optional*, defaults to 10000):
|
||||
Maximum frequency period for timestep embeddings.
|
||||
|
||||
Attributes:
|
||||
pe_embedder (`EmbedND`):
|
||||
Multi-axis rotary embedding generator for positional encodings.
|
||||
img_in (`nn.Linear`):
|
||||
Projection layer for image patch tokens.
|
||||
time_in (`MLPEmbedder`):
|
||||
Embedding layer for timestep embeddings.
|
||||
txt_in (`nn.Linear`):
|
||||
Projection layer for text conditioning.
|
||||
blocks (`nn.ModuleList`):
|
||||
Stack of transformer blocks (`PRXBlock`).
|
||||
final_layer (`LastLayer`):
|
||||
Projection layer mapping hidden tokens back to patch outputs.
|
||||
|
||||
Methods:
|
||||
attn_processors:
|
||||
Returns a dictionary of all attention processors in the model.
|
||||
set_attn_processor(processor):
|
||||
Replaces attention processors across all attention layers.
|
||||
process_inputs(image_latent, txt):
|
||||
Converts inputs into patch tokens, encodes text, and produces positional encodings.
|
||||
compute_timestep_embedding(timestep, dtype):
|
||||
Creates a timestep embedding of dimension 256, scaled and projected.
|
||||
forward_transformers(image_latent, cross_attn_conditioning, timestep, time_embedding, attention_mask,
|
||||
**block_kwargs):
|
||||
Runs the sequence of transformer blocks over image and text tokens.
|
||||
forward(image_latent, timestep, cross_attn_conditioning, micro_conditioning, cross_attn_mask=None,
|
||||
attention_kwargs=None, return_dict=True):
|
||||
Full forward pass from latent input to reconstructed output image.
|
||||
|
||||
Returns:
|
||||
`Transformer2DModelOutput` if `return_dict=True` (default), otherwise a tuple containing:
|
||||
- `sample` (`torch.Tensor`): Reconstructed image of shape `(B, C, H, W)`.
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 16,
|
||||
patch_size: int = 2,
|
||||
context_in_dim: int = 2304,
|
||||
hidden_size: int = 1792,
|
||||
mlp_ratio: float = 3.5,
|
||||
num_heads: int = 28,
|
||||
depth: int = 16,
|
||||
axes_dim: list = None,
|
||||
theta: int = 10000,
|
||||
time_factor: float = 1000.0,
|
||||
time_max_period: int = 10000,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if axes_dim is None:
|
||||
axes_dim = [32, 32]
|
||||
|
||||
# Store parameters directly
|
||||
self.in_channels = in_channels
|
||||
self.patch_size = patch_size
|
||||
self.out_channels = self.in_channels * self.patch_size**2
|
||||
|
||||
self.time_factor = time_factor
|
||||
self.time_max_period = time_max_period
|
||||
|
||||
if hidden_size % num_heads != 0:
|
||||
raise ValueError(f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}")
|
||||
|
||||
pe_dim = hidden_size // num_heads
|
||||
|
||||
if sum(axes_dim) != pe_dim:
|
||||
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.pe_embedder = PRXEmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
||||
self.img_in = nn.Linear(self.in_channels * self.patch_size**2, self.hidden_size, bias=True)
|
||||
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
||||
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[
|
||||
PRXBlock(
|
||||
self.hidden_size,
|
||||
self.num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
)
|
||||
for i in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.final_layer = FinalLayer(self.hidden_size, 1, self.out_channels)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def _compute_timestep_embedding(self, timestep: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
return self.time_in(
|
||||
get_timestep_embedding(
|
||||
timesteps=timestep,
|
||||
embedding_dim=256,
|
||||
max_period=self.time_max_period,
|
||||
scale=self.time_factor,
|
||||
flip_sin_to_cos=True, # Match original cos, sin order
|
||||
).to(dtype)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
timestep: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[Tuple[torch.Tensor, ...], Transformer2DModelOutput]:
|
||||
r"""
|
||||
Forward pass of the PRXTransformer2DModel.
|
||||
|
||||
The latent image is split into patch tokens, combined with text conditioning, and processed through a stack of
|
||||
transformer blocks modulated by the timestep. The output is reconstructed into the latent image space.
|
||||
|
||||
Args:
|
||||
hidden_states (`torch.Tensor`):
|
||||
Input latent image tensor of shape `(B, C, H, W)`.
|
||||
timestep (`torch.Tensor`):
|
||||
Timestep tensor of shape `(B,)` or `(1,)`, used for temporal conditioning.
|
||||
encoder_hidden_states (`torch.Tensor`):
|
||||
Text conditioning tensor of shape `(B, L_txt, context_in_dim)`.
|
||||
attention_mask (`torch.Tensor`, *optional*):
|
||||
Boolean mask of shape `(B, L_txt)`, where `0` marks padding in the text sequence.
|
||||
attention_kwargs (`dict`, *optional*):
|
||||
Additional arguments passed to attention layers.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether to return a `Transformer2DModelOutput` or a tuple.
|
||||
|
||||
Returns:
|
||||
`Transformer2DModelOutput` if `return_dict=True`, otherwise a tuple:
|
||||
|
||||
- `sample` (`torch.Tensor`): Output latent image of shape `(B, C, H, W)`.
|
||||
"""
|
||||
# Process text conditioning
|
||||
txt = self.txt_in(encoder_hidden_states)
|
||||
|
||||
# Convert image to sequence and embed
|
||||
img = img2seq(hidden_states, self.patch_size)
|
||||
img = self.img_in(img)
|
||||
|
||||
# Generate positional embeddings
|
||||
bs, _, h, w = hidden_states.shape
|
||||
img_ids = get_image_ids(bs, h, w, patch_size=self.patch_size, device=hidden_states.device)
|
||||
pe = self.pe_embedder(img_ids)
|
||||
|
||||
# Compute time embedding
|
||||
vec = self._compute_timestep_embedding(timestep, dtype=img.dtype)
|
||||
|
||||
# Apply transformer blocks
|
||||
for block in self.blocks:
|
||||
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
||||
img = self._gradient_checkpointing_func(
|
||||
block.__call__,
|
||||
img,
|
||||
txt,
|
||||
vec,
|
||||
pe,
|
||||
attention_mask,
|
||||
)
|
||||
else:
|
||||
img = block(
|
||||
hidden_states=img,
|
||||
encoder_hidden_states=txt,
|
||||
temb=vec,
|
||||
image_rotary_emb=pe,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
|
||||
# Final layer and convert back to image
|
||||
img = self.final_layer(img, vec)
|
||||
output = seq2img(img, self.patch_size, hidden_states.shape)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
@@ -144,6 +144,7 @@ else:
|
||||
"FluxKontextPipeline",
|
||||
"FluxKontextInpaintPipeline",
|
||||
]
|
||||
_import_structure["prx"] = ["PRXPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
@@ -719,6 +720,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .prx import PRXPipeline
|
||||
from .qwenimage import (
|
||||
QwenImageControlNetInpaintPipeline,
|
||||
QwenImageControlNetPipeline,
|
||||
|
||||
63
src/diffusers/pipelines/prx/__init__.py
Normal file
63
src/diffusers/pipelines/prx/__init__.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...utils import (
|
||||
DIFFUSERS_SLOW_IMPORT,
|
||||
OptionalDependencyNotAvailable,
|
||||
_LazyModule,
|
||||
get_objects_from_module,
|
||||
is_torch_available,
|
||||
is_transformers_available,
|
||||
)
|
||||
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils import dummy_torch_and_transformers_objects # noqa F403
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_prx"] = ["PRXPipeline"]
|
||||
|
||||
# Import T5GemmaEncoder for pipeline loading compatibility
|
||||
try:
|
||||
if is_transformers_available():
|
||||
import transformers
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
|
||||
|
||||
_additional_imports["T5GemmaEncoder"] = T5GemmaEncoder
|
||||
# Patch transformers module directly for serialization
|
||||
if not hasattr(transformers, "T5GemmaEncoder"):
|
||||
transformers.T5GemmaEncoder = T5GemmaEncoder
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
raise OptionalDependencyNotAvailable()
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_output import PRXPipelineOutput
|
||||
from .pipeline_prx import PRXPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = _LazyModule(
|
||||
__name__,
|
||||
globals()["__file__"],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
)
|
||||
|
||||
for name, value in _dummy_objects.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
for name, value in _additional_imports.items():
|
||||
setattr(sys.modules[__name__], name, value)
|
||||
35
src/diffusers/pipelines/prx/pipeline_output.py
Normal file
35
src/diffusers/pipelines/prx/pipeline_output.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright 2025 The Photoroom and the HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
|
||||
from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class PRXPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for PRX pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
|
||||
num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
|
||||
"""
|
||||
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
767
src/diffusers/pipelines/prx/pipeline_prx.py
Normal file
767
src/diffusers/pipelines/prx/pipeline_prx.py
Normal file
@@ -0,0 +1,767 @@
|
||||
# Copyright 2025 The Photoroom and The HuggingFace Teams. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import html
|
||||
import inspect
|
||||
import re
|
||||
import urllib.parse as ul
|
||||
from typing import Callable, Dict, List, Optional, Union
|
||||
|
||||
import ftfy
|
||||
import torch
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
GemmaTokenizerFast,
|
||||
T5TokenizerFast,
|
||||
)
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
|
||||
|
||||
from diffusers.image_processor import PixArtImageProcessor
|
||||
from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
||||
from diffusers.models import AutoencoderDC, AutoencoderKL
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
||||
from diffusers.pipelines.prx.pipeline_output import PRXPipelineOutput
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import (
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
)
|
||||
from diffusers.utils.torch_utils import randn_tensor
|
||||
|
||||
|
||||
DEFAULT_RESOLUTION = 512
|
||||
|
||||
ASPECT_RATIO_256_BIN = {
|
||||
"0.46": [160, 352],
|
||||
"0.6": [192, 320],
|
||||
"0.78": [224, 288],
|
||||
"1.0": [256, 256],
|
||||
"1.29": [288, 224],
|
||||
"1.67": [320, 192],
|
||||
"2.2": [352, 160],
|
||||
}
|
||||
|
||||
ASPECT_RATIO_512_BIN = {
|
||||
"0.5": [352, 704],
|
||||
"0.57": [384, 672],
|
||||
"0.6": [384, 640],
|
||||
"0.68": [416, 608],
|
||||
"0.78": [448, 576],
|
||||
"0.88": [480, 544],
|
||||
"1.0": [512, 512],
|
||||
"1.13": [544, 480],
|
||||
"1.29": [576, 448],
|
||||
"1.46": [608, 416],
|
||||
"1.67": [640, 384],
|
||||
"1.75": [672, 384],
|
||||
"2.0": [704, 352],
|
||||
}
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
"""Text preprocessing utility for PRXPipeline."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize text preprocessor."""
|
||||
self.bad_punct_regex = re.compile(
|
||||
r"["
|
||||
+ "#®•©™&@·º½¾¿¡§~"
|
||||
+ r"\)"
|
||||
+ r"\("
|
||||
+ r"\]"
|
||||
+ r"\["
|
||||
+ r"\}"
|
||||
+ r"\{"
|
||||
+ r"\|"
|
||||
+ r"\\"
|
||||
+ r"\/"
|
||||
+ r"\*"
|
||||
+ r"]{1,}"
|
||||
)
|
||||
|
||||
def clean_text(self, text: str) -> str:
|
||||
"""Clean text using comprehensive text processing logic."""
|
||||
# See Deepfloyd https://github.com/deep-floyd/IF/blob/develop/deepfloyd_if/modules/t5.py
|
||||
text = str(text)
|
||||
text = ul.unquote_plus(text)
|
||||
text = text.strip().lower()
|
||||
text = re.sub("<person>", "person", text)
|
||||
|
||||
# Remove all urls:
|
||||
text = re.sub(
|
||||
r"\b((?:https?|www):(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@))",
|
||||
"",
|
||||
text,
|
||||
) # regex for urls
|
||||
|
||||
# @<nickname>
|
||||
text = re.sub(r"@[\w\d]+\b", "", text)
|
||||
|
||||
# 31C0—31EF CJK Strokes through 4E00—9FFF CJK Unified Ideographs
|
||||
text = re.sub(r"[\u31c0-\u31ef]+", "", text)
|
||||
text = re.sub(r"[\u31f0-\u31ff]+", "", text)
|
||||
text = re.sub(r"[\u3200-\u32ff]+", "", text)
|
||||
text = re.sub(r"[\u3300-\u33ff]+", "", text)
|
||||
text = re.sub(r"[\u3400-\u4dbf]+", "", text)
|
||||
text = re.sub(r"[\u4dc0-\u4dff]+", "", text)
|
||||
text = re.sub(r"[\u4e00-\u9fff]+", "", text)
|
||||
|
||||
# все виды тире / all types of dash --> "-"
|
||||
text = re.sub(
|
||||
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
|
||||
"-",
|
||||
text,
|
||||
)
|
||||
|
||||
# кавычки к одному стандарту
|
||||
text = re.sub(r"[`´«»" "¨]", '"', text)
|
||||
text = re.sub(r"['']", "'", text)
|
||||
|
||||
# " and &
|
||||
text = re.sub(r""?", "", text)
|
||||
text = re.sub(r"&", "", text)
|
||||
|
||||
# ip addresses:
|
||||
text = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", text)
|
||||
|
||||
# article ids:
|
||||
text = re.sub(r"\d:\d\d\s+$", "", text)
|
||||
|
||||
# \n
|
||||
text = re.sub(r"\\n", " ", text)
|
||||
|
||||
# "#123", "#12345..", "123456.."
|
||||
text = re.sub(r"#\d{1,3}\b", "", text)
|
||||
text = re.sub(r"#\d{5,}\b", "", text)
|
||||
text = re.sub(r"\b\d{6,}\b", "", text)
|
||||
|
||||
# filenames:
|
||||
text = re.sub(r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", text)
|
||||
|
||||
# Clean punctuation
|
||||
text = re.sub(r"[\"\']{2,}", r'"', text) # """AUSVERKAUFT"""
|
||||
text = re.sub(r"[\.]{2,}", r" ", text)
|
||||
|
||||
text = re.sub(self.bad_punct_regex, r" ", text) # ***AUSVERKAUFT***, #AUSVERKAUFT
|
||||
text = re.sub(r"\s+\.\s+", r" ", text) # " . "
|
||||
|
||||
# this-is-my-cute-cat / this_is_my_cute_cat
|
||||
regex2 = re.compile(r"(?:\-|\_)")
|
||||
if len(re.findall(regex2, text)) > 3:
|
||||
text = re.sub(regex2, " ", text)
|
||||
|
||||
# Basic cleaning
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
text = text.strip()
|
||||
|
||||
# Clean alphanumeric patterns
|
||||
text = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", text) # jc6640
|
||||
text = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", text) # jc6640vc
|
||||
text = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", text) # 6640vc231
|
||||
|
||||
# Common spam patterns
|
||||
text = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", text)
|
||||
text = re.sub(r"(free\s)?download(\sfree)?", "", text)
|
||||
text = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", text)
|
||||
text = re.sub(r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", text)
|
||||
text = re.sub(r"\bpage\s+\d+\b", "", text)
|
||||
|
||||
text = re.sub(r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", text) # j2d1a2a...
|
||||
text = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", text)
|
||||
|
||||
# Final cleanup
|
||||
text = re.sub(r"\b\s+\:\s+", r": ", text)
|
||||
text = re.sub(r"(\D[,\./])\b", r"\1 ", text)
|
||||
text = re.sub(r"\s+", " ", text)
|
||||
|
||||
text.strip()
|
||||
|
||||
text = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", text)
|
||||
text = re.sub(r"^[\'\_,\-\:;]", r"", text)
|
||||
text = re.sub(r"[\'\_,\-\:\-\+]$", r"", text)
|
||||
text = re.sub(r"^\.\S+$", "", text)
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
Examples:
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import PRXPipeline
|
||||
|
||||
>>> # Load pipeline with from_pretrained
|
||||
>>> pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft")
|
||||
>>> pipe.to("cuda")
|
||||
|
||||
>>> prompt = "A digital painting of a rusty, vintage tram on a sandy beach"
|
||||
>>> image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
>>> image.save("prx_output.png")
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class PRXPipeline(
|
||||
DiffusionPipeline,
|
||||
LoraLoaderMixin,
|
||||
FromSingleFileMixin,
|
||||
TextualInversionLoaderMixin,
|
||||
):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using PRX Transformer.
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
transformer ([`PRXTransformer2DModel`]):
|
||||
The PRX transformer model to denoise the encoded image latents.
|
||||
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
text_encoder ([`T5GemmaEncoder`]):
|
||||
Text encoder model for encoding prompts.
|
||||
tokenizer ([`T5TokenizerFast` or `GemmaTokenizerFast`]):
|
||||
Tokenizer for the text encoder.
|
||||
vae ([`AutoencoderKL`] or [`AutoencoderDC`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
Supports both AutoencoderKL (8x compression) and AutoencoderDC (32x compression).
|
||||
"""
|
||||
|
||||
model_cpu_offload_seq = "text_encoder->transformer->vae"
|
||||
_callback_tensor_inputs = ["latents", "prompt_embeds"]
|
||||
_optional_components = ["vae"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: PRXTransformer2DModel,
|
||||
scheduler: FlowMatchEulerDiscreteScheduler,
|
||||
text_encoder: T5GemmaEncoder,
|
||||
tokenizer: Union[T5TokenizerFast, GemmaTokenizerFast, AutoTokenizer],
|
||||
vae: Optional[Union[AutoencoderKL, AutoencoderDC]] = None,
|
||||
default_sample_size: Optional[int] = DEFAULT_RESOLUTION,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if PRXTransformer2DModel is None:
|
||||
raise ImportError(
|
||||
"PRXTransformer2DModel is not available. Please ensure the transformer_prx module is properly installed."
|
||||
)
|
||||
|
||||
self.text_preprocessor = TextPreprocessor()
|
||||
self.default_sample_size = default_sample_size
|
||||
self._guidance_scale = 1.0
|
||||
|
||||
self.register_modules(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
)
|
||||
|
||||
self.register_to_config(default_sample_size=self.default_sample_size)
|
||||
|
||||
if vae is not None:
|
||||
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
||||
else:
|
||||
self.image_processor = None
|
||||
|
||||
@property
|
||||
def vae_scale_factor(self):
|
||||
if self.vae is None:
|
||||
return 8
|
||||
if hasattr(self.vae, "spatial_compression_ratio"):
|
||||
return self.vae.spatial_compression_ratio
|
||||
else: # Flux VAE
|
||||
return 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
@property
|
||||
def do_classifier_free_guidance(self):
|
||||
"""Check if classifier-free guidance is enabled based on guidance scale."""
|
||||
return self._guidance_scale > 1.0
|
||||
|
||||
@property
|
||||
def guidance_scale(self):
|
||||
return self._guidance_scale
|
||||
|
||||
def get_default_resolution(self):
|
||||
"""Determine the default resolution based on the loaded VAE and config.
|
||||
|
||||
Returns:
|
||||
int: The default sample size (height/width) to use for generation.
|
||||
"""
|
||||
default_from_config = getattr(self.config, "default_sample_size", None)
|
||||
if default_from_config is not None:
|
||||
return default_from_config
|
||||
|
||||
return DEFAULT_RESOLUTION
|
||||
|
||||
def prepare_latents(
|
||||
self,
|
||||
batch_size: int,
|
||||
num_channels_latents: int,
|
||||
height: int,
|
||||
width: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Prepare initial latents for the diffusion process."""
|
||||
if latents is None:
|
||||
spatial_compression = self.vae_scale_factor
|
||||
latent_height, latent_width = (
|
||||
height // spatial_compression,
|
||||
width // spatial_compression,
|
||||
)
|
||||
shape = (batch_size, num_channels_latents, latent_height, latent_width)
|
||||
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
latents = latents.to(device)
|
||||
return latents
|
||||
|
||||
def encode_prompt(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
device: Optional[torch.device] = None,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
num_images_per_prompt: int = 1,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
):
|
||||
"""Encode text prompt using standard text encoder and tokenizer, or use precomputed embeddings."""
|
||||
if device is None:
|
||||
device = self._execution_device
|
||||
|
||||
if prompt_embeds is None:
|
||||
if isinstance(prompt, str):
|
||||
prompt = [prompt]
|
||||
# Encode the prompts
|
||||
prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask = (
|
||||
self._encode_prompt_standard(prompt, device, do_classifier_free_guidance, negative_prompt)
|
||||
)
|
||||
|
||||
# Duplicate embeddings for each generation per prompt
|
||||
if num_images_per_prompt > 1:
|
||||
# Repeat prompt embeddings
|
||||
bs_embed, seq_len, _ = prompt_embeds.shape
|
||||
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if prompt_attention_mask is not None:
|
||||
prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
|
||||
prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
# Repeat negative embeddings if using CFG
|
||||
if do_classifier_free_guidance and negative_prompt_embeds is not None:
|
||||
bs_embed, seq_len, _ = negative_prompt_embeds.shape
|
||||
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
||||
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
||||
|
||||
if negative_prompt_attention_mask is not None:
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
|
||||
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
|
||||
|
||||
return (
|
||||
prompt_embeds,
|
||||
prompt_attention_mask,
|
||||
negative_prompt_embeds if do_classifier_free_guidance else None,
|
||||
negative_prompt_attention_mask if do_classifier_free_guidance else None,
|
||||
)
|
||||
|
||||
def _tokenize_prompts(self, prompts: List[str], device: torch.device):
|
||||
"""Tokenize and clean prompts."""
|
||||
cleaned = [self.text_preprocessor.clean_text(text) for text in prompts]
|
||||
tokens = self.tokenizer(
|
||||
cleaned,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_attention_mask=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return tokens["input_ids"].to(device), tokens["attention_mask"].bool().to(device)
|
||||
|
||||
def _encode_prompt_standard(
|
||||
self,
|
||||
prompt: List[str],
|
||||
device: torch.device,
|
||||
do_classifier_free_guidance: bool = True,
|
||||
negative_prompt: str = "",
|
||||
):
|
||||
"""Encode prompt using standard text encoder and tokenizer with batch processing."""
|
||||
batch_size = len(prompt)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
if isinstance(negative_prompt, str):
|
||||
negative_prompt = [negative_prompt] * batch_size
|
||||
|
||||
prompts_to_encode = negative_prompt + prompt
|
||||
else:
|
||||
prompts_to_encode = prompt
|
||||
|
||||
input_ids, attention_mask = self._tokenize_prompts(prompts_to_encode, device)
|
||||
|
||||
with torch.no_grad():
|
||||
embeddings = self.text_encoder(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
output_hidden_states=True,
|
||||
)["last_hidden_state"]
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_text_embeddings, text_embeddings = embeddings.split(batch_size, dim=0)
|
||||
uncond_cross_attn_mask, cross_attn_mask = attention_mask.split(batch_size, dim=0)
|
||||
else:
|
||||
text_embeddings = embeddings
|
||||
cross_attn_mask = attention_mask
|
||||
uncond_text_embeddings = None
|
||||
uncond_cross_attn_mask = None
|
||||
|
||||
return text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask
|
||||
|
||||
def check_inputs(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
height: int,
|
||||
width: int,
|
||||
guidance_scale: float,
|
||||
callback_on_step_end_tensor_inputs: Optional[List[str]] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
):
|
||||
"""Check that all inputs are in correct format."""
|
||||
if prompt is not None and prompt_embeds is not None:
|
||||
raise ValueError(
|
||||
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
|
||||
" only forward one of the two."
|
||||
)
|
||||
|
||||
if prompt is None and prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
|
||||
)
|
||||
|
||||
if prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
if prompt_embeds is not None and guidance_scale > 1.0 and negative_prompt_embeds is None:
|
||||
raise ValueError(
|
||||
"When `prompt_embeds` is provided and `guidance_scale > 1.0`, "
|
||||
"`negative_prompt_embeds` must also be provided for classifier-free guidance."
|
||||
)
|
||||
|
||||
spatial_compression = self.vae_scale_factor
|
||||
if height % spatial_compression != 0 or width % spatial_compression != 0:
|
||||
raise ValueError(
|
||||
f"`height` and `width` have to be divisible by {spatial_compression} but are {height} and {width}."
|
||||
)
|
||||
|
||||
if guidance_scale < 1.0:
|
||||
raise ValueError(f"guidance_scale has to be >= 1.0 but is {guidance_scale}")
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not isinstance(callback_on_step_end_tensor_inputs, list):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be a list but is {callback_on_step_end_tensor_inputs}"
|
||||
)
|
||||
|
||||
if callback_on_step_end_tensor_inputs is not None and not all(
|
||||
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]] = None,
|
||||
negative_prompt: str = "",
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 28,
|
||||
timesteps: List[int] = None,
|
||||
guidance_scale: float = 4.0,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.Tensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
prompt_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
negative_prompt_attention_mask: Optional[torch.BoolTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
use_resolution_binning: bool = True,
|
||||
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
||||
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
||||
):
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`
|
||||
instead.
|
||||
negative_prompt (`str`, *optional*, defaults to `""`):
|
||||
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
|
||||
if `guidance_scale` is less than `1`).
|
||||
height (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.transformer.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 28):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
timesteps (`List[int]`, *optional*):
|
||||
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
||||
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
||||
passed will be used. Must be in descending order.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.Tensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will be generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided and `guidance_scale > 1`, negative embeddings will be generated from an
|
||||
empty string.
|
||||
prompt_attention_mask (`torch.BoolTensor`, *optional*):
|
||||
Pre-generated attention mask for `prompt_embeds`. If not provided, attention mask will be generated
|
||||
from `prompt` input argument.
|
||||
negative_prompt_attention_mask (`torch.BoolTensor`, *optional*):
|
||||
Pre-generated attention mask for `negative_prompt_embeds`. If not provided and `guidance_scale > 1`,
|
||||
attention mask will be generated from an empty string.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.prx.PRXPipelineOutput`] instead of a plain tuple.
|
||||
use_resolution_binning (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, the requested height and width are first mapped to the closest resolutions using
|
||||
predefined aspect ratio bins. After the produced latents are decoded into images, they are resized back
|
||||
to the requested resolution. Useful for generating non-square images at optimal resolutions.
|
||||
callback_on_step_end (`Callable`, *optional*):
|
||||
A function that calls at the end of each denoising steps during the inference. The function is called
|
||||
with the following arguments: `callback_on_step_end(self, step, timestep, callback_kwargs)`.
|
||||
`callback_kwargs` will include a list of all tensors as specified by
|
||||
`callback_on_step_end_tensor_inputs`.
|
||||
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
||||
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
||||
will be passed as `callback_kwargs` argument. You will only be able to include tensors that are listed
|
||||
in the `._callback_tensor_inputs` attribute.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.prx.PRXPipelineOutput`] or `tuple`: [`~pipelines.prx.PRXPipelineOutput`] if `return_dict` is
|
||||
True, otherwise a `tuple. When returning a tuple, the first element is a list with the generated images.
|
||||
"""
|
||||
|
||||
# 0. Set height and width
|
||||
default_resolution = self.get_default_resolution()
|
||||
height = height or default_resolution
|
||||
width = width or default_resolution
|
||||
|
||||
if use_resolution_binning:
|
||||
if self.image_processor is None:
|
||||
raise ValueError(
|
||||
"Resolution binning requires a VAE with image_processor, but VAE is not available. "
|
||||
"Set use_resolution_binning=False or provide a VAE."
|
||||
)
|
||||
if self.default_sample_size <= 256:
|
||||
aspect_ratio_bin = ASPECT_RATIO_256_BIN
|
||||
else:
|
||||
aspect_ratio_bin = ASPECT_RATIO_512_BIN
|
||||
|
||||
# Store original dimensions
|
||||
orig_height, orig_width = height, width
|
||||
# Map to closest resolution in the bin
|
||||
height, width = self.image_processor.classify_height_width_bin(height, width, ratios=aspect_ratio_bin)
|
||||
|
||||
# 1. Check inputs
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
height,
|
||||
width,
|
||||
guidance_scale,
|
||||
callback_on_step_end_tensor_inputs,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
)
|
||||
|
||||
if self.vae is None and output_type not in ["latent", "pt"]:
|
||||
raise ValueError(
|
||||
f"VAE is required for output_type='{output_type}' but it is not available. "
|
||||
"Either provide a VAE or set output_type='latent' or 'pt' to get latent outputs."
|
||||
)
|
||||
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
# Use execution device (handles offloading scenarios including group offloading)
|
||||
device = self._execution_device
|
||||
|
||||
self._guidance_scale = guidance_scale
|
||||
|
||||
# 2. Encode input prompt
|
||||
text_embeddings, cross_attn_mask, uncond_text_embeddings, uncond_cross_attn_mask = self.encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
||||
negative_prompt=negative_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
prompt_attention_mask=prompt_attention_mask,
|
||||
negative_prompt_attention_mask=negative_prompt_attention_mask,
|
||||
)
|
||||
# Expose standard names for callbacks parity
|
||||
prompt_embeds = text_embeddings
|
||||
negative_prompt_embeds = uncond_text_embeddings
|
||||
|
||||
# 3. Prepare timesteps
|
||||
if timesteps is not None:
|
||||
self.scheduler.set_timesteps(timesteps=timesteps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
num_inference_steps = len(timesteps)
|
||||
else:
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
self.num_timesteps = len(timesteps)
|
||||
|
||||
# 4. Prepare latent variables
|
||||
if self.vae is not None:
|
||||
num_channels_latents = self.vae.config.latent_channels
|
||||
else:
|
||||
# When vae is None, get latent channels from transformer
|
||||
num_channels_latents = self.transformer.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
text_embeddings.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 5. Prepare extra step kwargs
|
||||
extra_step_kwargs = {}
|
||||
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
|
||||
if accepts_eta:
|
||||
extra_step_kwargs["eta"] = 0.0
|
||||
|
||||
# 6. Prepare cross-attention embeddings and masks
|
||||
if self.do_classifier_free_guidance:
|
||||
ca_embed = torch.cat([uncond_text_embeddings, text_embeddings], dim=0)
|
||||
ca_mask = None
|
||||
if cross_attn_mask is not None and uncond_cross_attn_mask is not None:
|
||||
ca_mask = torch.cat([uncond_cross_attn_mask, cross_attn_mask], dim=0)
|
||||
else:
|
||||
ca_embed = text_embeddings
|
||||
ca_mask = cross_attn_mask
|
||||
|
||||
# 7. Denoising loop
|
||||
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
||||
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# Duplicate latents if using classifier-free guidance
|
||||
if self.do_classifier_free_guidance:
|
||||
latents_in = torch.cat([latents, latents], dim=0)
|
||||
# Normalize timestep for the transformer
|
||||
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).repeat(2).to(device)
|
||||
else:
|
||||
latents_in = latents
|
||||
# Normalize timestep for the transformer
|
||||
t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device)
|
||||
|
||||
# Forward through transformer
|
||||
noise_pred = self.transformer(
|
||||
hidden_states=latents_in,
|
||||
timestep=t_cont,
|
||||
encoder_hidden_states=ca_embed,
|
||||
attention_mask=ca_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# Apply CFG
|
||||
if self.do_classifier_free_guidance:
|
||||
noise_uncond, noise_text = noise_pred.chunk(2, dim=0)
|
||||
noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
|
||||
|
||||
# Compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
|
||||
if callback_on_step_end is not None:
|
||||
callback_kwargs = {}
|
||||
for k in callback_on_step_end_tensor_inputs:
|
||||
callback_kwargs[k] = locals()[k]
|
||||
callback_on_step_end(self, i, t, callback_kwargs)
|
||||
|
||||
# Call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
|
||||
# 8. Post-processing
|
||||
if output_type == "latent" or (output_type == "pt" and self.vae is None):
|
||||
image = latents
|
||||
else:
|
||||
# Unscale latents for VAE (supports both AutoencoderKL and AutoencoderDC)
|
||||
scaling_factor = getattr(self.vae.config, "scaling_factor", 0.18215)
|
||||
shift_factor = getattr(self.vae.config, "shift_factor", 0.0)
|
||||
latents = (latents / scaling_factor) + shift_factor
|
||||
# Decode using VAE (AutoencoderKL or AutoencoderDC)
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
# Resize back to original resolution if using binning
|
||||
if use_resolution_binning:
|
||||
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
||||
|
||||
# Use standard image processor for post-processing
|
||||
image = self.image_processor.postprocess(image, output_type=output_type)
|
||||
|
||||
# Offload all models
|
||||
self.maybe_free_model_hooks()
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return PRXPipelineOutput(images=image)
|
||||
@@ -1128,6 +1128,21 @@ class PriorTransformer(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PRXTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class QwenImageControlNetModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -1907,6 +1907,21 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PRXPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class QwenImageControlNetInpaintPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
83
tests/models/transformers/test_models_transformer_prx.py
Normal file
83
tests/models/transformers/test_models_transformer_prx.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2025 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
|
||||
from ...testing_utils import enable_full_determinism, torch_device
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
class PRXTransformerTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = PRXTransformer2DModel
|
||||
main_input_name = "hidden_states"
|
||||
uses_custom_attn_processor = True
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
return self.prepare_dummy_input()
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (16, 16, 16)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (16, 16, 16)
|
||||
|
||||
def prepare_dummy_input(self, height=16, width=16):
|
||||
batch_size = 1
|
||||
num_latent_channels = 16
|
||||
sequence_length = 16
|
||||
embedding_dim = 1792
|
||||
|
||||
hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to(torch_device)
|
||||
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
|
||||
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
|
||||
|
||||
return {
|
||||
"hidden_states": hidden_states,
|
||||
"timestep": timestep,
|
||||
"encoder_hidden_states": encoder_hidden_states,
|
||||
}
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"in_channels": 16,
|
||||
"patch_size": 2,
|
||||
"context_in_dim": 1792,
|
||||
"hidden_size": 1792,
|
||||
"mlp_ratio": 3.5,
|
||||
"num_heads": 28,
|
||||
"depth": 4, # Smaller depth for testing
|
||||
"axes_dim": [32, 32],
|
||||
"theta": 10_000,
|
||||
}
|
||||
inputs_dict = self.prepare_dummy_input()
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing_is_applied(self):
|
||||
expected_set = {"PRXTransformer2DModel"}
|
||||
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
0
tests/pipelines/prx/__init__.py
Normal file
0
tests/pipelines/prx/__init__.py
Normal file
265
tests/pipelines/prx/test_pipeline_prx.py
Normal file
265
tests/pipelines/prx/test_pipeline_prx.py
Normal file
@@ -0,0 +1,265 @@
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.t5gemma.configuration_t5gemma import T5GemmaConfig, T5GemmaModuleConfig
|
||||
from transformers.models.t5gemma.modeling_t5gemma import T5GemmaEncoder
|
||||
|
||||
from diffusers.models import AutoencoderDC, AutoencoderKL
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.prx.pipeline_prx import PRXPipeline
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from diffusers.utils import is_transformers_version
|
||||
|
||||
from ..pipeline_params import TEXT_TO_IMAGE_PARAMS
|
||||
from ..test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
condition=is_transformers_version(">", "4.57.1"),
|
||||
reason="See https://github.com/huggingface/diffusers/pull/12456#issuecomment-3424228544",
|
||||
strict=False,
|
||||
)
|
||||
class PRXPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = PRXPipeline
|
||||
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
|
||||
batch_params = frozenset(["prompt", "negative_prompt", "num_images_per_prompt"])
|
||||
test_xformers_attention = False
|
||||
test_layerwise_casting = True
|
||||
test_group_offloading = True
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
# Ensure PRXPipeline has an _execution_device property expected by __call__
|
||||
if not isinstance(getattr(PRXPipeline, "_execution_device", None), property):
|
||||
try:
|
||||
setattr(PRXPipeline, "_execution_device", property(lambda self: torch.device("cpu")))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = PRXTransformer2DModel(
|
||||
patch_size=1,
|
||||
in_channels=4,
|
||||
context_in_dim=8,
|
||||
hidden_size=8,
|
||||
mlp_ratio=2.0,
|
||||
num_heads=2,
|
||||
depth=1,
|
||||
axes_dim=[2, 2],
|
||||
)
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
sample_size=32,
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
block_out_channels=(4,),
|
||||
layers_per_block=1,
|
||||
latent_channels=4,
|
||||
norm_num_groups=1,
|
||||
use_quant_conv=False,
|
||||
use_post_quant_conv=False,
|
||||
shift_factor=0.0,
|
||||
scaling_factor=1.0,
|
||||
).eval()
|
||||
|
||||
torch.manual_seed(0)
|
||||
scheduler = FlowMatchEulerDiscreteScheduler()
|
||||
|
||||
torch.manual_seed(0)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
|
||||
tokenizer.model_max_length = 64
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
encoder_params = {
|
||||
"vocab_size": tokenizer.vocab_size,
|
||||
"hidden_size": 8,
|
||||
"intermediate_size": 16,
|
||||
"num_hidden_layers": 1,
|
||||
"num_attention_heads": 2,
|
||||
"num_key_value_heads": 1,
|
||||
"head_dim": 4,
|
||||
"max_position_embeddings": 64,
|
||||
"layer_types": ["full_attention"],
|
||||
"attention_bias": False,
|
||||
"attention_dropout": 0.0,
|
||||
"dropout_rate": 0.0,
|
||||
"hidden_activation": "gelu_pytorch_tanh",
|
||||
"rms_norm_eps": 1e-06,
|
||||
"attn_logit_softcapping": 50.0,
|
||||
"final_logit_softcapping": 30.0,
|
||||
"query_pre_attn_scalar": 4,
|
||||
"rope_theta": 10000.0,
|
||||
"sliding_window": 4096,
|
||||
}
|
||||
encoder_config = T5GemmaModuleConfig(**encoder_params)
|
||||
text_encoder_config = T5GemmaConfig(encoder=encoder_config, is_encoder_decoder=False, **encoder_params)
|
||||
text_encoder = T5GemmaEncoder(text_encoder_config)
|
||||
|
||||
return {
|
||||
"transformer": transformer,
|
||||
"vae": vae,
|
||||
"scheduler": scheduler,
|
||||
"text_encoder": text_encoder,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
return {
|
||||
"prompt": "",
|
||||
"negative_prompt": "",
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 1.0,
|
||||
"height": 32,
|
||||
"width": 32,
|
||||
"output_type": "pt",
|
||||
"use_resolution_binning": False,
|
||||
}
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
components = self.get_dummy_components()
|
||||
pipe = PRXPipeline(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
try:
|
||||
pipe.register_to_config(_execution_device="cpu")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs)[0]
|
||||
generated_image = image[0]
|
||||
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
expected_image = torch.zeros(3, 32, 32)
|
||||
max_diff = np.abs(generated_image - expected_image).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
|
||||
def test_callback_inputs(self):
|
||||
components = self.get_dummy_components()
|
||||
pipe = PRXPipeline(**components)
|
||||
pipe = pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
try:
|
||||
pipe.register_to_config(_execution_device="cpu")
|
||||
except Exception:
|
||||
pass
|
||||
self.assertTrue(
|
||||
hasattr(pipe, "_callback_tensor_inputs"),
|
||||
f" {PRXPipeline} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
|
||||
)
|
||||
|
||||
def callback_inputs_subset(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
def callback_inputs_all(pipe, i, t, callback_kwargs):
|
||||
for tensor_name in pipe._callback_tensor_inputs:
|
||||
assert tensor_name in callback_kwargs
|
||||
for tensor_name in callback_kwargs.keys():
|
||||
assert tensor_name in pipe._callback_tensor_inputs
|
||||
return callback_kwargs
|
||||
|
||||
inputs = self.get_dummy_inputs("cpu")
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_subset
|
||||
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
inputs["callback_on_step_end"] = callback_inputs_all
|
||||
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
|
||||
_ = pipe(**inputs)[0]
|
||||
|
||||
def test_attention_slicing_forward_pass(self, expected_max_diff=1e-3):
|
||||
if not self.test_attention_slicing:
|
||||
return
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
for component in pipe.components.values():
|
||||
if hasattr(component, "set_default_attn_processor"):
|
||||
component.set_default_attn_processor()
|
||||
pipe.to("cpu")
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
def to_np_local(tensor):
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
return tensor.detach().cpu().numpy()
|
||||
return tensor
|
||||
|
||||
generator_device = "cpu"
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_without_slicing = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=1)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing1 = pipe(**inputs)[0]
|
||||
|
||||
pipe.enable_attention_slicing(slice_size=2)
|
||||
inputs = self.get_dummy_inputs(generator_device)
|
||||
output_with_slicing2 = pipe(**inputs)[0]
|
||||
|
||||
max_diff1 = np.abs(to_np_local(output_with_slicing1) - to_np_local(output_without_slicing)).max()
|
||||
max_diff2 = np.abs(to_np_local(output_with_slicing2) - to_np_local(output_without_slicing)).max()
|
||||
self.assertLess(max(max_diff1, max_diff2), expected_max_diff)
|
||||
|
||||
def test_inference_with_autoencoder_dc(self):
|
||||
"""Test PRXPipeline with AutoencoderDC (DCAE) instead of AutoencoderKL."""
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
|
||||
torch.manual_seed(0)
|
||||
vae_dc = AutoencoderDC(
|
||||
in_channels=3,
|
||||
latent_channels=4,
|
||||
attention_head_dim=2,
|
||||
encoder_block_types=(
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
decoder_block_types=(
|
||||
"ResBlock",
|
||||
"EfficientViTBlock",
|
||||
),
|
||||
encoder_block_out_channels=(8, 8),
|
||||
decoder_block_out_channels=(8, 8),
|
||||
encoder_qkv_multiscales=((), (5,)),
|
||||
decoder_qkv_multiscales=((), (5,)),
|
||||
encoder_layers_per_block=(1, 1),
|
||||
decoder_layers_per_block=(1, 1),
|
||||
upsample_block_type="interpolate",
|
||||
downsample_block_type="stride_conv",
|
||||
decoder_norm_types="rms_norm",
|
||||
decoder_act_fns="silu",
|
||||
).eval()
|
||||
|
||||
components["vae"] = vae_dc
|
||||
|
||||
pipe = PRXPipeline(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
expected_scale_factor = vae_dc.spatial_compression_ratio
|
||||
self.assertEqual(pipe.vae_scale_factor, expected_scale_factor)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs)[0]
|
||||
generated_image = image[0]
|
||||
|
||||
self.assertEqual(generated_image.shape, (3, 32, 32))
|
||||
expected_image = torch.zeros(3, 32, 32)
|
||||
max_diff = np.abs(generated_image - expected_image).max()
|
||||
self.assertLessEqual(max_diff, 1e10)
|
||||
@@ -29,7 +29,7 @@ The benchmark results for Flux and CogVideoX can be found in [this](https://gith
|
||||
The tests, and the expected slices, were obtained from the `aws-g6e-xlarge-plus` GPU test runners. To run the slow tests, use the following command or an equivalent:
|
||||
|
||||
```bash
|
||||
HF_HUB_ENABLE_HF_TRANSFER=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
|
||||
HF_XET_HIGH_PERFORMANCE=1 RUN_SLOW=1 pytest -s tests/quantization/torchao/test_torchao.py::SlowTorchAoTests
|
||||
```
|
||||
|
||||
`diffusers-cli`:
|
||||
|
||||
Reference in New Issue
Block a user