1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

rename photon to prx

This commit is contained in:
DavidBert
2025-10-21 15:38:55 +00:00
parent cefc2cf82d
commit b7fb0fe9d6
12 changed files with 88 additions and 66 deletions

24
.gitignore vendored
View File

@@ -178,4 +178,26 @@ tags
.ruff_cache
# wandb
wandb
wandb
convert_checkpoints.py
dcae_mirage_generated_image_.png
dcae_prx_generated_image.png
example_usage.py
META_TENSOR_FIX.md
mirage_generated_image__.png
mirage_generated_image_.png
prx_generated_image.png
plan.md
test_existing_checkpoints_with_timestep_change.py
test_timestep_embedding.py
test_updated_checkpoint.png
test_updated_checkpoint.py
testhf.ipynb
update_checkpoint_parameters.py
verify_checkpoint_parameters.py
for_claude/mirage_layers.py
for_claude/mirage.py
for_claude/text_tower.py
for_claude/vae_tower.py
prx_/prx_layers.py
prx_/prx.py

View File

@@ -541,8 +541,8 @@
title: PAG
- local: api/pipelines/paint_by_example
title: Paint by Example
- local: api/pipelines/photon
title: Photon
- local: api/pipelines/prx
title: PRX
- local: api/pipelines/pixart
title: PixArt-α
- local: api/pipelines/pixart_sigma

View File

