[dtensor] group all dynamo tests together (#107487)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107487
Approved by: https://github.com/fduwjj
ghstack dependencies: #107472, #107473
This commit is contained in:
Wanchao Liang
2023-08-21 10:18:08 -07:00
committed by PyTorch MergeBot
parent 42f25d49f8
commit 9c2b4a35a3
3 changed files with 105 additions and 133 deletions

View File

@ -16,7 +16,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
from torch.testing._internal.distributed.fake_pg import FakeStore
class DummyMLP(torch.nn.Module):
@ -663,94 +662,5 @@ class TestDTensorPlacementTypes(DTensorTestBase):
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
class TestDynamoDTensor(torch._dynamo.test_case.TestCase):
def setUp(self):
super().setUp()
fake_store = FakeStore()
dist.init_process_group(
"fake", store=fake_store, rank=0, world_size=self.world_size
)
def tearDown(self):
super().tearDown()
dist.destroy_process_group()
@property
def device_type(self) -> str:
return "cuda" if torch.cuda.is_available() else "cpu"
@property
def world_size(self) -> int:
return 2
def test_fakify_dtensor(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# pass in DTensor as inputs/outputs to the function
def fn(x):
return x
x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(res, ref)
def test_dynamo_dtensor(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# test passing in DTensor as inputs/outputs and run some tensor computation
def fn(x):
return x * x + 2
x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(res, ref)
def test_dynamo_dtensor_from_local(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# create DTensor inside fn and run some compute
def fn(x):
dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
return dt.to_local() + 2
# below is the op approach for reference
# from torch.distributed._tensor.api import _FromTorchTensor
# def from_local_tensor(x):
# return _FromTorchTensor.apply(x, mesh, [Replicate()], False)
# _dt_lib_def = torch.library.Library("dtensor", "DEF")
# _dt_lib_def.define("from_local(Tensor self) -> Tensor")
# _dt_lib_impl = torch.library.Library("dtensor", "IMPL")
# _dt_lib_impl.impl("from_local", from_local_tensor, "Autograd")
x = torch.ones(1)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(res, ref)
def test_dynamo_dtensor_from_local_redistribute(self):
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
# pass in tensor as inputs/outputs, create DTensor and run redistribute
# (allgather collective) inside the fn
def fn(x):
dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
return dt.redistribute(mesh, [Replicate()]).to_local() + 2
x = torch.ones(1)
ref = fn(x)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
res = opt_fn(x)
self.assertEqual(res, ref)
if __name__ == "__main__":
run_tests()