1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-06-19 19:41:32 +05:30
parent 907ecf72b1
commit 802651e205
3 changed files with 14 additions and 18 deletions

View File

@@ -29,7 +29,7 @@ Chroma can use all the same optimizations as Flux.
## Inference
The Diffusers version of Chroma is based on the `unlocked-v37` version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
The Diffusers version of Chroma is based on the [`unlocked-v37`](https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors) version of the original model, which is available in the [Chroma repository](https://huggingface.co/lodestones/Chroma).
```python
import torch

View File

@@ -52,20 +52,21 @@ EXAMPLE_DOC_STRING = """
>>> import torch
>>> from diffusers import ChromaPipeline
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-schnell",
>>> pipe = ChromaPipeline.from_pretrained(
... model_id,
... transformer=transformer,
... text_encoder=text_encoder,
... tokenizer=tokenizer,
... torch_dtype=torch.bfloat16,
... )
>>> pipe.enable_model_cpu_offload()
>>> prompt = "A cat holding a sign that says hello world"
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
>>> prompt = [
... "A high-fashion close-up portrait of a blonde woman in clear sunglasses. The image uses a bold teal and red color split for dramatic lighting. The background is a simple teal-green. The photo is sharp and well-composed, and is designed for viewing with anaglyph 3D glasses for optimal effect. It looks professionally done."
... ]
>>> negative_prompt = [
... "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
... ]
>>> image = pipe(prompt, negative_prompt=negative_prompt).images[0]
>>> image.save("chroma.png")
```

View File

@@ -51,26 +51,21 @@ EXAMPLE_DOC_STRING = """
```py
>>> import torch
>>> from diffusers import ChromaTransformer2DModel, ChromaImg2ImgPipeline
>>> from transformers import AutoModel, Autotokenizer
>>> model_id = "lodestones/Chroma"
>>> ckpt_path = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v37.safetensors"
>>> transformer = ChromaTransformer2DModel.from_single_file(ckpt_path, torch_dtype=torch.bfloat16)
>>> text_encoder = AutoModel.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="text_encoder_2")
>>> tokenizer = AutoTokenizer.from_pretrained("black-forest-labs/FLUX.1-schnell", subfolder="tokenizer_2")
>>> pipe = ChromaImg2ImgPipeline.from_pretrained(
... "black-forest-labs/FLUX.1-schnell",
... model_id,
... transformer=transformer,
... text_encoder=text_encoder,
... tokenizer=tokenizer,
... torch_dtype=torch.bfloat16,
... )
>>> pipe.enable_model_cpu_offload()
>>> image = load_image(
>>> init_image = load_image(
... "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
... )
>>> prompt = "a scenic fastasy landscape with a river and mountains in the background, vibrant colors, detailed, high resolution"
>>> negative_prompt = "low quality, ugly, unfinished, out of focus, deformed, disfigure, blurry, smudged, restricted palette, flat colors"
>>> image = pipe(prompt, image=image, negative_prompt=negative_prompt).images[0]
>>> image = pipe(prompt, image=init_image, negative_prompt=negative_prompt).images[0]
>>> image.save("chroma-img2img.png")
```
"""