Files
pytorch/test/dynamo/test_fake_distributed.py

145 lines
6.2 KiB
Python

# Owner(s): ["module: dynamo"]
from unittest import skipIf
import torch
import torch.distributed as dist
from torch._dynamo.test_case import TestCase as DynamoTestCase
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
from torch.testing._internal.common_utils import instantiate_parametrized_tests
if dist.is_available():
from torch.distributed._functional_collectives import (
all_to_all_single_autograd,
wait_tensor,
)
from torch.distributed.device_mesh import init_device_mesh
def normalize_graph(gm):
return normalize_gm(gm.print_readable(print_output=False))
@skipIf(not dist.is_available(), "requires distributed")
class TestFakeDistributed(DynamoTestCase):
def setUp(self):
# Use FakeProcessGroup to run tests on a single process
dist.init_process_group(backend="fake", rank=0, world_size=2)
self.local_rank = 0
self.world_size = 2
def tearDown(self):
dist.destroy_process_group()
def test_all_to_all_single_autograd(self):
backend = AotEagerAndRecordGraphs()
@torch.compile(fullgraph=True, backend=backend)
def fn(x):
return all_to_all_single_autograd(
x,
None, # Will use equal splits
None, # Will use equal splits
group=dist.group.WORLD,
)
# Test backed shapes
x = torch.randn(8, 8, requires_grad=True)
torch._dynamo.mark_dynamic(x, 0)
torch._dynamo.mark_dynamic(x, 1)
wait_tensor(fn(x))
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
self.assertExpectedInline(
normalize_graph(backend.fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s77)", primals_2: "Sym(s27)", primals_3: "f32[s77, s27]"):
floordiv: "Sym((s77//2))" = primals_1 // 2
all_to_all_single: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.all_to_all_single.default(primals_3, [floordiv, floordiv], [floordiv, floordiv], '0'); primals_3 = None
wait_tensor: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None
return (wait_tensor, primals_1, primals_2, floordiv)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_graph(backend.bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s77)", primals_2: "Sym(s27)", floordiv: "Sym((s77//2))", tangents_1: "f32[2*((s77//2)), s27]"):
all_to_all_single_1: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.all_to_all_single.default(tangents_1, [floordiv, floordiv], [floordiv, floordiv], '0'); tangents_1 = floordiv = None
wait_tensor_1: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None
return (None, None, wait_tensor_1)
""", # noqa: B950
)
backend.fw_graphs.clear()
backend.bw_graphs.clear()
# Test unbacked shapes
x = torch.randn(8, 8, 8, requires_grad=True)
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
torch._dynamo.decorators.mark_unbacked(x, 2)
wait_tensor(fn(x))
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
self.assertExpectedInline(
normalize_graph(backend.fw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", primals_4: "f32[u0, u1, u2]"):
ge_1: "Sym(u0 >= 0)" = primals_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = primals_2 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
ge_5: "Sym(u2 >= 0)" = primals_3 >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
floordiv: "Sym((u0//2))" = primals_1 // 2
all_to_all_single: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.all_to_all_single.default(primals_4, [floordiv, floordiv], [floordiv, floordiv], '0'); primals_4 = None
wait_tensor: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None
return (wait_tensor, primals_1, primals_2, primals_3, floordiv)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_graph(backend.bw_graphs[0]),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", floordiv: "Sym((u0//2))", tangents_1: "f32[2*((u0//2)), u1, u2]"):
all_to_all_single_1: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.all_to_all_single.default(tangents_1, [floordiv, floordiv], [floordiv, floordiv], '0'); tangents_1 = floordiv = None
wait_tensor_1: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None
return (None, None, None, wait_tensor_1)
""", # noqa: B950
)
def test_device_mesh_get_local_rank(self):
device_mesh = init_device_mesh(
device_type="cpu",
mesh_shape=(self.world_size,),
mesh_dim_names=("dp",), # data parallel dimension
)
@torch.compile(backend="eager", fullgraph=True)
def fn(x):
local_rank = device_mesh.get_local_rank()
global_rank = device_mesh.get_rank()
if "dp" not in device_mesh.mesh_dim_names:
x = x * 2
return x + local_rank + global_rank
x = torch.ones(10)
res = fn(x)
self.assertEqual(res, x)
instantiate_parametrized_tests(TestFakeDistributed)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()