From 210a07b13cb7589005e1e3b6368e5deb0d815b66 Mon Sep 17 00:00:00 2001 From: Yusuke Suzuki Date: Tue, 14 Nov 2023 20:16:52 +0900 Subject: [PATCH] fix exception around NSFW filter on flax stable diffusion (#5675) --- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index bcf2a62177..5598477c92 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -410,13 +410,13 @@ class FlaxStableDiffusionPipeline(FlaxDiffusionPipeline): images_uint8_casted = np.asarray(images_uint8_casted).reshape(num_devices * batch_size, height, width, 3) images_uint8_casted, has_nsfw_concept = self._run_safety_checker(images_uint8_casted, safety_params, jit) - images = np.asarray(images) + images = np.asarray(images).copy() # block images if any(has_nsfw_concept): for i, is_nsfw in enumerate(has_nsfw_concept): if is_nsfw: - images[i] = np.asarray(images_uint8_casted[i]) + images[i, 0] = np.asarray(images_uint8_casted[i]) images = images.reshape(num_devices, batch_size, height, width, 3) else: