mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Flax pipeline pndm (#583)
* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline * todo comment * Fix imports * Fix imports * add dummies * Fix empty init * make pipeline work * up * Allow dtype to be overridden on model load. This may be a temporary solution until #567 is addressed. * Convert params to bfloat16 or fp16 after loading. This deals with the weights, not the model. * Use Flax schedulers (typing, docstring) * PNDM: replace control flow with jax functions. Otherwise jitting/parallelization don't work properly as they don't know how to deal with traced objects. I temporarily removed `step_prk`. * Pass latents shape to scheduler set_timesteps() PNDMScheduler uses it to reserve space, other schedulers will just ignore it. * Wrap model imports inside availability checks. * Optionally return state in from_config. Useful for Flax schedulers. * Do not convert model weights to dtype. * Re-enable PRK steps with functional implementation. Values returned still not verified for correctness. * Remove left over has_state var. * make style * Apply suggestion list -> tuple Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestion list -> tuple Co-authored-by: Suraj Patil <surajp815@gmail.com> * Remove unused comments. * Use zeros instead of empty. Co-authored-by: Mishig Davaadorj <dmishig@gmail.com> Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
@@ -56,5 +56,6 @@ if is_transformers_available() and is_flax_available():
|
||||
images: Union[List[PIL.Image.Image], np.ndarray]
|
||||
nsfw_content_detected: List[bool]
|
||||
|
||||
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
|
||||
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
|
||||
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker
|
||||
|
||||
@@ -186,7 +186,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
|
||||
return latents, scheduler_state
|
||||
|
||||
scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
|
||||
scheduler_state = self.scheduler.set_timesteps(
|
||||
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
|
||||
)
|
||||
|
||||
if debug:
|
||||
# run with python for loop
|
||||
|
||||
@@ -19,6 +19,7 @@ from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
@@ -155,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
def create_state(self):
|
||||
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
|
||||
|
||||
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
|
||||
def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState:
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
@@ -196,8 +197,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
return state.replace(
|
||||
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
|
||||
ets=jnp.array([]),
|
||||
counter=0,
|
||||
# Reserve space for the state variables
|
||||
cur_model_output=jnp.zeros(shape),
|
||||
cur_sample=jnp.zeros(shape),
|
||||
ets=jnp.zeros((4,) + shape),
|
||||
)
|
||||
|
||||
def step(
|
||||
@@ -227,22 +231,32 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps:
|
||||
return self.step_prk(
|
||||
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
|
||||
if self.config.skip_prk_steps:
|
||||
prev_sample, state = self.step_plms(
|
||||
state=state, model_output=model_output, timestep=timestep, sample=sample
|
||||
)
|
||||
else:
|
||||
return self.step_plms(
|
||||
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
|
||||
prev_sample, state = jax.lax.switch(
|
||||
jnp.where(state.counter < len(state.prk_timesteps), 0, 1),
|
||||
(self.step_prk, self.step_plms),
|
||||
# Args to either branch
|
||||
state,
|
||||
model_output,
|
||||
timestep,
|
||||
sample,
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
|
||||
def step_prk(
|
||||
self,
|
||||
state: PNDMSchedulerState,
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
|
||||
@@ -266,34 +280,46 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
||||
)
|
||||
|
||||
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
|
||||
diff_to_prev = jnp.where(
|
||||
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
|
||||
)
|
||||
prev_timestep = timestep - diff_to_prev
|
||||
timestep = state.prk_timesteps[state.counter // 4 * 4]
|
||||
|
||||
if state.counter % 4 == 0:
|
||||
state = state.replace(
|
||||
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
|
||||
ets=state.ets.append(model_output),
|
||||
cur_sample=sample,
|
||||
def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
|
||||
return (
|
||||
state.replace(
|
||||
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
|
||||
ets=state.ets.at[ets_at].set(model_output),
|
||||
cur_sample=sample,
|
||||
),
|
||||
model_output,
|
||||
)
|
||||
elif (self.counter - 1) % 4 == 0:
|
||||
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
|
||||
elif (self.counter - 2) % 4 == 0:
|
||||
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
|
||||
elif (self.counter - 3) % 4 == 0:
|
||||
|
||||
def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
|
||||
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
|
||||
|
||||
def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
|
||||
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
|
||||
|
||||
def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
|
||||
model_output = state.cur_model_output + 1 / 6 * model_output
|
||||
state = state.replace(cur_model_output=0)
|
||||
return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output
|
||||
|
||||
# cur_sample should not be `None`
|
||||
cur_sample = state.cur_sample if state.cur_sample is not None else sample
|
||||
state, model_output = jax.lax.switch(
|
||||
state.counter % 4,
|
||||
(remainder_0, remainder_1, remainder_2, remainder_3),
|
||||
# Args to either branch
|
||||
state,
|
||||
model_output,
|
||||
state.counter // 4,
|
||||
)
|
||||
|
||||
cur_sample = state.cur_sample
|
||||
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
|
||||
state = state.replace(counter=state.counter + 1)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
return (prev_sample, state)
|
||||
|
||||
def step_plms(
|
||||
self,
|
||||
@@ -301,7 +327,6 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
|
||||
@@ -334,36 +359,91 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
)
|
||||
|
||||
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
|
||||
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)
|
||||
|
||||
if state.counter != 1:
|
||||
state = state.replace(ets=state.ets.append(model_output))
|
||||
else:
|
||||
prev_timestep = timestep
|
||||
timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
|
||||
# Reference:
|
||||
# if state.counter != 1:
|
||||
# state.ets.append(model_output)
|
||||
# else:
|
||||
# prev_timestep = timestep
|
||||
# timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
|
||||
|
||||
if len(state.ets) == 1 and state.counter == 0:
|
||||
model_output = model_output
|
||||
state = state.replace(cur_sample=sample)
|
||||
elif len(state.ets) == 1 and state.counter == 1:
|
||||
model_output = (model_output + state.ets[-1]) / 2
|
||||
sample = state.cur_sample
|
||||
state = state.replace(cur_sample=None)
|
||||
elif len(state.ets) == 2:
|
||||
model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
|
||||
elif len(state.ets) == 3:
|
||||
model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
|
||||
else:
|
||||
model_output = (1 / 24) * (
|
||||
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
|
||||
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
|
||||
timestep = jnp.where(
|
||||
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
|
||||
)
|
||||
|
||||
# Reference:
|
||||
# if len(state.ets) == 1 and state.counter == 0:
|
||||
# model_output = model_output
|
||||
# state.cur_sample = sample
|
||||
# elif len(state.ets) == 1 and state.counter == 1:
|
||||
# model_output = (model_output + state.ets[-1]) / 2
|
||||
# sample = state.cur_sample
|
||||
# state.cur_sample = None
|
||||
# elif len(state.ets) == 2:
|
||||
# model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
|
||||
# elif len(state.ets) == 3:
|
||||
# model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
|
||||
# else:
|
||||
# model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
|
||||
|
||||
def counter_0(state: PNDMSchedulerState):
|
||||
ets = state.ets.at[0].set(model_output)
|
||||
return state.replace(
|
||||
ets=ets,
|
||||
cur_sample=sample,
|
||||
cur_model_output=jnp.array(model_output, dtype=jnp.float32),
|
||||
)
|
||||
|
||||
def counter_1(state: PNDMSchedulerState):
|
||||
return state.replace(
|
||||
cur_model_output=(model_output + state.ets[0]) / 2,
|
||||
)
|
||||
|
||||
def counter_2(state: PNDMSchedulerState):
|
||||
ets = state.ets.at[1].set(model_output)
|
||||
return state.replace(
|
||||
ets=ets,
|
||||
cur_model_output=(3 * ets[1] - ets[0]) / 2,
|
||||
cur_sample=sample,
|
||||
)
|
||||
|
||||
def counter_3(state: PNDMSchedulerState):
|
||||
ets = state.ets.at[2].set(model_output)
|
||||
return state.replace(
|
||||
ets=ets,
|
||||
cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
|
||||
cur_sample=sample,
|
||||
)
|
||||
|
||||
def counter_other(state: PNDMSchedulerState):
|
||||
ets = state.ets.at[3].set(model_output)
|
||||
next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])
|
||||
|
||||
ets = ets.at[0].set(ets[1])
|
||||
ets = ets.at[1].set(ets[2])
|
||||
ets = ets.at[2].set(ets[3])
|
||||
|
||||
return state.replace(
|
||||
ets=ets,
|
||||
cur_model_output=next_model_output,
|
||||
cur_sample=sample,
|
||||
)
|
||||
|
||||
counter = jnp.clip(state.counter, 0, 4)
|
||||
state = jax.lax.switch(
|
||||
counter,
|
||||
[counter_0, counter_1, counter_2, counter_3, counter_other],
|
||||
state,
|
||||
)
|
||||
|
||||
sample = state.cur_sample
|
||||
model_output = state.cur_model_output
|
||||
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
|
||||
state = state.replace(counter=state.counter + 1)
|
||||
|
||||
if not return_dict:
|
||||
return (prev_sample, state)
|
||||
|
||||
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
|
||||
return (prev_sample, state)
|
||||
|
||||
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
|
||||
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
|
||||
@@ -379,7 +459,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
# model_output -> e_θ(x_t, t)
|
||||
# prev_sample -> x_(t−δ)
|
||||
alpha_prod_t = self.alphas_cumprod[timestep]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
||||
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
|
||||
Reference in New Issue
Block a user