Tmp: create device mesh beforehand

This commit is contained in:
S1ro1
2025-06-19 15:11:30 +00:00
parent 8df21cf54a
commit 67be9a69ba

View File

@ -521,6 +521,23 @@ class Accelerator:
gradient_accumulation_plugin=gradient_accumulation_plugin,
)
if self.is_fsdp2:
from torch.distributed.device_mesh import init_device_mesh
context_parallel_size = self.state.fsdp_plugin.cp_size
world_size = self.state.num_processes
fsdp_size = world_size // context_parallel_size
device_mesh = init_device_mesh(
device_type=self.device.type,
mesh_shape=(fsdp_size, context_parallel_size),
mesh_dim_names=("fsdp", "cp"),
)
device_mesh["fsdp", "cp"]._flatten("fsdp_cp")
self.state.torch_device_mesh = device_mesh
self.device_placement = device_placement
if dataloader_config is None:
dataloader_config = DataLoaderConfiguration()
@ -1543,26 +1560,13 @@ class Accelerator:
f"`cp_size` set to {context_parallel_size}, which is greater than the number of processes {self.state.num_processes}. Please set to 1 to disable context parallel or use a smaller value."
)
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import set_rotate_method
cp_comm_strategy = self.state.fsdp_plugin.cp_comm_strategy
set_rotate_method(cp_comm_strategy)
world_size = self.state.num_processes
fsdp_size = world_size // context_parallel_size
device_mesh = init_device_mesh(
device_type=self.device.type,
mesh_shape=(fsdp_size, context_parallel_size),
mesh_dim_names=("fsdp", "cp"),
)
self.state.torch_device_mesh = device_mesh
device_mesh["fsdp", "cp"]._flatten("fsdp_cp")
self._cp_context = functools.partial(context_parallel, mesh=device_mesh["cp"])
self._cp_context = functools.partial(context_parallel, mesh=self.state.torch_device_mesh["cp"])
# Apply AC if needed
if self.state.fsdp_plugin.activation_checkpointing: