mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
replace references to deprecated KeyArray & PRNGKeyArray (#5324)
This commit is contained in:
4
setup.py
4
setup.py
@@ -102,8 +102,8 @@ _deps = [
|
||||
"importlib_metadata",
|
||||
"invisible-watermark>=0.2.0",
|
||||
"isort>=5.5.4",
|
||||
"jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib>=0.1.65",
|
||||
"jax>=0.4.1",
|
||||
"jaxlib>=0.4.1",
|
||||
"Jinja2",
|
||||
"k-diffusion>=0.0.12",
|
||||
"torchsde",
|
||||
|
||||
@@ -15,8 +15,8 @@ deps = {
|
||||
"importlib_metadata": "importlib_metadata",
|
||||
"invisible-watermark": "invisible-watermark>=0.2.0",
|
||||
"isort": "isort>=5.5.4",
|
||||
"jax": "jax>=0.2.8,!=0.3.2",
|
||||
"jaxlib": "jaxlib>=0.1.65",
|
||||
"jax": "jax>=0.4.1",
|
||||
"jaxlib": "jaxlib>=0.4.1",
|
||||
"Jinja2": "Jinja2",
|
||||
"k-diffusion": "k-diffusion>=0.0.12",
|
||||
"torchsde": "torchsde",
|
||||
|
||||
@@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
controlnet_conditioning_channel_order: str = "rgb"
|
||||
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
|
||||
@@ -192,7 +192,7 @@ class FlaxModelMixin(PushToHubMixin):
|
||||
```"""
|
||||
return self._cast_floating_to(params, jnp.float16, mask)
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> Dict:
|
||||
def init_weights(self, rng: jax.Array) -> Dict:
|
||||
raise NotImplementedError(f"init_weights method has to be implemented for {self}")
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
addition_embed_type_num_heads: int = 64
|
||||
projection_class_embeddings_input_dim: Optional[int] = None
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
|
||||
@@ -817,7 +817,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
||||
# init input tensors
|
||||
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
|
||||
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
||||
|
||||
@@ -241,7 +241,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
num_inference_steps: int,
|
||||
guidance_scale: float,
|
||||
latents: Optional[jnp.array] = None,
|
||||
@@ -351,7 +351,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: Union[float, jnp.array] = 7.5,
|
||||
latents: jnp.array = None,
|
||||
@@ -370,7 +370,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline):
|
||||
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
|
||||
params (`Dict` or `FrozenDict`):
|
||||
Dictionary containing the model parameters/weights.
|
||||
prng_seed (`jax.random.KeyArray` or `jax.Array`):
|
||||
prng_seed (`jax.Array` or `jax.Array`):
|
||||
Array containing random number generator key.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
|
||||
@@ -215,7 +215,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
num_inference_steps: int,
|
||||
height: int,
|
||||
width: int,
|
||||
@@ -312,7 +312,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline):
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
num_inference_steps: int = 50,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
|
||||
@@ -235,7 +235,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
start_timestep: int,
|
||||
num_inference_steps: int,
|
||||
height: int,
|
||||
@@ -340,7 +340,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
||||
prompt_ids: jnp.array,
|
||||
image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
strength: float = 0.8,
|
||||
num_inference_steps: int = 50,
|
||||
height: Optional[int] = None,
|
||||
@@ -361,7 +361,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline):
|
||||
Array representing an image batch to be used as the starting point.
|
||||
params (`Dict` or `FrozenDict`):
|
||||
Dictionary containing the model parameters/weights.
|
||||
prng_seed (`jax.random.KeyArray` or `jax.Array`):
|
||||
prng_seed (`jax.Array` or `jax.Array`):
|
||||
Array containing random number generator key.
|
||||
strength (`float`, *optional*, defaults to 0.8):
|
||||
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
|
||||
|
||||
@@ -270,7 +270,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
|
||||
mask: jnp.array,
|
||||
masked_image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
num_inference_steps: int,
|
||||
height: int,
|
||||
width: int,
|
||||
@@ -398,7 +398,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline):
|
||||
mask: jnp.array,
|
||||
masked_image: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
num_inference_steps: int = 50,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
|
||||
@@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel):
|
||||
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
||||
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)
|
||||
|
||||
def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
def init_weights(self, rng: jax.Array, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
|
||||
# init input tensor
|
||||
clip_input = jax.random.normal(rng, input_shape)
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
||||
self,
|
||||
prompt_ids: jax.Array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: Union[float, jax.Array] = 7.5,
|
||||
height: Optional[int] = None,
|
||||
@@ -170,7 +170,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline):
|
||||
self,
|
||||
prompt_ids: jnp.array,
|
||||
params: Union[Dict, FrozenDict],
|
||||
prng_seed: jax.random.KeyArray,
|
||||
prng_seed: jax.Array,
|
||||
num_inference_steps: int,
|
||||
height: int,
|
||||
width: int,
|
||||
|
||||
@@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
key: Optional[jax.random.KeyArray] = None,
|
||||
key: Optional[jax.Array] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
|
||||
"""
|
||||
@@ -211,7 +211,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`jnp.ndarray`):
|
||||
current instance of sample being created by diffusion process.
|
||||
key (`jax.random.KeyArray`): a PRNG key.
|
||||
key (`jax.Array`): a PRNG key.
|
||||
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -17,6 +17,7 @@ from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
@@ -139,7 +140,7 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
state: KarrasVeSchedulerState,
|
||||
sample: jnp.ndarray,
|
||||
sigma: float,
|
||||
key: random.KeyArray,
|
||||
key: jax.Array,
|
||||
) -> Tuple[jnp.ndarray, float]:
|
||||
"""
|
||||
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
|
||||
|
||||
@@ -18,6 +18,7 @@ from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import random
|
||||
|
||||
@@ -169,7 +170,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
model_output: jnp.ndarray,
|
||||
timestep: int,
|
||||
sample: jnp.ndarray,
|
||||
key: random.KeyArray,
|
||||
key: jax.Array,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxSdeVeOutput, Tuple]:
|
||||
"""
|
||||
@@ -228,7 +229,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
state: ScoreSdeVeSchedulerState,
|
||||
model_output: jnp.ndarray,
|
||||
sample: jnp.ndarray,
|
||||
key: random.KeyArray,
|
||||
key: jax.Array,
|
||||
return_dict: bool = True,
|
||||
) -> Union[FlaxSdeVeOutput, Tuple]:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user