mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
refactor based on dhruv's feedback: remove the class method
This commit is contained in:
@@ -324,6 +324,133 @@ class ConfigSpec:
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# ======================================================
|
||||
# InputParam and OutputParam templates
|
||||
# ======================================================
|
||||
|
||||
INPUT_PARAM_TEMPLATES = {
|
||||
"prompt": {
|
||||
"type_hint": str,
|
||||
"required": True,
|
||||
"description": "The prompt or prompts to guide image generation.",
|
||||
},
|
||||
"negative_prompt": {
|
||||
"type_hint": str,
|
||||
"default": None,
|
||||
"description": "The prompt or prompts not to guide the image generation.",
|
||||
},
|
||||
"max_sequence_length": {
|
||||
"type_hint": int,
|
||||
"default": 512,
|
||||
"description": "Maximum sequence length for prompt encoding.",
|
||||
},
|
||||
"height": {
|
||||
"type_hint": int,
|
||||
"description": "The height in pixels of the generated image.",
|
||||
},
|
||||
"width": {
|
||||
"type_hint": int,
|
||||
"description": "The width in pixels of the generated image.",
|
||||
},
|
||||
"num_inference_steps": {
|
||||
"type_hint": int,
|
||||
"default": 50,
|
||||
"description": "The number of denoising steps.",
|
||||
},
|
||||
"num_images_per_prompt": {
|
||||
"type_hint": int,
|
||||
"default": 1,
|
||||
"description": "The number of images to generate per prompt.",
|
||||
},
|
||||
"generator": {
|
||||
"type_hint": torch.Generator,
|
||||
"default": None,
|
||||
"description": "Torch generator for deterministic generation.",
|
||||
},
|
||||
"sigmas": {
|
||||
"type_hint": List[float],
|
||||
"default": None,
|
||||
"description": "Custom sigmas for the denoising process.",
|
||||
},
|
||||
"strength": {
|
||||
"type_hint": float,
|
||||
"default": 0.9,
|
||||
"description": "Strength for img2img/inpainting.",
|
||||
},
|
||||
"image": {
|
||||
"type_hint": PIL.Image.Image,
|
||||
"required": True,
|
||||
"description": "Input image for img2img, editing, or conditioning.",
|
||||
},
|
||||
"mask_image": {
|
||||
"type_hint": PIL.Image.Image,
|
||||
"required": True,
|
||||
"description": "Mask image for inpainting.",
|
||||
},
|
||||
"control_image": {
|
||||
"type_hint": PIL.Image.Image,
|
||||
"required": True,
|
||||
"description": "Control image for ControlNet conditioning.",
|
||||
},
|
||||
"padding_mask_crop": {
|
||||
"type_hint": int,
|
||||
"default": None,
|
||||
"description": "Padding for mask cropping in inpainting.",
|
||||
},
|
||||
"latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"default": None,
|
||||
"description": "Pre-generated noisy latents for image generation.",
|
||||
},
|
||||
"timesteps": {
|
||||
"type_hint": torch.Tensor,
|
||||
"default": None,
|
||||
"description": "Timesteps for the denoising process.",
|
||||
},
|
||||
"output_type": {
|
||||
"type_hint": str,
|
||||
"default": "pil",
|
||||
"description": "Output format: 'pil', 'np', 'pt'.",
|
||||
},
|
||||
"attention_kwargs": {
|
||||
"type_hint": Dict[str, Any],
|
||||
"default": None,
|
||||
"description": "Additional kwargs for attention processors.",
|
||||
},
|
||||
"denoiser_input_fields": {
|
||||
"kwargs_type": "denoiser_input_fields",
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
},
|
||||
"control_guidance_start": {
|
||||
"type_hint": float,
|
||||
"default": 0.0,
|
||||
"description": "When to start applying ControlNet.",
|
||||
},
|
||||
"control_guidance_end": {
|
||||
"type_hint": float,
|
||||
"default": 1.0,
|
||||
"description": "When to stop applying ControlNet.",
|
||||
},
|
||||
"controlnet_conditioning_scale": {
|
||||
"type_hint": float,
|
||||
"default": 1.0,
|
||||
"description": "Scale for ControlNet conditioning.",
|
||||
},
|
||||
}
|
||||
|
||||
OUTPUT_PARAM_TEMPLATES = {
|
||||
"images": {
|
||||
"type_hint": List[PIL.Image.Image],
|
||||
"description": "Generated images.",
|
||||
},
|
||||
"latents": {
|
||||
"type_hint": torch.Tensor,
|
||||
"description": "Denoised latents.",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# YiYi Notes: both inputs and intermediate_inputs are InputParam objects
|
||||
# however some fields are not relevant for intermediate_inputs
|
||||
# e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed
|
||||
@@ -344,190 +471,22 @@ class InputParam:
|
||||
return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>"
|
||||
|
||||
@classmethod
|
||||
def template(cls, name: str) -> Optional["InputParam"]:
|
||||
"""Get template for name if exists, otherwise None."""
|
||||
if hasattr(cls, name) and callable(getattr(cls, name)):
|
||||
return getattr(cls, name)()
|
||||
return None
|
||||
|
||||
# ======================================================
|
||||
# InputParam templates
|
||||
# ======================================================
|
||||
|
||||
@classmethod
|
||||
def prompt(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="prompt", type_hint=str, required=True, description="The prompt or prompts to guide image generation."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def negative_prompt(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="negative_prompt",
|
||||
type_hint=str,
|
||||
default=None,
|
||||
description="The prompt or prompts not to guide the image generation.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def max_sequence_length(cls, default: int = 512) -> "InputParam":
|
||||
return cls(
|
||||
name="max_sequence_length",
|
||||
type_hint=int,
|
||||
default=default,
|
||||
description="Maximum sequence length for prompt encoding.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def height(cls, default: Optional[int] = None) -> "InputParam":
|
||||
return cls(
|
||||
name="height", type_hint=int, default=default, description="The height in pixels of the generated image."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def width(cls, default: Optional[int] = None) -> "InputParam":
|
||||
return cls(
|
||||
name="width", type_hint=int, default=default, description="The width in pixels of the generated image."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_inference_steps(cls, default: int = 50) -> "InputParam":
|
||||
return cls(
|
||||
name="num_inference_steps", type_hint=int, default=default, description="The number of denoising steps."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def num_images_per_prompt(cls, default: int = 1) -> "InputParam":
|
||||
return cls(
|
||||
name="num_images_per_prompt",
|
||||
type_hint=int,
|
||||
default=default,
|
||||
description="The number of images to generate per prompt.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generator(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="generator",
|
||||
type_hint=torch.Generator,
|
||||
default=None,
|
||||
description="Torch generator for deterministic generation.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sigmas(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="sigmas", type_hint=List[float], default=None, description="Custom sigmas for the denoising process."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def strength(cls, default: float = 0.9) -> "InputParam":
|
||||
return cls(name="strength", type_hint=float, default=default, description="Strength for img2img/inpainting.")
|
||||
|
||||
# images
|
||||
@classmethod
|
||||
def image(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="image",
|
||||
type_hint=PIL.Image.Image,
|
||||
required=True,
|
||||
description="Input image for img2img, editing, or conditioning.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def mask_image(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="mask_image", type_hint=PIL.Image.Image, required=True, description="Mask image for inpainting."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def control_image(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="control_image",
|
||||
type_hint=PIL.Image.Image,
|
||||
required=True,
|
||||
description="Control image for ControlNet conditioning.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def padding_mask_crop(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="padding_mask_crop",
|
||||
type_hint=int,
|
||||
default=None,
|
||||
description="Padding for mask cropping in inpainting.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def latents(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="latents",
|
||||
type_hint=torch.Tensor,
|
||||
default=None,
|
||||
description="Pre-generated noisy latents for image generation.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def timesteps(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="timesteps", type_hint=torch.Tensor, default=None, description="Timesteps for the denoising process."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def output_type(cls) -> "InputParam":
|
||||
return cls(name="output_type", type_hint=str, default="pil", description="Output format: 'pil', 'np', 'pt''.")
|
||||
|
||||
@classmethod
|
||||
def attention_kwargs(cls) -> "InputParam":
|
||||
return cls(
|
||||
name="attention_kwargs",
|
||||
type_hint=Dict[str, Any],
|
||||
default=None,
|
||||
description="Additional kwargs for attention processors.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def denoiser_input_fields(cls) -> "InputParam":
|
||||
return cls(
|
||||
kwargs_type="denoiser_input_fields",
|
||||
type_hint=torch.Tensor,
|
||||
description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
|
||||
)
|
||||
|
||||
# ControlNet
|
||||
@classmethod
|
||||
def control_guidance_start(cls, default: float = 0.0) -> "InputParam":
|
||||
return cls(
|
||||
name="control_guidance_start",
|
||||
type_hint=float,
|
||||
default=default,
|
||||
description="When to start applying ControlNet.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def control_guidance_end(cls, default: float = 1.0) -> "InputParam":
|
||||
return cls(
|
||||
name="control_guidance_end",
|
||||
type_hint=float,
|
||||
default=default,
|
||||
description="When to stop applying ControlNet.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def controlnet_conditioning_scale(cls, default: float = 1.0) -> "InputParam":
|
||||
return cls(
|
||||
name="controlnet_conditioning_scale",
|
||||
type_hint=float,
|
||||
default=default,
|
||||
description="Scale for ControlNet conditioning.",
|
||||
)
|
||||
def template(cls, name: str, **overrides) -> "InputParam":
|
||||
"""Get template for name if exists, otherwise return basic InputParam with just the name."""
|
||||
if name in INPUT_PARAM_TEMPLATES:
|
||||
kwargs = {"name": name, **INPUT_PARAM_TEMPLATES[name]}
|
||||
# Override with user-provided values
|
||||
for key, value in overrides.items():
|
||||
kwargs[key] = value
|
||||
return cls(**kwargs)
|
||||
return cls(name=name, **overrides)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OutputParam:
|
||||
"""Specification for an output parameter."""
|
||||
|
||||
name: str
|
||||
name: str = None
|
||||
type_hint: Any = None
|
||||
description: str = ""
|
||||
kwargs_type: str = None # YiYi notes: remove this feature (maybe)
|
||||
@@ -538,23 +497,15 @@ class OutputParam:
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def template(cls, name: str) -> Optional["OutputParam"]:
|
||||
"""Get template for name if exists, otherwise None."""
|
||||
if hasattr(cls, name) and callable(getattr(cls, name)):
|
||||
return getattr(cls, name)()
|
||||
return None
|
||||
|
||||
# ======================================================
|
||||
# OutputParam templates
|
||||
# ======================================================
|
||||
|
||||
@classmethod
|
||||
def images(cls) -> "OutputParam":
|
||||
return cls(name="images", type_hint=List[PIL.Image.Image], description="Generated images.")
|
||||
|
||||
@classmethod
|
||||
def latents(cls) -> "OutputParam":
|
||||
return cls(name="latents", type_hint=torch.Tensor, description="Denoised latents.")
|
||||
def template(cls, name: str, **overrides) -> "OutputParam":
|
||||
"""Get template for name if exists, otherwise return basic OutputParam with just the name."""
|
||||
if name in OUTPUT_PARAM_TEMPLATES:
|
||||
kwargs = {"name": name, **OUTPUT_PARAM_TEMPLATES[name]}
|
||||
# Override with user-provided values
|
||||
for key, value in overrides.items():
|
||||
kwargs[key] = value
|
||||
return cls(**kwargs)
|
||||
return cls(name=name, **overrides)
|
||||
|
||||
|
||||
def format_inputs_short(inputs):
|
||||
@@ -890,4 +841,4 @@ def make_doc_string(
|
||||
output += "\n\n"
|
||||
output += format_output_params(outputs, indent_level=2)
|
||||
|
||||
return output
|
||||
return output
|
||||
Reference in New Issue
Block a user