diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 45556c538a..f8dde1fbd0 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -324,6 +324,133 @@ class ConfigSpec: description: Optional[str] = None +# ====================================================== +# InputParam and OutputParam templates +# ====================================================== + +INPUT_PARAM_TEMPLATES = { + "prompt": { + "type_hint": str, + "required": True, + "description": "The prompt or prompts to guide image generation.", + }, + "negative_prompt": { + "type_hint": str, + "default": None, + "description": "The prompt or prompts not to guide the image generation.", + }, + "max_sequence_length": { + "type_hint": int, + "default": 512, + "description": "Maximum sequence length for prompt encoding.", + }, + "height": { + "type_hint": int, + "description": "The height in pixels of the generated image.", + }, + "width": { + "type_hint": int, + "description": "The width in pixels of the generated image.", + }, + "num_inference_steps": { + "type_hint": int, + "default": 50, + "description": "The number of denoising steps.", + }, + "num_images_per_prompt": { + "type_hint": int, + "default": 1, + "description": "The number of images to generate per prompt.", + }, + "generator": { + "type_hint": torch.Generator, + "default": None, + "description": "Torch generator for deterministic generation.", + }, + "sigmas": { + "type_hint": List[float], + "default": None, + "description": "Custom sigmas for the denoising process.", + }, + "strength": { + "type_hint": float, + "default": 0.9, + "description": "Strength for img2img/inpainting.", + }, + "image": { + "type_hint": PIL.Image.Image, + "required": True, + "description": "Input image for img2img, editing, or conditioning.", + }, + "mask_image": { + "type_hint": PIL.Image.Image, + "required": True, + "description": "Mask image for inpainting.", + }, + "control_image": { + "type_hint": PIL.Image.Image, + "required": True, + "description": "Control image for ControlNet conditioning.", + }, + "padding_mask_crop": { + "type_hint": int, + "default": None, + "description": "Padding for mask cropping in inpainting.", + }, + "latents": { + "type_hint": torch.Tensor, + "default": None, + "description": "Pre-generated noisy latents for image generation.", + }, + "timesteps": { + "type_hint": torch.Tensor, + "default": None, + "description": "Timesteps for the denoising process.", + }, + "output_type": { + "type_hint": str, + "default": "pil", + "description": "Output format: 'pil', 'np', 'pt'.", + }, + "attention_kwargs": { + "type_hint": Dict[str, Any], + "default": None, + "description": "Additional kwargs for attention processors.", + }, + "denoiser_input_fields": { + "kwargs_type": "denoiser_input_fields", + "type_hint": torch.Tensor, + "description": "conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", + }, + "control_guidance_start": { + "type_hint": float, + "default": 0.0, + "description": "When to start applying ControlNet.", + }, + "control_guidance_end": { + "type_hint": float, + "default": 1.0, + "description": "When to stop applying ControlNet.", + }, + "controlnet_conditioning_scale": { + "type_hint": float, + "default": 1.0, + "description": "Scale for ControlNet conditioning.", + }, +} + +OUTPUT_PARAM_TEMPLATES = { + "images": { + "type_hint": List[PIL.Image.Image], + "description": "Generated images.", + }, + "latents": { + "type_hint": torch.Tensor, + "description": "Denoised latents.", + }, +} + + # YiYi Notes: both inputs and intermediate_inputs are InputParam objects # however some fields are not relevant for intermediate_inputs # e.g. unlike inputs, required only used in docstring for intermediate_inputs, we do not check if a required intermediate inputs is passed @@ -344,190 +471,22 @@ class InputParam: return f"<{self.name}: {'required' if self.required else 'optional'}, default={self.default}>" @classmethod - def template(cls, name: str) -> Optional["InputParam"]: - """Get template for name if exists, otherwise None.""" - if hasattr(cls, name) and callable(getattr(cls, name)): - return getattr(cls, name)() - return None - - # ====================================================== - # InputParam templates - # ====================================================== - - @classmethod - def prompt(cls) -> "InputParam": - return cls( - name="prompt", type_hint=str, required=True, description="The prompt or prompts to guide image generation." - ) - - @classmethod - def negative_prompt(cls) -> "InputParam": - return cls( - name="negative_prompt", - type_hint=str, - default=None, - description="The prompt or prompts not to guide the image generation.", - ) - - @classmethod - def max_sequence_length(cls, default: int = 512) -> "InputParam": - return cls( - name="max_sequence_length", - type_hint=int, - default=default, - description="Maximum sequence length for prompt encoding.", - ) - - @classmethod - def height(cls, default: Optional[int] = None) -> "InputParam": - return cls( - name="height", type_hint=int, default=default, description="The height in pixels of the generated image." - ) - - @classmethod - def width(cls, default: Optional[int] = None) -> "InputParam": - return cls( - name="width", type_hint=int, default=default, description="The width in pixels of the generated image." - ) - - @classmethod - def num_inference_steps(cls, default: int = 50) -> "InputParam": - return cls( - name="num_inference_steps", type_hint=int, default=default, description="The number of denoising steps." - ) - - @classmethod - def num_images_per_prompt(cls, default: int = 1) -> "InputParam": - return cls( - name="num_images_per_prompt", - type_hint=int, - default=default, - description="The number of images to generate per prompt.", - ) - - @classmethod - def generator(cls) -> "InputParam": - return cls( - name="generator", - type_hint=torch.Generator, - default=None, - description="Torch generator for deterministic generation.", - ) - - @classmethod - def sigmas(cls) -> "InputParam": - return cls( - name="sigmas", type_hint=List[float], default=None, description="Custom sigmas for the denoising process." - ) - - @classmethod - def strength(cls, default: float = 0.9) -> "InputParam": - return cls(name="strength", type_hint=float, default=default, description="Strength for img2img/inpainting.") - - # images - @classmethod - def image(cls) -> "InputParam": - return cls( - name="image", - type_hint=PIL.Image.Image, - required=True, - description="Input image for img2img, editing, or conditioning.", - ) - - @classmethod - def mask_image(cls) -> "InputParam": - return cls( - name="mask_image", type_hint=PIL.Image.Image, required=True, description="Mask image for inpainting." - ) - - @classmethod - def control_image(cls) -> "InputParam": - return cls( - name="control_image", - type_hint=PIL.Image.Image, - required=True, - description="Control image for ControlNet conditioning.", - ) - - @classmethod - def padding_mask_crop(cls) -> "InputParam": - return cls( - name="padding_mask_crop", - type_hint=int, - default=None, - description="Padding for mask cropping in inpainting.", - ) - - @classmethod - def latents(cls) -> "InputParam": - return cls( - name="latents", - type_hint=torch.Tensor, - default=None, - description="Pre-generated noisy latents for image generation.", - ) - - @classmethod - def timesteps(cls) -> "InputParam": - return cls( - name="timesteps", type_hint=torch.Tensor, default=None, description="Timesteps for the denoising process." - ) - - @classmethod - def output_type(cls) -> "InputParam": - return cls(name="output_type", type_hint=str, default="pil", description="Output format: 'pil', 'np', 'pt''.") - - @classmethod - def attention_kwargs(cls) -> "InputParam": - return cls( - name="attention_kwargs", - type_hint=Dict[str, Any], - default=None, - description="Additional kwargs for attention processors.", - ) - - @classmethod - def denoiser_input_fields(cls) -> "InputParam": - return cls( - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.", - ) - - # ControlNet - @classmethod - def control_guidance_start(cls, default: float = 0.0) -> "InputParam": - return cls( - name="control_guidance_start", - type_hint=float, - default=default, - description="When to start applying ControlNet.", - ) - - @classmethod - def control_guidance_end(cls, default: float = 1.0) -> "InputParam": - return cls( - name="control_guidance_end", - type_hint=float, - default=default, - description="When to stop applying ControlNet.", - ) - - @classmethod - def controlnet_conditioning_scale(cls, default: float = 1.0) -> "InputParam": - return cls( - name="controlnet_conditioning_scale", - type_hint=float, - default=default, - description="Scale for ControlNet conditioning.", - ) + def template(cls, name: str, **overrides) -> "InputParam": + """Get template for name if exists, otherwise return basic InputParam with just the name.""" + if name in INPUT_PARAM_TEMPLATES: + kwargs = {"name": name, **INPUT_PARAM_TEMPLATES[name]} + # Override with user-provided values + for key, value in overrides.items(): + kwargs[key] = value + return cls(**kwargs) + return cls(name=name, **overrides) @dataclass class OutputParam: """Specification for an output parameter.""" - name: str + name: str = None type_hint: Any = None description: str = "" kwargs_type: str = None # YiYi notes: remove this feature (maybe) @@ -538,23 +497,15 @@ class OutputParam: ) @classmethod - def template(cls, name: str) -> Optional["OutputParam"]: - """Get template for name if exists, otherwise None.""" - if hasattr(cls, name) and callable(getattr(cls, name)): - return getattr(cls, name)() - return None - - # ====================================================== - # OutputParam templates - # ====================================================== - - @classmethod - def images(cls) -> "OutputParam": - return cls(name="images", type_hint=List[PIL.Image.Image], description="Generated images.") - - @classmethod - def latents(cls) -> "OutputParam": - return cls(name="latents", type_hint=torch.Tensor, description="Denoised latents.") + def template(cls, name: str, **overrides) -> "OutputParam": + """Get template for name if exists, otherwise return basic OutputParam with just the name.""" + if name in OUTPUT_PARAM_TEMPLATES: + kwargs = {"name": name, **OUTPUT_PARAM_TEMPLATES[name]} + # Override with user-provided values + for key, value in overrides.items(): + kwargs[key] = value + return cls(**kwargs) + return cls(name=name, **overrides) def format_inputs_short(inputs): @@ -890,4 +841,4 @@ def make_doc_string( output += "\n\n" output += format_output_params(outputs, indent_level=2) - return output + return output \ No newline at end of file