1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
DN6
2025-10-07 15:54:04 +05:30
parent faf61a4877
commit 428399b590
3 changed files with 52 additions and 49 deletions

View File

@@ -79,29 +79,47 @@ class ContextParallelConfig:
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.ring_degree == 1 and self.ulysses_degree == 1:
raise ValueError(
"Either ring_degree or ulysses_degree must be greater than 1 in order to use context parallel inference"
)
if self.ring_degree < 1 or self.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if self.ring_degree > 1 and self.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
)
@property
def mesh_shape(self) -> Tuple[int, int]:
"""Shape of the device mesh (ring_degree, ulysses_degree)."""
return (self.ring_degree, self.ulysses_degree)
@property
def mesh_dim_names(self) -> Tuple[str, str]:
"""Dimension names for the device mesh."""
return ("ring", "ulysses")
def setup(self, rank: int, world_size: int, device: torch.device, mesh: torch.distributed.device_mesh.DeviceMesh):
self._rank = rank
self._world_size = world_size
self._device = device
self._mesh = mesh
if self.ring_degree is None:
self.ring_degree = 1
if self.ulysses_degree is None:
self.ulysses_degree = 1
if self.rotate_method != "allgather":
raise NotImplementedError(
f"Only rotate_method='allgather' is supported for now, but got {self.rotate_method}."
if self.ulysses_degree * self.ring_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({self.ring_degree}) and `ulysses_degree` ({self.ulysses_degree}) must not exceed the world size ({world_size})."
)
if self._flattened_mesh is None:
self._flattened_mesh = self._mesh._flatten()
if self._ring_mesh is None:
self._ring_mesh = self._mesh["ring"]
if self._ulysses_mesh is None:
self._ulysses_mesh = self._mesh["ulysses"]
if self._ring_local_rank is None:
self._ring_local_rank = self._ring_mesh.get_local_rank()
if self._ulysses_local_rank is None:
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
self._flattened_mesh = self._mesh._flatten()
self._ring_mesh = self._mesh["ring"]
self._ulysses_mesh = self._mesh["ulysses"]
self._ring_local_rank = self._ring_mesh.get_local_rank()
self._ulysses_local_rank = self._ulysses_mesh.get_local_rank()
@dataclass
@@ -119,7 +137,7 @@ class ParallelConfig:
_rank: int = None
_world_size: int = None
_device: torch.device = None
_cp_mesh: torch.distributed.device_mesh.DeviceMesh = None
_mesh: torch.distributed.device_mesh.DeviceMesh = None
def setup(
self,
@@ -127,14 +145,14 @@ class ParallelConfig:
world_size: int,
device: torch.device,
*,
cp_mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
mesh: Optional[torch.distributed.device_mesh.DeviceMesh] = None,
):
self._rank = rank
self._world_size = world_size
self._device = device
self._cp_mesh = cp_mesh
self._mesh = mesh
if self.context_parallel_config is not None:
self.context_parallel_config.setup(rank, world_size, device, cp_mesh)
self.context_parallel_config.setup(rank, world_size, device, mesh)
@dataclass(frozen=True)

View File

@@ -244,11 +244,7 @@ class _AttentionBackendRegistry:
supports_context_parallel = (
backend in cls._supports_context_parallel and cls._supports_context_parallel[backend]
)
is_degree_greater_than_1 = parallel_config is not None and (
parallel_config.context_parallel_config.ring_degree > 1
or parallel_config.context_parallel_config.ulysses_degree > 1
)
return supports_context_parallel and is_degree_greater_than_1
return supports_context_parallel and parallel_config.context_parallel_config is not None
@contextlib.contextmanager

View File

@@ -1483,46 +1483,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
config: Union[ParallelConfig, ContextParallelConfig],
cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None,
):
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
logger.warning(
"`enable_parallelism` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
)
if isinstance(config, ContextParallelConfig):
config = ParallelConfig(context_parallel_config=config)
if not torch.distributed.is_initialized():
raise RuntimeError("torch.distributed must be initialized before calling `enable_parallelism`.")
from ..hooks.context_parallel import apply_context_parallel
from .attention import AttentionModuleMixin
from .attention_processor import Attention, MochiAttention
if isinstance(config, ContextParallelConfig):
config = ParallelConfig(context_parallel_config=config)
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
device_type = torch._C._get_accelerator().type
device_module = torch.get_device_module(device_type)
device = torch.device(device_type, rank % device_module.device_count())
cp_mesh = None
mesh = None
if config.context_parallel_config is not None:
cp_config = config.context_parallel_config
if cp_config.ring_degree < 1 or cp_config.ulysses_degree < 1:
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
if cp_config.ring_degree > 1 and cp_config.ulysses_degree > 1:
raise ValueError(
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
)
if cp_config.ring_degree * cp_config.ulysses_degree > world_size:
raise ValueError(
f"The product of `ring_degree` ({cp_config.ring_degree}) and `ulysses_degree` ({cp_config.ulysses_degree}) must not exceed the world size ({world_size})."
)
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
mesh = torch.distributed.device_mesh.init_device_mesh(
device_type=device_type,
mesh_shape=(cp_config.ring_degree, cp_config.ulysses_degree),
mesh_dim_names=("ring", "ulysses"),
mesh_shape=cp_config.mesh_shape,
mesh_dim_names=cp_config.mesh_dim_names,
)
config.setup(rank, world_size, device, cp_mesh=cp_mesh)
config.setup(rank, world_size, device, mesh=mesh)
if cp_plan is None and self._cp_plan is None:
raise ValueError(