mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[tests] fix Pixart Sigma tests (#7966)
* checking tests * checking ii. * remove prints. * test_pixart_1024 * fix 1024.
This commit is contained in:
@@ -336,7 +336,7 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.0742, 0.0835, 0.2114, 0.0295, 0.0784, 0.2361, 0.1738, 0.2251, 0.3589])
|
||||
expected_slice = np.array([0.4517, 0.4446, 0.4375, 0.449, 0.4399, 0.4365, 0.4583, 0.4629, 0.4473])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
@@ -344,7 +344,12 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_pixart_512(self):
|
||||
generator = torch.Generator("cpu").manual_seed(0)
|
||||
|
||||
pipe = PixArtSigmaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
|
||||
transformer = Transformer2DModel.from_pretrained(
|
||||
self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = PixArtSigmaPipeline.from_pretrained(
|
||||
self.ckpt_id_1024, transformer=transformer, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = self.prompt
|
||||
@@ -352,7 +357,7 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
|
||||
image = pipe(prompt, generator=generator, num_inference_steps=2, output_type="np").images
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
expected_slice = np.array([0.3477, 0.3882, 0.4541, 0.3413, 0.3821, 0.4463, 0.4001, 0.4409, 0.4958])
|
||||
expected_slice = np.array([0.0479, 0.0378, 0.0217, 0.0942, 0.064, 0.0791, 0.2073, 0.1975, 0.2017])
|
||||
|
||||
max_diff = numpy_cosine_similarity_distance(image_slice.flatten(), expected_slice)
|
||||
self.assertLessEqual(max_diff, 1e-4)
|
||||
@@ -394,7 +399,12 @@ class PixArtSigmaPipelineIntegrationTests(unittest.TestCase):
|
||||
def test_pixart_512_without_resolution_binning(self):
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
pipe = PixArtSigmaPipeline.from_pretrained(self.ckpt_id_512, torch_dtype=torch.float16)
|
||||
transformer = Transformer2DModel.from_pretrained(
|
||||
self.ckpt_id_512, subfolder="transformer", torch_dtype=torch.float16
|
||||
)
|
||||
pipe = PixArtSigmaPipeline.from_pretrained(
|
||||
self.ckpt_id_1024, transformer=transformer, torch_dtype=torch.float16
|
||||
)
|
||||
pipe.enable_model_cpu_offload()
|
||||
|
||||
prompt = self.prompt
|
||||
|
||||
Reference in New Issue
Block a user