mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix style/quality for code
This commit is contained in:
@@ -481,13 +481,13 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if '|' in prompt:
|
||||
prompt = [x.strip() for x in prompt.split('|')]
|
||||
if "|" in prompt:
|
||||
prompt = [x.strip() for x in prompt.split("|")]
|
||||
print(f"composing {prompt}...")
|
||||
|
||||
if not weights:
|
||||
# specify weights for prompts (excluding the unconditional score)
|
||||
print('using equal positive weights (conjunction) for all prompts...')
|
||||
print("using equal positive weights (conjunction) for all prompts...")
|
||||
weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)
|
||||
else:
|
||||
# set prompt weight for each
|
||||
@@ -546,7 +546,9 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(dim=0, keepdims=True)
|
||||
noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(
|
||||
dim=0, keepdims=True
|
||||
)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
||||
@@ -570,4 +572,4 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
||||
|
||||
@@ -336,7 +336,10 @@ class TextualInversionDataset(Dataset):
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
||||
@@ -306,7 +306,10 @@ class TextualInversionDataset(Dataset):
|
||||
|
||||
if self.center_crop:
|
||||
crop = min(img.shape[0], img.shape[1])
|
||||
h, w, = (
|
||||
(
|
||||
h,
|
||||
w,
|
||||
) = (
|
||||
img.shape[0],
|
||||
img.shape[1],
|
||||
)
|
||||
|
||||
@@ -37,6 +37,7 @@ def rename_key(key):
|
||||
# PyTorch => Flax #
|
||||
#####################
|
||||
|
||||
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
|
||||
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
|
||||
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):
|
||||
|
||||
@@ -288,8 +288,10 @@ class AttentionBlock(nn.Module):
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers"
|
||||
),
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
@@ -450,8 +452,10 @@ class BasicTransformerBlock(nn.Module):
|
||||
if not is_xformers_available():
|
||||
print("Here is how to install it")
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers"
|
||||
),
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
|
||||
@@ -189,9 +189,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep.",
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
|
||||
@@ -198,9 +198,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
or isinstance(timestep, torch.LongTensor)
|
||||
):
|
||||
raise ValueError(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep.",
|
||||
(
|
||||
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
||||
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
|
||||
" one of the `scheduler.timesteps` as a timestep."
|
||||
),
|
||||
)
|
||||
|
||||
if not self.is_scale_input_called:
|
||||
|
||||
@@ -535,8 +535,10 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "scale_model_input"),
|
||||
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
|
||||
" timestep)`",
|
||||
(
|
||||
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
|
||||
" timestep)`"
|
||||
),
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "step"),
|
||||
|
||||
@@ -96,6 +96,7 @@ def ignore_underscore(key):
|
||||
|
||||
def sort_objects(objects, key=None):
|
||||
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
|
||||
|
||||
# If no key is provided, we use a noop.
|
||||
def noop(x):
|
||||
return x
|
||||
@@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement):
|
||||
"""
|
||||
Return the same `import_statement` but with objects properly sorted.
|
||||
"""
|
||||
|
||||
# This inner function sort imports between [ ].
|
||||
def _replace(match):
|
||||
imports = match.groups()[0]
|
||||
|
||||
Reference in New Issue
Block a user