mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
rename photon to prx
This commit is contained in:
24
.gitignore
vendored
24
.gitignore
vendored
@@ -178,4 +178,26 @@ tags
|
||||
.ruff_cache
|
||||
|
||||
# wandb
|
||||
wandb
|
||||
wandb
|
||||
convert_checkpoints.py
|
||||
dcae_mirage_generated_image_.png
|
||||
dcae_prx_generated_image.png
|
||||
example_usage.py
|
||||
META_TENSOR_FIX.md
|
||||
mirage_generated_image__.png
|
||||
mirage_generated_image_.png
|
||||
prx_generated_image.png
|
||||
plan.md
|
||||
test_existing_checkpoints_with_timestep_change.py
|
||||
test_timestep_embedding.py
|
||||
test_updated_checkpoint.png
|
||||
test_updated_checkpoint.py
|
||||
testhf.ipynb
|
||||
update_checkpoint_parameters.py
|
||||
verify_checkpoint_parameters.py
|
||||
for_claude/mirage_layers.py
|
||||
for_claude/mirage.py
|
||||
for_claude/text_tower.py
|
||||
for_claude/vae_tower.py
|
||||
prx_/prx_layers.py
|
||||
prx_/prx.py
|
||||
|
||||
@@ -541,8 +541,8 @@
|
||||
title: PAG
|
||||
- local: api/pipelines/paint_by_example
|
||||
title: Paint by Example
|
||||
- local: api/pipelines/photon
|
||||
title: Photon
|
||||
- local: api/pipelines/prx
|
||||
title: PRX
|
||||
- local: api/pipelines/pixart
|
||||
title: PixArt-α
|
||||
- local: api/pipelines/pixart_sigma
|
||||
|
||||
@@ -12,43 +12,43 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License. -->
|
||||
|
||||
# Photon
|
||||
# PRX
|
||||
|
||||
|
||||
Photon generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
|
||||
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
|
||||
|
||||
## Available models
|
||||
|
||||
Photon offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
|
||||
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
|
||||
|
||||
|
||||
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
|
||||
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
|
||||
| [`Photoroom/photon-256-t2i`](https://huggingface.co/Photoroom/photon-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-256-t2i-sft`](https://huggingface.co/Photoroom/photon-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i`](https://huggingface.co/Photoroom/photon-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-sft`](https://huggingface.co/Photoroom/photon-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-dc-ae`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/photon-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/photon-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
|
||||
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `torch.bfloat16` |
|
||||
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `torch.bfloat16` |s
|
||||
|
||||
Refer to [this](https://huggingface.co/collections/Photoroom/photon-models-68e66254c202ebfab99ad38e) collection for more information.
|
||||
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
|
||||
|
||||
## Loading the pipeline
|
||||
|
||||
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
|
||||
|
||||
```py
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
|
||||
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.to("cuda")
|
||||
|
||||
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
|
||||
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
|
||||
image.save("photon_output.png")
|
||||
image.save("prx_output.png")
|
||||
```
|
||||
|
||||
### Manual Component Loading
|
||||
@@ -57,9 +57,9 @@ Load components individually to customize the pipeline for instance to use quant
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
from diffusers.models import AutoencoderKL, AutoencoderDC
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
|
||||
from transformers import T5GemmaModel, GemmaTokenizerFast
|
||||
from diffusers import BitsAndBytesConfig as DiffusersBitsAndBytesConfig
|
||||
@@ -67,8 +67,8 @@ from transformers import BitsAndBytesConfig as BitsAndBytesConfig
|
||||
|
||||
quant_config = DiffusersBitsAndBytesConfig(load_in_8bit=True)
|
||||
# Load transformer
|
||||
transformer = PhotonTransformer2DModel.from_pretrained(
|
||||
"checkpoints/photon-512-t2i-sft",
|
||||
transformer = PRXTransformer2DModel.from_pretrained(
|
||||
"checkpoints/prx-512-t2i-sft",
|
||||
subfolder="transformer",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16,
|
||||
@@ -76,7 +76,7 @@ transformer = PhotonTransformer2DModel.from_pretrained(
|
||||
|
||||
# Load scheduler
|
||||
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
||||
"checkpoints/photon-512-t2i-sft", subfolder="scheduler"
|
||||
"checkpoints/prx-512-t2i-sft", subfolder="scheduler"
|
||||
)
|
||||
|
||||
# Load T5Gemma text encoder
|
||||
@@ -94,7 +94,7 @@ vae = AutoencoderKL.from_pretrained("black-forest-labs/FLUX.1-dev",
|
||||
quantization_config=quant_config,
|
||||
torch_dtype=torch.bfloat16)
|
||||
|
||||
pipe = PhotonPipeline(
|
||||
pipe = PRXPipeline(
|
||||
transformer=transformer,
|
||||
scheduler=scheduler,
|
||||
text_encoder=text_encoder,
|
||||
@@ -111,21 +111,21 @@ For memory-constrained environments:
|
||||
|
||||
```py
|
||||
import torch
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
pipe = PhotonPipeline.from_pretrained("Photoroom/photon-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload() # Offload components to CPU when not in use
|
||||
|
||||
# Or use sequential CPU offload for even lower memory
|
||||
pipe.enable_sequential_cpu_offload()
|
||||
```
|
||||
|
||||
## PhotonPipeline
|
||||
## PRXPipeline
|
||||
|
||||
[[autodoc]] PhotonPipeline
|
||||
[[autodoc]] PRXPipeline
|
||||
- all
|
||||
- __call__
|
||||
|
||||
## PhotonPipelineOutput
|
||||
## PRXPipelineOutput
|
||||
|
||||
[[autodoc]] pipelines.photon.pipeline_output.PhotonPipelineOutput
|
||||
[[autodoc]] pipelines.prx.pipeline_output.PRXPipelineOutput
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to convert Photon checkpoint from original codebase to diffusers format.
|
||||
Script to convert PRX checkpoint from original codebase to diffusers format.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
@@ -13,15 +13,15 @@ from typing import Dict, Tuple
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from diffusers.models.transformers.transformer_photon import PhotonTransformer2DModel
|
||||
from diffusers.pipelines.photon import PhotonPipeline
|
||||
from diffusers.models.transformers.transformer_prx import PRXTransformer2DModel
|
||||
from diffusers.pipelines.prx import PRXPipeline
|
||||
|
||||
|
||||
DEFAULT_RESOLUTION = 512
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhotonBase:
|
||||
class PRXBase:
|
||||
context_in_dim: int = 2304
|
||||
hidden_size: int = 1792
|
||||
mlp_ratio: float = 3.5
|
||||
@@ -34,22 +34,22 @@ class PhotonBase:
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhotonFlux(PhotonBase):
|
||||
class PRXFlux(PRXBase):
|
||||
in_channels: int = 16
|
||||
patch_size: int = 2
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PhotonDCAE(PhotonBase):
|
||||
class PRXDCAE(PRXBase):
|
||||
in_channels: int = 32
|
||||
patch_size: int = 1
|
||||
|
||||
|
||||
def build_config(vae_type: str) -> Tuple[dict, int]:
|
||||
if vae_type == "flux":
|
||||
cfg = PhotonFlux()
|
||||
cfg = PRXFlux()
|
||||
elif vae_type == "dc-ae":
|
||||
cfg = PhotonDCAE()
|
||||
cfg = PRXDCAE()
|
||||
else:
|
||||
raise ValueError(f"Unsupported VAE type: {vae_type}. Use 'flux' or 'dc-ae'")
|
||||
|
||||
@@ -64,7 +64,7 @@ def create_parameter_mapping(depth: int) -> dict:
|
||||
# Key mappings for structural changes
|
||||
mapping = {}
|
||||
|
||||
# Map old structure (layers in PhotonBlock) to new structure (layers in PhotonAttention)
|
||||
# Map old structure (layers in PRXBlock) to new structure (layers in PRXAttention)
|
||||
for i in range(depth):
|
||||
# QKV projections moved to attention module
|
||||
mapping[f"blocks.{i}.img_qkv_proj.weight"] = f"blocks.{i}.attention.img_qkv_proj.weight"
|
||||
@@ -108,8 +108,8 @@ def convert_checkpoint_parameters(old_state_dict: Dict[str, torch.Tensor], depth
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PhotonTransformer2DModel:
|
||||
"""Create and load PhotonTransformer2DModel from old checkpoint."""
|
||||
def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> PRXTransformer2DModel:
|
||||
"""Create and load PRXTransformer2DModel from old checkpoint."""
|
||||
|
||||
print(f"Loading checkpoint from: {checkpoint_path}")
|
||||
|
||||
@@ -137,8 +137,8 @@ def create_transformer_from_checkpoint(checkpoint_path: str, config: dict) -> Ph
|
||||
converted_state_dict = convert_checkpoint_parameters(state_dict, depth=model_depth)
|
||||
|
||||
# Create transformer with config
|
||||
print("Creating PhotonTransformer2DModel...")
|
||||
transformer = PhotonTransformer2DModel(**config)
|
||||
print("Creating PRXTransformer2DModel...")
|
||||
transformer = PRXTransformer2DModel(**config)
|
||||
|
||||
# Load state dict
|
||||
print("Loading converted parameters...")
|
||||
@@ -221,14 +221,14 @@ def create_model_index(vae_type: str, default_image_size: int, output_path: str)
|
||||
vae_class = "AutoencoderDC"
|
||||
|
||||
model_index = {
|
||||
"_class_name": "PhotonPipeline",
|
||||
"_class_name": "PRXPipeline",
|
||||
"_diffusers_version": "0.31.0.dev0",
|
||||
"_name_or_path": os.path.basename(output_path),
|
||||
"default_sample_size": default_image_size,
|
||||
"scheduler": ["diffusers", "FlowMatchEulerDiscreteScheduler"],
|
||||
"text_encoder": ["photon", "T5GemmaEncoder"],
|
||||
"text_encoder": ["prx", "T5GemmaEncoder"],
|
||||
"tokenizer": ["transformers", "GemmaTokenizerFast"],
|
||||
"transformer": ["diffusers", "PhotonTransformer2DModel"],
|
||||
"transformer": ["diffusers", "PRXTransformer2DModel"],
|
||||
"vae": ["diffusers", vae_class],
|
||||
}
|
||||
|
||||
@@ -275,7 +275,7 @@ def main(args):
|
||||
|
||||
# Verify the pipeline can be loaded
|
||||
try:
|
||||
pipeline = PhotonPipeline.from_pretrained(args.output_path)
|
||||
pipeline = PRXPipeline.from_pretrained(args.output_path)
|
||||
print("Pipeline loaded successfully!")
|
||||
print(f"Transformer: {type(pipeline.transformer).__name__}")
|
||||
print(f"VAE: {type(pipeline.vae).__name__}")
|
||||
@@ -298,10 +298,10 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Convert Photon checkpoint to diffusers format")
|
||||
parser = argparse.ArgumentParser(description="Convert PRX checkpoint to diffusers format")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original Photon checkpoint (.pth file )"
|
||||
"--checkpoint_path", type=str, required=True, help="Path to the original PRX checkpoint (.pth file )"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
||||
@@ -232,7 +232,7 @@ else:
|
||||
"MultiControlNetModel",
|
||||
"OmniGenTransformer2DModel",
|
||||
"ParallelConfig",
|
||||
"PhotonTransformer2DModel",
|
||||
"PRXTransformer2DModel",
|
||||
"PixArtTransformer2DModel",
|
||||
"PriorTransformer",
|
||||
"QwenImageControlNetModel",
|
||||
@@ -516,7 +516,7 @@ else:
|
||||
"MusicLDMPipeline",
|
||||
"OmniGenPipeline",
|
||||
"PaintByExamplePipeline",
|
||||
"PhotonPipeline",
|
||||
"PRXPipeline",
|
||||
"PIAPipeline",
|
||||
"PixArtAlphaPipeline",
|
||||
"PixArtSigmaPAGPipeline",
|
||||
@@ -928,7 +928,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MultiControlNetModel,
|
||||
OmniGenTransformer2DModel,
|
||||
ParallelConfig,
|
||||
PhotonTransformer2DModel,
|
||||
PRXTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageControlNetModel,
|
||||
@@ -1182,7 +1182,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
MusicLDMPipeline,
|
||||
OmniGenPipeline,
|
||||
PaintByExamplePipeline,
|
||||
PhotonPipeline,
|
||||
PRXPipeline,
|
||||
PIAPipeline,
|
||||
PixArtAlphaPipeline,
|
||||
PixArtSigmaPAGPipeline,
|
||||
|
||||
@@ -96,7 +96,7 @@ if is_torch_available():
|
||||
_import_structure["transformers.transformer_lumina2"] = ["Lumina2Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_mochi"] = ["MochiTransformer3DModel"]
|
||||
_import_structure["transformers.transformer_omnigen"] = ["OmniGenTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_photon"] = ["PhotonTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
|
||||
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
|
||||
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
|
||||
@@ -191,7 +191,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
LuminaNextDiT2DModel,
|
||||
MochiTransformer3DModel,
|
||||
OmniGenTransformer2DModel,
|
||||
PhotonTransformer2DModel,
|
||||
PRXTransformer2DModel,
|
||||
PixArtTransformer2DModel,
|
||||
PriorTransformer,
|
||||
QwenImageTransformer2DModel,
|
||||
|
||||
@@ -32,7 +32,7 @@ if is_torch_available():
|
||||
from .transformer_lumina2 import Lumina2Transformer2DModel
|
||||
from .transformer_mochi import MochiTransformer3DModel
|
||||
from .transformer_omnigen import OmniGenTransformer2DModel
|
||||
from .transformer_photon import PhotonTransformer2DModel
|
||||
from .transformer_prx import PRXTransformer2DModel
|
||||
from .transformer_qwenimage import QwenImageTransformer2DModel
|
||||
from .transformer_sd3 import SD3Transformer2DModel
|
||||
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
||||
|
||||
@@ -144,7 +144,7 @@ else:
|
||||
"FluxKontextPipeline",
|
||||
"FluxKontextInpaintPipeline",
|
||||
]
|
||||
_import_structure["photon"] = ["PhotonPipeline"]
|
||||
_import_structure["prx"] = ["PRXPipeline"]
|
||||
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
||||
_import_structure["audioldm2"] = [
|
||||
"AudioLDM2Pipeline",
|
||||
@@ -718,7 +718,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
StableDiffusionXLPAGPipeline,
|
||||
)
|
||||
from .paint_by_example import PaintByExamplePipeline
|
||||
from .photon import PhotonPipeline
|
||||
from .prx import PRXPipeline
|
||||
from .pia import PIAPipeline
|
||||
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
||||
from .qwenimage import (
|
||||
|
||||
@@ -12,7 +12,7 @@ from ...utils import (
|
||||
|
||||
_dummy_objects = {}
|
||||
_additional_imports = {}
|
||||
_import_structure = {"pipeline_output": ["PhotonPipelineOutput"]}
|
||||
_import_structure = {"pipeline_output": ["PRXPipelineOutput"]}
|
||||
|
||||
try:
|
||||
if not (is_transformers_available() and is_torch_available()):
|
||||
@@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable:
|
||||
|
||||
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
||||
else:
|
||||
_import_structure["pipeline_photon"] = ["PhotonPipeline"]
|
||||
_import_structure["pipeline_prx"] = ["PRXPipeline"]
|
||||
|
||||
# Import T5GemmaEncoder for pipeline loading compatibility
|
||||
try:
|
||||
@@ -44,8 +44,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
except OptionalDependencyNotAvailable:
|
||||
from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
|
||||
else:
|
||||
from .pipeline_output import PhotonPipelineOutput
|
||||
from .pipeline_photon import PhotonPipeline
|
||||
from .pipeline_output import PRXPipelineOutput
|
||||
from .pipeline_prx import PRXPipeline
|
||||
|
||||
else:
|
||||
import sys
|
||||
|
||||
@@ -22,9 +22,9 @@ from ...utils import BaseOutput
|
||||
|
||||
|
||||
@dataclass
|
||||
class PhotonPipelineOutput(BaseOutput):
|
||||
class PRXPipelineOutput(BaseOutput):
|
||||
"""
|
||||
Output class for Photon pipelines.
|
||||
Output class for PRX pipelines.
|
||||
|
||||
Args:
|
||||
images (`List[PIL.Image.Image]` or `np.ndarray`)
|
||||
|
||||
@@ -1098,7 +1098,7 @@ class ParallelConfig(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PhotonTransformer2DModel(metaclass=DummyObject):
|
||||
class PRXTransformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
@@ -1847,7 +1847,7 @@ class PaintByExamplePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class PhotonPipeline(metaclass=DummyObject):
|
||||
class PRXPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
||||
Reference in New Issue
Block a user