Feat: first version

This commit is contained in:
S1ro1
2025-08-23 15:03:28 +00:00
parent 979d81e4a9
commit 91985ab9d7

View File

@ -1620,6 +1620,19 @@ class Accelerator:
self._cp_context = functools.partial(context_parallel, mesh=self.torch_device_mesh["cp"])
try:
from torch.distributed.tensor.experimental._attention import (
create_cp_block_mask,
)
self._create_block_mask_fn = functools.partial(
create_cp_block_mask, device_mesh=self.torch_device_mesh["cp"]
)
except ImportError:
from torch.nn.attention.flex_attention import create_block_mask
self._create_block_mask_fn = create_block_mask
for arg in args:
if isinstance(arg, torch.nn.Module):
_attach_context_parallel_hooks(arg)
@ -4042,6 +4055,22 @@ class Accelerator:
)
yield
def create_block_mask(
self,
mask_mod,
B,
H,
Q_LEN,
KV_LEN,
):
return self._create_block_mask_fn(
mask_mod=mask_mod,
B=B,
H=H,
Q_LEN=Q_LEN,
KV_LEN=KV_LEN,
)
@contextmanager
def autocast(self, autocast_handler: AutocastKwargs = None):
"""