diff --git a/setup.py b/setup.py index 16a77c1fe3..7ad5646d4f 100644 --- a/setup.py +++ b/setup.py @@ -102,8 +102,8 @@ _deps = [ "importlib_metadata", "invisible-watermark>=0.2.0", "isort>=5.5.4", - "jax>=0.2.8,!=0.3.2", - "jaxlib>=0.1.65", + "jax>=0.4.1", + "jaxlib>=0.4.1", "Jinja2", "k-diffusion>=0.0.12", "torchsde", diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index d4b94ba6d4..970013c31a 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -15,8 +15,8 @@ deps = { "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark>=0.2.0", "isort": "isort>=5.5.4", - "jax": "jax>=0.2.8,!=0.3.2", - "jaxlib": "jaxlib>=0.1.65", + "jax": "jax>=0.4.1", + "jaxlib": "jaxlib>=0.4.1", "Jinja2": "Jinja2", "k-diffusion": "k-diffusion>=0.0.12", "torchsde": "torchsde", diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index a826df48e4..076e618321 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): controlnet_conditioning_channel_order: str = "rgb" conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) - def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py index 97f7b43bc6..ea4d1bfea5 100644 --- a/src/diffusers/models/modeling_flax_utils.py +++ b/src/diffusers/models/modeling_flax_utils.py @@ -192,7 +192,7 @@ class FlaxModelMixin(PushToHubMixin): ```""" return self._cast_floating_to(params, jnp.float16, mask) - def init_weights(self, rng: jax.random.KeyArray) -> Dict: + def init_weights(self, rng: jax.Array) -> Dict: raise NotImplementedError(f"init_weights method has to be implemented for {self}") @classmethod diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index a56db67b6a..770cbf09cc 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): addition_embed_type_num_heads: int = 64 projection_class_embeddings_input_dim: Optional[int] = None - def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index b8f5b1d0e3..d2dde2ba19 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -817,7 +817,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin): dtype=self.dtype, ) - def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py index b2c8871aa0..b57e776e49 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -241,7 +241,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int, guidance_scale: float, latents: Optional[jnp.array] = None, @@ -351,7 +351,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int = 50, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, @@ -370,7 +370,7 @@ class FlaxStableDiffusionControlNetPipeline(FlaxDiffusionPipeline): Array representing the ControlNet input condition to provide guidance to the `unet` for generation. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights. - prng_seed (`jax.random.KeyArray` or `jax.Array`): + prng_seed (`jax.Array` or `jax.Array`): Array containing random number generator key. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 131a7c7bc2..a847cd15c6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -215,7 +215,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, @@ -312,7 +312,7 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index a9717533fa..42a79db6b2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -235,7 +235,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, start_timestep: int, num_inference_steps: int, height: int, @@ -340,7 +340,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, strength: float = 0.8, num_inference_steps: int = 50, height: Optional[int] = None, @@ -361,7 +361,7 @@ class FlaxStableDiffusionImg2ImgPipeline(FlaxDiffusionPipeline): Array representing an image batch to be used as the starting point. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights. - prng_seed (`jax.random.KeyArray` or `jax.Array`): + prng_seed (`jax.Array` or `jax.Array`): Array containing random number generator key. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py index b43fa38370..153267da10 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -270,7 +270,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): mask: jnp.array, masked_image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, @@ -398,7 +398,7 @@ class FlaxStableDiffusionInpaintPipeline(FlaxDiffusionPipeline): mask: jnp.array, masked_image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 3a8c316795..5966600462 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -87,7 +87,7 @@ class FlaxStableDiffusionSafetyChecker(FlaxPreTrainedModel): module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + def init_weights(self, rng: jax.Array, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor clip_input = jax.random.normal(rng, input_shape) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 3acb5ae538..8f043c7c66 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -89,7 +89,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): self, prompt_ids: jax.Array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int = 50, guidance_scale: Union[float, jax.Array] = 7.5, height: Optional[int] = None, @@ -170,7 +170,7 @@ class FlaxStableDiffusionXLPipeline(FlaxDiffusionPipeline): self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 529d2bd03a..ab7d70f466 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -198,7 +198,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: Optional[jax.random.KeyArray] = None, + key: Optional[jax.Array] = None, return_dict: bool = True, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ @@ -211,7 +211,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin): timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - key (`jax.random.KeyArray`): a PRNG key. + key (`jax.Array`): a PRNG key. return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index 45c0dbddf7..4a8606007d 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -17,6 +17,7 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union import flax +import jax import jax.numpy as jnp from jax import random @@ -139,7 +140,7 @@ class FlaxKarrasVeScheduler(FlaxSchedulerMixin, ConfigMixin): state: KarrasVeSchedulerState, sample: jnp.ndarray, sigma: float, - key: random.KeyArray, + key: jax.Array, ) -> Tuple[jnp.ndarray, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index b6240559fc..935f972a9b 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -18,6 +18,7 @@ from dataclasses import dataclass from typing import Optional, Tuple, Union import flax +import jax import jax.numpy as jnp from jax import random @@ -169,7 +170,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: random.KeyArray, + key: jax.Array, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """ @@ -228,7 +229,7 @@ class FlaxScoreSdeVeScheduler(FlaxSchedulerMixin, ConfigMixin): state: ScoreSdeVeSchedulerState, model_output: jnp.ndarray, sample: jnp.ndarray, - key: random.KeyArray, + key: jax.Array, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """