mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164200 Approved by: https://github.com/SherlockNoMad, https://github.com/jansel
145 lines
6.2 KiB
Python
145 lines
6.2 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
from unittest import skipIf
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch._dynamo.test_case import TestCase as DynamoTestCase
|
|
from torch._dynamo.testing import AotEagerAndRecordGraphs, normalize_gm
|
|
from torch.testing._internal.common_utils import instantiate_parametrized_tests
|
|
|
|
|
|
if dist.is_available():
|
|
from torch.distributed._functional_collectives import (
|
|
all_to_all_single_autograd,
|
|
wait_tensor,
|
|
)
|
|
from torch.distributed.device_mesh import init_device_mesh
|
|
|
|
|
|
def normalize_graph(gm):
|
|
return normalize_gm(gm.print_readable(print_output=False))
|
|
|
|
|
|
@skipIf(not dist.is_available(), "requires distributed")
|
|
class TestFakeDistributed(DynamoTestCase):
|
|
def setUp(self):
|
|
# Use FakeProcessGroup to run tests on a single process
|
|
dist.init_process_group(backend="fake", rank=0, world_size=2)
|
|
self.local_rank = 0
|
|
self.world_size = 2
|
|
|
|
def tearDown(self):
|
|
dist.destroy_process_group()
|
|
|
|
def test_all_to_all_single_autograd(self):
|
|
backend = AotEagerAndRecordGraphs()
|
|
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def fn(x):
|
|
return all_to_all_single_autograd(
|
|
x,
|
|
None, # Will use equal splits
|
|
None, # Will use equal splits
|
|
group=dist.group.WORLD,
|
|
)
|
|
|
|
# Test backed shapes
|
|
x = torch.randn(8, 8, requires_grad=True)
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
torch._dynamo.mark_dynamic(x, 1)
|
|
wait_tensor(fn(x))
|
|
self.assertEqual(len(backend.fw_graphs), 1)
|
|
self.assertEqual(len(backend.bw_graphs), 1)
|
|
self.assertExpectedInline(
|
|
normalize_graph(backend.fw_graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "Sym(s77)", primals_2: "Sym(s27)", primals_3: "f32[s77, s27]"):
|
|
floordiv: "Sym((s77//2))" = primals_1 // 2
|
|
|
|
all_to_all_single: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.all_to_all_single.default(primals_3, [floordiv, floordiv], [floordiv, floordiv], '0'); primals_3 = None
|
|
|
|
wait_tensor: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None
|
|
return (wait_tensor, primals_1, primals_2, floordiv)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
normalize_graph(backend.bw_graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "Sym(s77)", primals_2: "Sym(s27)", floordiv: "Sym((s77//2))", tangents_1: "f32[2*((s77//2)), s27]"):
|
|
all_to_all_single_1: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.all_to_all_single.default(tangents_1, [floordiv, floordiv], [floordiv, floordiv], '0'); tangents_1 = floordiv = None
|
|
wait_tensor_1: "f32[2*((s77//2)), s27]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None
|
|
return (None, None, wait_tensor_1)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
backend.fw_graphs.clear()
|
|
backend.bw_graphs.clear()
|
|
|
|
# Test unbacked shapes
|
|
x = torch.randn(8, 8, 8, requires_grad=True)
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
torch._dynamo.decorators.mark_unbacked(x, 1)
|
|
torch._dynamo.decorators.mark_unbacked(x, 2)
|
|
wait_tensor(fn(x))
|
|
self.assertEqual(len(backend.fw_graphs), 1)
|
|
self.assertEqual(len(backend.bw_graphs), 1)
|
|
self.assertExpectedInline(
|
|
normalize_graph(backend.fw_graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", primals_4: "f32[u0, u1, u2]"):
|
|
ge_1: "Sym(u0 >= 0)" = primals_1 >= 0
|
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
|
ge_3: "Sym(u1 >= 0)" = primals_2 >= 0
|
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
|
ge_5: "Sym(u2 >= 0)" = primals_3 >= 0
|
|
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u2 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
|
|
|
|
floordiv: "Sym((u0//2))" = primals_1 // 2
|
|
|
|
all_to_all_single: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.all_to_all_single.default(primals_4, [floordiv, floordiv], [floordiv, floordiv], '0'); primals_4 = None
|
|
|
|
wait_tensor: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single); all_to_all_single = None
|
|
return (wait_tensor, primals_1, primals_2, primals_3, floordiv)
|
|
""", # noqa: B950
|
|
)
|
|
self.assertExpectedInline(
|
|
normalize_graph(backend.bw_graphs[0]),
|
|
"""\
|
|
class GraphModule(torch.nn.Module):
|
|
def forward(self, primals_1: "Sym(u0)", primals_2: "Sym(u1)", primals_3: "Sym(u2)", floordiv: "Sym((u0//2))", tangents_1: "f32[2*((u0//2)), u1, u2]"):
|
|
all_to_all_single_1: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.all_to_all_single.default(tangents_1, [floordiv, floordiv], [floordiv, floordiv], '0'); tangents_1 = floordiv = None
|
|
wait_tensor_1: "f32[2*((u0//2)), u1, u2]" = torch.ops._c10d_functional.wait_tensor.default(all_to_all_single_1); all_to_all_single_1 = None
|
|
return (None, None, None, wait_tensor_1)
|
|
""", # noqa: B950
|
|
)
|
|
|
|
def test_device_mesh_get_local_rank(self):
|
|
device_mesh = init_device_mesh(
|
|
device_type="cpu",
|
|
mesh_shape=(self.world_size,),
|
|
mesh_dim_names=("dp",), # data parallel dimension
|
|
)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
local_rank = device_mesh.get_local_rank()
|
|
global_rank = device_mesh.get_rank()
|
|
if "dp" not in device_mesh.mesh_dim_names:
|
|
x = x * 2
|
|
return x + local_rank + global_rank
|
|
|
|
x = torch.ones(10)
|
|
res = fn(x)
|
|
self.assertEqual(res, x)
|
|
|
|
|
|
instantiate_parametrized_tests(TestFakeDistributed)
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|