mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
[FSDP] Use post_reduce_stream.record_event() on hsdp+cpuoffload (#160481)
Fixes https://github.com/pytorch/pytorch/issues/160291 `post_reduce_stream` is `all_reduce_stream` during HSDP, but CPU-GPU sync is hard coded to `reduce_scatter_stream` The hard-code could fail unit test on HSDP+CPU offload, add unit test here. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160481 Approved by: https://github.com/weifengpy
This commit is contained in:
committed by
PyTorch MergeBot
parent
3d126e17e0
commit
e6e45e6ae8
@ -335,7 +335,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True, False, 2],
|
||||
"device_type": [device_type.type],
|
||||
"test_device_type": [device_type.type],
|
||||
"offload_policy": [OffloadPolicy()],
|
||||
"delay_after_forward": [False, True],
|
||||
"delay_before_all_gather": [False, True],
|
||||
@ -360,7 +360,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
CPUOffloadPolicy(pin_memory=True),
|
||||
CPUOffloadPolicy(pin_memory=False),
|
||||
],
|
||||
"device_type": [device_type.type],
|
||||
"test_device_type": [device_type.type],
|
||||
"delay_after_forward": [False, True],
|
||||
"delay_before_all_gather": [False, True],
|
||||
"delay_before_reduce_scatter": [False, True],
|
||||
@ -381,7 +381,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
self.run_subtests(
|
||||
{
|
||||
"reshard_after_forward": [True],
|
||||
"device_type": [device_type.type],
|
||||
"test_device_type": [device_type.type],
|
||||
"offload_policy": [OffloadPolicy()],
|
||||
"delay_after_forward": [False, True],
|
||||
"delay_before_all_gather": [False, True],
|
||||
@ -396,7 +396,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
self,
|
||||
reshard_after_forward: Union[bool, int],
|
||||
offload_policy: OffloadPolicy,
|
||||
device_type: str,
|
||||
test_device_type: str,
|
||||
delay_after_forward: bool,
|
||||
delay_before_all_gather: bool,
|
||||
delay_before_reduce_scatter: bool,
|
||||
@ -412,7 +412,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
in (2, 3)
|
||||
):
|
||||
return
|
||||
assert device_type in ("cuda", "hpu", "xpu", "cpu"), f"{device_type}"
|
||||
assert test_device_type in ("cuda", "hpu", "xpu", "cpu"), f"{test_device_type}"
|
||||
torch.manual_seed(42)
|
||||
vocab_size = 1024
|
||||
model_args = ModelArgs(
|
||||
@ -424,7 +424,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
)
|
||||
model = Transformer(model_args)
|
||||
ref_model = copy.deepcopy(model)
|
||||
if device_type == device_type:
|
||||
if test_device_type == device_type.type:
|
||||
replicate(
|
||||
ref_model.to(device_type),
|
||||
device_ids=[self.rank],
|
||||
@ -433,7 +433,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
gloo_pg = dist.new_group(backend="gloo")
|
||||
replicate(ref_model, process_group=gloo_pg)
|
||||
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
|
||||
mesh = init_device_mesh(device_type, (self.world_size,))
|
||||
mesh = init_device_mesh(test_device_type, (self.world_size,))
|
||||
fully_shard_fn = functools.partial(
|
||||
fully_shard,
|
||||
mesh=mesh,
|
||||
@ -483,12 +483,12 @@ class TestFullyShard1DTrainingCore(FSDPTest):
|
||||
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
|
||||
losses.append(_model(inp).sum())
|
||||
if _model is model and delay_after_forward:
|
||||
torch.get_device_module(device_type)._sleep(
|
||||
torch.get_device_module(test_device_type)._sleep(
|
||||
int(delay_in_ms * get_cycles_per_ms())
|
||||
)
|
||||
losses[-1].backward()
|
||||
if _model is model and delay_before_optim:
|
||||
torch.get_device_module(device_type)._sleep(
|
||||
torch.get_device_module(test_device_type)._sleep(
|
||||
int(delay_in_ms * get_cycles_per_ms())
|
||||
)
|
||||
_optim.step()
|
||||
@ -1360,6 +1360,10 @@ class TestFullyShardHSDPTraining(FSDPTest):
|
||||
"use_activation_checkpointing": [False, True],
|
||||
"mlp_dim": [3, 16, 17],
|
||||
"sync_gradients_at_last_batch": [True, False],
|
||||
"offload_policy": [
|
||||
CPUOffloadPolicy(pin_memory=True),
|
||||
CPUOffloadPolicy(pin_memory=False),
|
||||
],
|
||||
},
|
||||
functools.partial(self._test_train_parity_hsdp, global_mesh),
|
||||
)
|
||||
@ -1371,6 +1375,7 @@ class TestFullyShardHSDPTraining(FSDPTest):
|
||||
use_activation_checkpointing: bool,
|
||||
mlp_dim: int,
|
||||
sync_gradients_at_last_batch: bool,
|
||||
offload_policy: CPUOffloadPolicy,
|
||||
):
|
||||
torch.manual_seed(42)
|
||||
model = nn.Sequential(
|
||||
@ -1389,10 +1394,16 @@ class TestFullyShardHSDPTraining(FSDPTest):
|
||||
if use_activation_checkpointing:
|
||||
checkpoint(mlp)
|
||||
fully_shard(
|
||||
mlp, mesh=global_mesh, reshard_after_forward=reshard_after_forward
|
||||
mlp,
|
||||
mesh=global_mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
fully_shard(
|
||||
model, mesh=global_mesh, reshard_after_forward=reshard_after_forward
|
||||
model,
|
||||
mesh=global_mesh,
|
||||
reshard_after_forward=reshard_after_forward,
|
||||
offload_policy=offload_policy,
|
||||
)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
|
||||
check_sharded_parity(self, ref_model, model)
|
||||
|
@ -628,7 +628,7 @@ def foreach_reduce(
|
||||
if non_blocking:
|
||||
# Record an event on which to block the CPU thread to
|
||||
# ensure that the D2H copy finishes before the optimizer
|
||||
fsdp_param.grad_offload_event = reduce_scatter_stream.record_event()
|
||||
fsdp_param.grad_offload_event = post_reduce_stream.record_event()
|
||||
if to_accumulate_grad:
|
||||
assert isinstance(fsdp_param.sharded_param.grad, DTensor)
|
||||
fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad
|
||||
|
Reference in New Issue
Block a user