From d9b73ffd5155507c7990aa792b439b886c35dbb9 Mon Sep 17 00:00:00 2001 From: DN6 Date: Mon, 15 Dec 2025 16:12:50 +0530 Subject: [PATCH] update --- tests/models/testing_utils/parallelism.py | 102 ++++++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 tests/models/testing_utils/parallelism.py diff --git a/tests/models/testing_utils/parallelism.py b/tests/models/testing_utils/parallelism.py new file mode 100644 index 0000000000..3bbbfe91bb --- /dev/null +++ b/tests/models/testing_utils/parallelism.py @@ -0,0 +1,102 @@ +# coding=utf-8 +# Copyright 2025 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import pytest +import torch +import torch.multiprocessing as mp + +from diffusers.models._modeling_parallel import ContextParallelConfig + +from ...testing_utils import ( + is_context_parallel, + require_torch_multi_accelerator, +) + + +def _context_parallel_worker(rank, world_size, model_class, init_dict, cp_dict, inputs_dict, result_queue): + try: + # Setup distributed environment + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "12355" + + torch.distributed.init_process_group( + backend="nccl", + init_method="env://", + world_size=world_size, + rank=rank, + ) + torch.cuda.set_device(rank) + device = torch.device(f"cuda:{rank}") + + model = model_class(**init_dict) + model.to(device) + model.eval() + + inputs_on_device = {} + for key, value in inputs_dict.items(): + if isinstance(value, torch.Tensor): + inputs_on_device[key] = value.to(device) + else: + inputs_on_device[key] = value + + cp_config = ContextParallelConfig(**cp_dict) + model.enable_parallelism(config=cp_config) + + with torch.no_grad(): + output = model(**inputs_on_device, return_dict=False)[0] + + if rank == 0: + result_queue.put(("success", output.shape)) + + except Exception as e: + if rank == 0: + result_queue.put(("error", str(e))) + finally: + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +@is_context_parallel +@require_torch_multi_accelerator +class ContextParallelTesterMixin: + base_precision = 1e-3 + + @pytest.mark.parametrize("cp_type", ["ulysses_degree", "ring_degree"], ids=["ulysses", "ring"]) + def test_context_parallel_inference(self, cp_type): + if not torch.distributed.is_available(): + pytest.skip("torch.distributed is not available.") + + if not hasattr(self.model_class, "_cp_plan") or self.model_class._cp_plan is None: + pytest.skip("Model does not have a _cp_plan defined for context parallel inference.") + + world_size = 2 + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() + cp_dict = {cp_type: world_size} + + ctx = mp.get_context("spawn") + result_queue = ctx.Queue() + + mp.spawn( + _context_parallel_worker, + args=(world_size, self.model_class, init_dict, cp_dict, inputs_dict, result_queue), + nprocs=world_size, + join=True, + ) + + status, result = result_queue.get(timeout=60) + assert status == "success", f"Context parallel inference failed: {result}"