mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Community Pipeline] MagicMix (#1839)
* initial * type hints * update scheduler type hint * add to README * add example generation to README * v -> mix_factor * load scheduler from pretrained
This commit is contained in:
@@ -25,6 +25,7 @@ If a community doesn't work as expected, please open an issue and ping the autho
|
||||
| K-Diffusion Stable Diffusion | Run Stable Diffusion with any of [K-Diffusion's samplers](https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py) | [Stable Diffusion with K Diffusion](#stable-diffusion-with-k-diffusion) | - | [Patrick von Platen](https://github.com/patrickvonplaten/) |
|
||||
| Checkpoint Merger Pipeline | Diffusion Pipeline that enables merging of saved model checkpoints | [Checkpoint Merger Pipeline](#checkpoint-merger-pipeline) | - | [Naga Sai Abhinay Devarinti](https://github.com/Abhinay1997/) |
|
||||
Stable Diffusion v1.1-1.4 Comparison | Run all 4 model checkpoints for Stable Diffusion and compare their results together | [Stable Diffusion Comparison](#stable-diffusion-comparisons) | - | [Suvaditya Mukherjee](https://github.com/suvadityamuk) |
|
||||
MagicMix | Diffusion Pipeline for semantic mixing of an image and a text prompt | [MagicMix](#magic-mix) | - | [Partho Das](https://github.com/daspartho) |
|
||||
|
||||
|
||||
|
||||
@@ -815,6 +816,50 @@ plt.title('Stable Diffusion v1.4')
|
||||
plt.axis('off')
|
||||
|
||||
plt.show()
|
||||
```python
|
||||
```
|
||||
|
||||
As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.
|
||||
As a result, you can look at a grid of all 4 generated images being shown together, that captures a difference the advancement of the training between the 4 checkpoints.
|
||||
|
||||
### Magic Mix
|
||||
|
||||
Implementation of the [MagicMix: Semantic Mixing with Diffusion Models](https://arxiv.org/abs/2210.16056) paper. This is a Diffusion Pipeline for semantic mixing of an image and a text prompt to create a new concept while preserving the spatial layout and geometry of the subject in the image. The pipeline takes an image that provides the layout semantics and a prompt that provides the content semantics for the mixing process.
|
||||
|
||||
There are 3 parameters for the method-
|
||||
- `mix_factor`: It is the interpolation constant used in the layout generation phase. The greater the value of `mix_factor`, the greater the influence of the prompt on the layout generation process.
|
||||
- `kmax` and `kmin`: These determine the range for the layout and content generation process. A higher value of kmax results in loss of more information about the layout of the original image and a higher value of kmin results in more steps for content generation process.
|
||||
|
||||
Here is an example usage-
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, DDIMScheduler
|
||||
from PIL import Image
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4",
|
||||
custom_pipeline="magic_mix",
|
||||
scheduler = DDIMScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler"),
|
||||
).to('cuda')
|
||||
|
||||
img = Image.open('phone.jpg')
|
||||
mix_img = pipe(
|
||||
img,
|
||||
prompt = 'bed',
|
||||
kmin = 0.3,
|
||||
kmax = 0.5,
|
||||
mix_factor = 0.5,
|
||||
)
|
||||
mix_img.save('phone_bed_mix.jpg')
|
||||
```
|
||||
The `mix_img` is a PIL image that can be saved locally or displayed directly in a google colab. Generated image is a mix of the layout semantics of the given image and the content semantics of the prompt.
|
||||
|
||||
E.g. the above script generates the following image:
|
||||
|
||||
`phone.jpg`
|
||||
|
||||

|
||||
|
||||
`phone_bed_mix.jpg`
|
||||
|
||||

|
||||
|
||||
For more example generations check out this [demo notebook](https://github.com/daspartho/MagicMix/blob/main/demo.ipynb).
|
||||
|
||||
152
examples/community/magic_mix.py
Normal file
152
examples/community/magic_mix.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
DiffusionPipeline,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from PIL import Image
|
||||
from torchvision import transforms as tfms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
|
||||
class MagicMixPipeline(DiffusionPipeline):
|
||||
def __init__(
|
||||
self,
|
||||
vae: AutoencoderKL,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[PNDMScheduler, LMSDiscreteScheduler, DDIMScheduler],
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
|
||||
|
||||
# convert PIL image to latents
|
||||
def encode(self, img):
|
||||
with torch.no_grad():
|
||||
latent = self.vae.encode(tfms.ToTensor()(img).unsqueeze(0).to(self.device) * 2 - 1)
|
||||
latent = 0.18215 * latent.latent_dist.sample()
|
||||
return latent
|
||||
|
||||
# convert latents to PIL image
|
||||
def decode(self, latent):
|
||||
latent = (1 / 0.18215) * latent
|
||||
with torch.no_grad():
|
||||
img = self.vae.decode(latent).sample
|
||||
img = (img / 2 + 0.5).clamp(0, 1)
|
||||
img = img.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
img = (img * 255).round().astype("uint8")
|
||||
return Image.fromarray(img[0])
|
||||
|
||||
# convert prompt into text embeddings, also unconditional embeddings
|
||||
def prep_text(self, prompt):
|
||||
text_input = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
text_embedding = self.text_encoder(text_input.input_ids.to(self.device))[0]
|
||||
|
||||
uncond_input = self.tokenizer(
|
||||
"",
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
uncond_embedding = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
|
||||
|
||||
return torch.cat([uncond_embedding, text_embedding])
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
img: Image.Image,
|
||||
prompt: str,
|
||||
kmin: float = 0.3,
|
||||
kmax: float = 0.6,
|
||||
mix_factor: float = 0.5,
|
||||
seed: int = 42,
|
||||
steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
) -> Image.Image:
|
||||
tmin = steps - int(kmin * steps)
|
||||
tmax = steps - int(kmax * steps)
|
||||
|
||||
text_embeddings = self.prep_text(prompt)
|
||||
|
||||
self.scheduler.set_timesteps(steps)
|
||||
|
||||
width, height = img.size
|
||||
encoded = self.encode(img)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
noise = torch.randn(
|
||||
(1, self.unet.in_channels, height // 8, width // 8),
|
||||
).to(self.device)
|
||||
|
||||
latents = self.scheduler.add_noise(
|
||||
encoded,
|
||||
noise,
|
||||
timesteps=self.scheduler.timesteps[tmax],
|
||||
)
|
||||
|
||||
input = torch.cat([latents] * 2)
|
||||
|
||||
input = self.scheduler.scale_model_input(input, self.scheduler.timesteps[tmax])
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.unet(
|
||||
input,
|
||||
self.scheduler.timesteps[tmax],
|
||||
encoder_hidden_states=text_embeddings,
|
||||
).sample
|
||||
|
||||
pred_uncond, pred_text = pred.chunk(2)
|
||||
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(pred, self.scheduler.timesteps[tmax], latents).prev_sample
|
||||
|
||||
for i, t in enumerate(tqdm(self.scheduler.timesteps)):
|
||||
if i > tmax:
|
||||
if i < tmin: # layout generation phase
|
||||
orig_latents = self.scheduler.add_noise(
|
||||
encoded,
|
||||
noise,
|
||||
timesteps=t,
|
||||
)
|
||||
|
||||
input = (mix_factor * latents) + (
|
||||
1 - mix_factor
|
||||
) * orig_latents # interpolating between layout noise and conditionally generated noise to preserve layout sematics
|
||||
input = torch.cat([input] * 2)
|
||||
|
||||
else: # content generation phase
|
||||
input = torch.cat([latents] * 2)
|
||||
|
||||
input = self.scheduler.scale_model_input(input, t)
|
||||
|
||||
with torch.no_grad():
|
||||
pred = self.unet(
|
||||
input,
|
||||
t,
|
||||
encoder_hidden_states=text_embeddings,
|
||||
).sample
|
||||
|
||||
pred_uncond, pred_text = pred.chunk(2)
|
||||
pred = pred_uncond + guidance_scale * (pred_text - pred_uncond)
|
||||
|
||||
latents = self.scheduler.step(pred, t, latents).prev_sample
|
||||
|
||||
return self.decode(latents)
|
||||
Reference in New Issue
Block a user