[FSDP] Use post_reduce_stream.record_event() on hsdp+cpuoffload (#160481)

Fixes https://github.com/pytorch/pytorch/issues/160291
`post_reduce_stream` is `all_reduce_stream` during HSDP, but CPU-GPU sync is hard coded to `reduce_scatter_stream`
The hard-code could fail unit test on HSDP+CPU offload, add unit test here.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160481
Approved by: https://github.com/weifengpy
This commit is contained in:
mori360
2025-08-19 02:20:12 +00:00
committed by PyTorch MergeBot
parent 3d126e17e0
commit e6e45e6ae8
2 changed files with 23 additions and 12 deletions

View File

@ -335,7 +335,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
self.run_subtests(
{
"reshard_after_forward": [True, False, 2],
"device_type": [device_type.type],
"test_device_type": [device_type.type],
"offload_policy": [OffloadPolicy()],
"delay_after_forward": [False, True],
"delay_before_all_gather": [False, True],
@ -360,7 +360,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
CPUOffloadPolicy(pin_memory=True),
CPUOffloadPolicy(pin_memory=False),
],
"device_type": [device_type.type],
"test_device_type": [device_type.type],
"delay_after_forward": [False, True],
"delay_before_all_gather": [False, True],
"delay_before_reduce_scatter": [False, True],
@ -381,7 +381,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
self.run_subtests(
{
"reshard_after_forward": [True],
"device_type": [device_type.type],
"test_device_type": [device_type.type],
"offload_policy": [OffloadPolicy()],
"delay_after_forward": [False, True],
"delay_before_all_gather": [False, True],
@ -396,7 +396,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
self,
reshard_after_forward: Union[bool, int],
offload_policy: OffloadPolicy,
device_type: str,
test_device_type: str,
delay_after_forward: bool,
delay_before_all_gather: bool,
delay_before_reduce_scatter: bool,
@ -412,7 +412,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
in (2, 3)
):
return
assert device_type in ("cuda", "hpu", "xpu", "cpu"), f"{device_type}"
assert test_device_type in ("cuda", "hpu", "xpu", "cpu"), f"{test_device_type}"
torch.manual_seed(42)
vocab_size = 1024
model_args = ModelArgs(
@ -424,7 +424,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
)
model = Transformer(model_args)
ref_model = copy.deepcopy(model)
if device_type == device_type:
if test_device_type == device_type.type:
replicate(
ref_model.to(device_type),
device_ids=[self.rank],
@ -433,7 +433,7 @@ class TestFullyShard1DTrainingCore(FSDPTest):
gloo_pg = dist.new_group(backend="gloo")
replicate(ref_model, process_group=gloo_pg)
ref_optim = torch.optim.Adam(ref_model.parameters(), lr=1e-2)
mesh = init_device_mesh(device_type, (self.world_size,))
mesh = init_device_mesh(test_device_type, (self.world_size,))
fully_shard_fn = functools.partial(
fully_shard,
mesh=mesh,
@ -483,12 +483,12 @@ class TestFullyShard1DTrainingCore(FSDPTest):
_optim.zero_grad(set_to_none=(iter_idx % 2 == 0))
losses.append(_model(inp).sum())
if _model is model and delay_after_forward:
torch.get_device_module(device_type)._sleep(
torch.get_device_module(test_device_type)._sleep(
int(delay_in_ms * get_cycles_per_ms())
)
losses[-1].backward()
if _model is model and delay_before_optim:
torch.get_device_module(device_type)._sleep(
torch.get_device_module(test_device_type)._sleep(
int(delay_in_ms * get_cycles_per_ms())
)
_optim.step()
@ -1360,6 +1360,10 @@ class TestFullyShardHSDPTraining(FSDPTest):
"use_activation_checkpointing": [False, True],
"mlp_dim": [3, 16, 17],
"sync_gradients_at_last_batch": [True, False],
"offload_policy": [
CPUOffloadPolicy(pin_memory=True),
CPUOffloadPolicy(pin_memory=False),
],
},
functools.partial(self._test_train_parity_hsdp, global_mesh),
)
@ -1371,6 +1375,7 @@ class TestFullyShardHSDPTraining(FSDPTest):
use_activation_checkpointing: bool,
mlp_dim: int,
sync_gradients_at_last_batch: bool,
offload_policy: CPUOffloadPolicy,
):
torch.manual_seed(42)
model = nn.Sequential(
@ -1389,10 +1394,16 @@ class TestFullyShardHSDPTraining(FSDPTest):
if use_activation_checkpointing:
checkpoint(mlp)
fully_shard(
mlp, mesh=global_mesh, reshard_after_forward=reshard_after_forward
mlp,
mesh=global_mesh,
reshard_after_forward=reshard_after_forward,
offload_policy=offload_policy,
)
fully_shard(
model, mesh=global_mesh, reshard_after_forward=reshard_after_forward
model,
mesh=global_mesh,
reshard_after_forward=reshard_after_forward,
offload_policy=offload_policy,
)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
check_sharded_parity(self, ref_model, model)

View File

@ -628,7 +628,7 @@ def foreach_reduce(
if non_blocking:
# Record an event on which to block the CPU thread to
# ensure that the D2H copy finishes before the optimizer
fsdp_param.grad_offload_event = reduce_scatter_stream.record_event()
fsdp_param.grad_offload_event = post_reduce_stream.record_event()
if to_accumulate_grad:
assert isinstance(fsdp_param.sharded_param.grad, DTensor)
fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad