From 71f23497639fe69de00d93cf91edc31b08dcd7a4 Mon Sep 17 00:00:00 2001 From: Nan Liu <45443761+nanlliu@users.noreply.github.com> Date: Tue, 13 Dec 2022 18:38:36 -0600 Subject: [PATCH] fix style/quality for code --- examples/community/composable_stable_diffusion.py | 12 +++++++----- examples/textual_inversion/textual_inversion.py | 5 ++++- examples/textual_inversion/textual_inversion_flax.py | 5 ++++- src/diffusers/modeling_flax_pytorch_utils.py | 1 + src/diffusers/models/attention.py | 12 ++++++++---- .../scheduling_euler_ancestral_discrete.py | 8 +++++--- .../schedulers/scheduling_euler_discrete.py | 8 +++++--- tests/test_scheduler.py | 6 ++++-- utils/custom_init_isort.py | 2 ++ 9 files changed, 40 insertions(+), 19 deletions(-) diff --git a/examples/community/composable_stable_diffusion.py b/examples/community/composable_stable_diffusion.py index 8bf2953967..4775f2e76a 100644 --- a/examples/community/composable_stable_diffusion.py +++ b/examples/community/composable_stable_diffusion.py @@ -481,13 +481,13 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - if '|' in prompt: - prompt = [x.strip() for x in prompt.split('|')] + if "|" in prompt: + prompt = [x.strip() for x in prompt.split("|")] print(f"composing {prompt}...") if not weights: # specify weights for prompts (excluding the unconditional score) - print('using equal positive weights (conjunction) for all prompts...') + print("using equal positive weights (conjunction) for all prompts...") weights = torch.tensor([guidance_scale] * len(prompt), device=self.device).reshape(-1, 1, 1, 1) else: # set prompt weight for each @@ -546,7 +546,9 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum(dim=0, keepdims=True) + noise_pred = noise_pred_uncond + (weights * (noise_pred_text - noise_pred_uncond)).sum( + dim=0, keepdims=True + ) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample @@ -570,4 +572,4 @@ class ComposableStableDiffusionPipeline(DiffusionPipeline): if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) \ No newline at end of file + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index 380a8b4738..1f5a3c2af7 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -336,7 +336,10 @@ class TextualInversionDataset(Dataset): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/examples/textual_inversion/textual_inversion_flax.py b/examples/textual_inversion/textual_inversion_flax.py index a9fa9e3693..390602a74c 100644 --- a/examples/textual_inversion/textual_inversion_flax.py +++ b/examples/textual_inversion/textual_inversion_flax.py @@ -306,7 +306,10 @@ class TextualInversionDataset(Dataset): if self.center_crop: crop = min(img.shape[0], img.shape[1]) - h, w, = ( + ( + h, + w, + ) = ( img.shape[0], img.shape[1], ) diff --git a/src/diffusers/modeling_flax_pytorch_utils.py b/src/diffusers/modeling_flax_pytorch_utils.py index 9c7a5de2ad..cfeecf0f10 100644 --- a/src/diffusers/modeling_flax_pytorch_utils.py +++ b/src/diffusers/modeling_flax_pytorch_utils.py @@ -37,6 +37,7 @@ def rename_key(key): # PyTorch => Flax # ##################### + # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict): diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 99dd5d8d51..eddeb151c7 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -288,8 +288,10 @@ class AttentionBlock(nn.Module): def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): if not is_xformers_available(): raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), name="xformers", ) elif not torch.cuda.is_available(): @@ -450,8 +452,10 @@ class BasicTransformerBlock(nn.Module): if not is_xformers_available(): print("Here is how to install it") raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", + ( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers" + ), name="xformers", ) elif not torch.cuda.is_available(): diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index f5905a3f83..bc28d81bcf 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -189,9 +189,11 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): or isinstance(timestep, torch.LongTensor) ): raise ValueError( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep.", + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), ) if not self.is_scale_input_called: diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 9cb4a1eaa5..0f0a14272c 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -198,9 +198,11 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin): or isinstance(timestep, torch.LongTensor) ): raise ValueError( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep.", + ( + "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" + " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" + " one of the `scheduler.timesteps` as a timestep." + ), ) if not self.is_scale_input_called: diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index ba109bdf47..4434232ee5 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -535,8 +535,10 @@ class SchedulerCommonTest(unittest.TestCase): ) self.assertTrue( hasattr(scheduler, "scale_model_input"), - f"{scheduler_class} does not implement a required class method `scale_model_input(sample," - " timestep)`", + ( + f"{scheduler_class} does not implement a required class method `scale_model_input(sample," + " timestep)`" + ), ) self.assertTrue( hasattr(scheduler, "step"), diff --git a/utils/custom_init_isort.py b/utils/custom_init_isort.py index 44165d1fce..a204a155fe 100644 --- a/utils/custom_init_isort.py +++ b/utils/custom_init_isort.py @@ -96,6 +96,7 @@ def ignore_underscore(key): def sort_objects(objects, key=None): "Sort a list of `objects` following the rules of isort. `key` optionally maps an object to a str." + # If no key is provided, we use a noop. def noop(x): return x @@ -117,6 +118,7 @@ def sort_objects_in_import(import_statement): """ Return the same `import_statement` but with objects properly sorted. """ + # This inner function sort imports between [ ]. def _replace(match): imports = match.groups()[0]