1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Flax pipeline pndm (#583)

* WIP: flax FlaxDiffusionPipeline & FlaxStableDiffusionPipeline

* todo comment

* Fix imports

* Fix imports

* add dummies

* Fix empty init

* make pipeline work

* up

* Allow dtype to be overridden on model load.

This may be a temporary solution until #567 is addressed.

* Convert params to bfloat16 or fp16 after loading.

This deals with the weights, not the model.

* Use Flax schedulers (typing, docstring)

* PNDM: replace control flow with jax functions.

Otherwise jitting/parallelization don't work properly as they don't know
how to deal with traced objects.

I temporarily removed `step_prk`.

* Pass latents shape to scheduler set_timesteps()

PNDMScheduler uses it to reserve space, other schedulers will just
ignore it.

* Wrap model imports inside availability checks.

* Optionally return state in from_config.

Useful for Flax schedulers.

* Do not convert model weights to dtype.

* Re-enable PRK steps with functional implementation.

Values returned still not verified for correctness.

* Remove left over has_state var.

* make style

* Apply suggestion list -> tuple

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply suggestion list -> tuple

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Remove unused comments.

* Use zeros instead of empty.

Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
Co-authored-by: Mishig Davaadorj <mishig.davaadorj@coloradocollege.edu>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Pedro Cuenca
2022-09-27 14:16:11 +02:00
committed by GitHub
parent c070e5f0c5
commit ab3fd671d7
3 changed files with 135 additions and 52 deletions

View File

@@ -56,5 +56,6 @@ if is_transformers_available() and is_flax_available():
images: Union[List[PIL.Image.Image], np.ndarray]
nsfw_content_detected: List[bool]
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker

View File

@@ -186,7 +186,9 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
latents, scheduler_state = self.scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()
return latents, scheduler_state
scheduler_state = self.scheduler.set_timesteps(params["scheduler"], num_inference_steps=num_inference_steps)
scheduler_state = self.scheduler.set_timesteps(
params["scheduler"], num_inference_steps=num_inference_steps, shape=latents.shape
)
if debug:
# run with python for loop

View File

@@ -19,6 +19,7 @@ from dataclasses import dataclass
from typing import Optional, Tuple, Union
import flax
import jax
import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
@@ -155,7 +156,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
def create_state(self):
return PNDMSchedulerState.create(num_train_timesteps=self.config.num_train_timesteps)
def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState:
def set_timesteps(self, state: PNDMSchedulerState, shape: Tuple, num_inference_steps: int) -> PNDMSchedulerState:
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
@@ -196,8 +197,11 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
return state.replace(
timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64),
ets=jnp.array([]),
counter=0,
# Reserve space for the state variables
cur_model_output=jnp.zeros(shape),
cur_sample=jnp.zeros(shape),
ets=jnp.zeros((4,) + shape),
)
def step(
@@ -227,22 +231,32 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
When returning a tuple, the first element is the sample tensor.
"""
if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps:
return self.step_prk(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
if self.config.skip_prk_steps:
prev_sample, state = self.step_plms(
state=state, model_output=model_output, timestep=timestep, sample=sample
)
else:
return self.step_plms(
state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict
prev_sample, state = jax.lax.switch(
jnp.where(state.counter < len(state.prk_timesteps), 0, 1),
(self.step_prk, self.step_plms),
# Args to either branch
state,
model_output,
timestep,
sample,
)
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
def step_prk(
self,
state: PNDMSchedulerState,
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the
@@ -266,34 +280,46 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
)
diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2
diff_to_prev = jnp.where(
state.counter % 2, 0, self.config.num_train_timesteps // state.num_inference_steps // 2
)
prev_timestep = timestep - diff_to_prev
timestep = state.prk_timesteps[state.counter // 4 * 4]
if state.counter % 4 == 0:
state = state.replace(
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
ets=state.ets.append(model_output),
cur_sample=sample,
def remainder_0(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return (
state.replace(
cur_model_output=state.cur_model_output + 1 / 6 * model_output,
ets=state.ets.at[ets_at].set(model_output),
cur_sample=sample,
),
model_output,
)
elif (self.counter - 1) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 2) % 4 == 0:
state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output)
elif (self.counter - 3) % 4 == 0:
def remainder_1(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
def remainder_2(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
return state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output), model_output
def remainder_3(state: PNDMSchedulerState, model_output: jnp.ndarray, ets_at: int):
model_output = state.cur_model_output + 1 / 6 * model_output
state = state.replace(cur_model_output=0)
return state.replace(cur_model_output=jnp.zeros_like(state.cur_model_output)), model_output
# cur_sample should not be `None`
cur_sample = state.cur_sample if state.cur_sample is not None else sample
state, model_output = jax.lax.switch(
state.counter % 4,
(remainder_0, remainder_1, remainder_2, remainder_3),
# Args to either branch
state,
model_output,
state.counter // 4,
)
cur_sample = state.cur_sample
prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
return (prev_sample, state)
def step_plms(
self,
@@ -301,7 +327,6 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
return_dict: bool = True,
) -> Union[FlaxSchedulerOutput, Tuple]:
"""
Step function propagating the sample with the linear multi-step method. This has one forward pass with multiple
@@ -334,36 +359,91 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
)
prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps
prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0)
if state.counter != 1:
state = state.replace(ets=state.ets.append(model_output))
else:
prev_timestep = timestep
timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
# Reference:
# if state.counter != 1:
# state.ets.append(model_output)
# else:
# prev_timestep = timestep
# timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps
if len(state.ets) == 1 and state.counter == 0:
model_output = model_output
state = state.replace(cur_sample=sample)
elif len(state.ets) == 1 and state.counter == 1:
model_output = (model_output + state.ets[-1]) / 2
sample = state.cur_sample
state = state.replace(cur_sample=None)
elif len(state.ets) == 2:
model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
elif len(state.ets) == 3:
model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
else:
model_output = (1 / 24) * (
55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]
prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep)
timestep = jnp.where(
state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep
)
# Reference:
# if len(state.ets) == 1 and state.counter == 0:
# model_output = model_output
# state.cur_sample = sample
# elif len(state.ets) == 1 and state.counter == 1:
# model_output = (model_output + state.ets[-1]) / 2
# sample = state.cur_sample
# state.cur_sample = None
# elif len(state.ets) == 2:
# model_output = (3 * state.ets[-1] - state.ets[-2]) / 2
# elif len(state.ets) == 3:
# model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12
# else:
# model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4])
def counter_0(state: PNDMSchedulerState):
ets = state.ets.at[0].set(model_output)
return state.replace(
ets=ets,
cur_sample=sample,
cur_model_output=jnp.array(model_output, dtype=jnp.float32),
)
def counter_1(state: PNDMSchedulerState):
return state.replace(
cur_model_output=(model_output + state.ets[0]) / 2,
)
def counter_2(state: PNDMSchedulerState):
ets = state.ets.at[1].set(model_output)
return state.replace(
ets=ets,
cur_model_output=(3 * ets[1] - ets[0]) / 2,
cur_sample=sample,
)
def counter_3(state: PNDMSchedulerState):
ets = state.ets.at[2].set(model_output)
return state.replace(
ets=ets,
cur_model_output=(23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12,
cur_sample=sample,
)
def counter_other(state: PNDMSchedulerState):
ets = state.ets.at[3].set(model_output)
next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0])
ets = ets.at[0].set(ets[1])
ets = ets.at[1].set(ets[2])
ets = ets.at[2].set(ets[3])
return state.replace(
ets=ets,
cur_model_output=next_model_output,
cur_sample=sample,
)
counter = jnp.clip(state.counter, 0, 4)
state = jax.lax.switch(
counter,
[counter_0, counter_1, counter_2, counter_3, counter_other],
state,
)
sample = state.cur_sample
model_output = state.cur_model_output
prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output)
state = state.replace(counter=state.counter + 1)
if not return_dict:
return (prev_sample, state)
return FlaxSchedulerOutput(prev_sample=prev_sample, state=state)
return (prev_sample, state)
def _get_prev_sample(self, sample, timestep, prev_timestep, model_output):
# See formula (9) of PNDM paper https://arxiv.org/pdf/2202.09778.pdf
@@ -379,7 +459,7 @@ class FlaxPNDMScheduler(SchedulerMixin, ConfigMixin):
# model_output -> e_θ(x_t, t)
# prev_sample -> x_(tδ)
alpha_prod_t = self.alphas_cumprod[timestep]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod)
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev