mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
further deprecate PairwiseParallel and SequenceParallel from test (#114402)
**Remaining Issue** When replace SequenceParallel, tests would pass even setting `input_layouts=Replicate()`. Still looking into it... **Summary** This is a follow-up PR to #114314. **Test Plan** `python test_files.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/114402 Approved by: https://github.com/wanchaol
This commit is contained in:
committed by
PyTorch MergeBot
parent
c1e51fcbfc
commit
8ae3835323
@ -14,7 +14,11 @@ from torch.distributed._tensor import (
|
||||
init_device_mesh,
|
||||
)
|
||||
from torch.distributed._tensor.placement_types import _Partial, Replicate, Shard
|
||||
from torch.distributed.tensor.parallel import PairwiseParallel, parallelize_module
|
||||
from torch.distributed.tensor.parallel import (
|
||||
ColwiseParallel,
|
||||
parallelize_module,
|
||||
RowwiseParallel,
|
||||
)
|
||||
|
||||
from torch.testing._internal.common_utils import run_tests
|
||||
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
@ -109,13 +113,17 @@ class DTensorTest(DTensorTestBase):
|
||||
def test_modules_w_meta_dtensor(self):
|
||||
model = DummyMLP("meta")
|
||||
device_mesh = self.build_device_mesh()
|
||||
model_tp = parallelize_module(model, device_mesh, PairwiseParallel())
|
||||
parallelize_plan = {
|
||||
"net1": ColwiseParallel(),
|
||||
"net2": RowwiseParallel(),
|
||||
}
|
||||
model_tp = parallelize_module(model, device_mesh, parallelize_plan)
|
||||
model_tp.to_empty(device=self.device_type)
|
||||
model_tp.reset_parameters()
|
||||
optim = torch.optim.SGD(model_tp.parameters(), lr=0.1)
|
||||
model_regular = DummyMLP(self.device_type)
|
||||
model_regular_tp = parallelize_module(
|
||||
model_regular, device_mesh, PairwiseParallel()
|
||||
model_regular, device_mesh, parallelize_plan
|
||||
)
|
||||
optim_regular = torch.optim.SGD(model_regular_tp.parameters(), lr=0.1)
|
||||
model_regular_tp.reset_parameters()
|
||||
|
Reference in New Issue
Block a user