1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add circular padding for artifact-free StableDiffusionPanoramaPipeline (#4025)

* Add circular padding option

* Fix style with black

* Fix corner case with small image size

* Add circular padding test cases

* Fix docstring

* Improve docstring for circular padding, remove slow test case

* Update docs for circular padding argument

* Add images comparison for circular padding
This commit is contained in:
Evgenii Kashin
2023-07-12 16:19:46 +01:00
committed by GitHub
parent 4b50ecceb0
commit af48bf2008
3 changed files with 117 additions and 10 deletions

View File

@@ -60,6 +60,25 @@ and increase the VRAM usage.
</Tip>
<Tip>
Circular padding is applied to ensure there are no stitching artifacts when working with
panoramas that needs to seamlessly transition from the rightmost part to the leftmost part.
By enabling circular padding (set `circular_padding=True`), the operation applies additional
crops after the rightmost point of the image, allowing the model to "see” the transition
from the rightmost part to the leftmost part. This helps maintain visual consistency in
a 360-degree sense and creates a proper “panorama” that can be viewed using 360-degree
panorama viewers. When decoding latents in StableDiffusion, circular padding is applied
to ensure that the decoded latents match in the RGB space.
Without circular padding, there is a stitching artifact (default):
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20no_circular_padding.png)
With circular padding, the right and the left parts are matching (`circular_padding=True`):
![img](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/indoor_%20circular_padding.png)
</Tip>
## StableDiffusionPanoramaPipeline
[[autodoc]] StableDiffusionPanoramaPipeline
- __call__

View File

@@ -373,6 +373,19 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
return image
def decode_latents_with_padding(self, latents, padding=8):
# Add padding to latents for circular inference
# padding is the number of latents to add on each side
# it would slightly increase the memory usage, but remove the boundary artifacts
latents = 1 / self.vae.config.scaling_factor * latents
latents_left = latents[..., :padding]
latents_right = latents[..., -padding:]
latents = torch.cat((latents_right, latents, latents_left), axis=-1)
image = self.vae.decode(latents, return_dict=False)[0]
padding_pix = self.vae_scale_factor * padding
image = image[..., padding_pix:-padding_pix]
return image
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -457,13 +470,16 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
latents = latents * self.scheduler.init_noise_sigma
return latents
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False):
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
# if panorama's height/width < window_size, num_blocks of height/width should return 1
panorama_height /= 8
panorama_width /= 8
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1
if circular_padding:
num_blocks_width = panorama_width // stride if panorama_width > window_size else 1
else:
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1
total_num_blocks = int(num_blocks_height * num_blocks_width)
views = []
for i in range(total_num_blocks):
@@ -496,6 +512,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
circular_padding: bool = False,
):
r"""
Function invoked when calling the pipeline for generation.
@@ -560,6 +577,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
circular_padding (`bool`, *optional*, defaults to `False`):
If set to True, circular padding is applied to ensure there are no stitching artifacts. Circular
padding allows the model to seamlessly generate a transition from the rightmost part of the image to
the leftmost part, maintaining consistency in a 360-degree sense.
Examples:
@@ -627,10 +648,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
# 6. Define panorama grid and initialize views for synthesis.
# prepare batch grid
views = self.get_views(height, width)
views = self.get_views(height, width, circular_padding=circular_padding)
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch)
count = torch.zeros_like(latents)
value = torch.zeros_like(latents)
@@ -655,9 +675,29 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
for j, batch_view in enumerate(views_batch):
vb_size = len(batch_view)
# get the latents corresponding to the current view coordinates
latents_for_view = torch.cat(
[latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view]
)
if circular_padding:
latents_for_view = []
for h_start, h_end, w_start, w_end in batch_view:
if w_end > latents.shape[3]:
# Add circular horizontal padding
latent_view = torch.cat(
(
latents[:, :, h_start:h_end, w_start:],
latents[:, :, h_start:h_end, : w_end - latents.shape[3]],
),
axis=-1,
)
else:
latent_view = latents[:, :, h_start:h_end, w_start:w_end]
latents_for_view.append(latent_view)
latents_for_view = torch.cat(latents_for_view)
else:
latents_for_view = torch.cat(
[
latents[:, :, h_start:h_end, w_start:w_end]
for h_start, h_end, w_start, w_end in batch_view
]
)
# rematch block's scheduler status
self.scheduler.__dict__.update(views_scheduler_status[j])
@@ -698,8 +738,19 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
latents_denoised_batch.chunk(vb_size), batch_view
):
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
if circular_padding and w_end > latents.shape[3]:
# Case for circular padding
value[:, :, h_start:h_end, w_start:] += latents_view_denoised[
:, :, h_start:h_end, : latents.shape[3] - w_start
]
value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[
:, :, h_start:h_end, latents.shape[3] - w_start :
]
count[:, :, h_start:h_end, w_start:] += 1
count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1
else:
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
count[:, :, h_start:h_end, w_start:w_end] += 1
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
latents = torch.where(count > 0, value / count, value)
@@ -711,7 +762,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
callback(i, t, latents)
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
if circular_padding:
image = self.decode_latents_with_padding(latents)
else:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents

View File

@@ -125,6 +125,22 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_panorama_circular_padding_case(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs, circular_padding=True).images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
# override to speed the overall test timing up.
def test_inference_batch_consistent(self):
super().test_inference_batch_consistent(batch_sizes=[1, 2])
@@ -170,6 +186,24 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_panorama_views_batch_circular_padding(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionPanoramaPipeline(**components)
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, circular_padding=True, view_batch_size=2)
image = output.images
image_slice = image[0, -3:, -3:, -1]
assert image.shape == (1, 64, 64, 3)
expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
def test_stable_diffusion_panorama_euler(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()