mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
solve merge conflict: manually add back the remote code change to modular_pipeline
This commit is contained in:
@@ -11,12 +11,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import inspect
|
||||
|
||||
|
||||
import traceback
|
||||
import warnings
|
||||
from collections import OrderedDict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional, Type
|
||||
from typing import Any, Dict, List, Tuple, Union, Optional
|
||||
from copy import deepcopy
|
||||
|
||||
|
||||
@@ -31,11 +33,10 @@ from huggingface_hub.utils import validate_hf_hub_args
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict
|
||||
from ..utils import (
|
||||
is_accelerate_available,
|
||||
is_accelerate_version,
|
||||
logging,
|
||||
PushToHubMixin,
|
||||
)
|
||||
from ..pipelines.pipeline_loading_utils import _get_pipeline_class, simple_get_class_obj, _fetch_class_library_tuple
|
||||
from ..pipelines.pipeline_loading_utils import simple_get_class_obj, _fetch_class_library_tuple
|
||||
from .modular_pipeline_utils import (
|
||||
ComponentSpec,
|
||||
ConfigSpec,
|
||||
@@ -43,14 +44,12 @@ from .modular_pipeline_utils import (
|
||||
OutputParam,
|
||||
format_components,
|
||||
format_configs,
|
||||
format_input_params,
|
||||
format_inputs_short,
|
||||
format_intermediates_short,
|
||||
format_output_params,
|
||||
format_params,
|
||||
make_doc_string,
|
||||
)
|
||||
from .components_manager import ComponentsManager
|
||||
from ..utils.dynamic_modules_utils import get_class_from_dynamic_module, resolve_trust_remote_code
|
||||
|
||||
from copy import deepcopy
|
||||
if is_accelerate_available():
|
||||
@@ -245,19 +244,76 @@ class BlockState:
|
||||
|
||||
|
||||
|
||||
class ModularPipelineMixin:
|
||||
class ModularPipelineMixin(ConfigMixin):
|
||||
"""
|
||||
Mixin for all PipelineBlocks: PipelineBlock, AutoPipelineBlocks, SequentialPipelineBlocks
|
||||
"""
|
||||
|
||||
config_name = "config.json"
|
||||
|
||||
@classmethod
|
||||
def _get_signature_keys(cls, obj):
|
||||
parameters = inspect.signature(obj.__init__).parameters
|
||||
required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
|
||||
optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
|
||||
expected_modules = set(required_parameters.keys()) - {"self"}
|
||||
|
||||
return expected_modules, optional_parameters
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
**kwargs,
|
||||
):
|
||||
hub_kwargs_names = [
|
||||
"cache_dir",
|
||||
"force_download",
|
||||
"local_files_only",
|
||||
"proxies",
|
||||
"resume_download",
|
||||
"revision",
|
||||
"subfolder",
|
||||
"token",
|
||||
]
|
||||
hub_kwargs = {name: kwargs.pop(name) for name in hub_kwargs_names if name in kwargs}
|
||||
|
||||
config = cls.load_config(pretrained_model_name_or_path)
|
||||
has_remote_code = "auto_map" in config and cls.__name__ in config["auto_map"]
|
||||
trust_remote_code = resolve_trust_remote_code(
|
||||
trust_remote_code, pretrained_model_name_or_path, has_remote_code
|
||||
)
|
||||
if not (has_remote_code and trust_remote_code):
|
||||
raise ValueError("TODO")
|
||||
|
||||
class_ref = config["auto_map"][cls.__name__]
|
||||
module_file, class_name = class_ref.split(".")
|
||||
module_file = module_file + ".py"
|
||||
block_cls = get_class_from_dynamic_module(
|
||||
pretrained_model_name_or_path,
|
||||
module_file=module_file,
|
||||
class_name=class_name,
|
||||
is_modular=True,
|
||||
**hub_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
expected_kwargs, optional_kwargs = block_cls._get_signature_keys(block_cls)
|
||||
block_kwargs = {
|
||||
name: kwargs.pop(name) for name in kwargs if name in expected_kwargs or name in optional_kwargs
|
||||
}
|
||||
print(f"block_kwargs: {block_kwargs}")
|
||||
|
||||
return block_cls(**block_kwargs)
|
||||
|
||||
def setup_loader(self, modular_repo: Optional[Union[str, os.PathLike]] = None, component_manager: Optional[ComponentsManager] = None, collection: Optional[str] = None):
|
||||
"""
|
||||
create a mouldar loader, optionally accept modular_repo to load from hub.
|
||||
create a ModularLoader, optionally accept modular_repo to load from hub.
|
||||
"""
|
||||
|
||||
# Import components loader (it is model-specific class)
|
||||
loader_class_name = MODULAR_LOADER_MAPPING[self.model_name]
|
||||
loader_class_name = MODULAR_LOADER_MAPPING.get(self.model_name, ModularLoader.__name__)
|
||||
diffusers_module = importlib.import_module("diffusers")
|
||||
loader_class = getattr(diffusers_module, loader_class_name)
|
||||
|
||||
@@ -365,7 +421,8 @@ class PipelineBlock(ModularPipelineMixin):
|
||||
@property
|
||||
def description(self) -> str:
|
||||
"""Description of the block. Must be implemented by subclasses."""
|
||||
raise NotImplementedError("description method must be implemented in subclasses")
|
||||
# raise NotImplementedError("description method must be implemented in subclasses")
|
||||
return "TODO: add a description"
|
||||
|
||||
@property
|
||||
def expected_components(self) -> List[ComponentSpec]:
|
||||
|
||||
Reference in New Issue
Block a user