[DDP][Compile] Test to Ensure torch.compile works w/static_graph=True (#114621)

Resolves https://github.com/pytorch/pytorch/issues/93672. This was
actually fixed by https://github.com/pytorch/pytorch/pull/103487 but I didn't
realize that PR also fixes torch compile at the time.

Differential Revision: [D51596148](https://our.internmc.facebook.com/intern/diff/D51596148/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114621
Approved by: https://github.com/wconstab
This commit is contained in:
Rohan Varma
2023-11-30 17:29:02 -08:00
committed by PyTorch MergeBot
parent 6e495eef60
commit 3c78ea4c9d
3 changed files with 57 additions and 9 deletions

View File

@ -10072,5 +10072,37 @@ class DistributedTest:
model, device_mesh=device_mesh
)
@skip_if_lt_x_gpu(2)
@require_world_size(2)
@skip_but_pass_in_sandcastle_if(
BACKEND not in DistTestCases.backend_feature["ddp"],
f"The {BACKEND} backend does not support DistributedDataParallel",
)
def test_ddp_compile_static_graph(self):
"Tests that DDP works with torch compile when static_graph=True"
model = torch.nn.Linear(10, 10).cuda(self.rank)
model_clone = copy.deepcopy(model)
ddp = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[self.rank],
)
ddp_static = torch.nn.parallel.DistributedDataParallel(
model_clone,
device_ids=[self.rank],
static_graph=True
)
ddp = torch.compile(ddp)
ddp_static = torch.compile(ddp_static)
input = torch.rand(10, 10).cuda(self.rank)
# verify output and gradient parity
for _ in range(6):
out_ddp = ddp(input).sum()
out_ddp_static = ddp_static(input).sum()
self.assertEqual(out_ddp, out_ddp_static)
out_ddp.backward()
out_ddp_static.backward()
for p1, p2 in zip(ddp.parameters(), ddp_static.parameters()):
self.assertEqual(p1.grad, p2.grad)
instantiate_parametrized_tests(DistributedTest._DistTestBase)