import torch import numpy as np from PIL import Image from modules import shared, devices, processing, images, sd_vae, sd_samplers, processing_helpers, prompt_parser, token_merge from modules.sd_hijack_hypertile import hypertile_set create_binary_mask = processing_helpers.create_binary_mask apply_overlay = processing_helpers.apply_overlay apply_color_correction = processing_helpers.apply_color_correction setup_color_correction = processing_helpers.setup_color_correction images_tensor_to_samples = processing_helpers.images_tensor_to_samples txt2img_image_conditioning = processing_helpers.txt2img_image_conditioning img2img_image_conditioning = processing_helpers.img2img_image_conditioning get_fixed_seed = processing_helpers.get_fixed_seed create_random_tensors = processing_helpers.create_random_tensors decode_first_stage = processing_helpers.decode_first_stage old_hires_fix_first_pass_dimensions = processing_helpers.old_hires_fix_first_pass_dimensions validate_sample = processing_helpers.validate_sample def get_conds_with_caching(function, required_prompts, steps, cache): if cache[0] is not None and (required_prompts, steps) == cache[0]: return cache[1] with devices.autocast(): cache[1] = function(shared.sd_model, required_prompts, steps) cache[0] = (required_prompts, steps) return cache[1] def check_rollback_vae(): if shared.cmd_opts.rollback_vae: if not torch.cuda.is_available(): shared.log.error("Rollback VAE functionality requires compatible GPU") shared.cmd_opts.rollback_vae = False elif torch.__version__.startswith('1.') or torch.__version__.startswith('2.0'): shared.log.error("Rollback VAE functionality requires Torch 2.1 or higher") shared.cmd_opts.rollback_vae = False elif 0 < torch.cuda.get_device_capability()[0] < 8: shared.log.error('Rollback VAE functionality device capabilities not met') shared.cmd_opts.rollback_vae = False def process_original(p: processing.StableDiffusionProcessing): cached_uc = [None, None] cached_c = [None, None] sampler_config = sd_samplers.find_sampler_config(p.sampler_name) step_multiplier = 2 if sampler_config and sampler_config.options.get("second_order", False) else 1 uc = get_conds_with_caching(prompt_parser.get_learned_conditioning, p.negative_prompts, p.steps * step_multiplier, cached_uc) c = get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, p.prompts, p.steps * step_multiplier, cached_c) with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts) x_samples_ddim = [processing.decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] try: for x in x_samples_ddim: devices.test_for_nans(x, "vae") except devices.NansException as e: check_rollback_vae() if not shared.opts.no_half and not shared.opts.no_half_vae and shared.cmd_opts.rollback_vae: shared.log.warning('Tensor with all NaNs was produced in VAE') devices.dtype_vae = torch.bfloat16 vae_file, vae_source = sd_vae.resolve_vae(p.sd_model.sd_model_checkpoint) sd_vae.load_vae(p.sd_model, vae_file, vae_source) x_samples_ddim = [processing.decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))] for x in x_samples_ddim: devices.test_for_nans(x, "vae") else: raise e x_samples_ddim = torch.stack(x_samples_ddim).float() x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) del samples_ddim return x_samples_ddim def sample_txt2img(p: processing.StableDiffusionProcessingTxt2Img, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): p.ops.append('txt2img') hypertile_set(p) p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model) if hasattr(p.sampler, "initialize"): p.sampler.initialize(p) x = create_random_tensors([4, p.height // 8, p.width // 8], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) samples = p.sampler.sample(p, x, conditioning, unconditional_conditioning, image_conditioning=txt2img_image_conditioning(p, x)) shared.state.nextjob() if not p.enable_hr or shared.state.interrupted or shared.state.skipped: return samples p.init_hr() if p.is_hr_pass: prev_job = shared.state.job target_width = p.hr_upscale_to_x target_height = p.hr_upscale_to_y decoded_samples = None if shared.opts.samples_save and shared.opts.save_images_before_highres_fix and not p.do_not_save_samples: decoded_samples = decode_first_stage(p.sd_model, samples.to(dtype=devices.dtype_vae)) decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) for i, x_sample in enumerate(decoded_samples): x_sample = validate_sample(x_sample) image = Image.fromarray(x_sample) orig_extra_generation_params, orig_detailer = p.extra_generation_params, p.detailer_denabled p.extra_generation_params = {} p.detailer_denabled = False info = processing.create_infotext(p, p.all_prompts, p.all_seeds, p.all_subseeds, [], iteration=p.iteration, position_in_batch=i) p.extra_generation_params, p.detailer_enabled = orig_extra_generation_params, orig_detailer images.save_image(image, p.outpath_samples, "", seeds[i], prompts[i], shared.opts.samples_format, info=info, suffix="-before-hires") if p.hr_upscaler.lower().startswith('latent'): # non-latent upscaling p.hr_force = True shared.state.job = 'Upscale' samples = images.resize_image(1, samples, target_width, target_height, upscaler_name=p.hr_upscaler) if getattr(p, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0: image_conditioning = img2img_image_conditioning(p, decode_first_stage(p.sd_model, samples.to(dtype=devices.dtype_vae)), samples) else: image_conditioning = txt2img_image_conditioning(p, samples.to(dtype=devices.dtype_vae)) else: shared.state.job = 'Upscale' if decoded_samples is None: decoded_samples = decode_first_stage(p.sd_model, samples.to(dtype=devices.dtype_vae)) decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0) batch_images = [] for _i, x_sample in enumerate(decoded_samples): x_sample = validate_sample(x_sample) image = Image.fromarray(x_sample) image = images.resize_image(1, image, target_width, target_height, upscaler_name=p.hr_upscaler) image = np.array(image).astype(np.float32) / 255.0 image = np.moveaxis(image, 2, 0) batch_images.append(image) resized_samples = torch.from_numpy(np.array(batch_images)) resized_samples = resized_samples.to(device=shared.device, dtype=devices.dtype_vae) resized_samples = 2.0 * resized_samples - 1.0 if shared.opts.sd_vae_sliced_encode and len(decoded_samples) > 1: samples = torch.stack([p.sd_model.get_first_stage_encoding(p.sd_model.encode_first_stage(torch.unsqueeze(resized_sample, 0)))[0] for resized_sample in resized_samples]) else: samples = p.sd_model.get_first_stage_encoding(p.sd_model.encode_first_stage(resized_samples)) image_conditioning = img2img_image_conditioning(p, resized_samples, samples) if p.hr_force: shared.state.job = 'HiRes' if p.denoising_strength > 0: p.ops.append('hires') devices.torch_gc() # GC now before running the next img2img to prevent running out of memory p.sampler = sd_samplers.create_sampler(p.hr_sampler_name or p.sampler_name, p.sd_model) if hasattr(p.sampler, "initialize"): p.sampler.initialize(p) samples = samples[:, :, p.truncate_y//2:samples.shape[2]-(p.truncate_y+1)//2, p.truncate_x//2:samples.shape[3]-(p.truncate_x+1)//2] noise = create_random_tensors(samples.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=subseed_strength, p=p) token_merge.apply_token_merging(p.sd_model) hypertile_set(p, hr=True) samples = p.sampler.sample_img2img(p, samples, noise, conditioning, unconditional_conditioning, steps=p.hr_second_pass_steps or p.steps, image_conditioning=image_conditioning) token_merge.apply_token_merging(p.sd_model) else: p.ops.append('upscale') x = None p.is_hr_pass = False shared.state.job = prev_job shared.state.nextjob() return samples def sample_img2img(p, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts): # pylint: disable=unused-argument hypertile_set(p) x = create_random_tensors([4, p.height // 8, p.width // 8], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p) x *= p.initial_noise_multiplier samples = p.sampler.sample_img2img(p, p.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning) if p.mask is not None: samples = samples * p.nmask + p.init_latent * p.mask del x devices.torch_gc() shared.state.nextjob() return samples