1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix style/quality for code

This commit is contained in:
Nan Liu
2022-12-13 18:38:36 -06:00
parent 8ee23f22d1
commit 71f2349763
9 changed files with 40 additions and 19 deletions

View File

@@ -481,13 +481,13 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if '|' in prompt:
prompt = [x.strip() for x in prompt.split('|')]
if "|" in prompt:
prompt = [x.strip() for x in prompt.split("|")]
print(f"composing {prompt}...")
if not weights:
# specify weights for prompts (excluding the unconditional score)
print('using equal positive weights (conjunction) for all prompts...')
print("using equal positive weights (conjunction) for all prompts...")
weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1)
else:
# set prompt weight for each
@@ -546,7 +546,9 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(dim=0, keepdims=True)
noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(
dim=0, keepdims=True
)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
@@ -570,4 +572,4 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline):
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@@ -336,7 +336,10 @@ class TextualInversionDataset(Dataset):
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)

View File

@@ -306,7 +306,10 @@ class TextualInversionDataset(Dataset):
if self.center_crop:
crop = min(img.shape[0], img.shape[1])
h, w, = (
(
h,
w,
) = (
img.shape[0],
img.shape[1],
)

View File

@@ -37,6 +37,7 @@ def rename_key(key):
# PyTorch => Flax #
#####################
# Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69
# and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py
def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict):

View File

@@ -288,8 +288,10 @@ class AttentionBlock(nn.Module):
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():
@@ -450,8 +452,10 @@ class BasicTransformerBlock(nn.Module):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers"
),
name="xformers",
)
elif not torch.cuda.is_available():

View File

@@ -189,9 +189,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep.",
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if not self.is_scale_input_called:

View File

@@ -198,9 +198,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
or isinstance(timestep, torch.LongTensor)
):
raise ValueError(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep.",
(
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
" `EulerDiscreteScheduler.step()` is not supported. Make sure to pass"
" one of the `scheduler.timesteps` as a timestep."
),
)
if not self.is_scale_input_called:

View File

@@ -535,8 +535,10 @@ class SchedulerCommonTest(unittest.TestCase):
)
self.assertTrue(
hasattr(scheduler, "scale_model_input"),
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
" timestep)`",
(
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
" timestep)`"
),
)
self.assertTrue(
hasattr(scheduler, "step"),

View File

@@ -96,6 +96,7 @@ def ignore_underscore(key):
def sort_objects(objects, key=None):
"Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str."
# If no key is provided, we use a noop.
def noop(x):
return x
@@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement):
"""
Return the same `import_statement` but with objects properly sorted.
"""
# This inner function sort imports between [ ].
def _replace(match):
imports = match.groups()[0]