mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Fix Chroma attention padding order and update docs to use lodestones/Chroma1-HD (#12508)
* [Fix] Move attention mask padding after T5 embedding * [Fix] Move attention mask padding after T5 embedding * Clean up whitespace in pipeline_chroma.py Removed unnecessary blank lines for cleaner code. * Fix * Fix * Update model to final Chroma1-HD checkpoint * Update to Chroma1-HD * Update model to Chroma1-HD * Update model to Chroma1-HD * Update Chroma model links to Chroma1-HD * Add comment about padding/masking * Fix checkpoint/repo references * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
This commit is contained in:
@@ -12,7 +12,7 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
# ChromaTransformer2DModel
|
||||
|
||||
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma)
|
||||
A modified flux Transformer model from [Chroma](https://huggingface.co/lodestones/Chroma1-HD)
|
||||
|
||||
## ChromaTransformer2DModel
|
||||
|
||||
|
||||
@@ -19,20 +19,21 @@ specific language governing permissions and limitations under the License.
|
||||
|
||||
Chroma is a text to image generation model based on Flux.
|
||||
|
||||
Original model checkpoints for Chroma can be found [here](https://huggingface.co/lodestones/Chroma).
|
||||
Original model checkpoints for Chroma can be found here:
|
||||
* High-resolution finetune: [lodestones/Chroma1-HD](https://huggingface.co/lodestones/Chroma1-HD)
|
||||
* Base model: [lodestones/Chroma1-Base](https://huggingface.co/lodestones/Chroma1-Base)
|
||||
* Original repo with progress checkpoints: [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) (loading this repo with `from_pretrained` will load a Diffusers-compatible version of the `unlocked-v37` checkpoint)
|
||||
|
||||
> [!TIP]
|
||||
> Chroma can use all the same optimizations as Flux.
|
||||
|
||||
## Inference
|
||||
|
||||
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import ChromaPipeline
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma", torch_dtype=torch.bfloat16)
|
||||
pipe = ChromaPipeline.from_pretrained("lodestones/Chroma1-HD", torch_dtype=torch.bfloat16)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = [
|
||||
@@ -63,10 +64,10 @@ Then run the following example
|
||||
import torch
|
||||
from diffusers import ChromaTransformer2DModel, ChromaPipeline
|
||||
|
||||
model_id = "lodestones/Chroma"
|
||||
model_id = "lodestones/Chroma1-HD"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors", torch_dtype=dtype)
|
||||
transformer = ChromaTransformer2DModel.from_single_file("https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors", torch_dtype=dtype)
|
||||
|
||||
pipe = ChromaPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=dtype)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
@@ -379,7 +379,7 @@ class ChromaTransformer2DModel(
|
||||
"""
|
||||
The Transformer model introduced in Flux, modified for Chroma.
|
||||
|
||||
Reference: https://huggingface.co/lodestones/Chroma
|
||||
Reference: https://huggingface.co/lodestones/Chroma1-HD
|
||||
|
||||
Args:
|
||||
patch_size (`int`, defaults to `1`):
|
||||
|
||||
@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaPipeline
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> model_id = "lodestones/Chroma1-HD"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
|
||||
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
|
||||
>>> pipe = ChromaPipeline.from_pretrained(
|
||||
... model_id,
|
||||
@@ -158,7 +158,7 @@ class ChromaPipeline(
|
||||
r"""
|
||||
The Chroma pipeline for text-to-image generation.
|
||||
|
||||
Reference: https://huggingface.co/lodestones/Chroma/
|
||||
Reference: https://huggingface.co/lodestones/Chroma1-HD/
|
||||
|
||||
Args:
|
||||
transformer ([`ChromaTransformer2DModel`]):
|
||||
@@ -233,20 +233,23 @@ class ChromaPipeline(
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask.clone()
|
||||
tokenizer_mask = text_inputs.attention_mask
|
||||
|
||||
# Chroma requires the attention mask to include one padding token
|
||||
seq_lengths = attention_mask.sum(dim=1)
|
||||
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).bool()
|
||||
tokenizer_mask_device = tokenizer_mask.to(device)
|
||||
|
||||
# unlike FLUX, Chroma uses the attention mask when generating the T5 embedding
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=False,
|
||||
attention_mask=tokenizer_mask_device,
|
||||
)[0]
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.to(device=device)
|
||||
|
||||
# for the text tokens, chroma requires that all except the first padding token are masked out during the forward pass through the transformer
|
||||
seq_lengths = tokenizer_mask_device.sum(dim=1)
|
||||
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
|
||||
@@ -53,8 +53,8 @@ EXAMPLE_DOC_STRING = """
|
||||
>>> import torch
|
||||
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
|
||||
|
||||
>>> model_id = "lodestones/Chroma"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
|
||||
>>> model_id = "lodestones/Chroma1-HD"
|
||||
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma1-HD/blob/main/Chroma1-HD.safetensors"
|
||||
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
|
||||
... model_id,
|
||||
... transformer=transformer,
|
||||
@@ -170,7 +170,7 @@ class ChromaImg2ImgPipeline(
|
||||
r"""
|
||||
The Chroma pipeline for image-to-image generation.
|
||||
|
||||
Reference: https://huggingface.co/lodestones/Chroma/
|
||||
Reference: https://huggingface.co/lodestones/Chroma1-HD/
|
||||
|
||||
Args:
|
||||
transformer ([`ChromaTransformer2DModel`]):
|
||||
@@ -247,20 +247,21 @@ class ChromaImg2ImgPipeline(
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
attention_mask = text_inputs.attention_mask.clone()
|
||||
tokenizer_mask = text_inputs.attention_mask
|
||||
|
||||
# Chroma requires the attention mask to include one padding token
|
||||
seq_lengths = attention_mask.sum(dim=1)
|
||||
mask_indices = torch.arange(attention_mask.size(1)).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).long()
|
||||
tokenizer_mask_device = tokenizer_mask.to(device)
|
||||
|
||||
prompt_embeds = self.text_encoder(
|
||||
text_input_ids.to(device), output_hidden_states=False, attention_mask=attention_mask.to(device)
|
||||
text_input_ids.to(device),
|
||||
output_hidden_states=False,
|
||||
attention_mask=tokenizer_mask_device,
|
||||
)[0]
|
||||
|
||||
dtype = self.text_encoder.dtype
|
||||
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.to(dtype=dtype, device=device)
|
||||
|
||||
seq_lengths = tokenizer_mask_device.sum(dim=1)
|
||||
mask_indices = torch.arange(tokenizer_mask_device.size(1), device=device).unsqueeze(0).expand(batch_size, -1)
|
||||
attention_mask = (mask_indices <= seq_lengths.unsqueeze(1)).to(dtype=dtype, device=device)
|
||||
|
||||
_, seq_len, _ = prompt_embeds.shape
|
||||
|
||||
|
||||
Reference in New Issue
Block a user