DTensor: add comm tests to test_tp_examples (#121669)

This adds some basic comm tests to test_tp_examples. This validates that the expected distributed calls are being made for `test_transformer_training`.

Fixes #121649

Test plan:

```
pytest test/distributed/tensor/parallel/test_tp_examples.py -k test_transformer_training
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121669
Approved by: https://github.com/wanchaol
This commit is contained in:
Tristan Rice
2024-03-15 03:37:45 +00:00
committed by PyTorch MergeBot
parent 02083f5452
commit e4fda049c2

View File

@ -210,18 +210,48 @@ class DistTensorParallelExampleTest(DTensorTestBase):
# Compare outputs on the same input.
output = model(inp)
output_tp = model_tp(inp)
with CommDebugMode() as comm_mode:
output_tp = model_tp(inp)
self.assertEqual(output, output_tp)
if is_seq_parallel:
self.assertDictEqual(comm_mode.get_comm_counts(), {
c10d_functional.all_reduce: 1,
c10d_functional.reduce_scatter_tensor: 4,
c10d_functional.all_gather_into_tensor: 7,
})
else:
self.assertDictEqual(comm_mode.get_comm_counts(), {
c10d_functional.all_reduce: 5,
c10d_functional.all_gather_into_tensor: 2,
})
# Ensure gradients are equal.
output.sum().backward()
output_tp.sum().backward()
with CommDebugMode() as comm_mode:
output_tp.sum().backward()
self._check_module(model, model_tp, check_grad=True)
if is_seq_parallel:
self.assertDictEqual(comm_mode.get_comm_counts(), {
c10d_functional.reduce_scatter_tensor: 4,
c10d_functional.all_gather_into_tensor: 7,
})
else:
self.assertDictEqual(comm_mode.get_comm_counts(), {
c10d_functional.all_reduce: 8,
c10d_functional.all_gather_into_tensor: 1,
})
# Ensure model weights are still the same after update.
optim.step()
optim_tp.step()
with CommDebugMode() as comm_mode:
optim_tp.step()
self._check_module(model, model_tp)
if is_seq_parallel:
self.assertDictEqual(comm_mode.get_comm_counts(), {
c10d_functional.all_reduce: 30,
})
else:
self.assertDictEqual(comm_mode.get_comm_counts(), {})
# Compare outputs on another input.
torch.manual_seed(11)