mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
style
This commit is contained in:
@@ -4,7 +4,7 @@ import os
|
||||
|
||||
# Simple typed wrapper for parameter overrides
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from huggingface_hub.utils import (
|
||||
@@ -42,7 +42,7 @@ class MellonParam:
|
||||
fieldOptions: Optional[Dict[str, Any]] = None
|
||||
onChange: Any = None
|
||||
onSignal: Any = None
|
||||
required_block_params: Optional[Union[str, List[str]]] = None
|
||||
required_block_params: Optional[Union[str, List[str]]] = None
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dict for Mellon schema, excluding None values and name."""
|
||||
@@ -59,7 +59,13 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def control_image(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="control_image", label="Control Image", type="image", display=display, required_block_params=["control_image"])
|
||||
return cls(
|
||||
name="control_image",
|
||||
label="Control Image",
|
||||
type="image",
|
||||
display=display,
|
||||
required_block_params=["control_image"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def latents(cls, display: str = "input") -> "MellonParam":
|
||||
@@ -67,11 +73,23 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def image_latents(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="image_latents", label="Image Latents", type="latents", display=display, required_block_params=["image_latents"])
|
||||
return cls(
|
||||
name="image_latents",
|
||||
label="Image Latents",
|
||||
type="latents",
|
||||
display=display,
|
||||
required_block_params=["image_latents"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def first_frame_latents(cls, display: str = "input") -> "MellonParam":
|
||||
return cls(name="first_frame_latents", label="First Frame Latents", type="latents", display=display, required_block_params=["first_frame_latents"])
|
||||
return cls(
|
||||
name="first_frame_latents",
|
||||
label="First Frame Latents",
|
||||
type="latents",
|
||||
display=display,
|
||||
required_block_params=["first_frame_latents"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def image_latents_with_strength(cls) -> "MellonParam":
|
||||
@@ -97,7 +115,13 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def image_embeds(cls, display: str = "output") -> "MellonParam":
|
||||
return cls(name="image_embeds", label="Image Embeddings", type="image_embeds", display=display, required_block_params=["image_embeds"])
|
||||
return cls(
|
||||
name="image_embeds",
|
||||
label="Image Embeddings",
|
||||
type="image_embeds",
|
||||
display=display,
|
||||
required_block_params=["image_embeds"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def controlnet_conditioning_scale(cls, default: float = 0.5) -> "MellonParam":
|
||||
@@ -140,15 +164,38 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def prompt(cls, default: str = "") -> "MellonParam":
|
||||
return cls(name="prompt", label="Prompt", type="string", default=default, display="textarea", required_block_params=["prompt"])
|
||||
return cls(
|
||||
name="prompt",
|
||||
label="Prompt",
|
||||
type="string",
|
||||
default=default,
|
||||
display="textarea",
|
||||
required_block_params=["prompt"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def negative_prompt(cls, default: str = "") -> "MellonParam":
|
||||
return cls(name="negative_prompt", label="Negative Prompt", type="string", default=default, display="textarea", required_block_params=["negative_prompt"])
|
||||
return cls(
|
||||
name="negative_prompt",
|
||||
label="Negative Prompt",
|
||||
type="string",
|
||||
default=default,
|
||||
display="textarea",
|
||||
required_block_params=["negative_prompt"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def strength(cls, default: float = 0.5) -> "MellonParam":
|
||||
return cls(name="strength", label="Strength", type="float", default=default, min=0.0, max=1.0, step=0.01, required_block_params=["strength"])
|
||||
return cls(
|
||||
name="strength",
|
||||
label="Strength",
|
||||
type="float",
|
||||
default=default,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
required_block_params=["strength"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def guidance_scale(cls, default: float = 5.0) -> "MellonParam":
|
||||
@@ -165,29 +212,73 @@ class MellonParam:
|
||||
|
||||
@classmethod
|
||||
def height(cls, default: int = 1024) -> "MellonParam":
|
||||
return cls(name="height", label="Height", type="int", default=default, min=64, step=8, required_block_params=["height"])
|
||||
return cls(
|
||||
name="height",
|
||||
label="Height",
|
||||
type="int",
|
||||
default=default,
|
||||
min=64,
|
||||
step=8,
|
||||
required_block_params=["height"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def width(cls, default: int = 1024) -> "MellonParam":
|
||||
return cls(name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"])
|
||||
return cls(
|
||||
name="width", label="Width", type="int", default=default, min=64, step=8, required_block_params=["width"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def seed(cls, default: int = 0) -> "MellonParam":
|
||||
return cls(name="seed", label="Seed", type="int", default=default, min=0, max=4294967295, display="random", required_block_params=["generator"])
|
||||
return cls(
|
||||
name="seed",
|
||||
label="Seed",
|
||||
type="int",
|
||||
default=default,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
display="random",
|
||||
required_block_params=["generator"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_inference_steps(cls, default: int = 25) -> "MellonParam":
|
||||
return cls(
|
||||
name="num_inference_steps", label="Steps", type="int", default=default, min=1, max=100, display="slider", required_block_params=["num_inference_steps"]
|
||||
name="num_inference_steps",
|
||||
label="Steps",
|
||||
type="int",
|
||||
default=default,
|
||||
min=1,
|
||||
max=100,
|
||||
display="slider",
|
||||
required_block_params=["num_inference_steps"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_frames(cls, default: int = 81) -> "MellonParam":
|
||||
return cls(name="num_frames", label="Frames", type="int", default=default, min=1, max=480, display="slider", required_block_params=["num_frames"])
|
||||
return cls(
|
||||
name="num_frames",
|
||||
label="Frames",
|
||||
type="int",
|
||||
default=default,
|
||||
min=1,
|
||||
max=480,
|
||||
display="slider",
|
||||
required_block_params=["num_frames"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def layers(cls, default: int = 4) -> "MellonParam":
|
||||
return cls(name="layers", label="Layers", type="int", default=default, min=1, max=10, display="slider", required_block_params=["layers"])
|
||||
return cls(
|
||||
name="layers",
|
||||
label="Layers",
|
||||
type="int",
|
||||
default=default,
|
||||
min=1,
|
||||
max=10,
|
||||
display="slider",
|
||||
required_block_params=["layers"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def videos(cls) -> "MellonParam":
|
||||
@@ -201,7 +292,9 @@ class MellonParam:
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"])
|
||||
return cls(
|
||||
name="vae", label="VAE", type="diffusers_auto_model", display="input", required_block_params=["vae"]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def image_encoder(cls) -> "MellonParam":
|
||||
@@ -211,7 +304,13 @@ class MellonParam:
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="image_encoder", label="Image Encoder", type="diffusers_auto_model", display="input", required_block_params=["image_encoder"])
|
||||
return cls(
|
||||
name="image_encoder",
|
||||
label="Image Encoder",
|
||||
type="diffusers_auto_model",
|
||||
display="input",
|
||||
required_block_params=["image_encoder"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def unet(cls) -> "MellonParam":
|
||||
@@ -241,7 +340,13 @@ class MellonParam:
|
||||
Contains keys like 'model_id', 'repo_id', 'execution_device' etc. Use components.get_one(model_id) to retrieve
|
||||
the actual model.
|
||||
"""
|
||||
return cls(name="controlnet", label="ControlNet Model", type="diffusers_auto_model", display="input", required_block_params=["controlnet"])
|
||||
return cls(
|
||||
name="controlnet",
|
||||
label="ControlNet Model",
|
||||
type="diffusers_auto_model",
|
||||
display="input",
|
||||
required_block_params=["controlnet"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def text_encoders(cls) -> "MellonParam":
|
||||
@@ -253,7 +358,13 @@ class MellonParam:
|
||||
'repo_id': '...'
|
||||
} Use components.get_one(model_id) to retrieve each model.
|
||||
"""
|
||||
return cls(name="text_encoders", label="Text Encoders", type="diffusers_auto_models", display="input", required_block_params=["text_encoder"])
|
||||
return cls(
|
||||
name="text_encoders",
|
||||
label="Text Encoders",
|
||||
type="diffusers_auto_models",
|
||||
display="input",
|
||||
required_block_params=["text_encoder"],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def controlnet_bundle(cls, display: str = "input") -> "MellonParam":
|
||||
@@ -268,7 +379,13 @@ class MellonParam:
|
||||
|
||||
Output from Controlnet node, input to Denoise node.
|
||||
"""
|
||||
return cls(name="controlnet_bundle", label="ControlNet", type="custom_controlnet", display=display, required_block_params="controlnet_image")
|
||||
return cls(
|
||||
name="controlnet_bundle",
|
||||
label="ControlNet",
|
||||
type="custom_controlnet",
|
||||
display=display,
|
||||
required_block_params="controlnet_image",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def ip_adapter(cls) -> "MellonParam":
|
||||
@@ -561,7 +678,9 @@ class MellonPipelineConfig:
|
||||
return params
|
||||
|
||||
def __repr__(self) -> str:
|
||||
lines = [f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})"]
|
||||
lines = [
|
||||
f"MellonPipelineConfig(label={self.label!r}, default_repo={self.default_repo!r}, default_dtype={self.default_dtype!r})"
|
||||
]
|
||||
for node_type, spec in self.node_specs.items():
|
||||
if spec is None:
|
||||
lines.append(f" {node_type}: None")
|
||||
@@ -575,7 +694,6 @@ class MellonPipelineConfig:
|
||||
lines.append(f" outputs: {outputs}")
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to a JSON-serializable dictionary."""
|
||||
return {
|
||||
@@ -729,7 +847,7 @@ class MellonPipelineConfig:
|
||||
@classmethod
|
||||
def from_blocks(
|
||||
cls,
|
||||
blocks: "ModularPipelineBlocks",
|
||||
blocks,
|
||||
template: Dict[str, Optional[Dict[str, Any]]] = None,
|
||||
label: str = "",
|
||||
default_repo: str = "",
|
||||
@@ -740,26 +858,44 @@ class MellonPipelineConfig:
|
||||
"""
|
||||
if template is None:
|
||||
template = DEFAULT_NODE_SPECS
|
||||
|
||||
|
||||
sub_block_map = dict(blocks.sub_blocks)
|
||||
|
||||
|
||||
def filter_spec_for_block(template_spec: Dict[str, Any], block) -> Optional[Dict[str, Any]]:
|
||||
"""Filter template spec params based on what the block actually supports."""
|
||||
block_input_names = set(block.input_names)
|
||||
block_output_names = set(block.intermediate_output_names)
|
||||
block_component_names = set(block.component_names)
|
||||
|
||||
filtered_inputs = [p for p in template_spec.get("inputs", []) if p.required_block_params is None or all(name in block_input_names for name in p.required_block_params)]
|
||||
filtered_model_inputs = [p for p in template_spec.get("model_inputs", []) if p.required_block_params is None or all(name in block_component_names for name in p.required_block_params)]
|
||||
filtered_outputs = [p for p in template_spec.get("outputs", []) if p.required_block_params is None or all(name in block_output_names for name in p.required_block_params)]
|
||||
|
||||
|
||||
filtered_inputs = [
|
||||
p
|
||||
for p in template_spec.get("inputs", [])
|
||||
if p.required_block_params is None
|
||||
or all(name in block_input_names for name in p.required_block_params)
|
||||
]
|
||||
filtered_model_inputs = [
|
||||
p
|
||||
for p in template_spec.get("model_inputs", [])
|
||||
if p.required_block_params is None
|
||||
or all(name in block_component_names for name in p.required_block_params)
|
||||
]
|
||||
filtered_outputs = [
|
||||
p
|
||||
for p in template_spec.get("outputs", [])
|
||||
if p.required_block_params is None
|
||||
or all(name in block_output_names for name in p.required_block_params)
|
||||
]
|
||||
|
||||
filtered_input_names = {p.name for p in filtered_inputs}
|
||||
filtered_model_input_names = {p.name for p in filtered_model_inputs}
|
||||
|
||||
filtered_required_inputs = [r for r in template_spec.get("required_inputs", []) if r in filtered_input_names]
|
||||
filtered_required_model_inputs = [r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names]
|
||||
|
||||
|
||||
filtered_required_inputs = [
|
||||
r for r in template_spec.get("required_inputs", []) if r in filtered_input_names
|
||||
]
|
||||
filtered_required_model_inputs = [
|
||||
r for r in template_spec.get("required_model_inputs", []) if r in filtered_model_input_names
|
||||
]
|
||||
|
||||
return {
|
||||
"inputs": filtered_inputs,
|
||||
"model_inputs": filtered_model_inputs,
|
||||
@@ -768,24 +904,24 @@ class MellonPipelineConfig:
|
||||
"required_model_inputs": filtered_required_model_inputs,
|
||||
"block_name": template_spec.get("block_name"),
|
||||
}
|
||||
|
||||
|
||||
# Build node specs
|
||||
node_specs = {}
|
||||
for node_type, template_spec in template.items():
|
||||
if template_spec is None:
|
||||
node_specs[node_type] = None
|
||||
continue
|
||||
|
||||
|
||||
block_name = template_spec.get("block_name")
|
||||
if block_name is None or block_name not in sub_block_map:
|
||||
node_specs[node_type] = None
|
||||
continue
|
||||
|
||||
|
||||
node_specs[node_type] = filter_spec_for_block(template_spec, sub_block_map[block_name])
|
||||
|
||||
|
||||
return cls(
|
||||
node_specs=node_specs,
|
||||
label=label or getattr(blocks, "model_name", ""),
|
||||
default_repo=default_repo,
|
||||
default_dtype=default_dtype,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user