diff --git a/examples/community/rerender_a_video.py b/examples/community/rerender_a_video.py index cae5fcb2b9..c421acf354 100644 --- a/examples/community/rerender_a_video.py +++ b/examples/community/rerender_a_video.py @@ -782,7 +782,7 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline): self.attn_state.reset() # 4.1 prepare frames - image = self.image_processor.preprocess(frames[0]).to(dtype=torch.float32) + image = self.image_processor.preprocess(frames[0]).to(dtype=self.dtype) first_image = image[0] # C, H, W # 4.2 Prepare controlnet_conditioning_image @@ -926,8 +926,8 @@ class RerenderAVideoPipeline(StableDiffusionControlNetImg2ImgPipeline): prev_image = frames[idx - 1] control_image = control_frames[idx] # 5.1 prepare frames - image = self.image_processor.preprocess(image).to(dtype=torch.float32) - prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32) + image = self.image_processor.preprocess(image).to(dtype=self.dtype) + prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype) warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask( self.flow_model, first_image, image[0], first_result, False, self.device