mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 14:34:54 +08:00
Compare commits
6 Commits
attention_
...
ciflow/tru
| Author | SHA1 | Date | |
|---|---|---|---|
| b3f6c3090b | |||
| ab2665d3cc | |||
| d7d2e8731b | |||
| 84cad15e82 | |||
| f4a9ac120c | |||
| bdd80556c7 |
@ -392,11 +392,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
replicate_size = self.world_size // (pp_size)
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape=(replicate_size, 1, pp_size),
|
||||
mesh_dim_names=("replicate", "shard", "pp"),
|
||||
mesh_shape=(replicate_size, pp_size),
|
||||
mesh_dim_names=("replicate", "pp"),
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
dp_mesh = device_mesh["replicate", "shard"]
|
||||
dp_mesh = device_mesh["replicate"]
|
||||
pp_mesh = device_mesh["pp"]
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
|
||||
@ -416,15 +416,13 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
param_dtype=MixedPrecisionParam,
|
||||
reduce_dtype=torch.float32,
|
||||
)
|
||||
replicate_config = {"mp_policy": mp_policy}
|
||||
replicate_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
|
||||
for layer_id in range(len(partial_model)):
|
||||
replicate(
|
||||
partial_model[layer_id],
|
||||
device_mesh=dp_mesh,
|
||||
**replicate_config,
|
||||
reshard_after_forward=False,
|
||||
)
|
||||
dp_model = replicate(partial_model, device_mesh=dp_mesh, **replicate_config)
|
||||
dp_model = replicate(partial_model, **replicate_config)
|
||||
return dp_model
|
||||
|
||||
# Apply same precision to reference model (without replicate)
|
||||
@ -582,11 +580,11 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
replicate_size = self.world_size // (pp_size)
|
||||
device_mesh = init_device_mesh(
|
||||
device_type,
|
||||
mesh_shape=(replicate_size, 1, pp_size),
|
||||
mesh_dim_names=("replicate", "shard", "pp"),
|
||||
mesh_shape=(replicate_size, pp_size),
|
||||
mesh_dim_names=("replicate", "pp"),
|
||||
)
|
||||
torch.manual_seed(42)
|
||||
dp_mesh = device_mesh["replicate", "shard"]
|
||||
dp_mesh = device_mesh["replicate"]
|
||||
pp_mesh = device_mesh["pp"]
|
||||
pp_group = device_mesh["pp"].get_group()
|
||||
dp_group = device_mesh["replicate"].get_group()
|
||||
@ -648,10 +646,9 @@ class ComposabilityTest(MultiProcessTestCase):
|
||||
for layer_id in range(len(partial_model)):
|
||||
replicate(
|
||||
partial_model[layer_id],
|
||||
device_mesh=dp_mesh,
|
||||
reshard_after_forward=False,
|
||||
mesh=dp_mesh,
|
||||
)
|
||||
dp_model = replicate(partial_model, device_mesh=dp_mesh)
|
||||
dp_model = replicate(partial_model, mesh=dp_mesh)
|
||||
return dp_model
|
||||
|
||||
def pipelined_models_parameters(start_layer, model):
|
||||
|
||||
@ -3,7 +3,7 @@
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
from typing import Optional, Union
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -14,7 +14,6 @@ from torch.distributed.fsdp import MixedPrecisionPolicy
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_collectives import (
|
||||
_get_gradient_divide_factors,
|
||||
)
|
||||
from torch.distributed.tensor import Shard
|
||||
from torch.testing._internal.common_distributed import (
|
||||
requires_nccl_version,
|
||||
SaveForwardInputsModel,
|
||||
@ -46,35 +45,20 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
|
||||
def _init_models_and_optims(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
param_dtype: Optional[torch.dtype],
|
||||
reduce_dtype: Optional[torch.dtype],
|
||||
use_shard_placement_fn,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(*[MLP(16, torch.device("cpu")) for _ in range(3)])
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
|
||||
def _shard_placement_fn(param: nn.Parameter) -> Optional[Shard]:
|
||||
largest_dim = -1
|
||||
largest_dim_size = -1
|
||||
for dim, dim_size in enumerate(param.shape):
|
||||
if dim_size > largest_dim_size:
|
||||
largest_dim = dim
|
||||
largest_dim_size = dim_size
|
||||
assert largest_dim >= 0, f"{param.shape}"
|
||||
return Shard(largest_dim)
|
||||
|
||||
mp_policy = MixedPrecisionPolicy(
|
||||
param_dtype=param_dtype, reduce_dtype=reduce_dtype
|
||||
)
|
||||
shard_placement_fn = _shard_placement_fn if use_shard_placement_fn else None
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mp_policy=mp_policy,
|
||||
shard_placement_fn=shard_placement_fn,
|
||||
)
|
||||
for mlp in model:
|
||||
replicate_fn(mlp)
|
||||
@ -82,27 +66,13 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=True)
|
||||
return ref_model, ref_optim, model, optim
|
||||
|
||||
def _get_use_shard_placement_fn_vals_for_bf16_reduce(self):
|
||||
use_shard_placement_fn_vals = [False]
|
||||
if self.world_size == 2:
|
||||
# For world size >2, gradient elements get reduced in different
|
||||
# orders for the baseline vs. dim-1 sharding, leading to numeric
|
||||
# differences for bf16 reduction, so only test world size 2.
|
||||
use_shard_placement_fn_vals.append(True)
|
||||
return use_shard_placement_fn_vals
|
||||
|
||||
@skipIfRocmVersionLessThan((7, 0))
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_compute_dtype(self):
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"param_dtype": [torch.bfloat16, torch.float16],
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_compute_dtype,
|
||||
)
|
||||
@ -110,14 +80,10 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
def _test_compute_dtype(
|
||||
self,
|
||||
param_dtype: torch.dtype,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
use_shard_placement_fn: bool,
|
||||
):
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=None,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -175,39 +141,14 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@requires_nccl_version((2, 10), "Need NCCL 2.10+ for bf16 collectives")
|
||||
def test_reduce_dtype(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": [False, True],
|
||||
},
|
||||
self._test_reduce_dtype_fp32_reduce,
|
||||
)
|
||||
use_shard_placement_fn_vals = (
|
||||
self._get_use_shard_placement_fn_vals_for_bf16_reduce()
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_shard_placement_fn": use_shard_placement_fn_vals,
|
||||
},
|
||||
self._test_reduce_dtype_bf16_reduce,
|
||||
)
|
||||
self._test_reduce_dtype_fp32_reduce()
|
||||
self._test_reduce_dtype_bf16_reduce()
|
||||
|
||||
def _test_reduce_dtype_fp32_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
):
|
||||
if (
|
||||
self.world_size > 2
|
||||
and isinstance(reshard_after_forward, int)
|
||||
and use_shard_placement_fn
|
||||
):
|
||||
return
|
||||
def _test_reduce_dtype_fp32_reduce(self):
|
||||
param_dtype, reduce_dtype = torch.bfloat16, torch.float32
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
ref_model_bf16 = copy.deepcopy(ref_model).to(param_dtype)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -249,14 +190,12 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
||||
def _test_reduce_dtype_bf16_reduce(
|
||||
self, reshard_after_forward: Union[bool, int], use_shard_placement_fn: bool
|
||||
self,
|
||||
):
|
||||
param_dtype, reduce_dtype = torch.float32, torch.bfloat16
|
||||
ref_model, ref_optim, model, optim = self._init_models_and_optims(
|
||||
reshard_after_forward,
|
||||
param_dtype=param_dtype,
|
||||
reduce_dtype=reduce_dtype,
|
||||
use_shard_placement_fn=use_shard_placement_fn,
|
||||
)
|
||||
group = dist.distributed_c10d._get_default_group()
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
@ -321,12 +260,8 @@ class TestReplicateMixedPrecisionTraining(FSDPTest):
|
||||
ref_model_compute = copy.deepcopy(ref_model).to(param_dtype)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
for mlp in model:
|
||||
replicate(
|
||||
mlp, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
replicate(
|
||||
model, reshard_after_forward=reshard_after_forward, mp_policy=mp_policy
|
||||
)
|
||||
replicate(mlp, mp_policy=mp_policy)
|
||||
replicate(model, mp_policy=mp_policy)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
orig_reduce_scatter = dist.reduce_scatter_tensor
|
||||
|
||||
|
||||
@ -108,84 +108,70 @@ class TestReplicateRegisteredParams(FSDPTestMultiThread):
|
||||
"""Tests the parameter registration after forward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = MLP(3, device)
|
||||
# Since seed is per process, not per thread, we broadcast to ensure
|
||||
# the same parameters across ranks
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
if reshard_after_forward:
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
else:
|
||||
self._assert_tensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
torch.manual_seed(42)
|
||||
model = MLP(3, device)
|
||||
# Since seed is per process, not per thread, we broadcast to ensure
|
||||
# the same parameters across ranks
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model) # root only
|
||||
inp = torch.randn((2, 3), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
self._assert_tensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
for reshard_after_forward in (True, False, None):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model[0].in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model[0].out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(MLP(3, device), MLP(3, device))
|
||||
for param in model.parameters():
|
||||
dist.broadcast(param, src=0)
|
||||
ref_model = copy.deepcopy(model)
|
||||
replicate(model[0].in_proj)
|
||||
replicate(model[0].out_proj)
|
||||
replicate(model)
|
||||
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
non_root_params = list(model[0].in_proj.parameters()) + list(
|
||||
model[0].out_proj.parameters()
|
||||
)
|
||||
root_params = list(set(model.parameters()) - set(non_root_params))
|
||||
if reshard_after_forward is None:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
elif reshard_after_forward:
|
||||
self._assert_dtensor_params(non_root_params)
|
||||
self._assert_dtensor_params(root_params)
|
||||
else:
|
||||
self._assert_tensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
model(inp)
|
||||
non_root_params = list(model[0].in_proj.parameters()) + list(
|
||||
model[0].out_proj.parameters()
|
||||
)
|
||||
root_params = list(set(model.parameters()) - set(non_root_params))
|
||||
self._assert_tensor_params(non_root_params)
|
||||
self._assert_tensor_params(root_params)
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
for module in model.modules():
|
||||
if isinstance(module, FSDPModule):
|
||||
module.reshard() # however, we can manually reshard
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
self._assert_same_params(model.parameters(), ref_model.parameters())
|
||||
|
||||
@skip_if_lt_x_gpu(1)
|
||||
def test_param_registration_after_backward(self):
|
||||
"""Tests the parameter registration after backward."""
|
||||
device = torch.device(device_type.type, 0)
|
||||
# Single Replicate group
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward) # root only
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model = MLP(8, device)
|
||||
replicate(model) # root only
|
||||
inp = torch.randn((2, 8), device=device_type.type)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
# Multiple Replicate groups
|
||||
for reshard_after_forward in (True, False):
|
||||
model = MLP(8, device)
|
||||
replicate(model.in_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model.out_proj, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model = MLP(8, device)
|
||||
replicate(model.in_proj)
|
||||
replicate(model.out_proj)
|
||||
replicate(model)
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
model(inp).sum().backward()
|
||||
self._assert_dtensor_params(model.parameters())
|
||||
|
||||
def _assert_tensor_params(self, params: Iterable[nn.Parameter]):
|
||||
# need to iterate over the list multiple times
|
||||
@ -287,14 +273,11 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
[(7, 15), (15, 3)],
|
||||
[(16, 17), (17, 8)],
|
||||
],
|
||||
"use_shard_placement_fn": [False],
|
||||
},
|
||||
self._test_train_parity_single_group,
|
||||
)
|
||||
|
||||
def _test_train_parity_single_group(
|
||||
self, lin_shapes: list[tuple[int, int]], use_shard_placement_fn: bool
|
||||
):
|
||||
def _test_train_parity_single_group(self, lin_shapes: list[tuple[int, int]]):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(
|
||||
nn.Linear(*lin_shapes[0]), nn.ReLU(), nn.Linear(*lin_shapes[1])
|
||||
@ -333,7 +316,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False],
|
||||
"test_device_type": [device_type.type],
|
||||
"offload_policy": [OffloadPolicy()],
|
||||
"delay_after_forward": [False, True],
|
||||
@ -354,7 +336,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True], # save CI time
|
||||
"offload_policy": [
|
||||
CPUOffloadPolicy(pin_memory=True),
|
||||
CPUOffloadPolicy(pin_memory=False),
|
||||
@ -371,7 +352,6 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
|
||||
def _test_train_parity_multi_group(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
offload_policy: OffloadPolicy,
|
||||
test_device_type: str,
|
||||
delay_after_forward: bool,
|
||||
@ -405,13 +385,12 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
mesh = init_device_mesh(
|
||||
test_device_type,
|
||||
(self.world_size, 1),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
(self.world_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
fully_shard_fn = functools.partial(
|
||||
replicate,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mesh=mesh,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for module in model.modules():
|
||||
@ -527,12 +506,10 @@ class TestReplicate1DTrainingCore(FSDPTest):
|
||||
Tests parity when running a module that participates multiple
|
||||
times in forward.
|
||||
"""
|
||||
self.run_subtests(
|
||||
{"reshard_after_forward": [True, False]},
|
||||
self._test_multi_forward_module,
|
||||
)
|
||||
|
||||
def _test_multi_forward_module(self, reshard_after_forward: Union[bool, int]):
|
||||
self._test_multi_forward_module()
|
||||
|
||||
def _test_multi_forward_module(self):
|
||||
class MultiForwardModule(nn.Module):
|
||||
def __init__(self, device: torch.device):
|
||||
super().__init__()
|
||||
@ -687,7 +664,6 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
"""
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False],
|
||||
"checkpoint_impl": ["composable", "utils", "wrapper"],
|
||||
"module_grouping": ["block", "mem_eff", "mem_eff_weight_tied"],
|
||||
"test_device_type": [device_type.type],
|
||||
@ -697,7 +673,6 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
|
||||
def _test_train_parity_with_activation_checkpointing(
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
checkpoint_impl: str,
|
||||
module_grouping: str,
|
||||
test_device_type: str,
|
||||
@ -740,12 +715,11 @@ class TestReplicateTrainingCompose(FSDPTest):
|
||||
# Apply Replicate
|
||||
device_mesh = init_device_mesh(
|
||||
test_device_type,
|
||||
(self.world_size, 1),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
(self.world_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
fsdp_kwargs = {
|
||||
"reshard_after_forward": reshard_after_forward,
|
||||
"device_mesh": device_mesh,
|
||||
"mesh": device_mesh,
|
||||
}
|
||||
if module_grouping == "mem_eff":
|
||||
assert model_args.n_layers == 3
|
||||
@ -809,7 +783,6 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
def test_train_parity_with_shared_params(self):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
},
|
||||
self._test_train_shared_params,
|
||||
@ -817,7 +790,6 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
|
||||
def _test_train_shared_params(
|
||||
self,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
@ -830,8 +802,8 @@ class TestReplicateSharedParams(FSDPTest):
|
||||
if isinstance(module, TransformerBlock):
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module, reshard_after_forward=reshard_after_forward)
|
||||
replicate(model, reshard_after_forward=reshard_after_forward)
|
||||
replicate(module)
|
||||
replicate(model)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
|
||||
torch.manual_seed(42 + self.rank + 1)
|
||||
@ -868,11 +840,11 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
with/without resharding after backward.
|
||||
"""
|
||||
|
||||
shard_size, replicate_size = 1, self.world_size
|
||||
replicate_size = self.world_size
|
||||
meshes = init_device_mesh(
|
||||
device_type.type,
|
||||
(replicate_size, shard_size),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
(replicate_size,),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
self.run_subtests(
|
||||
{
|
||||
@ -928,8 +900,7 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
ref_model = copy.deepcopy(model).to(device_type)
|
||||
replicate_fn = functools.partial(
|
||||
replicate,
|
||||
device_mesh=mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
mesh=mesh,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
for mlp in model[1:]:
|
||||
@ -1040,8 +1011,8 @@ class TestReplicateGradientAccumulation(FSDPTest):
|
||||
ref_optim = torch.optim.AdamW(ref_model.parameters(), lr=1e-2)
|
||||
for module in model.modules():
|
||||
if isinstance(module, TransformerBlock):
|
||||
replicate(module, reshard_after_forward=False)
|
||||
replicate(model, reshard_after_forward=False)
|
||||
replicate(module)
|
||||
replicate(model)
|
||||
optim = torch.optim.AdamW(model.parameters(), lr=1e-2)
|
||||
|
||||
num_microbatches = 3
|
||||
@ -1145,8 +1116,8 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
def init_global_mesh(self) -> DeviceMesh:
|
||||
return init_device_mesh(
|
||||
device_type.type,
|
||||
(2, 1, 2),
|
||||
mesh_dim_names=("dp_replicate", "dp_shard", "tp"),
|
||||
(2, 2),
|
||||
mesh_dim_names=("dp_replicate", "tp"),
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(8)
|
||||
@ -1154,7 +1125,6 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
global_mesh = self.init_global_mesh()
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 5, 16, 17],
|
||||
"foreach": [False],
|
||||
@ -1165,12 +1135,11 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
def _test_replicate_tp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
foreach: bool,
|
||||
):
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate", "dp_shard"], global_mesh["tp"]
|
||||
dp_mesh, tp_mesh = global_mesh["dp_replicate"], global_mesh["tp"]
|
||||
dp_pg = dp_mesh._flatten().get_group() # used for `replicate()`
|
||||
|
||||
torch.manual_seed(42)
|
||||
@ -1197,8 +1166,8 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
continue
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(module)
|
||||
replicate(module, device_mesh=dp_mesh)
|
||||
replicate(model, device_mesh=dp_mesh)
|
||||
replicate(module, mesh=dp_mesh)
|
||||
replicate(model, mesh=dp_mesh)
|
||||
|
||||
# Checking parameters match orig model is critical to validate .full_tensor correctly replicates the
|
||||
# strided-sharded layers.
|
||||
@ -1229,11 +1198,9 @@ class TestReplicateTPTraining(FSDPTest):
|
||||
|
||||
for _, p in model.named_parameters():
|
||||
self.assertIsInstance(p, DTensor)
|
||||
self.assertEqual(p.device_mesh.ndim, 3)
|
||||
self.assertEqual(len(p.placements), 3)
|
||||
self.assertEqual(
|
||||
p.device_mesh.mesh_dim_names, ("dp_replicate", "dp_shard", "tp")
|
||||
)
|
||||
self.assertEqual(p.device_mesh.ndim, 2)
|
||||
self.assertEqual(len(p.placements), 2)
|
||||
self.assertEqual(p.device_mesh.mesh_dim_names, ("dp_replicate", "tp"))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -120,7 +120,7 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
if i % 2 == 0:
|
||||
self.assertTrue("replicate" in _get_registry(layer))
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
elif i % 2 == 1:
|
||||
self.assertTrue("fully_shard" in _get_registry(layer))
|
||||
for parameter in layer.parameters():
|
||||
@ -197,14 +197,14 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
]
|
||||
|
||||
global_mesh = self.init_replicate_tp_mesh()
|
||||
replicate_mesh = global_mesh["replicate", "shard"]
|
||||
replicate_mesh = global_mesh["replicate"]
|
||||
|
||||
for layer in layers:
|
||||
replicate(layer, device_mesh=replicate_mesh)
|
||||
replicate(layer, mesh=replicate_mesh)
|
||||
|
||||
for parameter in layer.parameters():
|
||||
self.assertEqual(parameter.device_mesh.shape, (2, 1))
|
||||
self.assertEqual(parameter.placements, (Replicate(), Shard(dim=0)))
|
||||
self.assertEqual(parameter.device_mesh.shape, (2,))
|
||||
self.assertEqual(parameter.placements, (Replicate(),))
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_train_replicate_fsdp(self):
|
||||
@ -263,7 +263,6 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
run_subtests(
|
||||
self,
|
||||
{
|
||||
"reshard_after_forward": [False, True],
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 16, 17],
|
||||
},
|
||||
@ -273,7 +272,6 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
def _test_train_parity_2d_mlp(
|
||||
self,
|
||||
global_mesh: DeviceMesh,
|
||||
reshard_after_forward: bool,
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
):
|
||||
@ -287,13 +285,12 @@ class ReplicateTest(MultiProcessTestCase):
|
||||
torch.manual_seed(42)
|
||||
model = MLPStack(mlp_dim)
|
||||
ref_model = copy.deepcopy(model).cuda()
|
||||
replicate(ref_model, device_mesh=replicate_shard_mesh)
|
||||
replicate(ref_model, mesh=replicate_mesh)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2, foreach=False)
|
||||
model.parallelize(
|
||||
tp_mesh,
|
||||
replicate_shard_mesh,
|
||||
use_activation_checkpointing,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2, foreach=False)
|
||||
|
||||
|
||||
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Optional, TYPE_CHECKING, Union
|
||||
from typing import Optional, overload, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@ -14,13 +14,12 @@ from torch.distributed.fsdp._fully_shard._fsdp_api import (
|
||||
OffloadPolicy,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_common import (
|
||||
DDPMeshInfo,
|
||||
detect_compiled_autograd,
|
||||
HSDPMeshInfo,
|
||||
)
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_init import (
|
||||
_get_device_from_mesh,
|
||||
_get_managed_states,
|
||||
_get_post_forward_mesh_info,
|
||||
_init_default_fully_shard_mesh,
|
||||
_move_states_to_device,
|
||||
)
|
||||
@ -39,12 +38,6 @@ from torch.distributed.utils import _get_root_modules
|
||||
from .contract import _get_registry, contract
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable
|
||||
|
||||
from torch.distributed.tensor import Shard
|
||||
|
||||
|
||||
cls_to_replicate_cls: dict[type, type] = {}
|
||||
|
||||
_ROOT_MODULE_PREFIX = ""
|
||||
@ -95,7 +88,7 @@ class _ReplicateState(FSDPState):
|
||||
modules: tuple[nn.Module, ...],
|
||||
device: torch.device,
|
||||
mp_policy: MixedPrecisionPolicy,
|
||||
auto_reshard_after_forward: bool,
|
||||
auto_reshard_after_forward: bool = False,
|
||||
) -> None:
|
||||
for module in modules:
|
||||
_insert_module_state(module, self)
|
||||
@ -171,8 +164,6 @@ def replicate_impl(
|
||||
mesh: DeviceMesh,
|
||||
*,
|
||||
device_id: Optional[Union[int, torch.device]] = None,
|
||||
reshard_after_forward: Optional[Union[bool, int]] = None,
|
||||
shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None,
|
||||
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
||||
offload_policy: OffloadPolicy = OffloadPolicy(),
|
||||
ignored_params: Optional[set[nn.Parameter]] = None,
|
||||
@ -184,30 +175,25 @@ def replicate_impl(
|
||||
)
|
||||
|
||||
mesh = mesh or _init_default_fully_shard_mesh()
|
||||
if mesh.ndim != 2:
|
||||
raise ValueError(f"replicate expects a 2D DeviceMesh but got {mesh}")
|
||||
if mesh.ndim != 1:
|
||||
raise ValueError(f"replicate expects a 1D DeviceMesh but got {mesh}")
|
||||
|
||||
else:
|
||||
if mesh.mesh_dim_names is None:
|
||||
raise AssertionError(
|
||||
"Please init the 2D mesh for HSDP with mesh_dim_names specified"
|
||||
)
|
||||
mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0)
|
||||
mesh_info = DDPMeshInfo(mesh, replicate_mesh_dim=0)
|
||||
device = _get_device_from_mesh(mesh)
|
||||
auto_reshard_after_forward = reshard_after_forward is None
|
||||
# If the user does not provide ``reshard_after_forward``, we set it to True.
|
||||
# During lazy_init, we identify which module is the root and override its value to False
|
||||
post_forward_mesh_info = _get_post_forward_mesh_info(
|
||||
reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type]
|
||||
mesh_info,
|
||||
)
|
||||
|
||||
post_forward_mesh_info = None
|
||||
|
||||
arg_module = module
|
||||
modules = (
|
||||
(module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module))
|
||||
)
|
||||
state = replicate.state(modules[0]) # type: ignore[attr-defined] # see [1]
|
||||
state.init(modules, device, mp_policy, auto_reshard_after_forward)
|
||||
state.init(modules, device, mp_policy)
|
||||
|
||||
managed_modules = _get_managed_modules(modules, ignored_params)
|
||||
params, buffers = _get_managed_states(managed_modules, ignored_params)
|
||||
@ -217,10 +203,10 @@ def replicate_impl(
|
||||
state._fsdp_param_group = FSDPParamGroup(
|
||||
params,
|
||||
modules,
|
||||
mesh_info,
|
||||
mesh_info, # type: ignore[arg-type]
|
||||
post_forward_mesh_info,
|
||||
device,
|
||||
shard_placement_fn,
|
||||
None,
|
||||
mp_policy,
|
||||
offload_policy,
|
||||
)
|
||||
@ -237,11 +223,39 @@ def replicate_impl(
|
||||
return arg_module
|
||||
|
||||
|
||||
@contract(state_cls=_ReplicateState)
|
||||
@overload
|
||||
# pyrefly: ignore [inconsistent-overload]
|
||||
def replicate(
|
||||
module: nn.Module,
|
||||
**kwargs,
|
||||
) -> nn.Module:
|
||||
*,
|
||||
mesh: Optional[DeviceMesh] = ...,
|
||||
mp_policy: MixedPrecisionPolicy = ...,
|
||||
offload_policy: OffloadPolicy = ...,
|
||||
ignored_params: Optional[set[nn.Parameter]] = ...,
|
||||
) -> ReplicateModule: ...
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore [inconsistent-overload]
|
||||
def replicate(
|
||||
module: list[nn.Module],
|
||||
*,
|
||||
mesh: Optional[DeviceMesh] = ...,
|
||||
mp_policy: MixedPrecisionPolicy = ...,
|
||||
offload_policy: OffloadPolicy = ...,
|
||||
ignored_params: Optional[set[nn.Parameter]] = ...,
|
||||
) -> list[ReplicateModule]: ...
|
||||
|
||||
|
||||
@contract(state_cls=_ReplicateState) # type: ignore[misc]
|
||||
def replicate(
|
||||
module: nn.Module,
|
||||
*,
|
||||
mesh: Optional[DeviceMesh] = None,
|
||||
mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(),
|
||||
offload_policy: OffloadPolicy = OffloadPolicy(),
|
||||
ignored_params: Optional[set[nn.Parameter]] = None,
|
||||
):
|
||||
r"""Replicates a module
|
||||
|
||||
Args:
|
||||
@ -253,24 +267,21 @@ def replicate(
|
||||
>>> replicate(module)
|
||||
"""
|
||||
|
||||
if "device_id" in kwargs:
|
||||
if not isinstance(kwargs["device_id"], (int, torch.device)):
|
||||
raise RuntimeError(
|
||||
"Expected device_id to be int or torch.device, "
|
||||
f"but got {type(kwargs['device_id'])}"
|
||||
)
|
||||
|
||||
if not is_composable_with_replicate(module):
|
||||
raise RuntimeError(
|
||||
"Cannot apply `replicate()` on a Module already managed by `fully_shard`"
|
||||
)
|
||||
|
||||
device_mesh = kwargs.pop("device_mesh", None)
|
||||
if device_mesh is None:
|
||||
device_mesh = replicate_mesh()
|
||||
if mesh is None:
|
||||
mesh = replicate_mesh()
|
||||
|
||||
module = replicate_impl(module, mesh=device_mesh, **kwargs)
|
||||
return module
|
||||
return replicate_impl(
|
||||
module,
|
||||
mesh,
|
||||
mp_policy=mp_policy,
|
||||
offload_policy=offload_policy,
|
||||
ignored_params=ignored_params,
|
||||
)
|
||||
|
||||
|
||||
class ReplicateModule(FSDPModule):
|
||||
@ -341,8 +352,8 @@ def replicate_mesh():
|
||||
device = torch._C._get_accelerator()
|
||||
mesh = init_device_mesh(
|
||||
device.type,
|
||||
mesh_shape=(default_pg.size(), 1),
|
||||
mesh_dim_names=("replicate", "shard"),
|
||||
mesh_shape=(default_pg.size(),),
|
||||
mesh_dim_names=("replicate",),
|
||||
)
|
||||
return mesh
|
||||
|
||||
|
||||
@ -492,7 +492,11 @@ def foreach_reduce(
|
||||
force_sum_reduction_for_comms,
|
||||
)
|
||||
)
|
||||
world_size = reduce_scatter_group.size()
|
||||
|
||||
if reduce_scatter_group is None:
|
||||
world_size = 1
|
||||
else:
|
||||
world_size = reduce_scatter_group.size()
|
||||
device_handle = _get_device_handle(device.type)
|
||||
current_stream = device_handle.current_stream()
|
||||
|
||||
@ -547,7 +551,7 @@ def foreach_reduce(
|
||||
reduce_output.copy_(reduce_scatter_input)
|
||||
reduce_scatter_event = reduce_scatter_stream.record_event()
|
||||
post_reduce_stream = reduce_scatter_stream
|
||||
if all_reduce_group is not None: # HSDP
|
||||
if all_reduce_group is not None: # HSDP or DDP/replicate
|
||||
# Accumulations must run in the reduce-scatter stream
|
||||
if not all_reduce_grads:
|
||||
if partial_reduce_output is not None:
|
||||
@ -690,7 +694,7 @@ def _get_all_gather_input_metadatas(
|
||||
|
||||
|
||||
def _get_gradient_divide_factors(
|
||||
reduce_scatter_group: dist.ProcessGroup,
|
||||
reduce_scatter_group: Optional[dist.ProcessGroup],
|
||||
all_reduce_group: Optional[dist.ProcessGroup],
|
||||
reduce_dtype: torch.dtype,
|
||||
device_type: str = "",
|
||||
@ -709,8 +713,11 @@ def _get_gradient_divide_factors(
|
||||
# For fp32/bf16, we do not need to worry about overflow/underflow, so we
|
||||
# use NCCL's built-in division to avoid separate div kernels
|
||||
overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16)
|
||||
if reduce_scatter_group is not None:
|
||||
data_parallel_size = reduce_scatter_group.size()
|
||||
else:
|
||||
data_parallel_size = 1
|
||||
|
||||
data_parallel_size = reduce_scatter_group.size()
|
||||
if all_reduce_group is not None:
|
||||
data_parallel_size *= all_reduce_group.size()
|
||||
|
||||
|
||||
@ -11,6 +11,7 @@ import torch.nn as nn
|
||||
from torch._prims_common import make_contiguous_strides_for
|
||||
from torch.distributed._functional_collectives import AsyncCollectiveTensor
|
||||
from torch.distributed.device_mesh import DeviceMesh
|
||||
from torch.distributed.fsdp._fully_shard._fsdp_common import DDPMeshInfo
|
||||
from torch.distributed.tensor import DTensor, Replicate, Shard
|
||||
from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta
|
||||
from torch.distributed.tensor.placement_types import _StridedShard, Placement
|
||||
@ -306,22 +307,29 @@ class FSDPParam:
|
||||
f"or 4 (HSDP+EP+TP) but got {self._spmd_mesh.ndim}."
|
||||
)
|
||||
self._spmd_placements: tuple[Placement, ...]
|
||||
dp_shard_tp_placement = (
|
||||
(
|
||||
_StridedShard(shard_dim, split_factor=split_factor)
|
||||
if split_factor > 1
|
||||
else fsdp_placement
|
||||
),
|
||||
*self._tp_spec.placements,
|
||||
)
|
||||
if dp_mesh.ndim == 1: # FSDP
|
||||
self._spmd_placements = dp_shard_tp_placement
|
||||
else: # HSDP
|
||||
if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP
|
||||
dp_shard_tp_placement = (
|
||||
(
|
||||
_StridedShard(shard_dim, split_factor=split_factor)
|
||||
if split_factor > 1
|
||||
else fsdp_placement
|
||||
),
|
||||
*self._tp_spec.placements,
|
||||
)
|
||||
else: # DDP
|
||||
dp_shard_tp_placement = (
|
||||
(Replicate()),
|
||||
*self._tp_spec.placements,
|
||||
)
|
||||
if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP
|
||||
if self.mesh_info.replicate_mesh_dim != 0:
|
||||
raise AssertionError(
|
||||
f"Expected replicate_mesh_dim to be 0, got {self.mesh_info.replicate_mesh_dim}"
|
||||
)
|
||||
self._spmd_placements = (Replicate(),) + dp_shard_tp_placement
|
||||
else: # FSDP or DDP
|
||||
self._spmd_placements = dp_shard_tp_placement
|
||||
|
||||
self._sharding_spec = DTensorSpec(
|
||||
self._spmd_mesh,
|
||||
self._spmd_placements,
|
||||
@ -330,10 +338,12 @@ class FSDPParam:
|
||||
param_data = cast(DTensor, param)._local_tensor
|
||||
else:
|
||||
self._spmd_mesh = self.mesh_info.mesh
|
||||
if isinstance(self.mesh_info, HSDPMeshInfo):
|
||||
if isinstance(self.mesh_info, HSDPMeshInfo): # HSDP
|
||||
self._spmd_placements = (Replicate(), fsdp_placement)
|
||||
else:
|
||||
elif isinstance(self.mesh_info, FSDPMeshInfo): # FSDP
|
||||
self._spmd_placements = (fsdp_placement,)
|
||||
elif isinstance(self.mesh_info, DDPMeshInfo): # DDP
|
||||
self._spmd_placements = (Replicate(),)
|
||||
self._sharding_spec = DTensorSpec(
|
||||
self._spmd_mesh,
|
||||
self._spmd_placements,
|
||||
@ -351,8 +361,13 @@ class FSDPParam:
|
||||
)
|
||||
self._orig_size = param_data.size()
|
||||
self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size)
|
||||
shard_rank = self.mesh_info.shard_mesh_rank
|
||||
shard_world_size = self.mesh_info.shard_mesh_size
|
||||
if isinstance(self.mesh_info, FSDPMeshInfo): # FSDP or HSDP
|
||||
shard_rank = self.mesh_info.shard_mesh_rank
|
||||
shard_world_size = self.mesh_info.shard_mesh_size
|
||||
else: # DDP
|
||||
shard_rank = 0
|
||||
shard_world_size = 1
|
||||
|
||||
if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0:
|
||||
# If sharding on nonzero dim, require even sharding for now because
|
||||
# the uneven sharding (1) requires extra copies before/after FSDP
|
||||
@ -401,12 +416,20 @@ class FSDPParam:
|
||||
if mesh_info is None:
|
||||
raise AssertionError("Expected post_forward_mesh_info to not be None")
|
||||
param_data = param._local_tensor if isinstance(param, DTensor) else param
|
||||
chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
|
||||
self.sharded_post_forward_size = _get_dim_chunked_size(
|
||||
chunks[mesh_info.shard_mesh_rank],
|
||||
param_data.size(),
|
||||
dim=self.fsdp_placement.dim,
|
||||
)
|
||||
if isinstance(mesh_info, FSDPMeshInfo):
|
||||
chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0)
|
||||
self.sharded_post_forward_size = _get_dim_chunked_size(
|
||||
chunks[mesh_info.shard_mesh_rank],
|
||||
param_data.size(),
|
||||
dim=self.fsdp_placement.dim,
|
||||
)
|
||||
else: # DDP
|
||||
chunks = _chunk_with_empty(param_data, 1, dim=0)
|
||||
self.sharded_post_forward_size = _get_dim_chunked_size(
|
||||
chunks[0],
|
||||
param_data.size(),
|
||||
dim=self.fsdp_placement.dim,
|
||||
)
|
||||
self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for(
|
||||
self.sharded_post_forward_size
|
||||
)
|
||||
|
||||
@ -29,6 +29,7 @@ from ._fsdp_collectives import (
|
||||
)
|
||||
from ._fsdp_common import (
|
||||
compiled_autograd_enabled,
|
||||
DDPMeshInfo,
|
||||
FSDPMeshInfo,
|
||||
HSDPMeshInfo,
|
||||
is_bw,
|
||||
@ -315,7 +316,10 @@ class FSDPParamGroup:
|
||||
self._wait_all_gather_streams_on_event(self._reshard_after_forward_event)
|
||||
self._reshard_after_forward_event = None
|
||||
|
||||
world_size = self._all_gather_process_group.size()
|
||||
if isinstance(self.mesh_info, FSDPMeshInfo):
|
||||
world_size = self._all_gather_process_group.size()
|
||||
else:
|
||||
world_size = 1
|
||||
if world_size == 1:
|
||||
# can't skip due to early return in wait_for_unshard if
|
||||
# no self._all_gather_result
|
||||
@ -356,7 +360,10 @@ class FSDPParamGroup:
|
||||
if prev_all_gather_state := self.comm_ctx.all_gather_state:
|
||||
self._wait_all_gather_streams_on_event(prev_all_gather_state.event)
|
||||
self.comm_ctx.all_gather_state = None # free the all-gather result
|
||||
world_size = self._all_gather_process_group.size()
|
||||
if isinstance(self.mesh_info, FSDPMeshInfo):
|
||||
world_size = self._all_gather_process_group.size()
|
||||
else:
|
||||
world_size = 1
|
||||
if world_size == 1:
|
||||
# directly initialize unsharded parameters from sharded parameters
|
||||
|
||||
@ -531,7 +538,11 @@ class FSDPParamGroup:
|
||||
self.comm_ctx.reduce_scatter_state.event
|
||||
)
|
||||
self.comm_ctx.reduce_scatter_state = None
|
||||
all_reduce_pg = self._all_reduce_process_group if self._is_hsdp else None
|
||||
all_reduce_pg = (
|
||||
self._all_reduce_process_group
|
||||
if isinstance(self.mesh_info, DDPMeshInfo)
|
||||
else None
|
||||
)
|
||||
all_reduce_stream: torch.cuda.Stream
|
||||
if all_reduce_pg is None and self._all_reduce_hook_stream is not None:
|
||||
# this means the native HSDP is not enabled,
|
||||
@ -555,14 +566,22 @@ class FSDPParamGroup:
|
||||
) = foreach_reduce(
|
||||
fsdp_params_with_grad,
|
||||
unsharded_grads,
|
||||
self._reduce_scatter_process_group,
|
||||
(
|
||||
self._reduce_scatter_process_group
|
||||
if isinstance(self.mesh_info, FSDPMeshInfo)
|
||||
else None
|
||||
),
|
||||
self.comm_ctx.reduce_scatter_stream,
|
||||
self._reduce_scatter_comm,
|
||||
self._orig_dtype,
|
||||
self._reduce_dtype,
|
||||
self.device,
|
||||
self.gradient_divide_factor,
|
||||
self._all_reduce_process_group if self._is_hsdp else None,
|
||||
(
|
||||
self._all_reduce_process_group
|
||||
if isinstance(self.mesh_info, DDPMeshInfo)
|
||||
else None
|
||||
),
|
||||
all_reduce_stream,
|
||||
self.all_reduce_grads,
|
||||
self._partial_reduce_output,
|
||||
@ -776,9 +795,9 @@ class FSDPParamGroup:
|
||||
|
||||
@property
|
||||
def _all_reduce_process_group(self) -> dist.ProcessGroup:
|
||||
if not isinstance(self.mesh_info, HSDPMeshInfo):
|
||||
if not isinstance(self.mesh_info, DDPMeshInfo):
|
||||
raise AssertionError(
|
||||
f"Expected mesh_info to be HSDPMeshInfo, got {type(self.mesh_info)}"
|
||||
f"Expected mesh_info to be DDPMeshInfo or HSDPMeshInfo, got {type(self.mesh_info)}"
|
||||
)
|
||||
return self.mesh_info.replicate_process_group
|
||||
|
||||
|
||||
Reference in New Issue
Block a user