mirror of
https://github.com/huggingface/accelerate.git
synced 2025-10-20 10:03:46 +08:00
Tmp: create device mesh beforehand
This commit is contained in:
@ -521,6 +521,23 @@ class Accelerator:
|
||||
gradient_accumulation_plugin=gradient_accumulation_plugin,
|
||||
)
|
||||
|
||||
if self.is_fsdp2:
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
|
||||
context_parallel_size = self.state.fsdp_plugin.cp_size
|
||||
|
||||
world_size = self.state.num_processes
|
||||
|
||||
fsdp_size = world_size // context_parallel_size
|
||||
|
||||
device_mesh = init_device_mesh(
|
||||
device_type=self.device.type,
|
||||
mesh_shape=(fsdp_size, context_parallel_size),
|
||||
mesh_dim_names=("fsdp", "cp"),
|
||||
)
|
||||
device_mesh["fsdp", "cp"]._flatten("fsdp_cp")
|
||||
self.state.torch_device_mesh = device_mesh
|
||||
|
||||
self.device_placement = device_placement
|
||||
if dataloader_config is None:
|
||||
dataloader_config = DataLoaderConfiguration()
|
||||
@ -1543,26 +1560,13 @@ class Accelerator:
|
||||
f"`cp_size` set to {context_parallel_size}, which is greater than the number of processes {self.state.num_processes}. Please set to 1 to disable context parallel or use a smaller value."
|
||||
)
|
||||
|
||||
from torch.distributed.device_mesh import init_device_mesh
|
||||
from torch.distributed.tensor.experimental import context_parallel
|
||||
from torch.distributed.tensor.experimental._attention import set_rotate_method
|
||||
|
||||
cp_comm_strategy = self.state.fsdp_plugin.cp_comm_strategy
|
||||
set_rotate_method(cp_comm_strategy)
|
||||
|
||||
world_size = self.state.num_processes
|
||||
|
||||
fsdp_size = world_size // context_parallel_size
|
||||
|
||||
device_mesh = init_device_mesh(
|
||||
device_type=self.device.type,
|
||||
mesh_shape=(fsdp_size, context_parallel_size),
|
||||
mesh_dim_names=("fsdp", "cp"),
|
||||
)
|
||||
self.state.torch_device_mesh = device_mesh
|
||||
device_mesh["fsdp", "cp"]._flatten("fsdp_cp")
|
||||
|
||||
self._cp_context = functools.partial(context_parallel, mesh=device_mesh["cp"])
|
||||
self._cp_context = functools.partial(context_parallel, mesh=self.state.torch_device_mesh["cp"])
|
||||
|
||||
# Apply AC if needed
|
||||
if self.state.fsdp_plugin.activation_checkpointing:
|
||||
|
Reference in New Issue
Block a user