1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

refactor based on dhruv's feedback: remove the class method

This commit is contained in:
yiyixuxu
2026-01-18 00:35:01 +01:00
parent 25c968a38f
commit de03d7f100

View File

@@ -324,6 +324,133 @@ class ConfigSpec:
description: Optional[str] = None
# ======================================================
# InputParam and OutputParam templates
# ======================================================
INPUT_PARAM_TEMPLATES = {
"prompt": {
"type_hint": str,
"required": True,
"description": "The prompt or prompts to guide image generation.",
},
"negative_prompt": {
"type_hint": str,
"default": None,
"description": "The prompt or prompts not to guide the image generation.",
},
"max_sequence_length": {
"type_hint": int,
"default": 512,
"description": "Maximum sequence length for prompt encoding.",
},
"height": {
"type_hint": int,
"description": "The height in pixels of the generated image.",
},
"width": {
"type_hint": int,
"description": "The width in pixels of the generated image.",
},
"num_inference_steps": {
"type_hint": int,
"default": 50,
"description": "The number of denoising steps.",
},
"num_images_per_prompt": {
"type_hint": int,
"default": 1,
"description": "The number of images to generate per prompt.",
},
"generator": {
"type_hint": torch.Generator,
"default": None,
"description": "Torch generator for deterministic generation.",
},
"sigmas": {
"type_hint": List[float],
"default": None,
"description": "Custom sigmas for the denoising process.",
},
"strength": {
"type_hint": float,
"default": 0.9,
"description": "Strength for img2img/inpainting.",
},
"image": {
"type_hint": PIL.Image.Image,
"required": True,
"description": "Input image for img2img, editing, or conditioning.",
},
"mask_image": {
"type_hint": PIL.Image.Image,
"required": True,
"description": "Mask image for inpainting.",
},
"control_image": {
"type_hint": PIL.Image.Image,
"required": True,
"description": "Control image for ControlNet conditioning.",
},
"padding_mask_crop": {
"type_hint": int,
"default": None,
"description": "Padding for mask cropping in inpainting.",
},
"latents": {
"type_hint": torch.Tensor,
"default": None,
"description": "Pre-generated noisy latents for image generation.",
},
"timesteps": {
"type_hint": torch.Tensor,
"default": None,
"description": "Timesteps for the denoising process.",
},
"output_type": {
"type_hint": str,
"default": "pil",
"description": "Output format: 'pil', 'np', 'pt'.",
},
"attention_kwargs": {
"type_hint": Dict[str, Any],
"default": None,
"description": "Additional kwargs for attention processors.",
},
"denoiser_input_fields": {
"kwargs_type": "denoiser_input_fields",
"type_hint": torch.Tensor,
"description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
},
"control_guidance_start": {
"type_hint": float,
"default": 0.0,
"description": "When to start applying ControlNet.",
},
"control_guidance_end": {
"type_hint": float,
"default": 1.0,
"description": "When to stop applying ControlNet.",
},
"controlnet_conditioning_scale": {
"type_hint": float,
"default": 1.0,
"description": "Scale for ControlNet conditioning.",
},
}
OUTPUT_PARAM_TEMPLATES = {
"images": {
"type_hint": List[PIL.Image.Image],
"description": "Generated images.",
},
"latents": {
"type_hint": torch.Tensor,
"description": "Denoised latents.",
},
}
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
# however some fields are not relevant for intermediate_inputs
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
@@ -344,190 +471,22 @@ class InputParam:
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
@classmethod
def template(cls, name: str) -> Optional["InputParam"]:
"""Get template for name if exists, otherwise None."""
if hasattr(cls, name) and callable(getattr(cls, name)):
return getattr(cls, name)()
return None
# ======================================================
# InputParam templates
# ======================================================
@classmethod
def prompt(cls) -> "InputParam":
return cls(
name="prompt", type_hint=str, required=True, description="The prompt or prompts to guide image generation."
)
@classmethod
def negative_prompt(cls) -> "InputParam":
return cls(
name="negative_prompt",
type_hint=str,
default=None,
description="The prompt or prompts not to guide the image generation.",
)
@classmethod
def max_sequence_length(cls, default: int = 512) -> "InputParam":
return cls(
name="max_sequence_length",
type_hint=int,
default=default,
description="Maximum sequence length for prompt encoding.",
)
@classmethod
def height(cls, default: Optional[int] = None) -> "InputParam":
return cls(
name="height", type_hint=int, default=default, description="The height in pixels of the generated image."
)
@classmethod
def width(cls, default: Optional[int] = None) -> "InputParam":
return cls(
name="width", type_hint=int, default=default, description="The width in pixels of the generated image."
)
@classmethod
def num_inference_steps(cls, default: int = 50) -> "InputParam":
return cls(
name="num_inference_steps", type_hint=int, default=default, description="The number of denoising steps."
)
@classmethod
def num_images_per_prompt(cls, default: int = 1) -> "InputParam":
return cls(
name="num_images_per_prompt",
type_hint=int,
default=default,
description="The number of images to generate per prompt.",
)
@classmethod
def generator(cls) -> "InputParam":
return cls(
name="generator",
type_hint=torch.Generator,
default=None,
description="Torch generator for deterministic generation.",
)
@classmethod
def sigmas(cls) -> "InputParam":
return cls(
name="sigmas", type_hint=List[float], default=None, description="Custom sigmas for the denoising process."
)
@classmethod
def strength(cls, default: float = 0.9) -> "InputParam":
return cls(name="strength", type_hint=float, default=default, description="Strength for img2img/inpainting.")
# images
@classmethod
def image(cls) -> "InputParam":
return cls(
name="image",
type_hint=PIL.Image.Image,
required=True,
description="Input image for img2img, editing, or conditioning.",
)
@classmethod
def mask_image(cls) -> "InputParam":
return cls(
name="mask_image", type_hint=PIL.Image.Image, required=True, description="Mask image for inpainting."
)
@classmethod
def control_image(cls) -> "InputParam":
return cls(
name="control_image",
type_hint=PIL.Image.Image,
required=True,
description="Control image for ControlNet conditioning.",
)
@classmethod
def padding_mask_crop(cls) -> "InputParam":
return cls(
name="padding_mask_crop",
type_hint=int,
default=None,
description="Padding for mask cropping in inpainting.",
)
@classmethod
def latents(cls) -> "InputParam":
return cls(
name="latents",
type_hint=torch.Tensor,
default=None,
description="Pre-generated noisy latents for image generation.",
)
@classmethod
def timesteps(cls) -> "InputParam":
return cls(
name="timesteps", type_hint=torch.Tensor, default=None, description="Timesteps for the denoising process."
)
@classmethod
def output_type(cls) -> "InputParam":
return cls(name="output_type", type_hint=str, default="pil", description="Output format: 'pil', 'np', 'pt''.")
@classmethod
def attention_kwargs(cls) -> "InputParam":
return cls(
name="attention_kwargs",
type_hint=Dict[str, Any],
default=None,
description="Additional kwargs for attention processors.",
)
@classmethod
def denoiser_input_fields(cls) -> "InputParam":
return cls(
kwargs_type="denoiser_input_fields",
type_hint=torch.Tensor,
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
)
# ControlNet
@classmethod
def control_guidance_start(cls, default: float = 0.0) -> "InputParam":
return cls(
name="control_guidance_start",
type_hint=float,
default=default,
description="When to start applying ControlNet.",
)
@classmethod
def control_guidance_end(cls, default: float = 1.0) -> "InputParam":
return cls(
name="control_guidance_end",
type_hint=float,
default=default,
description="When to stop applying ControlNet.",
)
@classmethod
def controlnet_conditioning_scale(cls, default: float = 1.0) -> "InputParam":
return cls(
name="controlnet_conditioning_scale",
type_hint=float,
default=default,
description="Scale for ControlNet conditioning.",
)
def template(cls, name: str, **overrides) -> "InputParam":
"""Get template for name if exists, otherwise return basic InputParam with just the name."""
if name in INPUT_PARAM_TEMPLATES:
kwargs = {"name": name, **INPUT_PARAM_TEMPLATES[name]}
# Override with user-provided values
for key, value in overrides.items():
kwargs[key] = value
return cls(**kwargs)
return cls(name=name, **overrides)
@dataclass
class OutputParam:
"""Specification for an output parameter."""
name: str
name: str = None
type_hint: Any = None
description: str = ""
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
@@ -538,23 +497,15 @@ class OutputParam:
)
@classmethod
def template(cls, name: str) -> Optional["OutputParam"]:
"""Get template for name if exists, otherwise None."""
if hasattr(cls, name) and callable(getattr(cls, name)):
return getattr(cls, name)()
return None
# ======================================================
# OutputParam templates
# ======================================================
@classmethod
def images(cls) -> "OutputParam":
return cls(name="images", type_hint=List[PIL.Image.Image], description="Generated images.")
@classmethod
def latents(cls) -> "OutputParam":
return cls(name="latents", type_hint=torch.Tensor, description="Denoised latents.")
def template(cls, name: str, **overrides) -> "OutputParam":
"""Get template for name if exists, otherwise return basic OutputParam with just the name."""
if name in OUTPUT_PARAM_TEMPLATES:
kwargs = {"name": name, **OUTPUT_PARAM_TEMPLATES[name]}
# Override with user-provided values
for key, value in overrides.items():
kwargs[key] = value
return cls(**kwargs)
return cls(name=name, **overrides)
def format_inputs_short(inputs):
@@ -890,4 +841,4 @@ def make_doc_string(
output += "\n\n"
output += format_output_params(outputs, indent_level=2)
return output
return output