mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
refactor
This commit is contained in:
@@ -23,9 +23,8 @@ from ..models._modeling_parallel import (
|
||||
ContextParallelInput,
|
||||
ContextParallelModelPlan,
|
||||
ContextParallelOutput,
|
||||
ParallelConfig,
|
||||
_InternalParallelConfig,
|
||||
)
|
||||
from ..models.attention_dispatch import _parallel_context
|
||||
from ..utils import get_logger
|
||||
from ..utils.torch_utils import unwrap_module
|
||||
from .hooks import HookRegistry, ModelHook
|
||||
@@ -33,7 +32,6 @@ from .hooks import HookRegistry, ModelHook
|
||||
|
||||
logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
_CONTEXT_PARALLEL_MODEL_HOOK = "context_parallel_model_hook"
|
||||
_CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE = "cp_input---{}"
|
||||
_CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE = "cp_output---{}"
|
||||
|
||||
@@ -76,7 +74,7 @@ class ModuleForwardMetadata:
|
||||
|
||||
def apply_context_parallel(
|
||||
module: torch.nn.Module,
|
||||
parallel_config: ParallelConfig,
|
||||
parallel_config: _InternalParallelConfig,
|
||||
plan: Dict[str, ContextParallelModelPlan],
|
||||
) -> None:
|
||||
"""Apply context parallel on a model."""
|
||||
@@ -105,45 +103,26 @@ def apply_context_parallel(
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
registry.register_hook(hook, hook_name)
|
||||
|
||||
# HACK: we cannot use context managers or setattr or similar solutions in an overwritten forward
|
||||
# diffusers hook method because Dynamo fails to trace it. Instead, we make use of module hooks
|
||||
# available in pytorch to set the parallel context before/after the forward/backward pass.
|
||||
# It is dirty, but fullgraph=True tracing works because of this and I haven't found a better solution yet.
|
||||
# The previous/older implementation simply did this:
|
||||
# def new_forward(self, ...):
|
||||
# with _parallel_context(parallel_config):
|
||||
# return self.fn_ref.original_forward(*args, **kwargs)
|
||||
# TODO: ask help from Pytorch team on how to improve this
|
||||
@torch.compiler.disable
|
||||
def forward_pre_hook(module, args):
|
||||
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
|
||||
module._diffusers_parallel_config_setter_context.__enter__()
|
||||
|
||||
@torch.compiler.disable
|
||||
def forward_hook(module, args, output):
|
||||
if module._diffusers_parallel_config_setter_context is not None:
|
||||
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
|
||||
module._diffusers_parallel_config_setter_context = None
|
||||
def remove_context_parallel(module: torch.nn.Module, plan: Dict[str, ContextParallelModelPlan]) -> None:
|
||||
for module_id, cp_model_plan in plan.items():
|
||||
submodule = _get_submodule_by_name(module, module_id)
|
||||
if not isinstance(submodule, list):
|
||||
submodule = [submodule]
|
||||
|
||||
@torch.compiler.disable
|
||||
def backward_pre_hook(module, grad_output):
|
||||
module._diffusers_parallel_config_setter_context = _parallel_context(parallel_config)
|
||||
module._diffusers_parallel_config_setter_context.__enter__()
|
||||
|
||||
@torch.compiler.disable
|
||||
def backward_hook(module, grad_output, grad_input):
|
||||
if module._diffusers_parallel_config_setter_context is not None:
|
||||
module._diffusers_parallel_config_setter_context.__exit__(None, None, None)
|
||||
module._diffusers_parallel_config_setter_context = None
|
||||
|
||||
module.register_forward_pre_hook(forward_pre_hook)
|
||||
module.register_forward_hook(forward_hook)
|
||||
module.register_full_backward_pre_hook(backward_pre_hook)
|
||||
module.register_full_backward_hook(backward_hook)
|
||||
for m in submodule:
|
||||
registry = HookRegistry.check_if_exists_or_initialize(m)
|
||||
if isinstance(cp_model_plan, dict):
|
||||
hook_name = _CONTEXT_PARALLEL_INPUT_HOOK_TEMPLATE.format(module_id)
|
||||
elif isinstance(cp_model_plan, (ContextParallelOutput, list, tuple)):
|
||||
hook_name = _CONTEXT_PARALLEL_OUTPUT_HOOK_TEMPLATE.format(module_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported context parallel model plan type: {type(cp_model_plan)}")
|
||||
registry.remove_hook(hook_name)
|
||||
|
||||
|
||||
class ContextParallelSplitHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
@@ -228,7 +207,7 @@ class ContextParallelSplitHook(ModelHook):
|
||||
|
||||
|
||||
class ContextParallelGatherHook(ModelHook):
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: ParallelConfig) -> None:
|
||||
def __init__(self, metadata: ContextParallelModelPlan, parallel_config: _InternalParallelConfig) -> None:
|
||||
super().__init__()
|
||||
self.metadata = metadata
|
||||
self.parallel_config = parallel_config
|
||||
|
||||
@@ -25,6 +25,7 @@ from ..utils import (
|
||||
_import_structure = {}
|
||||
|
||||
if is_torch_available():
|
||||
_import_structure["_modeling_parallel"] = ["ParallelConfig"]
|
||||
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
|
||||
_import_structure["attention_dispatch"] = ["AttentionBackendName", "attention_backend"]
|
||||
_import_structure["auto_model"] = ["AutoModel"]
|
||||
@@ -112,6 +113,7 @@ if is_flax_available():
|
||||
|
||||
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
||||
if is_torch_available():
|
||||
from ._modeling_parallel import ParallelConfig
|
||||
from .adapter import MultiAdapter, T2IAdapter
|
||||
from .attention_dispatch import AttentionBackendName, attention_backend
|
||||
from .auto_model import AutoModel
|
||||
|
||||
@@ -35,6 +35,18 @@ logger = get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
@dataclass
|
||||
class ParallelConfig:
|
||||
ring_degree: Optional[int] = None
|
||||
ulysses_degree: Optional[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.ring_degree is None:
|
||||
self.ring_degree = 1
|
||||
if self.ulysses_degree is None:
|
||||
self.ulysses_degree = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class _InternalParallelConfig:
|
||||
rank: int
|
||||
world_size: int
|
||||
ring_degree: int
|
||||
|
||||
@@ -38,8 +38,9 @@ from ..utils import (
|
||||
)
|
||||
from ..utils.constants import DIFFUSERS_ATTN_BACKEND, DIFFUSERS_ATTN_CHECKS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._modeling_parallel import ParallelConfig
|
||||
from ._modeling_parallel import _InternalParallelConfig
|
||||
|
||||
_REQUIRED_FLASH_VERSION = "2.6.3"
|
||||
_REQUIRED_SAGE_VERSION = "2.1.1"
|
||||
@@ -58,12 +59,12 @@ _CAN_USE_XFORMERS_ATTN = is_xformers_available() and is_xformers_version(">=", _
|
||||
|
||||
if _CAN_USE_FLASH_ATTN:
|
||||
from flash_attn import flash_attn_func, flash_attn_varlen_func
|
||||
from flash_attn.flash_attn_interface import _flash_attn_backward, _flash_attn_forward
|
||||
from flash_attn.flash_attn_interface import _wrapped_flash_attn_backward, _wrapped_flash_attn_forward
|
||||
else:
|
||||
flash_attn_func = None
|
||||
flash_attn_varlen_func = None
|
||||
_flash_attn_forward = None
|
||||
_flash_attn_backward = None
|
||||
_wrapped_flash_attn_backward = None
|
||||
_wrapped_flash_attn_forward = None
|
||||
|
||||
|
||||
if _CAN_USE_FLASH_ATTN_3:
|
||||
@@ -192,7 +193,7 @@ class _AttentionBackendRegistry:
|
||||
_supports_context_parallel = {}
|
||||
_active_backend = AttentionBackendName(DIFFUSERS_ATTN_BACKEND)
|
||||
_checks_enabled = DIFFUSERS_ATTN_CHECKS
|
||||
_parallel_config: Optional["ParallelConfig"] = None
|
||||
_parallel_config: Optional["_InternalParallelConfig"] = None
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
@@ -252,17 +253,6 @@ def attention_backend(backend: Union[str, AttentionBackendName] = AttentionBacke
|
||||
_AttentionBackendRegistry._active_backend = old_backend
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _parallel_context(parallel_config: "ParallelConfig"):
|
||||
old_parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
_AttentionBackendRegistry._parallel_config = parallel_config
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_AttentionBackendRegistry._parallel_config = old_parallel_config
|
||||
|
||||
|
||||
def dispatch_attention_fn(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
@@ -637,14 +627,15 @@ class _cudnn_attention_af(torch.autograd.Function):
|
||||
if enable_gqa:
|
||||
raise ValueError("`enable_gqa` is not yet supported for cuDNN attention.")
|
||||
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_causal = is_causal
|
||||
ctx.scale = scale
|
||||
ctx.attn_mask = attn_mask
|
||||
tensors_to_save = ()
|
||||
|
||||
# Contiguous is a must here! Calling cuDNN backend with aten ops produces incorrect results
|
||||
# if the input tensors are not contiguous.
|
||||
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
|
||||
query = query.transpose(1, 2).contiguous()
|
||||
tensors_to_save += (query, key, value)
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
|
||||
out, lse, cum_seq_q, cum_seq_k, max_q, max_k, philox_seed, philox_offset, debug_attn_mask = (
|
||||
torch.ops.aten._scaled_dot_product_cudnn_attention(
|
||||
query=query,
|
||||
@@ -659,9 +650,14 @@ class _cudnn_attention_af(torch.autograd.Function):
|
||||
)
|
||||
)
|
||||
|
||||
tensors_to_save += (out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
|
||||
ctx.save_for_backward(*tensors_to_save)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_causal = is_causal
|
||||
ctx.scale = scale
|
||||
ctx.attn_mask = attn_mask
|
||||
ctx.max_q = max_q
|
||||
ctx.max_k = max_k
|
||||
ctx.save_for_backward(query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset)
|
||||
|
||||
out = out.transpose(1, 2).contiguous()
|
||||
if lse is not None:
|
||||
@@ -674,8 +670,12 @@ class _cudnn_attention_af(torch.autograd.Function):
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = ctx.saved_tensors
|
||||
saved_tensors = ctx.saved_tensors if hasattr(ctx, "saved_tensors") else ctx.to_save
|
||||
query, key, value, out, lse, cum_seq_q, cum_seq_k, philox_seed, philox_offset = saved_tensors
|
||||
|
||||
grad_out = grad_out.transpose(1, 2).contiguous()
|
||||
key = key.transpose(1, 2).contiguous()
|
||||
value = value.transpose(1, 2).contiguous()
|
||||
|
||||
# Cannot pass first 5 arguments as kwargs because: https://github.com/pytorch/pytorch/blob/d26ca5de058dbcf56ac52bb43e84dd98df2ace97/torch/_dynamo/variables/torch.py#L1341
|
||||
grad_query, grad_key, grad_value = torch.ops.aten._scaled_dot_product_cudnn_attention_backward(
|
||||
@@ -726,15 +726,33 @@ class _flash_attention_2_af(torch.autograd.Function):
|
||||
softcap = 0.0
|
||||
alibi_slopes = None
|
||||
deterministic = False
|
||||
grad_enabled = any(x.requires_grad for x in (query, key, value))
|
||||
|
||||
if scale is None:
|
||||
scale = query.shape[-1] ** (-0.5)
|
||||
|
||||
# flash-attn only returns LSE if dropout_p > 0. So, we need to workaround.
|
||||
parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
if query.requires_grad or (parallel_config is not None and parallel_config.world_size > 1):
|
||||
if grad_enabled or (parallel_config is not None and parallel_config.world_size > 1):
|
||||
dropout_p = dropout_p if dropout_p > 0 else 1e-30
|
||||
|
||||
with torch.set_grad_enabled(grad_enabled):
|
||||
out, lse, S_dmask, rng_state = _wrapped_flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
lse = lse.permute(0, 2, 1)
|
||||
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.scale = scale
|
||||
ctx.is_causal = is_causal
|
||||
@@ -743,22 +761,6 @@ class _flash_attention_2_af(torch.autograd.Function):
|
||||
ctx.alibi_slopes = alibi_slopes
|
||||
ctx.deterministic = deterministic
|
||||
|
||||
out, lse, S_dmask, rng_state = _flash_attn_forward(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
dropout_p,
|
||||
scale,
|
||||
is_causal,
|
||||
window_size[0],
|
||||
window_size[1],
|
||||
softcap,
|
||||
alibi_slopes,
|
||||
return_lse,
|
||||
)
|
||||
|
||||
ctx.save_for_backward(query, key, value, out, lse, rng_state)
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@staticmethod
|
||||
@@ -767,10 +769,11 @@ class _flash_attention_2_af(torch.autograd.Function):
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
query, key, value, out, lse, rng_state = ctx.saved_tensors
|
||||
saved_tensors = ctx.saved_tensors if hasattr(ctx, "saved_tensors") else ctx.to_save
|
||||
query, key, value, out, lse, rng_state = saved_tensors
|
||||
grad_query, grad_key, grad_value = torch.empty_like(query), torch.empty_like(key), torch.empty_like(value)
|
||||
|
||||
lse_d = _flash_attn_backward( # noqa: F841
|
||||
lse_d = _wrapped_flash_attn_backward( # noqa: F841
|
||||
grad_out,
|
||||
query,
|
||||
key,
|
||||
@@ -876,10 +879,18 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
ring_mesh = parallel_config._ring_mesh
|
||||
rank = parallel_config._ring_local_rank
|
||||
world_size = parallel_config.ring_degree
|
||||
|
||||
next_rank = (rank + 1) % world_size
|
||||
prev_out = prev_lse = None
|
||||
|
||||
ctx.save_for_backward(query, key, value)
|
||||
ctx.dropout_p = dropout_p
|
||||
ctx.is_causal = is_causal
|
||||
ctx.scale = scale
|
||||
ctx.enable_gqa = enable_gqa
|
||||
ctx.return_lse = return_lse
|
||||
ctx.op = op
|
||||
ctx.op_ctx = torch.autograd.function.FunctionCtx()
|
||||
|
||||
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
|
||||
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
|
||||
kv_buffer = kv_buffer.chunk(world_size)
|
||||
@@ -887,11 +898,14 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
for i in range(world_size):
|
||||
if i > 0:
|
||||
kv = kv_buffer[next_rank]
|
||||
key = kv[: key.numel()].reshape_as(key)
|
||||
value = kv[key.numel() :].reshape_as(value)
|
||||
key_numel = key.numel()
|
||||
key = kv[:key_numel].reshape_as(key)
|
||||
value = kv[key_numel:].reshape_as(value)
|
||||
next_rank = (next_rank + 1) % world_size
|
||||
|
||||
out, lse = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True)
|
||||
out, lse = op.forward(
|
||||
ctx.op_ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, True
|
||||
)
|
||||
|
||||
if parallel_config.convert_to_fp32:
|
||||
out = out.to(torch.float32)
|
||||
@@ -906,6 +920,7 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
|
||||
out = out.to(query.dtype)
|
||||
lse = lse.squeeze(-1)
|
||||
|
||||
return (out, lse) if return_lse else out
|
||||
|
||||
@staticmethod
|
||||
@@ -914,7 +929,55 @@ class TemplatedRingAttention(torch.autograd.Function):
|
||||
grad_out: torch.Tensor,
|
||||
*args,
|
||||
):
|
||||
raise NotImplementedError("Backward pass is not implemented for TemplatedRingAttention.")
|
||||
parallel_config = _AttentionBackendRegistry._parallel_config
|
||||
ring_mesh = parallel_config._ring_mesh
|
||||
rank = parallel_config._ring_local_rank
|
||||
world_size = parallel_config.ring_degree
|
||||
next_rank = (rank + 1) % world_size
|
||||
next_ranks = list(range(1, world_size)) + [0]
|
||||
|
||||
query, key, value = ctx.saved_tensors
|
||||
|
||||
accum_dtype = torch.float32 if parallel_config.convert_to_fp32 else query.dtype
|
||||
grad_query = torch.zeros_like(query, dtype=accum_dtype)
|
||||
grad_key = torch.zeros_like(key, dtype=accum_dtype)
|
||||
grad_value = torch.zeros_like(value, dtype=accum_dtype)
|
||||
next_grad_kv = None
|
||||
|
||||
kv_buffer = torch.cat([key.flatten(), value.flatten()]).contiguous()
|
||||
kv_buffer = funcol.all_gather_tensor(kv_buffer, gather_dim=0, group=ring_mesh.get_group())
|
||||
kv_buffer = kv_buffer.chunk(world_size)
|
||||
|
||||
for i in range(world_size):
|
||||
if i > 0:
|
||||
kv = kv_buffer[next_rank]
|
||||
key_numel = key.numel()
|
||||
key = kv[:key_numel].reshape_as(key)
|
||||
value = kv[key_numel:].reshape_as(value)
|
||||
next_rank = (next_rank + 1) % world_size
|
||||
|
||||
saved_tensors = list(ctx.op_ctx.to_save)
|
||||
saved_tensors[1] = key
|
||||
saved_tensors[2] = value
|
||||
ctx.op_ctx.to_save = tuple(saved_tensors)
|
||||
|
||||
grad_query_op, grad_key_op, grad_value_op, *_ = ctx.op.backward(ctx.op_ctx, grad_out)
|
||||
|
||||
if i > 0:
|
||||
grad_kv_buffer = _wait_tensor(next_grad_kv)
|
||||
grad_key_numel = grad_key.numel()
|
||||
grad_key = grad_kv_buffer[:grad_key_numel].reshape_as(grad_key)
|
||||
grad_value = grad_kv_buffer[grad_key_numel:].reshape_as(grad_value)
|
||||
|
||||
grad_query += grad_query_op
|
||||
grad_key += grad_key_op
|
||||
grad_value += grad_value_op
|
||||
|
||||
if i < world_size - 1:
|
||||
grad_kv_buffer = torch.cat([grad_key.flatten(), grad_value.flatten()]).contiguous()
|
||||
next_grad_kv = funcol.permute_tensor(grad_kv_buffer, next_ranks, group=ring_mesh.get_group())
|
||||
|
||||
return grad_query, grad_key, grad_value, None, None, None, None, None, None, None
|
||||
|
||||
|
||||
class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
@@ -937,6 +1000,8 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
world_size = parallel_config.ulysses_degree
|
||||
group = ulysses_mesh.get_group()
|
||||
|
||||
ctx.op_ctx = torch.autograd.function.FunctionCtx()
|
||||
|
||||
B, S_LOCAL, H, D = query.shape
|
||||
H_LOCAL = H // world_size
|
||||
query, key, value = (
|
||||
@@ -948,7 +1013,7 @@ class TemplatedUlyssesAttention(torch.autograd.Function):
|
||||
)
|
||||
query, key, value = (x.flatten(0, 1).permute(1, 0, 2, 3).contiguous() for x in (query, key, value))
|
||||
|
||||
out = op.apply(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
|
||||
out = op.forward(ctx.op_ctx, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa, return_lse)
|
||||
if return_lse:
|
||||
out, lse, *_ = out
|
||||
|
||||
@@ -1291,7 +1356,7 @@ def _native_cudnn_attention(
|
||||
|
||||
lse = None
|
||||
if parallel_config is None and not return_lse:
|
||||
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
|
||||
query, key, value = (x.permute(0, 2, 1, 3).contiguous() for x in (query, key, value))
|
||||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.CUDNN_ATTENTION):
|
||||
out = torch.nn.functional.scaled_dot_product_attention(
|
||||
query=query,
|
||||
|
||||
@@ -63,6 +63,7 @@ from ..utils.hub_utils import (
|
||||
populate_model_card,
|
||||
)
|
||||
from ..utils.torch_utils import empty_device_cache
|
||||
from ._modeling_parallel import ContextParallelModelPlan, ParallelConfig, _InternalParallelConfig
|
||||
from .model_loading_utils import (
|
||||
_caching_allocator_warmup,
|
||||
_determine_device_map,
|
||||
@@ -1501,19 +1502,20 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
|
||||
)
|
||||
|
||||
def parallelize(self, *, ring_degree: int = 1, ulysses_degree: int = 1, cp_plan=None):
|
||||
from ..hooks.context_parallel import ParallelConfig, apply_context_parallel
|
||||
@contextmanager
|
||||
def parallelize(self, *, config: ParallelConfig, cp_plan: Optional[Dict[str, ContextParallelModelPlan]] = None):
|
||||
from ..hooks.context_parallel import apply_context_parallel, remove_context_parallel
|
||||
from .attention_dispatch import _AttentionBackendRegistry
|
||||
|
||||
# TODO(aryan): add cp_plan type hint
|
||||
logger.warning(
|
||||
"`parallelize` is an experimental feature. The API may change in the future and breaking changes may be introduced at any time without warning."
|
||||
)
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
raise RuntimeError("torch.distributed must be initialized before calling `parallelize`.")
|
||||
if ring_degree < 1 or ulysses_degree < 1:
|
||||
if config.ring_degree < 1 or config.ulysses_degree < 1:
|
||||
raise ValueError("`ring_degree` and `ulysses_degree` must be greater than or equal to 1.")
|
||||
if ring_degree > 1 and ulysses_degree > 1:
|
||||
if config.ring_degree > 1 and config.ulysses_degree > 1:
|
||||
raise ValueError(
|
||||
"Unified Ulysses-Ring attention is not yet supported. Please set either `ring_degree` or `ulysses_degree` to 1."
|
||||
)
|
||||
@@ -1521,9 +1523,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
if ring_degree * ulysses_degree > world_size:
|
||||
if config.ring_degree * config.ulysses_degree > world_size:
|
||||
raise ValueError(
|
||||
f"The product of `ring_degree` ({ring_degree}) and `ulysses_degree` ({ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
f"The product of `ring_degree` ({config.ring_degree}) and `ulysses_degree` ({config.ulysses_degree}) must not exceed the world size ({world_size})."
|
||||
)
|
||||
|
||||
device_type = torch._C._get_accelerator().type
|
||||
@@ -1532,14 +1534,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
|
||||
cp_mesh = torch.distributed.device_mesh.init_device_mesh(
|
||||
device_type=device_type,
|
||||
mesh_shape=(ring_degree, ulysses_degree),
|
||||
mesh_shape=(config.ring_degree, config.ulysses_degree),
|
||||
mesh_dim_names=("ring", "ulysses"),
|
||||
)
|
||||
parallel_config = ParallelConfig(
|
||||
parallel_config = _InternalParallelConfig(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
ring_degree=ring_degree,
|
||||
ulysses_degree=ulysses_degree,
|
||||
ring_degree=config.ring_degree,
|
||||
ulysses_degree=config.ulysses_degree,
|
||||
device=device,
|
||||
cp_mesh=cp_mesh,
|
||||
)
|
||||
@@ -1550,6 +1552,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
||||
cp_plan = cp_plan if cp_plan is not None else self._cp_plan
|
||||
|
||||
apply_context_parallel(self, parallel_config, cp_plan)
|
||||
_AttentionBackendRegistry._parallel_config = parallel_config
|
||||
|
||||
yield
|
||||
|
||||
remove_context_parallel(self, cp_plan)
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
|
||||
Reference in New Issue
Block a user