diff --git a/docs/source/en/modular_diffusers/end_to_end_guide.md b/docs/source/en/modular_diffusers/end_to_end_guide.md index ab4ba8020d..132c4870b7 100644 --- a/docs/source/en/modular_diffusers/end_to_end_guide.md +++ b/docs/source/en/modular_diffusers/end_to_end_guide.md @@ -505,7 +505,7 @@ We provide a auto controlnet input block that you can directly put into your wor ```py ->>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks_presets import StableDiffusionXLAutoControlNetInputStep +>>> from diffusers.modular_pipelines.stable_diffusion_xl.modular_blocks import StableDiffusionXLAutoControlNetInputStep >>> control_input_block = StableDiffusionXLAutoControlNetInputStep() >>> print(control_input_block) ``` @@ -613,7 +613,7 @@ to use You can easily share your differential diffusion workflow on the hub, by creating a modular repo like this https://huggingface.co/YiYiXu/modular-diffdiff -To create a Modular Repo and share on hub, you just need to run. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily. +To create a Modular Repo and share on hub, you just need to run `save_pretrained()` along with the `push_to_hub=True` flag. Note that if your pipeline contains custom block, you need to manually upload the code to the hub. But we are working on a command line tool to help you upload it very easily. ```py dd_pipeline.save_pretrained("YiYiXu/test_modular_doc", push_to_hub=True) @@ -626,7 +626,7 @@ With a modular repo, it is very easy for the community to use the workflow you j >>> import torch >>> from diffusers.utils import load_image >>> ->>> repo_id = "YiYiXu/modular-diffdiff" +>>> repo_id = "YiYiXu/modular-diffdiff-0704" >>> >>> components = ComponentsManager() >>> diff --git a/docs/source/en/modular_diffusers/getting_started.md b/docs/source/en/modular_diffusers/getting_started.md index c742230364..ff1633988e 100644 --- a/docs/source/en/modular_diffusers/getting_started.md +++ b/docs/source/en/modular_diffusers/getting_started.md @@ -31,10 +31,12 @@ Pipeline blocks are the fundamental building blocks of the Modular Diffusers sys - [`PipelineBlock`]: The most granular block - you define the computation logic. - [`SequentialPipelineBlocks`]: A multi-block composed of multiple blocks that run sequentially, passing outputs as inputs to the next block. -- [`LoopSequentialPipelineBlocks`]: A special type of multi-block that forms loops. +- [`LoopSequentialPipelineBlocks`]: A special type of `SequentialPipelineBlocks` that runs the same sequence of blocks multiple times (loops), typically used for iterative processes like denoising steps in diffusion models. - [`AutoPipelineBlocks`]: A multi-block composed of multiple blocks that are selected at runtime based on the inputs. -All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They can be used standalone or combined into larger blocks. Blocks are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. +All blocks have a consistent interface defining their requirements (components, configs, inputs, outputs) and computation logic. They can be defined standalone or combined into larger blocks - They are designed to be assembled into workflows for tasks such as image generation, video creation, and inpainting. However, blocks aren't runnable on thier own and they need to be converted into a a ModularPipeline to actually run. + +**Blocks vs Pipelines**: Blocks are just definitions - they define what components, inputs/outputs, and computation logics are needed, but they don't actually run anything. To execute blocks, you need to put them into a `ModularPipeline`. See the [ModularPipeline from ModularPipelineBlocks](#modularpipeline-from-modularpipelineblocks) section for how to create and run pipelines. It is very easy to use a `ModularPipelineBlocks` officially supported in 🧨 Diffusers @@ -321,10 +323,10 @@ In standard `model_index.json`, each component entry is a `(library, class)` tup ], ``` -In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs {})` +In `modular_model_index.json`, each component entry contains 3 elements: `(library, class, loading_specs_dict)` - `library` and `class`: Information about the actual component loaded in the pipeline at the time of saving (will be `null` if not loaded) -- `loading_specs`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint`. +- `loading_specs_dict`: A dictionary containing all information required to load this component, including `repo`, `revision`, `subfolder`, `variant`, and `type_hint`. ```py "text_encoder": [ @@ -342,21 +344,8 @@ In `modular_model_index.json`, each component entry contains 3 elements: `(libra } ], ``` -Some components may not have `repo` field, they cannot be loaded from a repository and can only be created with default config from the pipeline -```py - "image_processor": [ - "diffusers", - "VaeImageProcessor", - { - "type_hint": [ - "diffusers", - "VaeImageProcessor" - ] - } - ], -``` -Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs` dictionary. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. +Unlike standard repositories where components must be in subfolders within the same repo, modular repositories can fetch components from different repositories based on the `loading_specs_dict`. e.g. the `text_encoder` component will be fetched from the "text_encoder" folder in `stabilityai/stable-diffusion-xl-base-1.0` while other components come from different repositories. ### Creating a `ModularPipeline` from `ModularPipelineBlocks` @@ -370,7 +359,7 @@ Let's convert our `t2i_blocks` (which we created earlier) into a runnable `Modul t2i_blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) # Now convert it to a ModularPipeline -modular_repo_id = "YiYiXu/modular-loader-t2i" +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" t2i_pipeline = t2i_blocks.init_pipeline(modular_repo_id) ``` @@ -398,22 +387,36 @@ You can read more about Components Manager [here](TODO) You can create a `ModularPipeline` from a HuggingFace Hub repository with `from_pretrained` method, as long as it's a modular repo: ```py -# YiYi TODO: this is not yet supported actually 😢, need to add support from diffusers import ModularPipeline -pipeline = ModularPipeline.from_pretrained(repo_id, components_manager=..., collection=...) +pipeline = ModularPipeline.from_pretrained( "YiYiXu/modular-loader-t2i-0704") ``` Loading custom code is also supported: ```py from diffusers import ModularPipeline -modular_repo_id = "YiYiXu/modular-diffdiff" +modular_repo_id = "YiYiXu/modular-diffdiff-0704" diffdiff_pipeline = ModularPipeline.from_pretrained(modular_repo_id, trust_remote_code=True) ``` +This modular repository contains custom code. The [`config.json`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/config.json) file defines a custom `DiffDiffBlocks` class and points to its implementation: + +```json +{ + "_class_name": "DiffDiffBlocks", + "auto_map": { + "ModularPipelineBlocks": "block.DiffDiffBlocks" + } +} +``` + +The `auto_map` tells the pipeline where to find the custom blocks definition - in this case, it's looking for `DiffDiffBlocks` in the `block.py` file. The actual `DiffDiffBlocks` class is defined in [`block.py`](https://huggingface.co/YiYiXu/modular-diffdiff-0704/blob/main/block.py) within the repository. + +When `diffdiff_pipeline.blocks` is created, it's based on the `DiffDiffBlocks` definition from the custom code in the repository, allowing you to use specialized blocks that aren't part of the standard diffusers library. + ### Loading components into a `ModularPipeline` -Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whether using `from_pretrained` or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using `load_components`: +Unlike `DiffusionPipeline`, when you create a `ModularPipeline` instance (whether using `from_pretrained` or converting from pipeline blocks), its components aren't loaded automatically. You need to explicitly load model components using `load_default_components` or `load_components(names=..,)`: ```py # This will load ALL the expected components into pipeline @@ -428,49 +431,15 @@ All expected components are now loaded into the pipeline. You can also partially >>> t2i_pipeline.load_components(names=["unet", "vae"], torch_dtype=torch.float16) ``` -You can inspect the `loader` attribute of a pipeline to understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. The loader is synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements. - -For example, if your pipeline needs a `text_encoder` component, the loader will include the loading spec for `text_encoder` from the modular repo. If the pipeline doesn't need a component (like `controlnet` in a basic text-to-image pipeline), that component won't appear in the loader even if it exists in the modular repo. - -The loader has the same structure as `modular_model_index.json` - each component entry contains the `(library, class, loading_specs)` format. You'll need to understand that structure to properly read the loading status below. - - - -💡 **How to read the loader**: -- **`library` and `class` fields**: Show info about actually loaded components. If `null`, the component is not loaded yet. -- **`loading_specs`**: If it does not have `repo` field or if it is `null`, the component cannot be loaded from a repository and can only be created with default config by the pipeline. - - - -Let's inspect the `t2i_pipeline.loader`, you can see all the components expected to load are listed as entries in the loader. The `guider` and `image_processor` components were created using default config (their `library` and `class` field are populated, this means they are initialized, but their loading spec dict is missing loading related fields). The `vae` and `unet` components were loaded using their respective loading specs. The rest of the components (scheduler, text_encoder, text_encoder_2, tokenizer, tokenizer_2) are not loaded yet (their `library`, `class` fields are `null`), but you can examine their loading specs to see where they would be loaded from when you call `load_components()`. - +You can inspect the pipeline's loading status by simply printing the pipeline itself. It helps you understand what components are expected to load, which ones are already loaded, how they were loaded, and what loading specs are available. Let's print out the `t2i_pipeline`: ```py ->>> t2i_pipeline.loader -StableDiffusionXLModularLoader { - "_class_name": "StableDiffusionXLModularLoader", - "_diffusers_version": "0.34.0.dev0", +>>> t2i_pipeline +StableDiffusionXLModularPipeline { + "_blocks_class_name": "SequentialPipelineBlocks", + "_class_name": "StableDiffusionXLModularPipeline", + "_diffusers_version": "0.35.0.dev0", "force_zeros_for_empty_prompt": true, - "guider": [ - "diffusers", - "ClassifierFreeGuidance", - { - "type_hint": [ - "diffusers", - "ClassifierFreeGuidance" - ] - } - ], - "image_processor": [ - "diffusers", - "VaeImageProcessor", - { - "type_hint": [ - "diffusers", - "VaeImageProcessor" - ] - } - ], "scheduler": [ null, null, @@ -572,31 +541,42 @@ StableDiffusionXLModularLoader { } ``` +You can see all the components that will be loaded using `from_pretrained` method are listed as entries. Each entry contains 3 elements: `(library, class, loading_specs_dict)`: + +- **`library` and `class`**: Show the actual loaded component info. If `null`, the component is not loaded yet. +- **`loading_specs_dict`**: Contains all the information needed to load the component (repo, subfolder, variant, etc.) + +In this example: +- **Loaded components**: `vae` and `unet` (their `library` and `class` fields show the actual loaded models) +- **Not loaded yet**: `scheduler`, `text_encoder`, `text_encoder_2`, `tokenizer`, `tokenizer_2` (their `library` and `class` fields are `null`, but you can see their loading specs to know where they'll be loaded from when you call `load_components()`) + +You're looking at essentailly the pipeline's config dict that's synced with the `modular_model_index.json` from the repository you used during `init_pipeline()` - it takes the loading specs that match the pipeline's component requirements. + +For example, if your pipeline needs a `text_encoder` component, it will include the loading spec for `text_encoder` from the modular repo during the `init_pipeline`. If the pipeline doesn't need a component (like `controlnet` in a basic text-to-image pipeline), that component won't be included even if it exists in the modular repo. + There are also a few properties that can provide a quick summary of component loading status: ```py # All components expected by the pipeline ->>> t2i_pipeline.loader.component_names +>>> t2i_pipeline.component_names ['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'guider', 'scheduler', 'unet', 'vae', 'image_processor'] # Components that are not loaded yet (will be loaded with from_pretrained) ->>> t2i_pipeline.loader.null_component_names +>>> t2i_pipeline.null_component_names ['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler'] # Components that will be loaded from pretrained models ->>> t2i_pipeline.loader.pretrained_component_names +>>> t2i_pipeline.pretrained_component_names ['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'scheduler', 'unet', 'vae'] # Components that are created with default config (no repo needed) ->>> t2i_pipeline.loader.config_component_names +>>> t2i_pipeline.config_component_names ['guider', 'image_processor'] ``` ### Modifying Loading Specs -When you call `pipeline.load_components(names=...)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. The pipeline's `loader` attribute is synced with these specs - it shows you exactly what will be loaded and from where. - -You can change where components are loaded from by default by modifying the `modular_model_index.json` in the repository. You can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. +When you call `pipeline.load_components(names=)` or `pipeline.load_default_components()`, it uses the loading specs from the modular repository's `modular_model_index.json`. You can change where components are loaded from by default by modifying the `modular_model_index.json` in the repository. You can change any field in the loading specs: `repo`, `subfolder`, `variant`, `revision`, etc. ```py # Original spec in modular_model_index.json @@ -682,6 +662,31 @@ StableDiffusionXLModularLoader { ... } ``` + + +💡 **Modifying Component Specs**: You can get a copy of the current component spec from the pipeline using `get_component_spec()`. This makes it easy to modify the spec and updating components. + +```py +>>> unet_spec = t2i_pipeline.get_component_spec("unet") +>>> unet_spec +ComponentSpec( + name='unet', + type_hint=, + repo='RunDiffusion/Juggernaut-XL-v9', + subfolder='unet', + variant='fp16', + default_creation_method='from_pretrained' +) + +# Modify the spec to load from a different repository +>>> unet_spec.repo = "stabilityai/stable-diffusion-xl-base-1.0" + +# Load the component with the modified spec +>>> unet = unet_spec.load() +``` + + + ### Running a `ModularPipeline` @@ -728,7 +733,7 @@ from diffusers.modular_pipelines.stable_diffusion_xl import TEXT2IMAGE_BLOCKS # create pipeline from official blocks preset blocks = SequentialPipelineBlocks.from_blocks_dict(TEXT2IMAGE_BLOCKS) -modular_repo_id = "YiYiXu/modular-loader-t2i" +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" pipeline = blocks.init_pipeline(modular_repo_id) pipeline.load_default_components(torch_dtype=torch.float16) @@ -750,7 +755,7 @@ from diffusers.modular_pipelines.stable_diffusion_xl import IMAGE2IMAGE_BLOCKS # create pipeline from blocks preset blocks = SequentialPipelineBlocks.from_blocks_dict(IMAGE2IMAGE_BLOCKS) -modular_repo_id = "YiYiXu/modular-loader-t2i" +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" pipeline = blocks.init_pipeline(modular_repo_id) pipeline.load_default_components(torch_dtype=torch.float16) @@ -775,7 +780,7 @@ from diffusers.utils import load_image # create pipeline from blocks preset blocks = SequentialPipelineBlocks.from_blocks_dict(INPAINT_BLOCKS) -modular_repo_id = "YiYiXu/modular-loader-t2i" +modular_repo_id = "YiYiXu/modular-loader-t2i-0704" pipeline = blocks.init_pipeline(modular_repo_id) pipeline.load_default_components(torch_dtype=torch.float16) @@ -809,7 +814,7 @@ For ControlNet, we provide one auto block you can place at the `denoise` step. L >>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS >>> ALL_BLOCKS["controlnet"] InsertableDict([ - 0: ('denoise', ) + 0: ('denoise', ) ]) >>> controlnet_blocks = ALL_BLOCKS["controlnet"]["denoise"]() >>> controlnet_blocks @@ -899,7 +904,7 @@ Let's walk through the steps: >>> from diffusers.modular_pipelines.stable_diffusion_xl import ALL_BLOCKS >>> ALL_BLOCKS["ip_adapter"] InsertableDict([ - 0: ('ip_adapter', ) + 0: ('ip_adapter', ) ]) ``` @@ -932,8 +937,7 @@ StableDiffusionXLAutoIPAdapterStep( Sub-Blocks: • ip_adapter [trigger: ip_adapter_image] (StableDiffusionXLIPAdapterStep) Description: IP Adapter step that prepares ip adapter image embeddings. - Note that this step only prepares the embeddings - in order for it to work correctly, you need to load ip adapter weights into unet via ModularPipeline.loader. - e.g. pipeline.loader.load_ip_adapter() and pipeline.loader.set_ip_adapter_scale(). + Note that this step only prepares the embeddings - in order for it to work correctly, you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale(). See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin) for more details ) @@ -958,12 +962,12 @@ modular_repo_id = "YiYiXu/modular-demo-auto" pipeline = blocks.init_pipeline(modular_repo_id) pipeline.load_default_components(torch_dtype=torch.float16) -pipeline.loader.load_ip_adapter( +pipeline.load_ip_adapter( "h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin" ) -pipeline.loader.set_ip_adapter_scale(0.8) +pipeline.set_ip_adapter_scale(0.8) pipeline.to("cuda") ``` @@ -1020,31 +1024,23 @@ components = ComponentsManager() components.enable_auto_cpu_offload(device="cuda") ``` -Since we have a modular setup where different pipelines may share components, we recommend using a standalone loader to load components all at once and add them to each pipeline with `update_components()`. +Since we have a modular setup where different pipelines may share components, we recommend using a seperate `ModularPipeline` to load components all at once and add them to each pipeline with `update_components()`. - - -💡 **Load components without pipeline blocks**: -- `blocks.init_pipeline(repo)` creates a pipeline with a built-in loader that only includes components its blocks needs -- `StableDiffusionXLModularLoader.from_pretrained(repo)` set up a standalone loader that includes everything in the repo's `modular_model_index.json` - - - ```py -from diffusers import StableDiffusionXLModularLoader +from diffusers import ModularPipeline t2i_repo = "YiYiXu/modular-demo-auto" -t2i_loader = StableDiffusionXLModularLoader.from_pretrained(t2i_repo, components_manager=components, collection="t2i") +t2i_loader_pipe = ModularPipeline.from_pretrained(t2i_repo, components_manager=components, collection="t2i") text_node = text_blocks.init_pipeline(t2i_repo, components_manager=components) decoder_node = decoder_blocks.init_pipeline(t2i_repo, components_manager=components) t2i_pipe = t2i_blocks.init_pipeline(t2i_repo, components_manager=components) ``` -We'll load components in `t2i_loader`. You can get the list of all loadable components from loader's `pretrained_component_names` property. +We'll load components in `t2i_loader_pipe`. You can get the list of all loadable components from loader's `pretrained_component_names` property. ```py ->>> t2i_loader.pretrained_component_names +>>> t2i_loader_pipe.pretrained_component_names ['controlnet', 'image_encoder', 'scheduler', 'text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2', 'unet', 'vae'] ``` @@ -1054,7 +1050,7 @@ It include controlnet and image_encoder for ip-adapter that we don't need now. B import torch # inspect before you load # t2i_loader -t2i_loader.load(t2i_loader.pretrained_component_names, torch_dtype=torch.float16) +t2i_loader_pipe.load_components(names=t2i_loader_pipe.pretrained_component_names, torch_dtype=torch.float16) ``` All the models are registered to components manager under the collection "t2i". @@ -1088,15 +1084,15 @@ Additional Component Info: ``` Let's add the loaded components to each pipeline. We'll follow this pattern for each pipeline: -1. Check what components the pipeline needs: inspect `pipeline.loader` or use `loader.null_component_names` +1. Check what components the pipeline needs: inspect `pipeline` or use `pipeline.null_component_names` 2. Get them from the components manager: use its `search_models()`/`get_one`/`get_components_from_names` method 3. Update the pipeline: `pipeline.update_components()` -4. Verify the components are loaded correctly: inspect `pipeline.loader` as well as components manager +4. Verify the components are loaded correctly: inspect `pipeline` as well as components manager We will start with `decoder_node`. First, check what components it needs: ```py ->>> decoder_node.loader.null_component_names +>>> decoder_node.null_component_names ['vae'] ``` The pipeline only needs a `vae`. Looking at the components manager table, there's only one VAE available: @@ -1116,24 +1112,24 @@ decoder_node.update_components(vae=vae) Verify it's correctly loaded: ```py -decoder_node.loader +decoder_node ``` Now let's do the same for `text_node`. Get the list of components the pipeline needs to load: ```py ->>> text_node.loader.null_component_names +>>> text_node.null_component_names ['text_encoder', 'text_encoder_2', 'tokenizer', 'tokenizer_2'] ``` Pass the list directly to the components manager to get the components and add it to the pipeline ```py -text_components = components.get_components_by_names(text_node.loader.null_component_names) +text_components = components.get_components_by_names(text_node.null_component_names) # Add components to pipeline text_node.update_components(**text_components) # Verify components are loaded -assert not text_node.loader.null_component_names -text_node.loader +assert not text_node.null_component_names +text_node ``` Finally, let's set up `t2i_pipe`: @@ -1141,12 +1137,12 @@ Finally, let's set up `t2i_pipe`: ```py # Get unet & scheduler from components manager and add to pipeline -comps = components.get_components_by_names(t2i_pipe.loader.null_component_names) +comps = components.get_components_by_names(t2i_pipe.null_component_names) t2i_pipe.update_components(**comps) # Verify everything is loaded -assert not t2i_pipe.loader.null_component_names -t2i_pipe.loader +assert not t2i_pipe.null_component_names +t2i_pipe # Verify components manager hasn't changed (we only reused existing components) components @@ -1183,7 +1179,7 @@ image.save("modular_part2_t2i.png") Now let's add a LoRA to our pipeline. With the modular approach we will be able to reuse intermediate outputs from blocks that otherwise needs to be re-run. Let's load the LoRA weights and see what happens: ```py -t2i_loader.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face") +t2i_loader_pipe.load_lora_weights("CiroN2022/toy-face", weight_name="toy_face_sdxl.safetensors", adapter_name="toy_face") components ``` Notice that the "Additional Component Info" section shows that only the `unet` component has the LoRA adapter loaded. This means we can skip the text encoding step and reuse the existing embeddings, making the generation much faster. @@ -1231,12 +1227,12 @@ ipa_node = ipa_blocks.init_pipeline(t2i_repo, components_manager=components) comps = components.get_components_by_names(ipa_node.loader.null_component_names) ipa_node.update_components(**comps) -t2i_loader.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") -t2i_loader.set_ip_adapter_scale(0.6) +t2i_loader_pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin") +t2i_loader_pipe.set_ip_adapter_scale(0.6) # check it's correctly loaded -assert not ipa_node.loader.null_component_names -ipa_node.loader +assert not ipa_node.null_component_names +ipa_node # find out inputs/outputs print(ipa_node.doc) @@ -1305,7 +1301,7 @@ refiner_pipe = refiner_blocks.init_pipeline(refiner_repo, components_manager=com We want to reuse components from the t2i pipeline in the refiner as much as possible. First, let's check the loading status of the refiner pipeline to understand what components are needed: ```py ->>> refiner_pipe.loader +>>> refiner_pipe ``` Looking at the loader output, you can see that `text_encoder` and `tokenizer` have empty loading spec maps (their `repo` fields are `null`), this is because refiner pipeline does not use these two components so they are not listed in the `modular_model_index.json` in `refiner_repo`. The `unet` is different from the one we loaded for text-to-image. The remaining components: `vae`, `text_encoder_2`, `tokenizer_2`, and `scheduler` are already available in the t2i collection, we can reuse them instead of loading duplicates. @@ -1314,7 +1310,7 @@ Looking at the loader output, you can see that `text_encoder` and `tokenizer` ha refiner_pipe.load_components(names="unet", torch_dtype=torch.float16) # verify loaded correctly -refiner_pipe.loader +refiner_pipe # veryfiy registered to components manager under refiner components diff --git a/docs/source/en/modular_diffusers/write_own_pipeline_block.md b/docs/source/en/modular_diffusers/write_own_pipeline_block.md index 4739bbc690..f65af4463f 100644 --- a/docs/source/en/modular_diffusers/write_own_pipeline_block.md +++ b/docs/source/en/modular_diffusers/write_own_pipeline_block.md @@ -107,7 +107,7 @@ def __call__(self, components, state): # You can access them like: block_state.image, block_state.processed_image # Update the pipeline state with your updated block_states - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state ``` @@ -140,7 +140,7 @@ When you convert your blocks into a pipeline using `blocks.init_pipeline()`, the That's all you need to define in order to create a `PipelineBlock`. There is no hidden complexity. In fact we are going to create a helper function that take exactly these variables as input and return a pipeline block. We will use this helper function through out the tutorial to create test blocks -Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.add_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that. +Note that for `__call__` method, the only part you should implement differently is the part between `self.get_block_state()` and `self.set_block_state()`, which can be abstracted into a simple function that takes `block_state` and returns the updated state. Our helper function accepts a `block_fn` that does exactly that. **Helper Function** @@ -172,7 +172,7 @@ def make_block(inputs=[], intermediate_inputs=[], intermediate_outputs=[], block block_state = self.get_block_state(state) if block_fn is not None: block_state = block_fn(block_state, state) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state return TestBlock @@ -403,7 +403,7 @@ class DenoiseLoop(PipelineBlock): for t in range(block_state.num_inference_steps): # ... loop logic here pass - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state ``` @@ -455,7 +455,7 @@ class LoopWrapper(LoopSequentialPipelineBlocks): for i in range(block_state.num_steps): # loop_step executes all registered blocks in sequence components, block_state = self.loop_step(components, block_state, i=i) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state ``` @@ -464,7 +464,7 @@ class LoopWrapper(LoopSequentialPipelineBlocks): Loop blocks are standard `PipelineBlock`s, but their `__call__` method works differently: * It receives the iteration variable (e.g., `i`) passed by the loop wrapper * It works directly with `block_state` instead of pipeline state -* No need to call `self.get_block_state()` or `self.add_block_state()` +* No need to call `self.get_block_state()` or `self.set_block_state()` ```py class LoopBlock(PipelineBlock): diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 314a4126d2..885d37fc8e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -240,7 +240,6 @@ else: [ "ComponentsManager", "ComponentSpec", - "ModularLoader", "ModularPipeline", "ModularPipelineBlocks", ] @@ -360,7 +359,7 @@ else: _import_structure["modular_pipelines"].extend( [ "StableDiffusionXLAutoBlocks", - "StableDiffusionXLModularLoader", + "StableDiffusionXLModularPipeline", ] ) _import_structure["pipelines"].extend( @@ -881,7 +880,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .modular_pipelines import ( ComponentsManager, ComponentSpec, - ModularLoader, ModularPipeline, ModularPipelineBlocks, ) @@ -983,7 +981,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: else: from .modular_pipelines import ( StableDiffusionXLAutoBlocks, - StableDiffusionXLModularLoader, + StableDiffusionXLModularPipeline, ) from .pipelines import ( AllegroPipeline, diff --git a/src/diffusers/modular_pipelines/__init__.py b/src/diffusers/modular_pipelines/__init__.py index 9b18c8b048..bf34eed28b 100644 --- a/src/diffusers/modular_pipelines/__init__.py +++ b/src/diffusers/modular_pipelines/__init__.py @@ -29,7 +29,6 @@ else: "AutoPipelineBlocks", "SequentialPipelineBlocks", "LoopSequentialPipelineBlocks", - "ModularLoader", "PipelineState", "BlockState", ] @@ -40,7 +39,7 @@ else: "OutputParam", "InsertableDict", ] - _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularLoader"] + _import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"] _import_structure["components_manager"] = ["ComponentsManager"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -55,7 +54,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: AutoPipelineBlocks, BlockState, LoopSequentialPipelineBlocks, - ModularLoader, ModularPipeline, ModularPipelineBlocks, PipelineBlock, @@ -71,7 +69,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: ) from .stable_diffusion_xl import ( StableDiffusionXLAutoBlocks, - StableDiffusionXLModularLoader, + StableDiffusionXLModularPipeline, ) else: import sys diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index a1bdd86e8c..cf6501ad27 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -38,26 +38,6 @@ if is_accelerate_available(): logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# YiYi Notes: copied from modeling_utils.py (decide later where to put this) -def get_memory_footprint(self, return_buffers=True): - r""" - Get the memory footprint of a model. This will return the memory footprint of the current model in bytes. Useful to - benchmark the memory footprint of the current model and design some tests. Solution inspired from the PyTorch - discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2 - - Arguments: - return_buffers (`bool`, *optional*, defaults to `True`): - Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers are - tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch norm - layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2 - """ - mem = sum([param.nelement() * param.element_size() for param in self.parameters()]) - if return_buffers: - mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()]) - mem = mem + mem_bufs - return mem - - class CustomOffloadHook(ModelHook): """ A hook that offloads a model on the CPU until its forward pass is called. It ensures the model and its inputs are @@ -170,6 +150,8 @@ class AutoOffloadStrategy: the available memory on the device. """ + # YiYi TODO: instead of memory_reserve_margin, we should let user set the maximum_total_models_size to keep on device + # the actual memory usage would be higher. But it's simpler this way, and can be tested def __init__(self, memory_reserve_margin="3GB"): self.memory_reserve_margin = convert_file_size_to_int(memory_reserve_margin) @@ -177,7 +159,7 @@ class AutoOffloadStrategy: if len(hooks) == 0: return [] - current_module_size = get_memory_footprint(model) + current_module_size = model.get_memory_footprint() mem_on_device = torch.cuda.mem_get_info(execution_device.index)[0] mem_on_device = mem_on_device - self.memory_reserve_margin @@ -190,12 +172,13 @@ class AutoOffloadStrategy: # exlucde models that's not currently loaded on the device module_sizes = dict( sorted( - {hook.model_id: get_memory_footprint(hook.model) for hook in hooks}.items(), + {hook.model_id: hook.model.get_memory_footprint() for hook in hooks}.items(), key=lambda x: x[1], reverse=True, ) ) + # YiYi/Dhruv TODO: sort smallest to largest, and offload in that order we would tend to keep the larger models on GPU more often def search_best_candidate(module_sizes, min_memory_offload): """ search the optimal combination of models to offload to cpu, given a dictionary of module sizes and a @@ -652,7 +635,7 @@ class ComponentsManager: info.update( { "class_name": component.__class__.__name__, - "size_gb": get_memory_footprint(component) / (1024**3), + "size_gb": component.get_memory_footprint() / (1024**3), "adapters": None, # Default to None "has_hook": has_hook, "execution_device": execution_device, diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 99db80d315..d0429a1f45 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -19,6 +19,7 @@ import warnings from collections import OrderedDict from copy import deepcopy from dataclasses import dataclass, field +from types import SimpleNamespace from typing import Any, Dict, List, Optional, Tuple, Union import torch @@ -55,9 +56,15 @@ if is_accelerate_available(): logger = logging.get_logger(__name__) # pylint: disable=invalid-name -MODULAR_LOADER_MAPPING = OrderedDict( +MODULAR_PIPELINE_MAPPING = OrderedDict( [ - ("stable-diffusion-xl", "StableDiffusionXLModularLoader"), + ("stable-diffusion-xl", "StableDiffusionXLModularPipeline"), + ] +) + +MODULAR_PIPELINE_BLOCKS_MAPPING = OrderedDict( + [ + ("StableDiffusionXLModularPipeline", "StableDiffusionXLAutoBlocks"), ] ) @@ -73,7 +80,7 @@ class PipelineState: input_kwargs: Dict[str, List[str]] = field(default_factory=dict) intermediate_kwargs: Dict[str, List[str]] = field(default_factory=dict) - def add_input(self, key: str, value: Any, kwargs_type: str = None): + def set_input(self, key: str, value: Any, kwargs_type: str = None): """ Add an input to the pipeline state with optional metadata. @@ -89,7 +96,7 @@ class PipelineState: else: self.input_kwargs[kwargs_type].append(key) - def add_intermediate(self, key: str, value: Any, kwargs_type: str = None): + def set_intermediate(self, key: str, value: Any, kwargs_type: str = None): """ Add an intermediate value to the pipeline state with optional metadata. @@ -329,25 +336,18 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin): collection: Optional[str] = None, ): """ - create a ModularLoader, optionally accept modular_repo to load from hub. + create a ModularPipeline, optionally accept modular_repo to load from hub. """ - loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__) + pipeline_class_name = MODULAR_PIPELINE_MAPPING.get(self.model_name, ModularPipeline.__name__) diffusers_module = importlib.import_module("diffusers") - loader_class = getattr(diffusers_module, loader_class_name) + pipeline_class = getattr(diffusers_module, pipeline_class_name) - # Create deep copies to avoid modifying the original specs - component_specs = deepcopy(self.expected_components) - config_specs = deepcopy(self.expected_configs) - # Create the loader with the updated specs - specs = component_specs + config_specs - - loader = loader_class( - specs=specs, + modular_pipeline = pipeline_class( + blocks=deepcopy(self), pretrained_model_name_or_path=pretrained_model_name_or_path, components_manager=components_manager, collection=collection, ) - modular_pipeline = ModularPipeline(blocks=deepcopy(self), loader=loader) return modular_pipeline @@ -512,12 +512,12 @@ class PipelineBlock(ModularPipelineBlocks): data[input_param.kwargs_type][k] = v return BlockState(**data) - def add_block_state(self, state: PipelineState, block_state: BlockState): + def set_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediate_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) + state.set_intermediate(output_param.name, param, output_param.kwargs_type) for input_param in self.intermediate_inputs: if hasattr(block_state, input_param.name): @@ -525,7 +525,7 @@ class PipelineBlock(ModularPipelineBlocks): # Only add if the value is different from what's in the state current_value = state.get_intermediate(input_param.name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(input_param.name, param, input_param.kwargs_type) + state.set_intermediate(input_param.name, param, input_param.kwargs_type) for input_param in self.intermediate_inputs: if input_param.name and hasattr(block_state, input_param.name): @@ -533,7 +533,7 @@ class PipelineBlock(ModularPipelineBlocks): # Only add if the value is different from what's in the state current_value = state.get_intermediate(input_param.name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(input_param.name, param, input_param.kwargs_type) + state.set_intermediate(input_param.name, param, input_param.kwargs_type) elif input_param.kwargs_type: # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters # we need to first find out which inputs are and loop through them. @@ -541,7 +541,7 @@ class PipelineBlock(ModularPipelineBlocks): for param_name, current_value in intermediate_kwargs.items(): param = getattr(block_state, param_name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(param_name, param, input_param.kwargs_type) + state.set_intermediate(param_name, param, input_param.kwargs_type) def combine_inputs(*named_input_lists: List[Tuple[str, List[InputParam]]]) -> List[InputParam]: @@ -610,7 +610,6 @@ def combine_outputs(*named_output_lists: List[Tuple[str, List[OutputParam]]]) -> return list(combined_dict.values()) -# YiYi TODO: change blocks attribute to a different name, so it is not confused with the blocks attribute in ModularPipeline class AutoPipelineBlocks(ModularPipelineBlocks): """ A class that automatically selects a block to run based on the inputs. @@ -1524,12 +1523,12 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): data[input_param.kwargs_type][k] = v return BlockState(**data) - def add_block_state(self, state: PipelineState, block_state: BlockState): + def set_block_state(self, state: PipelineState, block_state: BlockState): for output_param in self.intermediate_outputs: if not hasattr(block_state, output_param.name): raise ValueError(f"Intermediate output '{output_param.name}' is missing in block state") param = getattr(block_state, output_param.name) - state.add_intermediate(output_param.name, param, output_param.kwargs_type) + state.set_intermediate(output_param.name, param, output_param.kwargs_type) for input_param in self.intermediate_inputs: if input_param.name and hasattr(block_state, input_param.name): @@ -1537,7 +1536,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): # Only add if the value is different from what's in the state current_value = state.get_intermediate(input_param.name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(input_param.name, param, input_param.kwargs_type) + state.set_intermediate(input_param.name, param, input_param.kwargs_type) elif input_param.kwargs_type: # if it is a kwargs type, e.g. "guider_input_fields", it is likely to be a list of parameters # we need to first find out which inputs are and loop through them. @@ -1547,7 +1546,7 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): continue param = getattr(block_state, param_name) if current_value is not param: # Using identity comparison to check if object was modified - state.add_intermediate(param_name, param, input_param.kwargs_type) + state.set_intermediate(param_name, param, input_param.kwargs_type) @property def doc(self): @@ -1643,123 +1642,22 @@ class LoopSequentialPipelineBlocks(ModularPipelineBlocks): # 2. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 3. do we need ConfigSpec? seems pretty unnecessrary for loader, can just add and kwargs to the loader # 4. add validator for methods where we accpet kwargs to be passed to from_pretrained() -class ModularLoader(ConfigMixin, PushToHubMixin): +class ModularPipeline(ConfigMixin, PushToHubMixin): """ - Base class for all Modular pipelines loaders. + Base class for all Modular pipelines. + Args: + blocks: ModularPipelineBlocks, the blocks to be used in the pipeline """ config_name = "modular_model_index.json" hf_device_map = None - def register_components(self, **kwargs): - """ - Register components with their corresponding specifications. - - This method is responsible for: - 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) - 2. Updates the modular_model_index.json configuration for serialization - 4. Adds components to the component manager if one is attached - - This method is called when: - - Components are first initialized in __init__: - - from_pretrained components not loaded during __init__ so they are registered as None; - - non from_pretrained components are created during __init__ and registered as the object itself - - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or - loader.update(guider=guider_spec) - - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(names=["unet"]) - - Args: - **kwargs: Keyword arguments where keys are component names and values are component objects. - E.g., register_components(unet=unet_model, text_encoder=encoder_model) - - Notes: - - Components must be created from ComponentSpec (have _diffusers_load_id attribute) - - When registering None for a component, it updates the modular_model_index.json config but sets attribute - to None - """ - for name, module in kwargs.items(): - # current component spec - component_spec = self._component_specs.get(name) - if component_spec is None: - logger.warning(f"ModularLoader.register_components: skipping unknown component '{name}'") - continue - - # check if it is the first time registration, i.e. calling from __init__ - is_registered = hasattr(self, name) - - # make sure the component is created from ComponentSpec - if module is not None and not hasattr(module, "_diffusers_load_id"): - raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") - - if module is not None: - # actual library and class name of the module - library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") - - # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config - # e.g. {"repo": "stabilityai/stable-diffusion-2-1", - # "type_hint": ("diffusers", "UNet2DConditionModel"), - # "subfolder": "unet", - # "variant": None, - # "revision": None} - component_spec_dict = self._component_spec_to_dict(component_spec) - - else: - # if module is None, e.g. self.register_components(unet=None) during __init__ - # we do not update the spec, - # but we still need to update the modular_model_index.json config based oncomponent spec - library, class_name = None, None - component_spec_dict = self._component_spec_to_dict(component_spec) - register_dict = {name: (library, class_name, component_spec_dict)} - - # set the component as attribute - # if it is not set yet, just set it and skip the process to check and warn below - if not is_registered: - self.register_to_config(**register_dict) - setattr(self, name, module) - if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: - self._components_manager.add(name, module, self._collection) - continue - - current_module = getattr(self, name, None) - # skip if the component is already registered with the same object - if current_module is module: - logger.info( - f"ModularLoader.register_components: {name} is already registered with same object, skipping" - ) - continue - - # warn if unregister - if current_module is not None and module is None: - logger.info( - f"ModularLoader.register_components: setting '{name}' to None " - f"(was {current_module.__class__.__name__})" - ) - # same type, new instance → replace but send debug log - elif ( - current_module is not None - and module is not None - and isinstance(module, current_module.__class__) - and current_module != module - ): - logger.debug( - f"ModularLoader.register_components: replacing existing '{name}' " - f"(same type {type(current_module).__name__}, new instance)" - ) - - # update modular_model_index.json config - self.register_to_config(**register_dict) - # finally set models - setattr(self, name, module) - # add to component manager if one is attached - if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: - self._components_manager.add(name, module, self._collection) - # YiYi TODO: add warning for passing multiple ComponentSpec/ConfigSpec with the same name def __init__( self, - specs: List[Union[ComponentSpec, ConfigSpec]], - pretrained_model_name_or_path: Optional[str] = None, + blocks: Optional[ModularPipelineBlocks] = None, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, components_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None, **kwargs, @@ -1767,25 +1665,35 @@ class ModularLoader(ConfigMixin, PushToHubMixin): """ Initialize the loader with a list of component specs and config specs. """ + if blocks is None: + blocks_class_name = MODULAR_PIPELINE_BLOCKS_MAPPING.get(self.__class__.__name__) + if blocks_class_name is not None: + diffusers_module = importlib.import_module("diffusers") + blocks_class = getattr(diffusers_module, blocks_class_name) + blocks = blocks_class() + else: + logger.warning(f"`blocks` is `None`, no default blocks class found for {self.__class__.__name__}") + + self.blocks = blocks self._components_manager = components_manager self._collection = collection - self._component_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ComponentSpec)} - self._config_specs = {spec.name: deepcopy(spec) for spec in specs if isinstance(spec, ConfigSpec)} + self._component_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_components} + self._config_specs = {spec.name: deepcopy(spec) for spec in self.blocks.expected_configs} # update component_specs and config_specs from modular_repo if pretrained_model_name_or_path is not None: config_dict = self.load_config(pretrained_model_name_or_path, **kwargs) for name, value in config_dict.items(): - # only update component_spec for from_pretrained components + # all the components in modular_model_index.json are from_pretrained components if ( name in self._component_specs - and self._component_specs[name].default_creation_method == "from_pretrained" and isinstance(value, (tuple, list)) and len(value) == 3 ): library, class_name, component_spec_dict = value component_spec = self._dict_to_component_spec(name, component_spec_dict) + component_spec.default_creation_method = "from_pretrained" self._component_specs[name] = component_spec elif name in self._config_specs: @@ -1805,6 +1713,243 @@ class ModularLoader(ConfigMixin, PushToHubMixin): default_configs[name] = config_spec.default self.register_to_config(**default_configs) + self.register_to_config(_blocks_class_name=self.blocks.__class__.__name__ if self.blocks is not None else None) + + @property + def default_call_parameters(self) -> Dict[str, Any]: + params = {} + for input_param in self.blocks.inputs: + params[input_param.name] = input_param.default + return params + + def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): + """ + Run one or more blocks in sequence, optionally you can pass a previous pipeline state. + """ + if state is None: + state = PipelineState() + + # Make a copy of the input kwargs + passed_kwargs = kwargs.copy() + + # Add inputs to state, using defaults if not provided in the kwargs or the state + # if same input already in the state, will override it if provided in the kwargs + + intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs] + for expected_input_param in self.blocks.inputs: + name = expected_input_param.name + default = expected_input_param.default + kwargs_type = expected_input_param.kwargs_type + if name in passed_kwargs: + if name not in intermediate_inputs: + state.set_input(name, passed_kwargs.pop(name), kwargs_type) + else: + state.set_input(name, passed_kwargs[name], kwargs_type) + elif name not in state.inputs: + state.set_input(name, default, kwargs_type) + + for expected_intermediate_param in self.blocks.intermediate_inputs: + name = expected_intermediate_param.name + kwargs_type = expected_intermediate_param.kwargs_type + if name in passed_kwargs: + state.set_intermediate(name, passed_kwargs.pop(name), kwargs_type) + + # Warn about unexpected inputs + if len(passed_kwargs) > 0: + warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") + # Run the pipeline + with torch.no_grad(): + try: + _, state = self.blocks(self, state) + except Exception: + error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" + logger.error(error_msg) + raise + + if output is None: + return state + + elif isinstance(output, str): + return state.get_intermediate(output) + + elif isinstance(output, (list, tuple)): + return state.get_intermediates(output) + else: + raise ValueError(f"Output '{output}' is not a valid output type") + + def load_default_components(self, **kwargs): + names = [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_pretrained" + ] + self.load_components(names=names, **kwargs) + + @classmethod + @validate_hf_hub_args + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + trust_remote_code: Optional[bool] = None, + components_manager: Optional[ComponentsManager] = None, + collection: Optional[str] = None, + **kwargs, + ): + from ..pipelines.pipeline_loading_utils import _get_pipeline_class + try: + blocks = ModularPipelineBlocks.from_pretrained( + pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs + ) + except EnvironmentError: + blocks = None + + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + proxies = kwargs.pop("proxies", None) + token = kwargs.pop("token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + load_config_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "token": token, + "local_files_only": local_files_only, + "revision": revision, + } + + config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) + pipeline_class = _get_pipeline_class(cls, config=config_dict) + + pipeline = pipeline_class( + blocks=blocks, + pretrained_model_name_or_path=pretrained_model_name_or_path, + components_manager=components_manager, + collection=collection, + **kwargs + ) + return pipeline + + # YiYi TODO: + # 1. should support save some components too! currently only modular_model_index.json is saved + # 2. maybe order the json file to make it more readable: configs first, then components + def save_pretrained( + self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs + ): + + self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) + + @property + def doc(self): + return self.blocks.doc + + + def register_components(self, **kwargs): + """ + Register components with their corresponding specifications. + + This method is responsible for: + 1. Sets component objects as attributes on the loader (e.g., self.unet = unet) + 2. Updates the modular_model_index.json configuration for serialization (only for from_pretrained components) + 3. Adds components to the component manager if one is attached (only for from_pretrained components) + + This method is called when: + - Components are first initialized in __init__: + - from_pretrained components not loaded during __init__ so they are registered as None; + - non from_pretrained components are created during __init__ and registered as the object itself + - Components are updated with the `update()` method: e.g. loader.update(unet=unet) or + loader.update(guider=guider_spec) + - (from_pretrained) Components are loaded with the `load()` method: e.g. loader.load(names=["unet"]) + + Args: + **kwargs: Keyword arguments where keys are component names and values are component objects. + E.g., register_components(unet=unet_model, text_encoder=encoder_model) + + Notes: + - Components must be created from ComponentSpec (have _diffusers_load_id attribute) + - When registering None for a component, it sets attribute to None but still syncs specs with the modular_model_index.json config + """ + for name, module in kwargs.items(): + # current component spec + component_spec = self._component_specs.get(name) + if component_spec is None: + logger.warning(f"ModularPipeline.register_components: skipping unknown component '{name}'") + continue + + # check if it is the first time registration, i.e. calling from __init__ + is_registered = hasattr(self, name) + is_from_pretrained = component_spec.default_creation_method == "from_pretrained" + + # make sure the component is created from ComponentSpec + if module is not None and not hasattr(module, "_diffusers_load_id"): + raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.") + + if module is not None: + # actual library and class name of the module + library, class_name = _fetch_class_library_tuple(module) # e.g. ("diffusers", "UNet2DConditionModel") + else: + # if module is None, e.g. self.register_components(unet=None) during __init__ + # we do not update the spec, + # but we still need to update the modular_model_index.json config based on component spec + library, class_name = None, None + + # extract the loading spec from the updated component spec that'll be used as part of modular_model_index.json config + # e.g. {"repo": "stabilityai/stable-diffusion-2-1", + # "type_hint": ("diffusers", "UNet2DConditionModel"), + # "subfolder": "unet", + # "variant": None, + # "revision": None} + component_spec_dict = self._component_spec_to_dict(component_spec) + + register_dict = {name: (library, class_name, component_spec_dict)} + + # set the component as attribute + # if it is not set yet, just set it and skip the process to check and warn below + if not is_registered: + if is_from_pretrained: + self.register_to_config(**register_dict) + setattr(self, name, module) + if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: + self._components_manager.add(name, module, self._collection) + continue + + current_module = getattr(self, name, None) + # skip if the component is already registered with the same object + if current_module is module: + logger.info( + f"ModularPipeline.register_components: {name} is already registered with same object, skipping" + ) + continue + + # warn if unregister + if current_module is not None and module is None: + logger.info( + f"ModularPipeline.register_components: setting '{name}' to None " + f"(was {current_module.__class__.__name__})" + ) + # same type, new instance → replace but send debug log + elif ( + current_module is not None + and module is not None + and isinstance(module, current_module.__class__) + and current_module != module + ): + logger.debug( + f"ModularPipeline.register_components: replacing existing '{name}' " + f"(same type {type(current_module).__name__}, new instance)" + ) + + # update modular_model_index.json config + if is_from_pretrained: + self.register_to_config(**register_dict) + # finally set models + setattr(self, name, module) + # add to component manager if one is attached + if module is not None and module._diffusers_load_id != "null" and self._components_manager is not None: + self._components_manager.add(name, module, self._collection) + + @property def device(self) -> torch.device: r""" @@ -1885,7 +2030,10 @@ class ModularLoader(ConfigMixin, PushToHubMixin): # return only components we've actually set as attributes on self return {name: getattr(self, name) for name in self._component_specs.keys() if hasattr(self, name)} - def update(self, **kwargs): + def get_component_spec(self, name: str) -> ComponentSpec: + return deepcopy(self._component_specs[name]) + + def update_components(self, **kwargs): """ Update components and configuration values after the loader has been instantiated. @@ -1938,7 +2086,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): for name, component in passed_components.items(): if not hasattr(component, "_diffusers_load_id"): - raise ValueError("`ModularLoader` only supports components created from `ComponentSpec`.") + raise ValueError("`ModularPipeline` only supports components created from `ComponentSpec`.") # YiYi TODO: remove this if we remove support for non config mixin components in `create()` method if component._diffusers_load_id == "null" and not isinstance(component, ConfigMixin): @@ -1953,7 +2101,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): component, current_component_spec.type_hint ): logger.warning( - f"ModularLoader.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + f"ModularPipeline.update: adding {name} with new type: {component.__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" ) # update _component_specs based on the new component new_component_spec = ComponentSpec.from_component(name, component) @@ -1975,7 +2123,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): created_components[name], current_component_spec.type_hint ): logger.warning( - f"ModularLoader.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" + f"ModularPipeline.update: adding {name} with new type: {created_components[name].__class__.__name__}, previous type: {current_component_spec.type_hint.__name__}" ) # update _component_specs based on the user passed component_spec self._component_specs[name] = component_spec @@ -1989,7 +2137,7 @@ class ModularLoader(ConfigMixin, PushToHubMixin): self.register_to_config(**config_to_register) # YiYi TODO: support map for additional from_pretrained kwargs - def load(self, names: Union[List[str], str], **kwargs): + def load_components(self, names: Union[List[str], str], **kwargs): """ Load selected components from specs. @@ -2246,58 +2394,16 @@ class ModularLoader(ConfigMixin, PushToHubMixin): ) return self - # YiYi TODO: - # 1. should support save some components too! currently only modular_model_index.json is saved - # 2. maybe order the json file to make it more readable: configs first, then components - def save_pretrained( - self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, spec_only: bool = True, **kwargs - ): - component_names = list(self._component_specs.keys()) - config_names = list(self._config_specs.keys()) - self.register_to_config(_components_names=component_names, _configs_names=config_names) - self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - config = dict(self.config) - config.pop("_components_names", None) - config.pop("_configs_names", None) - self._internal_dict = FrozenDict(config) - - @classmethod - @validate_hf_hub_args - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - spec_only: bool = True, - components_manager: Optional[ComponentsManager] = None, - collection: Optional[str] = None, - **kwargs, - ): - config_dict = cls.load_config(pretrained_model_name_or_path, **kwargs) - expected_component = set(config_dict.pop("_components_names")) - expected_config = set(config_dict.pop("_configs_names")) - - component_specs = [] - config_specs = [] - for name, value in config_dict.items(): - if name in expected_component and isinstance(value, (tuple, list)) and len(value) == 3: - library, class_name, component_spec_dict = value - # only pick up pretrained components from the repo - if component_spec_dict.get("repo", None) is not None: - component_spec = cls._dict_to_component_spec(name, component_spec_dict) - component_specs.append(component_spec) - - elif name in expected_config: - config_specs.append(ConfigSpec(name=name, default=value)) - - return cls(component_specs + config_specs, components_manager=components_manager, collection=collection) @staticmethod def _component_spec_to_dict(component_spec: ComponentSpec) -> Any: """ Convert a ComponentSpec into a JSON‐serializable dict for saving in `modular_model_index.json`. + If the default_creation_method is not from_pretrained, return None. This dict contains: - "type_hint": Tuple[str, str] - The fully‐qualified module path and class name of the component. + Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel")) - All loading fields defined by `component_spec.loading_fields()`, typically: - "repo": Optional[str] The model repository (e.g., "stabilityai/stable-diffusion-xl"). @@ -2317,23 +2423,36 @@ class ModularLoader(ConfigMixin, PushToHubMixin): Dict[str, Any]: A mapping suitable for JSON serialization. Example: - >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec >>> from diffusers.models.unet - import UNet2DConditionModel >>> spec = ComponentSpec( ... name="unet", ... type_hint=UNet2DConditionModel, - ... config=None, ... repo="path/to/repo", ... subfolder="subfolder", ... variant=None, ... revision=None, - ... default_creation_method="from_pretrained", ... ) >>> ModularLoader._component_spec_to_dict(spec) { - "type_hint": ("diffusers.models.unet", "UNet2DConditionModel"), "repo": "path/to/repo", "subfolder": - "subfolder", "variant": None, "revision": None, + >>> from diffusers.pipelines.modular_pipeline_utils import ComponentSpec + >>> from diffusers import UNet2DConditionModel + >>> spec = ComponentSpec( + ... name="unet", + ... type_hint=UNet2DConditionModel, + ... config=None, + ... repo="path/to/repo", + ... subfolder="subfolder", + ... variant=None, + ... revision=None, + ... default_creation_method="from_pretrained", + ... ) + >>> ModularPipeline._component_spec_to_dict(spec) + { + "type_hint": ("diffusers", "UNet2DConditionModel"), + "repo": "path/to/repo", + "subfolder": "subfolder", + "variant": None, + "revision": None, } """ + if component_spec.default_creation_method != "from_pretrained": + return None + if component_spec.type_hint is not None: lib_name, cls_name = _fetch_class_library_tuple(component_spec.type_hint) else: lib_name = None cls_name = None - if component_spec.default_creation_method == "from_pretrained": - load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} - else: - load_spec_dict = {} + load_spec_dict = {k: getattr(component_spec, k) for k in component_spec.loading_fields()} return { "type_hint": (lib_name, cls_name), **load_spec_dict, @@ -2345,7 +2464,51 @@ class ModularLoader(ConfigMixin, PushToHubMixin): spec_dict: Dict[str, Any], ) -> ComponentSpec: """ - Reconstruct a ComponentSpec from a dict. + Reconstruct a ComponentSpec from a loading specdict. + + This method converts a dictionary representation back into a ComponentSpec object. + The dict should contain: + - "type_hint": Tuple[str, str] + Library name and class name of the component. (e.g. ("diffusers", "UNet2DConditionModel")) + - All loading fields defined by `component_spec.loading_fields()`, typically: + - "repo": Optional[str] + The model repository (e.g., "stabilityai/stable-diffusion-xl"). + - "subfolder": Optional[str] + A subfolder within the repo where this component lives. + - "variant": Optional[str] + An optional variant identifier for the model. + - "revision": Optional[str] + A specific git revision (commit hash, tag, or branch). + - ... any other loading fields defined on the spec. + + Args: + name (str): + The name of the component. + specdict (Dict[str, Any]): + A dictionary containing the component specification data. + + Returns: + ComponentSpec: A reconstructed ComponentSpec object. + + Example: + >>> spec_dict = { + ... "type_hint": ("diffusers", "UNet2DConditionModel"), + ... "repo": "stabilityai/stable-diffusion-xl", + ... "subfolder": "unet", + ... "variant": None, + ... "revision": None, + ... } + >>> ModularPipeline._dict_to_component_spec("unet", spec_dict) + ComponentSpec( + name="unet", + type_hint=UNet2DConditionModel, + config=None, + repo="stabilityai/stable-diffusion-xl", + subfolder="unet", + variant=None, + revision=None, + default_creation_method="from_pretrained" + ) """ # make a shallow copy so we can pop() safely spec_dict = spec_dict.copy() @@ -2361,133 +2524,4 @@ class ModularLoader(ConfigMixin, PushToHubMixin): name=name, type_hint=type_hint, **spec_dict, - ) - - -class ModularPipeline: - """ - Base class for all Modular pipelines. - - Args: - blocks: ModularPipelineBlocks, the blocks to be used in the pipeline - loader: ModularLoader, the loader to be used in the pipeline - """ - - def __init__(self, blocks: ModularPipelineBlocks, loader: ModularLoader): - self.blocks = blocks - self.loader = loader - - def __repr__(self): - return f"ModularPipeline(\n blocks={repr(self.blocks)},\n loader={repr(self.loader)}\n)" - - @property - def default_call_parameters(self) -> Dict[str, Any]: - params = {} - for input_param in self.blocks.inputs: - params[input_param.name] = input_param.default - return params - - def __call__(self, state: PipelineState = None, output: Union[str, List[str]] = None, **kwargs): - """ - Run one or more blocks in sequence, optionally you can pass a previous pipeline state. - """ - if state is None: - state = PipelineState() - - # Make a copy of the input kwargs - passed_kwargs = kwargs.copy() - - # Add inputs to state, using defaults if not provided in the kwargs or the state - # if same input already in the state, will override it if provided in the kwargs - - intermediate_inputs = [inp.name for inp in self.blocks.intermediate_inputs] - for expected_input_param in self.blocks.inputs: - name = expected_input_param.name - default = expected_input_param.default - kwargs_type = expected_input_param.kwargs_type - if name in passed_kwargs: - if name not in intermediate_inputs: - state.add_input(name, passed_kwargs.pop(name), kwargs_type) - else: - state.add_input(name, passed_kwargs[name], kwargs_type) - elif name not in state.inputs: - state.add_input(name, default, kwargs_type) - - for expected_intermediate_param in self.blocks.intermediate_inputs: - name = expected_intermediate_param.name - kwargs_type = expected_intermediate_param.kwargs_type - if name in passed_kwargs: - state.add_intermediate(name, passed_kwargs.pop(name), kwargs_type) - - # Warn about unexpected inputs - if len(passed_kwargs) > 0: - warnings.warn(f"Unexpected input '{passed_kwargs.keys()}' provided. This input will be ignored.") - # Run the pipeline - with torch.no_grad(): - try: - pipeline, state = self.blocks(self.loader, state) - except Exception: - error_msg = f"Error in block: ({self.blocks.__class__.__name__}):\n" - logger.error(error_msg) - raise - - if output is None: - return state - - elif isinstance(output, str): - return state.get_intermediate(output) - - elif isinstance(output, (list, tuple)): - return state.get_intermediates(output) - else: - raise ValueError(f"Output '{output}' is not a valid output type") - - def load_default_components(self, **kwargs): - names = [ - name - for name in self.loader._component_specs.keys() - if self.loader._component_specs[name].default_creation_method == "from_pretrained" - ] - self.loader.load(names=names, **kwargs) - - def load_components(self, names: Union[List[str], str], **kwargs): - self.loader.load(names=names, **kwargs) - - def update_components(self, **kwargs): - self.loader.update(**kwargs) - - @classmethod - @validate_hf_hub_args - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - trust_remote_code: Optional[bool] = None, - components_manager: Optional[ComponentsManager] = None, - collection: Optional[str] = None, - **kwargs, - ): - blocks = ModularPipelineBlocks.from_pretrained( - pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **kwargs - ) - pipeline = blocks.init_pipeline( - pretrained_model_name_or_path, components_manager=components_manager, collection=collection, **kwargs - ) - return pipeline - - def save_pretrained( - self, save_directory: Optional[Union[str, os.PathLike]] = None, push_to_hub: bool = False, **kwargs - ): - self.blocks.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - self.loader.save_pretrained(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) - - @property - def doc(self): - return self.blocks.doc - - def to(self, *args, **kwargs): - self.loader.to(*args, **kwargs) - return self - - @property - def components(self): - return self.loader.components + ) \ No newline at end of file diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index 37696f5dfa..90f4586753 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -191,7 +191,7 @@ class ComponentSpec: # YiYi TODO: I think we should only support ConfigMixin for this method (after we make guider and image_processors config mixin) # otherwise we cannot do spec -> spec.create() -> component -> ComponentSpec.from_component(component) # the config info is lost in the process - # remove error check in from_component spec and ModularLoader.update() if we remove support for non configmixin in `create()` method + # remove error check in from_component spec and ModularPipeline.update_components() if we remove support for non configmixin in `create()` method def create(self, config: Optional[Union[FrozenDict, Dict[str, Any]]] = None, **kwargs) -> Any: """Create component using from_config with config.""" diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py index 95461cfc23..59ec46dc6d 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/__init__.py @@ -22,7 +22,7 @@ except OptionalDependencyNotAvailable: _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) else: _import_structure["encoders"] = ["StableDiffusionXLTextEncoderStep"] - _import_structure["modular_blocks_presets"] = [ + _import_structure["modular_blocks"] = [ "ALL_BLOCKS", "AUTO_BLOCKS", "CONTROLNET_BLOCKS", @@ -36,7 +36,7 @@ else: "StableDiffusionXLAutoIPAdapterStep", "StableDiffusionXLAutoVaeEncoderStep", ] - _import_structure["modular_loader"] = ["StableDiffusionXLModularLoader"] + _import_structure["modular_pipeline"] = ["StableDiffusionXLModularPipeline"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: @@ -48,7 +48,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: from .encoders import ( StableDiffusionXLTextEncoderStep, ) - from .modular_blocks_presets import ( + from .modular_blocks import ( ALL_BLOCKS, AUTO_BLOCKS, CONTROLNET_BLOCKS, @@ -62,7 +62,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: StableDiffusionXLAutoIPAdapterStep, StableDiffusionXLAutoVaeEncoderStep, ) - from .modular_loader import StableDiffusionXLModularLoader + from .modular_pipeline import StableDiffusionXLModularPipeline else: import sys diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py index 04da975aec..b064a74cbf 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py @@ -30,7 +30,7 @@ from ..modular_pipeline import ( PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam -from .modular_loader import StableDiffusionXLModularLoader +from .modular_pipeline import StableDiffusionXLModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -338,7 +338,7 @@ class StableDiffusionXLInputStep(PipelineBlock): ) @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) self.check_inputs(components, block_state) @@ -388,7 +388,7 @@ class StableDiffusionXLInputStep(PipelineBlock): [negative_ip_adapter_embed] * block_state.num_images_per_prompt, dim=0 ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -491,7 +491,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): return timesteps, num_inference_steps @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -537,7 +537,7 @@ class StableDiffusionXLImg2ImgSetTimestepsStep(PipelineBlock): ) block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -576,7 +576,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): ] @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -606,7 +606,7 @@ class StableDiffusionXLSetTimestepsStep(PipelineBlock): ) block_state.timesteps = block_state.timesteps[: block_state.num_inference_steps] - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -851,7 +851,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): return mask, masked_image_latents @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype @@ -900,7 +900,7 @@ class StableDiffusionXLInpaintPrepareLatentsStep(PipelineBlock): block_state.generator, ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -961,7 +961,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): ] @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype @@ -981,7 +981,7 @@ class StableDiffusionXLImg2ImgPrepareLatentsStep(PipelineBlock): block_state.add_noise, ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1066,7 +1066,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): return latents @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) if block_state.dtype is None: @@ -1091,7 +1091,7 @@ class StableDiffusionXLPrepareLatentsStep(PipelineBlock): block_state.latents, ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1249,7 +1249,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): return emb @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -1304,7 +1304,7 @@ class StableDiffusionXLImg2ImgPrepareAdditionalConditioningStep(PipelineBlock): block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim ).to(device=block_state.device, dtype=block_state.latents.dtype) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1420,7 +1420,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): return emb @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.device = components._execution_device @@ -1475,7 +1475,7 @@ class StableDiffusionXLPrepareAdditionalConditioningStep(PipelineBlock): block_state.guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim ).to(device=block_state.device, dtype=block_state.latents.dtype) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1590,7 +1590,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): return image @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) # (1) prepare controlnet inputs @@ -1693,7 +1693,7 @@ class StableDiffusionXLControlNetInputStep(PipelineBlock): block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -1824,7 +1824,7 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): return image @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) controlnet = unwrap_module(components.controlnet) @@ -1904,6 +1904,6 @@ class StableDiffusionXLControlNetUnionInputStep(PipelineBlock): block_state.controlnet_cond = block_state.control_image block_state.conditioning_scale = block_state.controlnet_conditioning_scale - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py index 92b84b8595..878e991dbf 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/decoders.py @@ -152,7 +152,7 @@ class StableDiffusionXLDecodeStep(PipelineBlock): block_state.images, output_type=block_state.output_type ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -212,6 +212,6 @@ class StableDiffusionXLInpaintOverlayMaskStep(PipelineBlock): for i in block_state.images ] - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py index fd61c235c2..7fe4a472ee 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/denoise.py @@ -29,7 +29,7 @@ from ..modular_pipeline import ( PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam -from .modular_loader import StableDiffusionXLModularLoader +from .modular_pipeline import StableDiffusionXLModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -66,7 +66,7 @@ class StableDiffusionXLLoopBeforeDenoiser(PipelineBlock): ] @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) return components, block_state @@ -131,7 +131,7 @@ class StableDiffusionXLInpaintLoopBeforeDenoiser(PipelineBlock): ) @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): self.check_inputs(components, block_state) block_state.scaled_latents = components.scheduler.scale_model_input(block_state.latents, t) @@ -202,7 +202,7 @@ class StableDiffusionXLLoopDenoiser(PipelineBlock): @torch.no_grad() def __call__( - self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int + self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int ) -> PipelineState: # Map the keys we'll see on each `guider_state_batch` (e.g. guider_state_batch.prompt_embeds) # to the corresponding (cond, uncond) fields on block_state. (e.g. block_state.prompt_embeds, block_state.negative_prompt_embeds) @@ -347,7 +347,7 @@ class StableDiffusionXLControlNetLoopDenoiser(PipelineBlock): return extra_kwargs @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): extra_controlnet_kwargs = self.prepare_extra_kwargs( components.controlnet.forward, **block_state.controlnet_kwargs ) @@ -494,7 +494,7 @@ class StableDiffusionXLLoopAfterDenoiser(PipelineBlock): return extra_kwargs @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline block_state.extra_step_kwargs = self.prepare_extra_kwargs( components.scheduler.step, generator=block_state.generator, eta=block_state.eta @@ -595,7 +595,7 @@ class StableDiffusionXLInpaintLoopAfterDenoiser(PipelineBlock): raise ValueError(f"noise is required for this step {self.__class__.__name__}") @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, block_state: BlockState, i: int, t: int): + def __call__(self, components: StableDiffusionXLModularPipeline, block_state: BlockState, i: int, t: int): self.check_inputs(components, block_state) # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline @@ -677,7 +677,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): ] @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.disable_guidance = True if components.unet.config.time_cond_proj_dim is not None else False @@ -698,7 +698,7 @@ class StableDiffusionXLDenoiseLoopWrapper(LoopSequentialPipelineBlocks): ): progress_bar.update() - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py index b4526537d7..bd0e962140 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/encoders.py @@ -37,7 +37,7 @@ from ...utils import ( ) from ..modular_pipeline import PipelineBlock, PipelineState from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam -from .modular_loader import StableDiffusionXLModularLoader +from .modular_pipeline import StableDiffusionXLModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -65,8 +65,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): return ( "IP Adapter step that prepares ip adapter image embeddings.\n" "Note that this step only prepares the embeddings - in order for it to work correctly, " - "you need to load ip adapter weights into unet via ModularPipeline.loader.\n" - "e.g. pipeline.loader.load_ip_adapter() and pipeline.loader.set_ip_adapter_scale().\n" + "you need to load ip adapter weights into unet via ModularPipeline.load_ip_adapter() and pipeline.set_ip_adapter_scale().\n" "See [ModularIPAdapterMixin](https://huggingface.co/docs/diffusers/api/loaders/ip_adapter#diffusers.loaders.ModularIPAdapterMixin)" " for more details" ) @@ -191,7 +190,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): return ip_adapter_image_embeds @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 @@ -212,7 +211,7 @@ class StableDiffusionXLIPAdapterStep(PipelineBlock): block_state.negative_ip_adapter_embeds.append(negative_image_embeds) block_state.ip_adapter_embeds[i] = image_embeds - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -537,7 +536,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: # Get inputs and intermediates block_state = self.get_block_state(state) self.check_inputs(block_state) @@ -573,7 +572,7 @@ class StableDiffusionXLTextEncoderStep(PipelineBlock): clip_skip=block_state.clip_skip, ) # Add outputs - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -663,7 +662,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): return image_latents @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} block_state.device = components._execution_device @@ -687,7 +686,7 @@ class StableDiffusionXLVaeEncoderStep(PipelineBlock): components, image=block_state.image, generator=block_state.generator ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state @@ -841,7 +840,7 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): return mask, masked_image_latents @torch.no_grad() - def __call__(self, components: StableDiffusionXLModularLoader, state: PipelineState) -> PipelineState: + def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype @@ -898,6 +897,6 @@ class StableDiffusionXLInpaintVaeEncoderStep(PipelineBlock): block_state.generator, ) - self.add_block_state(state, block_state) + self.set_block_state(state, block_state) return components, state diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py similarity index 100% rename from src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks_presets.py rename to src/diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py diff --git a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py similarity index 99% rename from src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py rename to src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py index 82c4d6de0f..90850ea536 100644 --- a/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_loader.py +++ b/src/diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py @@ -23,7 +23,7 @@ from ...loaders import ModularIPAdapterMixin, StableDiffusionXLLoraLoaderMixin, from ...pipelines.pipeline_utils import StableDiffusionMixin from ...pipelines.stable_diffusion_xl.pipeline_output import StableDiffusionXLPipelineOutput from ...utils import logging -from ..modular_pipeline import ModularLoader +from ..modular_pipeline import ModularPipeline from ..modular_pipeline_utils import InputParam, OutputParam @@ -32,13 +32,13 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name # YiYi TODO: move to a different file? stable_diffusion_xl_module should have its own folder? # YiYi Notes: model specific components: -## (1) it should inherit from ModularLoader +## (1) it should inherit from ModularPipeline ## (2) acts like a container that holds components and configs ## (3) define default config (related to components), e.g. default_sample_size, vae_scale_factor, num_channels_unet, num_channels_latents ## (4) inherit from model-specic loader class (e.g. StableDiffusionXLLoraLoaderMixin) ## (5) how to use together with Components_manager? -class StableDiffusionXLModularLoader( - ModularLoader, +class StableDiffusionXLModularPipeline( + ModularPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, diff --git a/src/diffusers/pipelines/pipeline_loading_utils.py b/src/diffusers/pipelines/pipeline_loading_utils.py index 7e48cca093..b5ac6cc301 100644 --- a/src/diffusers/pipelines/pipeline_loading_utils.py +++ b/src/diffusers/pipelines/pipeline_loading_utils.py @@ -468,7 +468,7 @@ def _get_pipeline_class( revision=revision, ) - if class_obj.__name__ != "DiffusionPipeline": + if class_obj.__name__ != "DiffusionPipeline" and class_obj.__name__ != "ModularPipeline": return class_obj diffusers_module = importlib.import_module(class_obj.__module__.split(".")[0]) diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 496039a436..b192b58531 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -1349,21 +1349,6 @@ class ComponentSpec(metaclass=DummyObject): requires_backends(cls, ["torch"]) -class ModularLoader(metaclass=DummyObject): - _backends = ["torch"] - - def __init__(self, *args, **kwargs): - requires_backends(self, ["torch"]) - - @classmethod - def from_config(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - @classmethod - def from_pretrained(cls, *args, **kwargs): - requires_backends(cls, ["torch"]) - - class ModularPipeline(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index a9daf50a7a..62f1735695 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -17,7 +17,7 @@ class StableDiffusionXLAutoBlocks(metaclass=DummyObject): requires_backends(cls, ["torch", "transformers"]) -class StableDiffusionXLModularLoader(metaclass=DummyObject): +class StableDiffusionXLModularPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] def __init__(self, *args, **kwargs):