1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
sayakpaul
2025-11-24 14:08:47 +05:30
parent 8048623daf
commit debafc6960
15 changed files with 15 additions and 53 deletions

View File

@@ -19,6 +19,7 @@ import numpy as np
import torch
from ...pipelines import FluxPipeline
from ...pipelines.flux.pipeline_flux_utils import calculate_shift
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ...utils.torch_utils import randn_tensor
@@ -90,20 +91,6 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"

View File

@@ -74,7 +74,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -72,7 +72,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -32,7 +32,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_flux_utils import FluxMixin, retrieve_latents, retrieve_timesteps
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput
@@ -93,19 +93,6 @@ PREFERRED_KONTEXT_RESOLUTIONS = [
]
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
class FluxKontextPipeline(
DiffusionPipeline,
FluxMixin,

View File

@@ -22,7 +22,7 @@ from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import is_torch_xla_available, logging, replace_example_docstring
from ...utils.torch_utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline
from .pipeline_flux_utils import FluxMixin, retrieve_latents, retrieve_timesteps
from .pipeline_flux_utils import FluxMixin, calculate_shift, retrieve_latents, retrieve_timesteps
from .pipeline_output import FluxPipelineOutput
@@ -117,19 +117,6 @@ PREFERRED_KONTEXT_RESOLUTIONS = [
]
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
class FluxKontextInpaintPipeline(
DiffusionPipeline,
FluxMixin,

View File

@@ -83,7 +83,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -161,7 +161,7 @@ DEFAULT_PROMPT_TEMPLATE = {
}
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -66,7 +66,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -134,7 +134,7 @@ def linear_quadratic_schedule(num_steps, threshold_noise=0.025, linear_steps=Non
return torch.tensor(sigma_schedule[:-1])
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -71,7 +71,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -60,7 +60,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -57,6 +57,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -71,7 +71,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -77,7 +77,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,

View File

@@ -76,7 +76,7 @@ EXAMPLE_DOC_STRING = """
"""
# Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,