mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dtensor] group all dynamo tests together (#107487)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107487 Approved by: https://github.com/fduwjj ghstack dependencies: #107472, #107473
This commit is contained in:
committed by
PyTorch MergeBot
parent
42f25d49f8
commit
9c2b4a35a3
@ -16,7 +16,6 @@ from torch.testing._internal.distributed._tensor.common_dtensor import (
|
||||
DTensorTestBase,
|
||||
with_comms,
|
||||
)
|
||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||
|
||||
|
||||
class DummyMLP(torch.nn.Module):
|
||||
@ -663,94 +662,5 @@ class TestDTensorPlacementTypes(DTensorTestBase):
|
||||
assert_array_equal(expected_is_tensor_empty, is_tensor_empty)
|
||||
|
||||
|
||||
class TestDynamoDTensor(torch._dynamo.test_case.TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
fake_store = FakeStore()
|
||||
dist.init_process_group(
|
||||
"fake", store=fake_store, rank=0, world_size=self.world_size
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
dist.destroy_process_group()
|
||||
|
||||
@property
|
||||
def device_type(self) -> str:
|
||||
return "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return 2
|
||||
|
||||
def test_fakify_dtensor(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
# pass in DTensor as inputs/outputs to the function
|
||||
def fn(x):
|
||||
return x
|
||||
|
||||
x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_dynamo_dtensor(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
# test passing in DTensor as inputs/outputs and run some tensor computation
|
||||
def fn(x):
|
||||
return x * x + 2
|
||||
|
||||
x = DTensor.from_local(torch.rand(1), mesh, [Shard(0)], run_check=False)
|
||||
ref = fn(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_dynamo_dtensor_from_local(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
# create DTensor inside fn and run some compute
|
||||
def fn(x):
|
||||
dt = DTensor.from_local(x, mesh, [Replicate()], run_check=False)
|
||||
return dt.to_local() + 2
|
||||
|
||||
# below is the op approach for reference
|
||||
# from torch.distributed._tensor.api import _FromTorchTensor
|
||||
# def from_local_tensor(x):
|
||||
# return _FromTorchTensor.apply(x, mesh, [Replicate()], False)
|
||||
|
||||
# _dt_lib_def = torch.library.Library("dtensor", "DEF")
|
||||
# _dt_lib_def.define("from_local(Tensor self) -> Tensor")
|
||||
|
||||
# _dt_lib_impl = torch.library.Library("dtensor", "IMPL")
|
||||
# _dt_lib_impl.impl("from_local", from_local_tensor, "Autograd")
|
||||
|
||||
x = torch.ones(1)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
def test_dynamo_dtensor_from_local_redistribute(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
# pass in tensor as inputs/outputs, create DTensor and run redistribute
|
||||
# (allgather collective) inside the fn
|
||||
def fn(x):
|
||||
dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
|
||||
return dt.redistribute(mesh, [Replicate()]).to_local() + 2
|
||||
|
||||
x = torch.ones(1)
|
||||
ref = fn(x)
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(res, ref)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user