1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00
This commit is contained in:
Dhruv Nair
2024-01-22 09:13:08 +00:00
parent 0746cf957a
commit dbfb8f1ea9
4 changed files with 18 additions and 15 deletions

View File

@@ -43,14 +43,13 @@ def build_sub_model_components(
checkpoint,
local_files_only=False,
load_safety_checker=False,
**kwargs,
model_type=None,
image_size=None,
**kwargs
):
if component_name in pipeline_components:
return {}
model_type = kwargs.pop("model_type", None)
image_size = kwargs.pop("image_size", None)
if component_name == "unet":
num_in_channels = kwargs.pop("num_in_channels", None)
unet_components = create_diffusers_unet_model_from_ldm(
@@ -112,10 +111,9 @@ def build_sub_model_components(
def set_additional_components(
pipeline_class_name,
original_config,
**kwargs,
model_type=None,
):
components = {}
model_type = kwargs.get("model_type", None)
if pipeline_class_name in REFINER_PIPELINES:
model_type = infer_model_type(original_config, model_type=model_type)
is_refiner = model_type == "SDXL-Refiner"
@@ -235,6 +233,9 @@ class FromSingleFileMixin:
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
model_type = kwargs.pop("model_type", None)
image_size = kwargs.pop("image_size", None)
init_kwargs = {}
for name in expected_modules:
if name in passed_class_obj:
@@ -247,13 +248,15 @@ class FromSingleFileMixin:
original_config,
checkpoint,
pretrained_model_link_or_path,
model_type=model_type,
image_size=image_size,
**kwargs,
)
if not components:
continue
init_kwargs.update(components)
additional_components = set_additional_components(class_name, original_config, **kwargs)
additional_components = set_additional_components(class_name, original_config, model_type=model_type)
if additional_components:
init_kwargs.update(additional_components)

View File

@@ -191,14 +191,14 @@ SD_2_TEXT_ENCODER_KEYS_TO_IGNORE = [
]
VALID_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
VALID_HF_URL_PREFIXES = ["https://huggingface.co/", "huggingface.co/", "hf.co/", "https://hf.co/"]
def _extract_repo_id_and_weights_name(pretrained_model_name_or_path):
pattern = r"([^/]+)/([^/]+)/(?:blob/main/)?(.+)"
weights_name = None
repo_id = (None,)
for prefix in VALID_URL_PREFIXES:
for prefix in VALID_HF_URL_PREFIXES:
pretrained_model_name_or_path = pretrained_model_name_or_path.replace(prefix, "")
match = re.match(pattern, pretrained_model_name_or_path)
if not match:

View File

@@ -533,7 +533,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
Args:
sample (`torch.FloatTensor`):
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
The noisy input tensor with the following shape `(batch, channel, num_frames, height, width)`.
timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
encoder_hidden_states (`torch.FloatTensor`):
The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.

View File

@@ -753,14 +753,14 @@ class StableDiffusionInpaintPipelineSlowTests(unittest.TestCase):
def test_download_ckpt_diff_format_is_same(self):
ckpt_path = "https://huggingface.co/runwayml/stable-diffusion-inpainting/blob/main/sd-v1-5-inpainting.ckpt"
pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
pipe.unet.set_attn_processor(AttnProcessor())
pipe.to("cuda")
sf_pipe = StableDiffusionInpaintPipeline.from_single_file(ckpt_path)
sf_pipe.scheduler = DDIMScheduler.from_config(sf_pipe.scheduler.config)
sf_pipe.unet.set_attn_processor(AttnProcessor())
sf_pipe.to("cuda")
inputs = self.get_inputs(torch_device)
inputs["num_inference_steps"] = 5
image_ckpt = pipe(**inputs).images[0]
image_ckpt = sf_pipe(**inputs).images[0]
pipe = StableDiffusionInpaintPipeline.from_pretrained("runwayml/stable-diffusion-inpainting")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)