mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
DTensor: add comm tests to test_tp_examples (#121669)
This adds some basic comm tests to test_tp_examples. This validates that the expected distributed calls are being made for `test_transformer_training`. Fixes #121649 Test plan: ``` pytest test/distributed/tensor/parallel/test_tp_examples.py -k test_transformer_training ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/121669 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
02083f5452
commit
e4fda049c2
@ -210,18 +210,48 @@ class DistTensorParallelExampleTest(DTensorTestBase):
|
||||
|
||||
# Compare outputs on the same input.
|
||||
output = model(inp)
|
||||
output_tp = model_tp(inp)
|
||||
with CommDebugMode() as comm_mode:
|
||||
output_tp = model_tp(inp)
|
||||
self.assertEqual(output, output_tp)
|
||||
if is_seq_parallel:
|
||||
self.assertDictEqual(comm_mode.get_comm_counts(), {
|
||||
c10d_functional.all_reduce: 1,
|
||||
c10d_functional.reduce_scatter_tensor: 4,
|
||||
c10d_functional.all_gather_into_tensor: 7,
|
||||
})
|
||||
else:
|
||||
self.assertDictEqual(comm_mode.get_comm_counts(), {
|
||||
c10d_functional.all_reduce: 5,
|
||||
c10d_functional.all_gather_into_tensor: 2,
|
||||
})
|
||||
|
||||
# Ensure gradients are equal.
|
||||
output.sum().backward()
|
||||
output_tp.sum().backward()
|
||||
with CommDebugMode() as comm_mode:
|
||||
output_tp.sum().backward()
|
||||
self._check_module(model, model_tp, check_grad=True)
|
||||
if is_seq_parallel:
|
||||
self.assertDictEqual(comm_mode.get_comm_counts(), {
|
||||
c10d_functional.reduce_scatter_tensor: 4,
|
||||
c10d_functional.all_gather_into_tensor: 7,
|
||||
})
|
||||
else:
|
||||
self.assertDictEqual(comm_mode.get_comm_counts(), {
|
||||
c10d_functional.all_reduce: 8,
|
||||
c10d_functional.all_gather_into_tensor: 1,
|
||||
})
|
||||
|
||||
# Ensure model weights are still the same after update.
|
||||
optim.step()
|
||||
optim_tp.step()
|
||||
with CommDebugMode() as comm_mode:
|
||||
optim_tp.step()
|
||||
self._check_module(model, model_tp)
|
||||
if is_seq_parallel:
|
||||
self.assertDictEqual(comm_mode.get_comm_counts(), {
|
||||
c10d_functional.all_reduce: 30,
|
||||
})
|
||||
else:
|
||||
self.assertDictEqual(comm_mode.get_comm_counts(), {})
|
||||
|
||||
# Compare outputs on another input.
|
||||
torch.manual_seed(11)
|
||||
|
Reference in New Issue
Block a user