mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 18:13:46 +08:00
Feat: first version
This commit is contained in:
@ -1620,6 +1620,19 @@ class Accelerator:
|
||||
|
||||
self._cp_context = functools.partial(context_parallel, mesh=self.torch_device_mesh["cp"])
|
||||
|
||||
try:
|
||||
from torch.distributed.tensor.experimental._attention import (
|
||||
create_cp_block_mask,
|
||||
)
|
||||
|
||||
self._create_block_mask_fn = functools.partial(
|
||||
create_cp_block_mask, device_mesh=self.torch_device_mesh["cp"]
|
||||
)
|
||||
except ImportError:
|
||||
from torch.nn.attention.flex_attention import create_block_mask
|
||||
|
||||
self._create_block_mask_fn = create_block_mask
|
||||
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.nn.Module):
|
||||
_attach_context_parallel_hooks(arg)
|
||||
@ -4042,6 +4055,22 @@ class Accelerator:
|
||||
)
|
||||
yield
|
||||
|
||||
def create_block_mask(
|
||||
self,
|
||||
mask_mod,
|
||||
B,
|
||||
H,
|
||||
Q_LEN,
|
||||
KV_LEN,
|
||||
):
|
||||
return self._create_block_mask_fn(
|
||||
mask_mod=mask_mod,
|
||||
B=B,
|
||||
H=H,
|
||||
Q_LEN=Q_LEN,
|
||||
KV_LEN=KV_LEN,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def autocast(self, autocast_handler: AutocastKwargs = None):
|
||||
"""
|
||||
|
Reference in New Issue
Block a user