mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
This adds a templated version of the ring attention forwards function as well as tests it with memory efficient attention. This doesn't add support for memory efficient attention in DTensor. That will be added in a follow up PR. This templating is also a POC of how to support other attention ops such as Jagged/nested tensor and as well how to implement striped attention in a scalable way. Misc changes: * Fixes all_to_all_single autograd implementation with CUDA + adds NCCL test * Adds compile support to the ring attention implementations (required some tweaks to process groups) Test plan: ``` pytest test/distributed/_tensor/test_attention.py pytest test/distributed/test_functional_api.py ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/124215 Approved by: https://github.com/wanchaol
815 lines
28 KiB
Python
815 lines
28 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
import sys
|
|
import unittest
|
|
from functools import partial, wraps
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.distributed._functional_collectives as ft_c
|
|
import torch.distributed._tensor as dt
|
|
import torch.distributed.distributed_c10d as c10d
|
|
|
|
from functorch import make_fx
|
|
from torch._inductor.utils import run_and_get_code
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.distributed.fake_pg import FakeStore
|
|
from torch.utils._triton import has_triton
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
from torch.testing._internal.common_distributed import (
|
|
MultiProcessTestCase,
|
|
MultiThreadedTestCase,
|
|
requires_nccl,
|
|
TEST_SKIPS,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TestCase,
|
|
)
|
|
|
|
|
|
def new_subgroups(group_size: int, pg_tag=None):
|
|
world_size = dist.get_world_size()
|
|
subgroups = []
|
|
cur_subgroup = None
|
|
|
|
for subgroup_id in range(world_size // group_size):
|
|
start_rank = subgroup_id * group_size
|
|
end_rank = start_rank + group_size
|
|
ranks_in_subgroup = list(range(start_rank, end_rank))
|
|
subgroup = c10d._new_group_with_tag(
|
|
ranks=ranks_in_subgroup,
|
|
pg_tag=pg_tag,
|
|
)
|
|
subgroups.append(subgroup)
|
|
|
|
rank = dist.get_rank()
|
|
if rank in ranks_in_subgroup:
|
|
cur_subgroup = subgroup
|
|
|
|
return cur_subgroup, subgroups
|
|
|
|
|
|
class TestExpand(MultiThreadedTestCase):
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_threads()
|
|
|
|
def test_expand_1d_rank_list(self):
|
|
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3])
|
|
self.assertEqual("", tag)
|
|
self.assertEqual([0, 1, 2, 3], rankset)
|
|
self.assertEqual(4, group_size)
|
|
|
|
tag, rankset, group_size = ft_c._expand_group([0, 1, 2, 3], "bla")
|
|
self.assertEqual("bla", tag)
|
|
|
|
def test_expand_2d_rank_list(self):
|
|
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]])
|
|
self.assertEqual("", tag)
|
|
self.assertEqual([0, 1, 2, 3], rankset)
|
|
self.assertEqual(2, group_size)
|
|
|
|
tag, rankset, group_size = ft_c._expand_group([[0, 1], [2, 3]], "blu")
|
|
self.assertEqual("blu", tag)
|
|
|
|
with self.assertRaisesRegex(ValueError, "group sizes must be identical"):
|
|
ft_c._expand_group([[0], [1, 2, 3]])
|
|
|
|
def test_expand_process_group(self):
|
|
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD)
|
|
self.assertEqual(c10d._get_group_tag(dist.group.WORLD), tag)
|
|
self.assertEqual([0, 1, 2, 3], rankset)
|
|
self.assertEqual(4, group_size)
|
|
|
|
tag, rankset, group_size = ft_c._expand_group(dist.group.WORLD, "bla")
|
|
self.assertEqual("bla", tag)
|
|
|
|
my_pg, others = new_subgroups(group_size=2)
|
|
tag, rankset, group_size = ft_c._expand_group(my_pg)
|
|
self.assertEqual(c10d._get_group_tag(my_pg), tag)
|
|
self.assertEqual(dist.get_process_group_ranks(my_pg), rankset)
|
|
self.assertEqual(2, group_size)
|
|
|
|
my_pg = None
|
|
for i in range(dist.get_world_size()):
|
|
group = c10d._new_group_with_tag([i], pg_tag="my_pg")
|
|
if i == dist.get_rank():
|
|
my_pg = group
|
|
tag, rankset, group_size = ft_c._expand_group(my_pg)
|
|
self.assertEqual("my_pg", tag)
|
|
self.assertEqual([dist.get_rank()], rankset)
|
|
self.assertEqual(1, group_size)
|
|
|
|
tag, rankset, group_size = ft_c._expand_group(my_pg, "bla")
|
|
self.assertEqual("bla", tag)
|
|
|
|
def test_expand_device_mesh(self):
|
|
mesh = dt.DeviceMesh("cpu", torch.arange(4))
|
|
tag, rankset, group_size = ft_c._expand_group(mesh)
|
|
self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
|
|
self.assertEqual([0, 1, 2, 3], rankset)
|
|
self.assertEqual(4, group_size)
|
|
|
|
mesh = dt.DeviceMesh("cpu", torch.arange(4))
|
|
tag, rankset, group_size = ft_c._expand_group(mesh)
|
|
self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
|
|
self.assertEqual([0, 1, 2, 3], rankset)
|
|
self.assertEqual(4, group_size)
|
|
|
|
def test_expand_device_mesh_tuple(self):
|
|
mesh = dt.DeviceMesh("cpu", torch.arange(4).view(2, 2))
|
|
with self.assertRaisesRegex(AssertionError, "Only 1D mesh"):
|
|
tag, rankset, group_size = ft_c._expand_group(mesh)
|
|
|
|
tag, rankset, group_size = ft_c._expand_group((mesh, 0))
|
|
self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=0)), tag)
|
|
expected_rankset = [0, 2] if dist.get_rank() in [0, 2] else [1, 3]
|
|
self.assertEqual(expected_rankset, rankset)
|
|
self.assertEqual(2, group_size)
|
|
|
|
tag, rankset, group_size = ft_c._expand_group((mesh, 1))
|
|
expected_rankset = [0, 1] if dist.get_rank() in [0, 1] else [2, 3]
|
|
self.assertEqual(c10d._get_group_tag(mesh.get_group(mesh_dim=1)), tag)
|
|
self.assertEqual(expected_rankset, rankset)
|
|
self.assertEqual(2, group_size)
|
|
|
|
|
|
class TestPgTag(MultiThreadedTestCase):
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_threads()
|
|
|
|
"""
|
|
The behavior we want is as follow:
|
|
|
|
- rankset+tag will always result in the same PG.
|
|
Do we enforce this by failing creation of new PGs or returning existing ones?
|
|
Return existing one.
|
|
|
|
- default tag gives existing behavior.
|
|
This means we should create duplicates.
|
|
- _expand_group on _default-tagged pg should always resolve to it
|
|
This mean we can't depend on empty tag + rankset.
|
|
"""
|
|
|
|
def test_pg_creation_with_tag(self):
|
|
my_group, _ = new_subgroups(group_size=2, pg_tag="blu")
|
|
my_group2, _ = new_subgroups(group_size=2, pg_tag="blu")
|
|
self.assertEqual(my_group, my_group2)
|
|
|
|
my_group3, _ = new_subgroups(group_size=2, pg_tag="blu2")
|
|
self.assertNotEqual(my_group, my_group3)
|
|
|
|
my_group4, _ = new_subgroups(group_size=2)
|
|
self.assertNotEqual(my_group, my_group4)
|
|
|
|
my_group5, _ = new_subgroups(group_size=2)
|
|
self.assertNotEqual(my_group4, my_group5)
|
|
|
|
def test_pg_lookup_roundtrip(self):
|
|
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
|
|
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="blu2")
|
|
pg_notag0, _ = new_subgroups(group_size=2)
|
|
pg_notag1, _ = new_subgroups(group_size=2)
|
|
|
|
def roundtrip(pg):
|
|
tag, rankset, _ = ft_c._expand_group(pg)
|
|
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
|
|
|
|
self.assertEqual(pg_tag0, roundtrip(pg_tag0))
|
|
self.assertEqual(pg_tag1, roundtrip(pg_tag1))
|
|
self.assertEqual(pg_notag0, roundtrip(pg_notag0))
|
|
self.assertEqual(pg_notag1, roundtrip(pg_notag1))
|
|
|
|
def test_pg_lookup_with_tag(self):
|
|
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
|
|
pg_tag1, _ = new_subgroups(group_size=2, pg_tag="bla")
|
|
pg_notag0, _ = new_subgroups(group_size=2)
|
|
|
|
def roundtrip(pg, pg_tag):
|
|
tag, rankset, _ = ft_c._expand_group(pg, pg_tag)
|
|
return c10d._find_pg_by_ranks_and_tag(tag, rankset)
|
|
|
|
self.assertEqual(pg_tag0, roundtrip(pg_tag1, "blu"))
|
|
self.assertEqual(pg_tag0, roundtrip(pg_notag0, "blu"))
|
|
# Cannot erase the tag of a PG
|
|
self.assertEqual(pg_tag0, roundtrip(pg_tag0, ""))
|
|
|
|
def test_find_or_create_pg(self):
|
|
pg = c10d._find_or_create_pg_by_ranks_and_tag("blu", [0, 1, 2, 3], 2)
|
|
pg_tag0, _ = new_subgroups(group_size=2, pg_tag="blu")
|
|
self.assertEqual(pg, pg_tag0)
|
|
|
|
def test_find_root_pg(self):
|
|
pg = c10d._find_pg_by_ranks_and_tag("", [0, 1, 2, 3])
|
|
self.assertEqual(dist.group.WORLD, pg)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestTraceableCollectives(MultiThreadedTestCase):
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_threads()
|
|
|
|
@parametrize("device", ["cpu", "cuda"])
|
|
def test_broadcast(self, device):
|
|
if device == "cuda":
|
|
if torch.cuda.device_count() < self.world_size:
|
|
self.skipTest("Not enough CUDA devices")
|
|
torch.cuda.set_device(dist.get_rank())
|
|
|
|
if dist.get_rank() == 0:
|
|
tensor = torch.ones([4], device=device)
|
|
else:
|
|
tensor = torch.zeros([4], device=device)
|
|
|
|
mesh = dt.DeviceMesh(device, torch.arange(4))
|
|
res = ft_c.broadcast(tensor, 0, mesh)
|
|
self.assertEqual(res, torch.ones([4], device=device))
|
|
|
|
@parametrize("device", ["cpu", "cuda"])
|
|
def test_all_reduce_eager(self, device):
|
|
if device == "cuda":
|
|
if torch.cuda.device_count() < self.world_size:
|
|
self.skipTest("Not enough CUDA devices")
|
|
torch.cuda.set_device(dist.get_rank())
|
|
|
|
tensor = torch.ones([4], device=device)
|
|
mesh = dt.DeviceMesh(device, torch.arange(4))
|
|
|
|
res = ft_c.all_reduce(tensor, "sum", mesh)
|
|
self.assertEqual(res, torch.tensor([4, 4, 4, 4], dtype=torch.float))
|
|
|
|
mesh = dt.DeviceMesh(device, torch.arange(4).view(2, 2))
|
|
res2 = ft_c.all_reduce(tensor, "sum", (mesh, 1))
|
|
self.assertEqual(res2, torch.tensor([2, 2, 2, 2], dtype=torch.float))
|
|
|
|
@parametrize("device", ["cpu", "cuda"])
|
|
def test_all_reduce_coalesced_eager(self, device):
|
|
if device == "cuda":
|
|
if torch.cuda.device_count() < self.world_size:
|
|
self.skipTest("Not enough CUDA devices")
|
|
torch.cuda.set_device(dist.get_rank())
|
|
|
|
t0 = torch.ones([4], device=device)
|
|
t1 = torch.ones([6], device=device) + 2
|
|
mesh = dt.DeviceMesh(device, torch.arange(4))
|
|
|
|
res = ft_c.all_reduce_coalesced([t0, t1], "sum", mesh)
|
|
self.assertEqual(res[0], t0 * 4)
|
|
self.assertEqual(res[1], t1 * 4)
|
|
|
|
@parametrize("device", ["cpu", "cuda"])
|
|
def test_all_gather_tensor(self, device):
|
|
if device == "cuda":
|
|
if torch.cuda.device_count() < self.world_size:
|
|
self.skipTest("Not enough CUDA devices")
|
|
torch.cuda.set_device(dist.get_rank())
|
|
|
|
# testing 1d/2d mesh
|
|
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
|
|
mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
|
|
for mesh in [mesh_1d, mesh_2d]:
|
|
dims_to_gather = [0, 1, 2]
|
|
for dim in dims_to_gather:
|
|
output_size = [3, 3, 3]
|
|
output_size[dim] *= mesh.size(0)
|
|
# each rank have its own tensor, all_gather gives a bigger tensor
|
|
local_tensor = torch.ones([3, 3, 3], device=device)
|
|
gathered_tensor = ft_c.all_gather_tensor(
|
|
local_tensor, gather_dim=dim, group=(mesh, 0)
|
|
)
|
|
self.assertEqual(gathered_tensor, torch.ones(output_size))
|
|
|
|
@parametrize("device", ["cpu", "cuda"])
|
|
def test_all_gather_into_tensor_coalesced(self, device):
|
|
if device == "cuda":
|
|
if torch.cuda.device_count() < self.world_size:
|
|
self.skipTest("Not enough CUDA devices")
|
|
torch.cuda.set_device(dist.get_rank())
|
|
|
|
tensors = [torch.ones([4], device=device), torch.ones([4], device=device) + 1]
|
|
mesh = dt.DeviceMesh(device, torch.arange(4))
|
|
|
|
res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
|
|
self.assertEqual(2, len(res))
|
|
self.assertEqual(torch.ones([4 * dist.get_world_size()], device=device), res[0])
|
|
self.assertEqual(
|
|
torch.ones([4 * dist.get_world_size()], device=device) + 1, res[1]
|
|
)
|
|
|
|
@parametrize("device", ["cpu", "cuda"])
|
|
def test_reduce_scatter_tensor(self, device):
|
|
if device == "cuda":
|
|
if torch.cuda.device_count() < self.world_size:
|
|
self.skipTest("Not enough CUDA devices")
|
|
torch.cuda.set_device(dist.get_rank())
|
|
|
|
# testing 1d/2d mesh
|
|
mesh_1d = dt.DeviceMesh(device, torch.arange(self.world_size))
|
|
mesh_2d = dt.DeviceMesh(device, torch.arange(self.world_size).view(2, 2))
|
|
for mesh in [mesh_1d, mesh_2d]:
|
|
dims_to_scatter = [0, 1]
|
|
for dim in dims_to_scatter:
|
|
group_size = mesh.size(0)
|
|
input_size = [3, 3]
|
|
output_size = [3, 3]
|
|
output_size[dim] *= group_size
|
|
input_tensor = torch.ones(output_size, device=device)
|
|
res_num = 1 * group_size
|
|
rs_tensor = ft_c.reduce_scatter_tensor(
|
|
input_tensor, "sum", scatter_dim=dim, group=(mesh, 0)
|
|
)
|
|
self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
|
|
|
|
@parametrize("device", ["cpu", "cuda"])
|
|
def test_reduce_scatter_into_tensor_coalesced(self, device):
|
|
if device == "cuda":
|
|
if torch.cuda.device_count() < self.world_size:
|
|
self.skipTest("Not enough CUDA devices")
|
|
torch.cuda.set_device(dist.get_rank())
|
|
tensors = [
|
|
torch.ones([4], dtype=torch.int64, device=device),
|
|
torch.ones([4], dtype=torch.int64, device=device) + 1,
|
|
]
|
|
mesh = dt.DeviceMesh(device, torch.arange(4))
|
|
|
|
res = ft_c.reduce_scatter_tensor_coalesced(tensors, "sum", [0, 0], mesh)
|
|
self.assertEqual(2, len(res))
|
|
self.assertEqual(torch.tensor([4], device=device), res[0])
|
|
self.assertEqual(torch.tensor([8], device=device), res[1])
|
|
|
|
|
|
class TestMetaCollectives(TestCase):
|
|
def test_all_reduce(self):
|
|
x = torch.rand((2, 3, 4), device="meta")
|
|
out = ft_c.all_reduce(x, "sum", "0")
|
|
self.assertEqual(x.size(), out.size())
|
|
|
|
|
|
class TestGradCollectives(MultiThreadedTestCase):
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_threads()
|
|
|
|
def test_all_reduce(self):
|
|
x = torch.rand([4], requires_grad=True)
|
|
y = torch.rand([4], requires_grad=True)
|
|
out = ft_c.all_reduce(x, "sum", dist.group.WORLD)
|
|
(out + y).sum().backward()
|
|
self.assertIsNone(x.grad)
|
|
|
|
|
|
class TestMakeFx(MultiThreadedTestCase):
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_threads()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
|
|
# race condition with threads causes is_fx_tracing flag to be set incorrectly.
|
|
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
|
self.assertFalse(torch.fx._symbolic_trace.is_fx_tracing())
|
|
|
|
def test_all_reduce_tracing(self):
|
|
def allred(input):
|
|
return ft_c.all_reduce(input, "sum", group=dist.group.WORLD) + 1
|
|
|
|
graph = make_fx(allred)(torch.rand(4))
|
|
FileCheck().check("all_reduce").check("wait_tensor").run(str(graph.graph))
|
|
|
|
mesh = dt.DeviceMesh("cpu", torch.arange(self.world_size))
|
|
|
|
def allred_mesh(input):
|
|
return ft_c.all_reduce(input, "sum", mesh) + 1
|
|
|
|
mesh_graph = make_fx(allred_mesh)(torch.rand(4))
|
|
FileCheck().check_not("get_attr").check("wait_tensor").run(
|
|
str(mesh_graph.graph)
|
|
)
|
|
|
|
def allred_mesh_dim(input):
|
|
return ft_c.all_reduce(input, "sum", (mesh, 0)) + 1
|
|
|
|
mesh_dim_graph = make_fx(allred_mesh_dim)(torch.rand(4))
|
|
FileCheck().check_not("get_attr").check("wait_tensor").run(
|
|
str(mesh_dim_graph.graph)
|
|
)
|
|
|
|
|
|
BACKEND = dist.Backend.NCCL if torch.cuda.is_available() else dist.Backend.GLOO
|
|
WORLD_SIZE = 2
|
|
|
|
|
|
def exit_if_lt_x_gpu(x):
|
|
if torch.cuda.device_count() < x:
|
|
sys.exit(TEST_SKIPS[f"multi-gpu-{x}"].exit_code)
|
|
|
|
|
|
def with_comms(func=None):
|
|
if func is None:
|
|
return partial(
|
|
with_comms,
|
|
)
|
|
|
|
@wraps(func)
|
|
def wrapper(self, *args, **kwargs):
|
|
if BACKEND == dist.Backend.NCCL and torch.cuda.device_count() < self.world_size:
|
|
sys.exit(TEST_SKIPS[f"multi-gpu-{self.world_size}"].exit_code)
|
|
self.dist_init()
|
|
func(self)
|
|
self.destroy_comms()
|
|
|
|
return wrapper
|
|
|
|
|
|
class TestCollectivesWithNCCL(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
os.environ["WORLD_SIZE"] = str(self.world_size)
|
|
os.environ["BACKEND"] = dist.Backend.NCCL
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def device(self):
|
|
return torch.device(self.rank)
|
|
|
|
@property
|
|
def world_size(self):
|
|
return WORLD_SIZE
|
|
|
|
@property
|
|
def process_group(self):
|
|
return dist.group.WORLD
|
|
|
|
def dist_init(self):
|
|
dist.init_process_group(
|
|
backend=BACKEND,
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
init_method=f"file://{self.file_name}",
|
|
)
|
|
|
|
# set device for nccl pg for collectives
|
|
if BACKEND == "nccl":
|
|
torch.cuda.set_device(self.rank)
|
|
|
|
def destroy_comms(self):
|
|
# Wait for all ranks to reach here before starting shutdown.
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
|
|
@requires_nccl()
|
|
@with_comms()
|
|
def test_all_gather_into_tensor_coalesced(self):
|
|
exit_if_lt_x_gpu(self.world_size)
|
|
|
|
tensors = [
|
|
torch.ones([4], device=f"cuda:{self.rank}"),
|
|
torch.ones([4], device=f"cuda:{self.rank}") + 1,
|
|
]
|
|
mesh = dt.DeviceMesh(f"cuda:{self.rank}", torch.arange(self.world_size))
|
|
|
|
res = ft_c.all_gather_into_tensor_coalesced(tensors, mesh)
|
|
self.assertEqual(2, len(res))
|
|
self.assertEqual(torch.ones([4 * dist.get_world_size()]), res[0])
|
|
self.assertEqual(torch.ones([4 * dist.get_world_size()]) + 1, res[1])
|
|
|
|
@with_comms()
|
|
def test_all_to_all_single(self):
|
|
device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
|
|
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
|
|
rank = dist.get_rank()
|
|
|
|
row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
|
|
x = torch.ones(int(row), 5, device=device) * (rank + 1)
|
|
split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
|
|
y = ft_c.all_to_all_single(
|
|
x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
|
|
)
|
|
expected = []
|
|
for idx, tensor in enumerate(torch.split(x, split_sizes)):
|
|
expected.append(torch.full_like(tensor, (idx + 1)))
|
|
expected = torch.cat(expected)
|
|
self.assertEqual(y, expected)
|
|
|
|
@with_comms()
|
|
def test_all_to_all_single_1d_input(self):
|
|
device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
|
|
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
|
|
rank = dist.get_rank()
|
|
|
|
row = self.world_size * (rank + 1) * (self.world_size + 1) / 2
|
|
x = torch.ones(int(row), device=device) * (rank + 1)
|
|
split_sizes = [(i + 1) * (rank + 1) for i in range(self.world_size)]
|
|
y = ft_c.all_to_all_single(
|
|
x, output_split_sizes=split_sizes, input_split_sizes=split_sizes, group=mesh
|
|
)
|
|
expected = []
|
|
for idx, tensor in enumerate(torch.split(x, split_sizes)):
|
|
expected.append(torch.full_like(tensor, (idx + 1)))
|
|
expected = torch.cat(expected)
|
|
self.assertEqual(y, expected)
|
|
|
|
@with_comms()
|
|
def test_all_to_all_single_split_sizes_none(self):
|
|
device = "cuda" if BACKEND == dist.Backend.NCCL else "cpu"
|
|
mesh = dt.DeviceMesh(device, torch.arange(self.world_size))
|
|
rank = dist.get_rank()
|
|
|
|
x = torch.ones(self.world_size, self.world_size, device=device) * (rank + 1)
|
|
y = ft_c.all_to_all_single(
|
|
x, output_split_sizes=None, input_split_sizes=None, group=mesh
|
|
)
|
|
expected = []
|
|
for idx, tensor in enumerate(torch.chunk(x, self.world_size)):
|
|
expected.append(torch.full_like(tensor, (idx + 1)))
|
|
expected = torch.cat(expected)
|
|
self.assertEqual(y, expected)
|
|
|
|
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
|
@requires_nccl()
|
|
@with_comms()
|
|
def test_tracing(self):
|
|
def allreduce(t, pg):
|
|
return ft_c.all_reduce(t, "sum", pg)
|
|
|
|
compiled_allreduce = torch.compile(allreduce, fullgraph=True)
|
|
compiled_allreduce(torch.randn(8, device=self.device), self.process_group)
|
|
|
|
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
|
def test_tracing_with_fakepg(self):
|
|
exit_if_lt_x_gpu(self.world_size)
|
|
|
|
def allreduce(t, pg):
|
|
return ft_c.all_reduce(t, "sum", pg)
|
|
|
|
compiled_allreduce = torch.compile(allreduce, fullgraph=True)
|
|
dist.init_process_group(
|
|
backend="fake",
|
|
rank=0,
|
|
world_size=8,
|
|
store=FakeStore(),
|
|
)
|
|
allreduce(torch.randn(8, device=self.device), pg=dist.group.WORLD)
|
|
|
|
|
|
class TestNCCLCollectivesWithWorldSize4(TestCollectivesWithNCCL):
|
|
@property
|
|
def world_size(self):
|
|
return 4
|
|
|
|
@requires_nccl()
|
|
@with_comms()
|
|
def test_permute_tensor_with_sub_group(self):
|
|
exit_if_lt_x_gpu(self.world_size)
|
|
|
|
device = "cuda"
|
|
mesh_dim_names = ["dp", "tp"]
|
|
|
|
mesh_2d = dt.init_device_mesh(
|
|
device, (2, self.world_size // 2), mesh_dim_names=mesh_dim_names
|
|
)
|
|
|
|
for mesh_name in mesh_dim_names:
|
|
mesh = mesh_2d[mesh_name]
|
|
rank = mesh.get_local_rank()
|
|
|
|
# rank0: [0., 1.], rank1: [2., 3.]
|
|
send_tensor = torch.arange(2, dtype=torch.float32, device=device) + 2 * rank
|
|
recvd_tensor = ft_c.permute_tensor(send_tensor, [1, 0], group=mesh)
|
|
|
|
# rank0: [2., 3.], rank1: [0., 1.]
|
|
expected = torch.arange(2, dtype=torch.float32, device=device) + 2 * (
|
|
(rank - 1 + 2) % 2
|
|
)
|
|
self.assertEqual(
|
|
recvd_tensor,
|
|
expected,
|
|
msg=f"Expected {expected} on {self.rank=} (local_rank={rank}), "
|
|
f"but received {recvd_tensor} instead.",
|
|
)
|
|
|
|
|
|
@instantiate_parametrized_tests
|
|
class TestFunctionalAutograd(MultiThreadedTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_threads()
|
|
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
@parametrize("compile", [True, False])
|
|
def test_all_to_all_single(self, compile: bool = True) -> None:
|
|
group = dist.group.WORLD.group_name
|
|
|
|
t = torch.ones((self.world_size, 2), requires_grad=True)
|
|
|
|
def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
|
|
sizes = [1] * world_size
|
|
t = t * 2
|
|
assert t.requires_grad
|
|
out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
|
|
out = out + 0
|
|
return out
|
|
|
|
if compile:
|
|
compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
|
|
else:
|
|
compiled = my_func
|
|
|
|
out = compiled(t, self.world_size)
|
|
self.assertEqual(out.shape, t.shape)
|
|
self.assertEqual(out, torch.full_like(t, 2.0))
|
|
self.assertIsNotNone(out.grad_fn)
|
|
self.assertTrue(out.requires_grad)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
self.assertEqual(t.grad, torch.full_like(t, 2.0))
|
|
|
|
def test_all_to_all_single_inductor(self) -> None:
|
|
group = dist.group.WORLD.group_name
|
|
|
|
t = torch.rand((self.world_size, 2), requires_grad=True)
|
|
|
|
def my_func(t: torch.Tensor, world_size: int) -> torch.Tensor:
|
|
sizes = [1] * world_size
|
|
t = t * 10
|
|
assert t.requires_grad
|
|
out = ft_c.all_to_all_single_autograd(t, sizes, sizes, group)
|
|
out = out + 2
|
|
return out.sum()
|
|
|
|
compiled = torch.compile(my_func, fullgraph=True)
|
|
|
|
def run_with_backward():
|
|
out = compiled(t, self.world_size)
|
|
out.backward()
|
|
|
|
res, codes = run_and_get_code(run_with_backward)
|
|
for code in codes:
|
|
FileCheck().check_count(
|
|
"_c10d_functional.all_to_all_single.default", 1, exactly=True
|
|
).check_count("_c10d_functional.wait_tensor.default", 1, exactly=True).run(
|
|
code
|
|
)
|
|
|
|
self.assertIsNotNone(t.grad)
|
|
|
|
@parametrize("compile", [True, False])
|
|
def test_all_gather_tensor(self, compile: bool) -> None:
|
|
group = dist.group.WORLD.group_name
|
|
|
|
def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
|
|
assert t.requires_grad
|
|
out = ft_c.all_gather_tensor_autograd(
|
|
t * 1.0,
|
|
gather_dim=dim,
|
|
group=group,
|
|
)
|
|
out = out * 1.0
|
|
return out
|
|
|
|
if compile:
|
|
compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
|
|
else:
|
|
compiled = my_func
|
|
|
|
dims_to_gather = [0, 1, 2]
|
|
for dim in dims_to_gather:
|
|
output_size = [3, 3, 3]
|
|
output_size[dim] *= self.world_size
|
|
# each rank have its own tensor, all_gather gives a bigger tensor
|
|
local_tensor = torch.ones([3, 3, 3], requires_grad=True)
|
|
gathered_tensor = compiled(local_tensor, dim)
|
|
self.assertEqual(gathered_tensor, torch.ones(output_size))
|
|
|
|
gathered_tensor.sum().backward()
|
|
self.assertEqual(
|
|
local_tensor.grad,
|
|
torch.full((3, 3, 3), fill_value=float(self.world_size)),
|
|
)
|
|
|
|
@parametrize("compile", [True, False])
|
|
def test_reduce_scatter_tensor(self, compile: bool) -> None:
|
|
group = dist.group.WORLD.group_name
|
|
|
|
def my_func(t: torch.Tensor, dim: int) -> torch.Tensor:
|
|
assert t.requires_grad
|
|
rs_tensor = (
|
|
ft_c.reduce_scatter_tensor_autograd(
|
|
input_tensor * 1.0, "sum", scatter_dim=dim, group=group
|
|
)
|
|
* 1.0
|
|
)
|
|
return rs_tensor
|
|
|
|
if compile:
|
|
compiled = torch.compile(my_func, fullgraph=True, backend="aot_eager")
|
|
else:
|
|
compiled = my_func
|
|
|
|
dims_to_scatter = [0, 1]
|
|
for dim in dims_to_scatter:
|
|
group_size = self.world_size
|
|
input_size = [3, 3]
|
|
output_size = [3, 3]
|
|
output_size[dim] *= group_size
|
|
input_tensor = torch.ones(output_size, requires_grad=True)
|
|
rs_tensor = compiled(input_tensor, dim)
|
|
res_num = 1 * group_size
|
|
self.assertEqual(rs_tensor, torch.ones(input_size) * res_num)
|
|
rs_tensor.sum().backward()
|
|
self.assertEqual(input_tensor.grad, torch.full(output_size, fill_value=1.0))
|
|
|
|
|
|
class TestFunctionalAutogradWithNCCL(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
os.environ["WORLD_SIZE"] = str(self.world_size)
|
|
os.environ["BACKEND"] = dist.Backend.NCCL
|
|
self._spawn_processes()
|
|
|
|
@property
|
|
def device(self):
|
|
return torch.device(self.rank)
|
|
|
|
@property
|
|
def world_size(self):
|
|
return 2
|
|
|
|
@property
|
|
def process_group(self):
|
|
return dist.group.WORLD
|
|
|
|
def dist_init(self):
|
|
dist.init_process_group(
|
|
backend=BACKEND,
|
|
world_size=self.world_size,
|
|
rank=self.rank,
|
|
init_method=f"file://{self.file_name}",
|
|
)
|
|
|
|
# set device for nccl pg for collectives
|
|
if BACKEND == "nccl":
|
|
torch.cuda.set_device(self.rank)
|
|
|
|
def destroy_comms(self):
|
|
# Wait for all ranks to reach here before starting shutdown.
|
|
dist.barrier()
|
|
dist.destroy_process_group()
|
|
|
|
@requires_nccl()
|
|
@with_comms()
|
|
def test_all_to_all_single(self) -> None:
|
|
group = self.process_group.group_name
|
|
|
|
t = torch.ones((self.world_size, 2), requires_grad=True, device=self.device)
|
|
|
|
sizes = [1] * self.world_size
|
|
assert t.requires_grad
|
|
out = ft_c.all_to_all_single_autograd(t * 2, sizes, sizes, group) + 0
|
|
|
|
self.assertEqual(out.shape, t.shape)
|
|
self.assertEqual(out, torch.full_like(t, 2.0))
|
|
self.assertIsNotNone(out.grad_fn)
|
|
self.assertTrue(out.requires_grad)
|
|
loss = out.sum()
|
|
loss.backward()
|
|
self.assertEqual(t.grad, torch.full_like(t, 2.0))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|