mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add circular padding for artifact-free StableDiffusionPanoramaPipeline (#4025)
* Add circular padding option * Fix style with black * Fix corner case with small image size * Add circular padding test cases * Fix docstring * Improve docstring for circular padding, remove slow test case * Update docs for circular padding argument * Add images comparison for circular padding
This commit is contained in:
@@ -60,6 +60,25 @@ and increase the VRAM usage.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Circular padding is applied to ensure there are no stitching artifacts when working with
|
||||
panoramas that needs to seamlessly transition from the rightmost part to the leftmost part.
|
||||
By enabling circular padding (set `circular_padding=True`), the operation applies additional
|
||||
crops after the rightmost point of the image, allowing the model to "see” the transition
|
||||
from the rightmost part to the leftmost part. This helps maintain visual consistency in
|
||||
a 360-degree sense and creates a proper “panorama” that can be viewed using 360-degree
|
||||
panorama viewers. When decoding latents in StableDiffusion, circular padding is applied
|
||||
to ensure that the decoded latents match in the RGB space.
|
||||
|
||||
Without circular padding, there is a stitching artifact (default):
|
||||

|
||||
|
||||
With circular padding, the right and the left parts are matching (`circular_padding=True`):
|
||||

|
||||
|
||||
</Tip>
|
||||
|
||||
## StableDiffusionPanoramaPipeline
|
||||
[[autodoc]] StableDiffusionPanoramaPipeline
|
||||
- __call__
|
||||
|
||||
@@ -373,6 +373,19 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
return image
|
||||
|
||||
def decode_latents_with_padding(self, latents, padding=8):
|
||||
# Add padding to latents for circular inference
|
||||
# padding is the number of latents to add on each side
|
||||
# it would slightly increase the memory usage, but remove the boundary artifacts
|
||||
latents = 1 / self.vae.config.scaling_factor * latents
|
||||
latents_left = latents[..., :padding]
|
||||
latents_right = latents[..., -padding:]
|
||||
latents = torch.cat((latents_right, latents, latents_left), axis=-1)
|
||||
image = self.vae.decode(latents, return_dict=False)[0]
|
||||
padding_pix = self.vae_scale_factor * padding
|
||||
image = image[..., padding_pix:-padding_pix]
|
||||
return image
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
|
||||
def prepare_extra_step_kwargs(self, generator, eta):
|
||||
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
|
||||
@@ -457,13 +470,16 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
latents = latents * self.scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8):
|
||||
def get_views(self, panorama_height, panorama_width, window_size=64, stride=8, circular_padding=False):
|
||||
# Here, we define the mappings F_i (see Eq. 7 in the MultiDiffusion paper https://arxiv.org/abs/2302.08113)
|
||||
# if panorama's height/width < window_size, num_blocks of height/width should return 1
|
||||
panorama_height /= 8
|
||||
panorama_width /= 8
|
||||
num_blocks_height = (panorama_height - window_size) // stride + 1 if panorama_height > window_size else 1
|
||||
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1
|
||||
if circular_padding:
|
||||
num_blocks_width = panorama_width // stride if panorama_width > window_size else 1
|
||||
else:
|
||||
num_blocks_width = (panorama_width - window_size) // stride + 1 if panorama_width > window_size else 1
|
||||
total_num_blocks = int(num_blocks_height * num_blocks_width)
|
||||
views = []
|
||||
for i in range(total_num_blocks):
|
||||
@@ -496,6 +512,7 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
circular_padding: bool = False,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
@@ -560,6 +577,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
||||
circular_padding (`bool`, *optional*, defaults to `False`):
|
||||
If set to True, circular padding is applied to ensure there are no stitching artifacts. Circular
|
||||
padding allows the model to seamlessly generate a transition from the rightmost part of the image to
|
||||
the leftmost part, maintaining consistency in a 360-degree sense.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -627,10 +648,9 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
|
||||
# 6. Define panorama grid and initialize views for synthesis.
|
||||
# prepare batch grid
|
||||
views = self.get_views(height, width)
|
||||
views = self.get_views(height, width, circular_padding=circular_padding)
|
||||
views_batch = [views[i : i + view_batch_size] for i in range(0, len(views), view_batch_size)]
|
||||
views_scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(views_batch)
|
||||
|
||||
count = torch.zeros_like(latents)
|
||||
value = torch.zeros_like(latents)
|
||||
|
||||
@@ -655,9 +675,29 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
for j, batch_view in enumerate(views_batch):
|
||||
vb_size = len(batch_view)
|
||||
# get the latents corresponding to the current view coordinates
|
||||
latents_for_view = torch.cat(
|
||||
[latents[:, :, h_start:h_end, w_start:w_end] for h_start, h_end, w_start, w_end in batch_view]
|
||||
)
|
||||
if circular_padding:
|
||||
latents_for_view = []
|
||||
for h_start, h_end, w_start, w_end in batch_view:
|
||||
if w_end > latents.shape[3]:
|
||||
# Add circular horizontal padding
|
||||
latent_view = torch.cat(
|
||||
(
|
||||
latents[:, :, h_start:h_end, w_start:],
|
||||
latents[:, :, h_start:h_end, : w_end - latents.shape[3]],
|
||||
),
|
||||
axis=-1,
|
||||
)
|
||||
else:
|
||||
latent_view = latents[:, :, h_start:h_end, w_start:w_end]
|
||||
latents_for_view.append(latent_view)
|
||||
latents_for_view = torch.cat(latents_for_view)
|
||||
else:
|
||||
latents_for_view = torch.cat(
|
||||
[
|
||||
latents[:, :, h_start:h_end, w_start:w_end]
|
||||
for h_start, h_end, w_start, w_end in batch_view
|
||||
]
|
||||
)
|
||||
|
||||
# rematch block's scheduler status
|
||||
self.scheduler.__dict__.update(views_scheduler_status[j])
|
||||
@@ -698,8 +738,19 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
for latents_view_denoised, (h_start, h_end, w_start, w_end) in zip(
|
||||
latents_denoised_batch.chunk(vb_size), batch_view
|
||||
):
|
||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||
if circular_padding and w_end > latents.shape[3]:
|
||||
# Case for circular padding
|
||||
value[:, :, h_start:h_end, w_start:] += latents_view_denoised[
|
||||
:, :, h_start:h_end, : latents.shape[3] - w_start
|
||||
]
|
||||
value[:, :, h_start:h_end, : w_end - latents.shape[3]] += latents_view_denoised[
|
||||
:, :, h_start:h_end, latents.shape[3] - w_start :
|
||||
]
|
||||
count[:, :, h_start:h_end, w_start:] += 1
|
||||
count[:, :, h_start:h_end, : w_end - latents.shape[3]] += 1
|
||||
else:
|
||||
value[:, :, h_start:h_end, w_start:w_end] += latents_view_denoised
|
||||
count[:, :, h_start:h_end, w_start:w_end] += 1
|
||||
|
||||
# take the MultiDiffusion step. Eq. 5 in MultiDiffusion paper: https://arxiv.org/abs/2302.08113
|
||||
latents = torch.where(count > 0, value / count, value)
|
||||
@@ -711,7 +762,10 @@ class StableDiffusionPanoramaPipeline(DiffusionPipeline, TextualInversionLoaderM
|
||||
callback(i, t, latents)
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
if circular_padding:
|
||||
image = self.decode_latents_with_padding(latents)
|
||||
else:
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
|
||||
@@ -125,6 +125,22 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_panorama_circular_padding_case(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPanoramaPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = sd_pipe(**inputs, circular_padding=True).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
# override to speed the overall test timing up.
|
||||
def test_inference_batch_consistent(self):
|
||||
super().test_inference_batch_consistent(batch_sizes=[1, 2])
|
||||
@@ -170,6 +186,24 @@ class StableDiffusionPanoramaPipelineFastTests(PipelineLatentTesterMixin, Pipeli
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_panorama_views_batch_circular_padding(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionPanoramaPipeline(**components)
|
||||
sd_pipe = sd_pipe.to(device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
output = sd_pipe(**inputs, circular_padding=True, view_batch_size=2)
|
||||
image = output.images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 64, 64, 3)
|
||||
|
||||
expected_slice = np.array([0.6127, 0.6299, 0.4595, 0.4051, 0.4543, 0.3925, 0.5510, 0.5693, 0.5031])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
def test_stable_diffusion_panorama_euler(self):
|
||||
device = "cpu" # ensure determinism for the device-dependent torch.Generator
|
||||
components = self.get_dummy_components()
|
||||
|
||||
Reference in New Issue
Block a user