mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[DDP][Compile] Test to Ensure torch.compile works w/static_graph=True (#114621)
Resolves https://github.com/pytorch/pytorch/issues/93672. This was actually fixed by https://github.com/pytorch/pytorch/pull/103487 but I didn't realize that PR also fixes torch compile at the time. Differential Revision: [D51596148](https://our.internmc.facebook.com/intern/diff/D51596148/) Pull Request resolved: https://github.com/pytorch/pytorch/pull/114621 Approved by: https://github.com/wconstab
This commit is contained in:
committed by
PyTorch MergeBot
parent
6e495eef60
commit
3c78ea4c9d
@ -10072,5 +10072,37 @@ class DistributedTest:
|
||||
model, device_mesh=device_mesh
|
||||
)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@require_world_size(2)
|
||||
@skip_but_pass_in_sandcastle_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
f"The {BACKEND} backend does not support DistributedDataParallel",
|
||||
)
|
||||
def test_ddp_compile_static_graph(self):
|
||||
"Tests that DDP works with torch compile when static_graph=True"
|
||||
model = torch.nn.Linear(10, 10).cuda(self.rank)
|
||||
model_clone = copy.deepcopy(model)
|
||||
ddp = torch.nn.parallel.DistributedDataParallel(
|
||||
model,
|
||||
device_ids=[self.rank],
|
||||
)
|
||||
ddp_static = torch.nn.parallel.DistributedDataParallel(
|
||||
model_clone,
|
||||
device_ids=[self.rank],
|
||||
static_graph=True
|
||||
)
|
||||
ddp = torch.compile(ddp)
|
||||
ddp_static = torch.compile(ddp_static)
|
||||
input = torch.rand(10, 10).cuda(self.rank)
|
||||
# verify output and gradient parity
|
||||
for _ in range(6):
|
||||
out_ddp = ddp(input).sum()
|
||||
out_ddp_static = ddp_static(input).sum()
|
||||
self.assertEqual(out_ddp, out_ddp_static)
|
||||
out_ddp.backward()
|
||||
out_ddp_static.backward()
|
||||
for p1, p2 in zip(ddp.parameters(), ddp_static.parameters()):
|
||||
self.assertEqual(p1.grad, p2.grad)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(DistributedTest._DistTestBase)
|
||||
|
Reference in New Issue
Block a user