@@ -12,43 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License. -->
# Photon
# PRX
Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
## Available models
Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information.
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
## Loading the pipeline
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
```py
from diffusers.pipelines.photon import PhotonPipeline
from diffusers.pipelines.prx import PRXPipeline
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.to("cuda")
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
image.save("photon_output.png")
image.save("prx_output.png")
```
### Manual Component Loading
@@ -57,9 +57,9 @@ Load components individually to customize the pipeline for instance to use quant
```py
import torch
from diffusers.pipelines.photon import PhotonPipeline
from diffusers.pipelines.prx import PRXPipeline
from diffusers.models import AutoencoderKL, AutoencoderDC
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from transformers import T5GemmaModel, GemmaTokenizerFast
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
@@ -67,8 +67,8 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfig
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
# Load transformer
transformer = PhotonTransformer2DModel.from_pretrained(
"checkpoints/photon-512-t2i-sft",
transformer = PRXTransformer2DModel.from_pretrained(
"checkpoints/prx-512-t2i-sft",
subfolder="transformer",
quantization_config=quant_config,
torch_dtype=torch.bfloat16,
@@ -76,7 +76,7 @@ transformer = PhotonTransformer2DModel.from_pretrained(
# Load scheduler
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
"checkpoints/photon-512-t2i-sft", subfolder="scheduler"
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
)
# Load T5Gemma text encoder
@@ -94,7 +94,7 @@ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
quantization_config=quant_config,
torch_dtype=torch.bfloat16)
pipe = PhotonPipeline(
pipe = PRXPipeline(
transformer=transformer,
scheduler=scheduler,
text_encoder=text_encoder,
@@ -111,21 +111,21 @@ For memory-constrained environments:
```py
import torch
from diffusers.pipelines.photon import PhotonPipeline
from diffusers.pipelines.prx import PRXPipeline
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
# Or use sequential CPU offload for even lower memory
pipe.enable_sequential_cpu_offload()
```
## PhotonPipeline
## PRXPipeline
[[autodoc]] PhotonPipeline
[[autodoc]] PRXPipeline
- all
- __call__
## PhotonPipelineOutput
## PRXPipelineOutput
[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput

View File

@@ -1,6 +1,6 @@
#!/usr/bin/env python3
"""
Script to convert Photon checkpoint from original codebase to diffusers format.
Script to convert PRX checkpoint from original codebase to diffusers format.
"""
import argparse
@@ -13,15 +13,15 @@ from typing import Dict, Tuple
import torch
from safetensors.torch import save_file
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
from diffusers.pipelines.photon import PhotonPipeline
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
from diffusers.pipelines.prx import PRXPipeline
DEFAULT_RESOLUTION = 512
@dataclass(frozen=True)
class PhotonBase:
class PRXBase:
context_in_dim: int = 2304
hidden_size: int = 1792
mlp_ratio: float = 3.5
@@ -34,22 +34,22 @@ class PhotonBase:
@dataclass(frozen=True)
class PhotonFlux(PhotonBase):
class PRXFlux(PRXBase):
in_channels: int = 16
patch_size: int = 2
@dataclass(frozen=True)
class PhotonDCAE(PhotonBase):
class PRXDCAE(PRXBase):
in_channels: int = 32
patch_size: int = 1
def build_config(vae_type: str) -> Tuple[dict, int]:
if vae_type == "flux":
cfg = PhotonFlux()
cfg = PRXFlux()
elif vae_type == "dc-ae":
cfg = PhotonDCAE()
cfg = PRXDCAE()
else:
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
@@ -64,7 +64,7 @@ def create_parameter_mapping(depth: int) -> dict:
# Key mappings for structural changes
mapping = {}
# Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention)
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
for i in range(depth):
# QKV projections moved to attention module
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
@@ -108,8 +108,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth
return converted_state_dict
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel:
"""Create and load PhotonTransformer2DModel from old checkpoint."""
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
"""Create and load PRXTransformer2DModel from old checkpoint."""
print(f"Loading checkpoint from: {checkpoint_path}")
@@ -137,8 +137,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
# Create transformer with config
print("Creating PhotonTransformer2DModel...")
transformer = PhotonTransformer2DModel(**config)
print("Creating PRXTransformer2DModel...")
transformer = PRXTransformer2DModel(**config)
# Load state dict
print("Loading converted parameters...")
@@ -221,14 +221,14 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str)
vae_class = "AutoencoderDC"
model_index = {
"_class_name": "PhotonPipeline",
"_class_name": "PRXPipeline",
"_diffusers_version": "0.31.0.dev0",
"_name_or_path": os.path.basename(output_path),
"default_sample_size": default_image_size,
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
"text_encoder": ["photon", "T5GemmaEncoder"],
"text_encoder": ["prx", "T5GemmaEncoder"],
"tokenizer": ["transformers", "GemmaTokenizerFast"],
"transformer": ["diffusers", "PhotonTransformer2DModel"],
"transformer": ["diffusers", "PRXTransformer2DModel"],
"vae": ["diffusers", vae_class],
}
@@ -275,7 +275,7 @@ def main(args):
# Verify the pipeline can be loaded
try:
pipeline = PhotonPipeline.from_pretrained(args.output_path)
pipeline = PRXPipeline.from_pretrained(args.output_path)
print("Pipeline loaded successfully!")
print(f"Transformer: {type(pipeline.transformer).__name__}")
print(f"VAE: {type(pipeline.vae).__name__}")
@@ -298,10 +298,10 @@ def main(args):
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
parser.add_argument(
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )"
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
)
parser.add_argument(

View File

@@ -232,7 +232,7 @@ else:
"MultiControlNetModel",
"OmniGenTransformer2DModel",
"ParallelConfig",
"PhotonTransformer2DModel",
"PRXTransformer2DModel",
"PixArtTransformer2DModel",
"PriorTransformer",
"QwenImageControlNetModel",
@@ -516,7 +516,7 @@ else:
"MusicLDMPipeline",
"OmniGenPipeline",
"PaintByExamplePipeline",
"PhotonPipeline",
"PRXPipeline",
"PIAPipeline",
"PixArtAlphaPipeline",
"PixArtSigmaPAGPipeline",
@@ -928,7 +928,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MultiControlNetModel,
OmniGenTransformer2DModel,
ParallelConfig,
PhotonTransformer2DModel,
PRXTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
QwenImageControlNetModel,
@@ -1182,7 +1182,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
MusicLDMPipeline,
OmniGenPipeline,
PaintByExamplePipeline,
PhotonPipeline,
PRXPipeline,
PIAPipeline,
PixArtAlphaPipeline,
PixArtSigmaPAGPipeline,

View File

@@ -96,7 +96,7 @@ if is_torch_available():
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
_import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"]
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
@@ -191,7 +191,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
LuminaNextDiT2DModel,
MochiTransformer3DModel,
OmniGenTransformer2DModel,
PhotonTransformer2DModel,
PRXTransformer2DModel,
PixArtTransformer2DModel,
PriorTransformer,
QwenImageTransformer2DModel,

View File

@@ -32,7 +32,7 @@ if is_torch_available():
from .transformer_lumina2 import Lumina2Transformer2DModel
from .transformer_mochi import MochiTransformer3DModel
from .transformer_omnigen import OmniGenTransformer2DModel
from .transformer_photon import PhotonTransformer2DModel
from .transformer_prx import PRXTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel

View File

@@ -144,7 +144,7 @@ else:
"FluxKontextPipeline",
"FluxKontextInpaintPipeline",
]
_import_structure["photon"] = ["PhotonPipeline"]
_import_structure["prx"] = ["PRXPipeline"]
_import_structure["audioldm"] = ["AudioLDMPipeline"]
_import_structure["audioldm2"] = [
"AudioLDM2Pipeline",
@@ -718,7 +718,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
StableDiffusionXLPAGPipeline,
)
from .paint_by_example import PaintByExamplePipeline
from .photon import PhotonPipeline
from .prx import PRXPipeline
from .pia import PIAPipeline
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
from .qwenimage import (

View File

@@ -12,7 +12,7 @@ from ...utils import (
_dummy_objects = {}
_additional_imports = {}
_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]}
_import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
try:
if not (is_transformers_available() and is_torch_available()):
@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
else:
_import_structure["pipeline_photon"] = ["PhotonPipeline"]
_import_structure["pipeline_prx"] = ["PRXPipeline"]
# Import T5GemmaEncoder for pipeline loading compatibility
try:
@@ -44,8 +44,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
except OptionalDependencyNotAvailable:
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
else:
from .pipeline_output import PhotonPipelineOutput
from .pipeline_photon import PhotonPipeline
from .pipeline_output import PRXPipelineOutput
from .pipeline_prx import PRXPipeline
else:
import sys

View File

@@ -22,9 +22,9 @@ from ...utils import BaseOutput
@dataclass
class PhotonPipelineOutput(BaseOutput):
class PRXPipelineOutput(BaseOutput):
"""
Output class for Photon pipelines.
Output class for PRX pipelines.
Args:
images (`List[PIL.Image.Image]` or `np.ndarray`)

View File

@@ -1098,7 +1098,7 @@ class ParallelConfig(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class PhotonTransformer2DModel(metaclass=DummyObject):
class PRXTransformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):

View File

@@ -1847,7 +1847,7 @@ class PaintByExamplePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class PhotonPipeline(metaclass=DummyObject):
class PRXPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):