mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
283 lines
12 KiB
Markdown
283 lines
12 KiB
Markdown
<!--Copyright 2023 The HuggingFace Team. All rights reserved.
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
|
the License. You may obtain a copy of the License at
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
|
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
|
specific language governing permissions and limitations under the License.
|
|
-->
|
|
|
|
# DiffEdit
|
|
|
|
[[open-in-colab]]
|
|
|
|
์ด๋ฏธ์ง ํธ์ง์ ํ๋ ค๋ฉด ์ผ๋ฐ์ ์ผ๋ก ํธ์งํ ์์ญ์ ๋ง์คํฌ๋ฅผ ์ ๊ณตํด์ผ ํฉ๋๋ค. DiffEdit๋ ํ
์คํธ ์ฟผ๋ฆฌ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ง์คํฌ๋ฅผ ์๋์ผ๋ก ์์ฑํ๋ฏ๋ก ์ด๋ฏธ์ง ํธ์ง ์ํํธ์จ์ด ์์ด๋ ๋ง์คํฌ๋ฅผ ๋ง๋ค๊ธฐ๊ฐ ์ ๋ฐ์ ์ผ๋ก ๋ ์ฌ์์ง๋๋ค. DiffEdit ์๊ณ ๋ฆฌ์ฆ์ ์ธ ๋จ๊ณ๋ก ์๋ํฉ๋๋ค:
|
|
|
|
1. Diffusion ๋ชจ๋ธ์ด ์ผ๋ถ ์ฟผ๋ฆฌ ํ
์คํธ์ ์ฐธ์กฐ ํ
์คํธ๋ฅผ ์กฐ๊ฑด๋ถ๋ก ์ด๋ฏธ์ง์ ๋
ธ์ด์ฆ๋ฅผ ์ ๊ฑฐํ์ฌ ์ด๋ฏธ์ง์ ์ฌ๋ฌ ์์ญ์ ๋ํด ์๋ก ๋ค๋ฅธ ๋
ธ์ด์ฆ ์ถ์ ์น๋ฅผ ์์ฑํ๊ณ , ๊ทธ ์ฐจ์ด๋ฅผ ์ฌ์ฉํ์ฌ ์ฟผ๋ฆฌ ํ
์คํธ์ ์ผ์นํ๋๋ก ์ด๋ฏธ์ง์ ์ด๋ ์์ญ์ ๋ณ๊ฒฝํด์ผ ํ๋์ง ์๋ณํ๊ธฐ ์ํ ๋ง์คํฌ๋ฅผ ์ถ๋ก ํฉ๋๋ค.
|
|
2. ์
๋ ฅ ์ด๋ฏธ์ง๊ฐ DDIM์ ์ฌ์ฉํ์ฌ ์ ์ฌ ๊ณต๊ฐ์ผ๋ก ์ธ์ฝ๋ฉ๋ฉ๋๋ค.
|
|
3. ๋ง์คํฌ ์ธ๋ถ์ ํฝ์
์ด ์
๋ ฅ ์ด๋ฏธ์ง์ ๋์ผํ๊ฒ ์ ์ง๋๋๋ก ๋ง์คํฌ๋ฅผ ๊ฐ์ด๋๋ก ์ฌ์ฉํ์ฌ ํ
์คํธ ์ฟผ๋ฆฌ์ ์กฐ๊ฑด์ด ์ง์ ๋ diffusion ๋ชจ๋ธ๋ก latents๋ฅผ ๋์ฝ๋ฉํฉ๋๋ค.
|
|
|
|
์ด ๊ฐ์ด๋์์๋ ๋ง์คํฌ๋ฅผ ์๋์ผ๋ก ๋ง๋ค์ง ์๊ณ DiffEdit๋ฅผ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ ํธ์งํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช
ํฉ๋๋ค.
|
|
|
|
์์ํ๊ธฐ ์ ์ ๋ค์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ด ์๋์ง ํ์ธํ์ธ์:
|
|
|
|
```py
|
|
# Colab์์ ํ์ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ค์นํ๊ธฐ ์ํด ์ฃผ์์ ์ ์ธํ์ธ์
|
|
#!pip install -q diffusers transformers accelerate
|
|
```
|
|
|
|
[`StableDiffusionDiffEditPipeline`]์๋ ์ด๋ฏธ์ง ๋ง์คํฌ์ ๋ถ๋ถ์ ์ผ๋ก ๋ฐ์ ๋ latents ์งํฉ์ด ํ์ํฉ๋๋ค. ์ด๋ฏธ์ง ๋ง์คํฌ๋ [`~StableDiffusionDiffEditPipeline.generate_mask`] ํจ์์์ ์์ฑ๋๋ฉฐ, ๋ ๊ฐ์ ํ๋ผ๋ฏธํฐ์ธ `source_prompt`์ `target_prompt`๊ฐ ํฌํจ๋ฉ๋๋ค. ์ด ๋งค๊ฐ๋ณ์๋ ์ด๋ฏธ์ง์์ ๋ฌด์์ ํธ์งํ ์ง ๊ฒฐ์ ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, *๊ณผ์ผ* ํ ๊ทธ๋ฆ์ *๋ฐฐ* ํ ๊ทธ๋ฆ์ผ๋ก ๋ณ๊ฒฝํ๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด ํ์ธ์:
|
|
|
|
```py
|
|
source_prompt = "a bowl of fruits"
|
|
target_prompt = "a bowl of pears"
|
|
```
|
|
|
|
๋ถ๋ถ์ ์ผ๋ก ๋ฐ์ ๋ latents๋ [`~StableDiffusionDiffEditPipeline.invert`] ํจ์์์ ์์ฑ๋๋ฉฐ, ์ผ๋ฐ์ ์ผ๋ก ์ด๋ฏธ์ง๋ฅผ ์ค๋ช
ํ๋ `prompt` ๋๋ *์บก์
*์ ํฌํจํ๋ ๊ฒ์ด inverse latent sampling ํ๋ก์ธ์ค๋ฅผ ๊ฐ์ด๋ํ๋ ๋ฐ ๋์์ด ๋ฉ๋๋ค. ์บก์
์ ์ข
์ข
`source_prompt`๊ฐ ๋ ์ ์์ง๋ง, ๋ค๋ฅธ ํ
์คํธ ์ค๋ช
์ผ๋ก ์์ ๋กญ๊ฒ ์คํํด ๋ณด์ธ์!
|
|
|
|
ํ์ดํ๋ผ์ธ, ์ค์ผ์ค๋ฌ, ์ญ ์ค์ผ์ค๋ฌ๋ฅผ ๋ถ๋ฌ์ค๊ณ ๋ฉ๋ชจ๋ฆฌ ์ฌ์ฉ๋์ ์ค์ด๊ธฐ ์ํด ๋ช ๊ฐ์ง ์ต์ ํ๋ฅผ ํ์ฑํํด ๋ณด๊ฒ ์ต๋๋ค:
|
|
|
|
```py
|
|
import torch
|
|
from diffusers import DDIMScheduler, DDIMInverseScheduler, StableDiffusionDiffEditPipeline
|
|
|
|
pipeline = StableDiffusionDiffEditPipeline.from_pretrained(
|
|
"stabilityai/stable-diffusion-2-1",
|
|
torch_dtype=torch.float16,
|
|
safety_checker=None,
|
|
use_safetensors=True,
|
|
)
|
|
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
|
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
|
|
pipeline.enable_model_cpu_offload()
|
|
pipeline.enable_vae_slicing()
|
|
```
|
|
|
|
์์ ํ๊ธฐ ์ํ ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์ต๋๋ค:
|
|
|
|
```py
|
|
from diffusers.utils import load_image, make_image_grid
|
|
|
|
img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
|
|
raw_image = load_image(img_url).resize((768, 768))
|
|
raw_image
|
|
```
|
|
|
|
์ด๋ฏธ์ง ๋ง์คํฌ๋ฅผ ์์ฑํ๊ธฐ ์ํด [`~StableDiffusionDiffEditPipeline.generate_mask`] ํจ์๋ฅผ ์ฌ์ฉํฉ๋๋ค. ์ด๋ฏธ์ง์์ ํธ์งํ ๋ด์ฉ์ ์ง์ ํ๊ธฐ ์ํด `source_prompt`์ `target_prompt`๋ฅผ ์ ๋ฌํด์ผ ํฉ๋๋ค:
|
|
|
|
```py
|
|
from PIL import Image
|
|
|
|
source_prompt = "a bowl of fruits"
|
|
target_prompt = "a basket of pears"
|
|
mask_image = pipeline.generate_mask(
|
|
image=raw_image,
|
|
source_prompt=source_prompt,
|
|
target_prompt=target_prompt,
|
|
)
|
|
Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768))
|
|
```
|
|
|
|
๋ค์์ผ๋ก, ๋ฐ์ ๋ latents๋ฅผ ์์ฑํ๊ณ ์ด๋ฏธ์ง๋ฅผ ๋ฌ์ฌํ๋ ์บก์
์ ์ ๋ฌํฉ๋๋ค:
|
|
|
|
```py
|
|
inv_latents = pipeline.invert(prompt=source_prompt, image=raw_image).latents
|
|
```
|
|
|
|
๋ง์ง๋ง์ผ๋ก, ์ด๋ฏธ์ง ๋ง์คํฌ์ ๋ฐ์ ๋ latents๋ฅผ ํ์ดํ๋ผ์ธ์ ์ ๋ฌํฉ๋๋ค. `target_prompt`๋ ์ด์ `prompt`๊ฐ ๋๋ฉฐ, `source_prompt`๋ `negative_prompt`๋ก ์ฌ์ฉ๋ฉ๋๋ค.
|
|
|
|
```py
|
|
output_image = pipeline(
|
|
prompt=target_prompt,
|
|
mask_image=mask_image,
|
|
image_latents=inv_latents,
|
|
negative_prompt=source_prompt,
|
|
).images[0]
|
|
mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L").resize((768, 768))
|
|
make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)
|
|
```
|
|
|
|
<div class="flex gap-4">
|
|
<div>
|
|
<img class="rounded-xl" src="https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"/>
|
|
<figcaption class="mt-2 text-center text-sm text-gray-500">original image</figcaption>
|
|
</div>
|
|
<div>
|
|
<img class="rounded-xl" src="https://github.com/Xiang-cd/DiffEdit-stable-diffusion/blob/main/assets/target.png?raw=true"/>
|
|
<figcaption class="mt-2 text-center text-sm text-gray-500">edited image</figcaption>
|
|
</div>
|
|
</div>
|
|
|
|
## Source์ target ์๋ฒ ๋ฉ ์์ฑํ๊ธฐ
|
|
|
|
Source์ target ์๋ฒ ๋ฉ์ ์๋์ผ๋ก ์์ฑํ๋ ๋์ [Flan-T5](https://huggingface.co/docs/transformers/model_doc/flan-t5) ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์๋์ผ๋ก ์์ฑํ ์ ์์ต๋๋ค.
|
|
|
|
Flan-T5 ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ๐ค Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ๋ถ๋ฌ์ต๋๋ค:
|
|
|
|
```py
|
|
import torch
|
|
from transformers import AutoTokenizer, T5ForConditionalGeneration
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-large")
|
|
model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto", torch_dtype=torch.float16)
|
|
```
|
|
|
|
๋ชจ๋ธ์ ํ๋กฌํํธํ source์ target ํ๋กฌํํธ๋ฅผ ์์ฑํ๊ธฐ ์ํด ์ด๊ธฐ ํ
์คํธ๋ค์ ์ ๊ณตํฉ๋๋ค.
|
|
|
|
```py
|
|
source_concept = "bowl"
|
|
target_concept = "basket"
|
|
|
|
source_text = f"Provide a caption for images containing a {source_concept}. "
|
|
"The captions should be in English and should be no longer than 150 characters."
|
|
|
|
target_text = f"Provide a caption for images containing a {target_concept}. "
|
|
"The captions should be in English and should be no longer than 150 characters."
|
|
```
|
|
|
|
๋ค์์ผ๋ก, ํ๋กฌํํธ๋ค์ ์์ฑํ๊ธฐ ์ํด ์ ํธ๋ฆฌํฐ ํจ์๋ฅผ ์์ฑํฉ๋๋ค.
|
|
|
|
```py
|
|
@torch.no_grad()
|
|
def generate_prompts(input_prompt):
|
|
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.to("cuda")
|
|
|
|
outputs = model.generate(
|
|
input_ids, temperature=0.8, num_return_sequences=16, do_sample=True, max_new_tokens=128, top_k=10
|
|
)
|
|
return tokenizer.batch_decode(outputs, skip_special_tokens=True)
|
|
|
|
source_prompts = generate_prompts(source_text)
|
|
target_prompts = generate_prompts(target_text)
|
|
print(source_prompts)
|
|
print(target_prompts)
|
|
```
|
|
|
|
> [!TIP]
|
|
> ๋ค์ํ ํ์ง์ ํ
์คํธ๋ฅผ ์์ฑํ๋ ์ ๋ต์ ๋ํด ์์ธํ ์์๋ณด๋ ค๋ฉด [์์ฑ ์ ๋ต](https://huggingface.co/docs/transformers/main/en/generation_strategies) ๊ฐ์ด๋๋ฅผ ์ฐธ์กฐํ์ธ์.
|
|
|
|
ํ
์คํธ ์ธ์ฝ๋ฉ์ ์ํด [`StableDiffusionDiffEditPipeline`]์์ ์ฌ์ฉํ๋ ํ
์คํธ ์ธ์ฝ๋ ๋ชจ๋ธ์ ๋ถ๋ฌ์ต๋๋ค. ํ
์คํธ ์ธ์ฝ๋๋ฅผ ์ฌ์ฉํ์ฌ ํ
์คํธ ์๋ฒ ๋ฉ์ ๊ณ์ฐํฉ๋๋ค:
|
|
|
|
```py
|
|
import torch
|
|
from diffusers import StableDiffusionDiffEditPipeline
|
|
|
|
pipeline = StableDiffusionDiffEditPipeline.from_pretrained(
|
|
"stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16, use_safetensors=True
|
|
)
|
|
pipeline.enable_model_cpu_offload()
|
|
pipeline.enable_vae_slicing()
|
|
|
|
@torch.no_grad()
|
|
def embed_prompts(sentences, tokenizer, text_encoder, device="cuda"):
|
|
embeddings = []
|
|
for sent in sentences:
|
|
text_inputs = tokenizer(
|
|
sent,
|
|
padding="max_length",
|
|
max_length=tokenizer.model_max_length,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
)
|
|
text_input_ids = text_inputs.input_ids
|
|
prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
|
|
embeddings.append(prompt_embeds)
|
|
return torch.concatenate(embeddings, dim=0).mean(dim=0).unsqueeze(0)
|
|
|
|
source_embeds = embed_prompts(source_prompts, pipeline.tokenizer, pipeline.text_encoder)
|
|
target_embeds = embed_prompts(target_prompts, pipeline.tokenizer, pipeline.text_encoder)
|
|
```
|
|
|
|
๋ง์ง๋ง์ผ๋ก, ์๋ฒ ๋ฉ์ [`~StableDiffusionDiffEditPipeline.generate_mask`] ๋ฐ [`~StableDiffusionDiffEditPipeline.invert`] ํจ์์ ํ์ดํ๋ผ์ธ์ ์ ๋ฌํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์์ฑํฉ๋๋ค:
|
|
|
|
```diff
|
|
from diffusers import DDIMInverseScheduler, DDIMScheduler
|
|
from diffusers.utils import load_image, make_image_grid
|
|
from PIL import Image
|
|
|
|
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
|
|
pipeline.inverse_scheduler = DDIMInverseScheduler.from_config(pipeline.scheduler.config)
|
|
|
|
img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
|
|
raw_image = load_image(img_url).resize((768, 768))
|
|
|
|
mask_image = pipeline.generate_mask(
|
|
image=raw_image,
|
|
- source_prompt=source_prompt,
|
|
- target_prompt=target_prompt,
|
|
+ source_prompt_embeds=source_embeds,
|
|
+ target_prompt_embeds=target_embeds,
|
|
)
|
|
|
|
inv_latents = pipeline.invert(
|
|
- prompt=source_prompt,
|
|
+ prompt_embeds=source_embeds,
|
|
image=raw_image,
|
|
).latents
|
|
|
|
output_image = pipeline(
|
|
mask_image=mask_image,
|
|
image_latents=inv_latents,
|
|
- prompt=target_prompt,
|
|
- negative_prompt=source_prompt,
|
|
+ prompt_embeds=target_embeds,
|
|
+ negative_prompt_embeds=source_embeds,
|
|
).images[0]
|
|
mask_image = Image.fromarray((mask_image.squeeze()*255).astype("uint8"), "L")
|
|
make_image_grid([raw_image, mask_image, output_image], rows=1, cols=3)
|
|
```
|
|
|
|
## ๋ฐ์ ์ ์ํ ์บก์
์์ฑํ๊ธฐ
|
|
|
|
`source_prompt`๋ฅผ ์บก์
์ผ๋ก ์ฌ์ฉํ์ฌ ๋ถ๋ถ์ ์ผ๋ก ๋ฐ์ ๋ latents๋ฅผ ์์ฑํ ์ ์์ง๋ง, [BLIP](https://huggingface.co/docs/transformers/model_doc/blip) ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์บก์
์ ์๋์ผ๋ก ์์ฑํ ์๋ ์์ต๋๋ค.
|
|
|
|
๐ค Transformers ๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ BLIP ๋ชจ๋ธ๊ณผ ํ๋ก์ธ์๋ฅผ ๋ถ๋ฌ์ต๋๋ค:
|
|
|
|
```py
|
|
import torch
|
|
from transformers import BlipForConditionalGeneration, BlipProcessor
|
|
|
|
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
|
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base", torch_dtype=torch.float16, low_cpu_mem_usage=True)
|
|
```
|
|
|
|
์
๋ ฅ ์ด๋ฏธ์ง์์ ์บก์
์ ์์ฑํ๋ ์ ํธ๋ฆฌํฐ ํจ์๋ฅผ ๋ง๋ญ๋๋ค:
|
|
|
|
```py
|
|
@torch.no_grad()
|
|
def generate_caption(images, caption_generator, caption_processor):
|
|
text = "a photograph of"
|
|
|
|
inputs = caption_processor(images, text, return_tensors="pt").to(device="cuda", dtype=caption_generator.dtype)
|
|
caption_generator.to("cuda")
|
|
outputs = caption_generator.generate(**inputs, max_new_tokens=128)
|
|
|
|
# ์บก์
generator ์คํ๋ก๋
|
|
caption_generator.to("cpu")
|
|
|
|
caption = caption_processor.batch_decode(outputs, skip_special_tokens=True)[0]
|
|
return caption
|
|
```
|
|
|
|
์
๋ ฅ ์ด๋ฏธ์ง๋ฅผ ๋ถ๋ฌ์ค๊ณ `generate_caption` ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ํด๋น ์ด๋ฏธ์ง์ ๋ํ ์บก์
์ ์์ฑํฉ๋๋ค:
|
|
|
|
```py
|
|
from diffusers.utils import load_image
|
|
|
|
img_url = "https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"
|
|
raw_image = load_image(img_url).resize((768, 768))
|
|
caption = generate_caption(raw_image, model, processor)
|
|
```
|
|
|
|
<div class="flex justify-center">
|
|
<figure>
|
|
<img class="rounded-xl" src="https://github.com/Xiang-cd/DiffEdit-stable-diffusion/raw/main/assets/origin.png"/>
|
|
<figcaption class="text-center">generated caption: "a photograph of a bowl of fruit on a table"</figcaption>
|
|
</figure>
|
|
</div>
|
|
|
|
์ด์ ์บก์
์ [`~StableDiffusionDiffEditPipeline.invert`] ํจ์์ ๋์ ๋ถ๋ถ์ ์ผ๋ก ๋ฐ์ ๋ latents๋ฅผ ์์ฑํ ์ ์์ต๋๋ค!
|