mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
@@ -119,11 +119,11 @@ def forward_backward_consistency_check(fwd_flow, bwd_flow, alpha=0.01, beta=0.5)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False):
|
||||
def get_warped_and_mask(flow_model, image1, image2, image3=None, pixel_consistency=False, device=None):
|
||||
if image3 is None:
|
||||
image3 = image1
|
||||
padder = InputPadder(image1.shape, padding_factor=8)
|
||||
image1, image2 = padder.pad(image1[None].cuda(), image2[None].cuda())
|
||||
image1, image2 = padder.pad(image1[None].to(device), image2[None].to(device))
|
||||
results_dict = flow_model(
|
||||
image1, image2, attn_splits_list=[2], corr_radius_list=[-1], prop_radius_list=[-1], pred_bidir_flow=True
|
||||
)
|
||||
@@ -307,6 +307,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
feature_extractor: CLIPImageProcessor,
|
||||
image_encoder=None,
|
||||
requires_safety_checker: bool = True,
|
||||
device=None,
|
||||
):
|
||||
super().__init__(
|
||||
vae,
|
||||
@@ -320,6 +321,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
image_encoder,
|
||||
requires_safety_checker,
|
||||
)
|
||||
self.to(device)
|
||||
|
||||
if safety_checker is None and requires_safety_checker:
|
||||
logger.warning(
|
||||
@@ -374,7 +376,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
attention_type="swin",
|
||||
ffn_dim_expansion=4,
|
||||
num_transformer_layers=6,
|
||||
).to("cuda")
|
||||
).to(self.device)
|
||||
|
||||
checkpoint = torch.utils.model_zoo.load_url(
|
||||
"https://huggingface.co/Anonymous-sub/Rerender/resolve/main/models/gmflow_sintel-0c07dcb3.pth",
|
||||
@@ -928,13 +930,13 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline):
|
||||
prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
|
||||
|
||||
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
|
||||
self.flow_model, first_image, image[0], first_result, False
|
||||
self.flow_model, first_image, image[0], first_result, False, self.device
|
||||
)
|
||||
blend_mask_0 = blur(F.max_pool2d(bwd_occ_0, kernel_size=9, stride=1, padding=4))
|
||||
blend_mask_0 = torch.clamp(blend_mask_0 + bwd_occ_0, 0, 1)
|
||||
|
||||
warped_pre, bwd_occ_pre, bwd_flow_pre = get_warped_and_mask(
|
||||
self.flow_model, prev_image[0], image[0], prev_result, False
|
||||
self.flow_model, prev_image[0], image[0], prev_result, False, self.device
|
||||
)
|
||||
blend_mask_pre = blur(F.max_pool2d(bwd_occ_pre, kernel_size=9, stride=1, padding=4))
|
||||
blend_mask_pre = torch.clamp(blend_mask_pre + bwd_occ_pre, 0, 1)
|
||||
|
||||
Reference in New Issue
Block a user