further deprecate PairwiseParallel and SequenceParallel from test (#114402)

**Remaining Issue**
When replace SequenceParallel, tests would pass even setting `input_layouts=Replicate()`. Still looking into it...

**Summary**
This is a follow-up PR to #114314.

**Test Plan**
`python test_files.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114402
Approved by: https://github.com/wanchaol
This commit is contained in:
Tianyu Liu
2023-11-29 15:41:52 -08:00
committed by PyTorch MergeBot
parent c1e51fcbfc
commit 8ae3835323
6 changed files with 81 additions and 24 deletions

View File

@ -14,7 +14,11 @@ from torch.distributed._tensor import (
init_device_mesh,
)
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module
from torch.distributed.tensor.parallel import (
ColwiseParallel,
parallelize_module,
RowwiseParallel,
)
from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
@ -109,13 +113,17 @@ class DTensorTest(DTensorTestBase):
def test_modules_w_meta_dtensor(self):
model = DummyMLP("meta")
device_mesh = self.build_device_mesh()
model_tp = parallelize_module(model, device_mesh, PairwiseParallel())
parallelize_plan = {
"net1": ColwiseParallel(),
"net2": RowwiseParallel(),
}
model_tp = parallelize_module(model, device_mesh, parallelize_plan)
model_tp.to_empty(device=self.device_type)
model_tp.reset_parameters()
optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)
model_regular = DummyMLP(self.device_type)
model_regular_tp = parallelize_module(
model_regular, device_mesh, PairwiseParallel()
model_regular, device_mesh, parallelize_plan
)
optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1)
model_regular_tp.reset_parameters()