mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-05 16:44:58 +08:00
py 3.9 has spawn. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146705 Approved by: https://github.com/colesbury
299 lines
11 KiB
Python
299 lines
11 KiB
Python
# Owner(s): ["oncall: distributed"]
|
|
|
|
import os
|
|
import sys
|
|
|
|
import torch
|
|
import torch.cuda
|
|
import torch.distributed as dist
|
|
import torch.distributed.algorithms._quantization.quantization as quant
|
|
from torch.distributed.algorithms._quantization.quantization import DQuantType
|
|
from torch.testing._internal.common_distributed import (
|
|
init_multigpu_helper,
|
|
MultiProcessTestCase,
|
|
requires_gloo,
|
|
requires_nccl,
|
|
skip_if_lt_x_gpu,
|
|
skip_if_rocm_multiprocess,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
run_tests,
|
|
skip_but_pass_in_sandcastle_if,
|
|
TEST_WITH_DEV_DBG_ASAN,
|
|
)
|
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = False
|
|
|
|
if not dist.is_available():
|
|
print("Distributed not available, skipping tests", file=sys.stderr)
|
|
sys.exit(0)
|
|
|
|
|
|
def _build_tensor(size, value=None, dtype=torch.float, device_id=None):
|
|
if value is None:
|
|
value = size
|
|
if device_id is None:
|
|
return torch.empty(size, dtype=dtype).fill_(value)
|
|
else:
|
|
return torch.empty(size, dtype=dtype).fill_(value).cuda(device_id)
|
|
|
|
|
|
if TEST_WITH_DEV_DBG_ASAN:
|
|
print(
|
|
"Skip dev-asan as torch + multiprocessing spawn have known issues",
|
|
file=sys.stderr,
|
|
)
|
|
sys.exit(0)
|
|
|
|
BACKEND = os.environ["BACKEND"]
|
|
if BACKEND == "gloo" or BACKEND == "nccl":
|
|
|
|
class DistQuantizationTests(MultiProcessTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self._spawn_processes()
|
|
torch.backends.cudnn.flags(enabled=True, allow_tf32=False).__enter__()
|
|
|
|
def tearDown(self):
|
|
super().tearDown()
|
|
try:
|
|
os.remove(self.file_name)
|
|
except OSError:
|
|
pass
|
|
|
|
@property
|
|
def op_timeout_sec(self):
|
|
return 1
|
|
|
|
@property
|
|
def world_size(self):
|
|
return int(os.environ["WORLD_SIZE"])
|
|
|
|
@requires_gloo()
|
|
@skip_but_pass_in_sandcastle_if(
|
|
BACKEND != "gloo", "Only gloo backend supports all_gather_fp16"
|
|
)
|
|
def test_all_gather_fp16(self):
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend="gloo"
|
|
)
|
|
group = list(range(0, self.world_size))
|
|
group_id = dist.group.WORLD
|
|
self._test_all_gather(
|
|
group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.FP16
|
|
)
|
|
|
|
@requires_gloo()
|
|
@skip_but_pass_in_sandcastle_if(
|
|
BACKEND != "gloo", "Only gloo backend supports all_gather_fp16"
|
|
)
|
|
def test_all_gather_bfp16(self):
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend="gloo"
|
|
)
|
|
group = list(range(0, self.world_size))
|
|
group_id = dist.group.WORLD
|
|
self._test_all_gather(
|
|
group, group_id, self.rank, dtype=torch.float32, qtype=DQuantType.BFP16
|
|
)
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(
|
|
BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16"
|
|
)
|
|
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
|
|
@skip_if_rocm_multiprocess
|
|
def test_all_to_all_fp16(self):
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
|
|
)
|
|
group = list(range(0, self.world_size))
|
|
group_id = dist.new_group(range(self.world_size))
|
|
rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
|
|
self._test_all_to_all(
|
|
group,
|
|
group_id,
|
|
self.rank,
|
|
cuda=True,
|
|
rank_to_GPU=rank_to_GPU,
|
|
dtype=torch.float32,
|
|
qtype=DQuantType.FP16,
|
|
)
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(
|
|
BACKEND != "nccl", "Only nccl backend supports all_to_all_fp16"
|
|
)
|
|
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
|
|
@skip_if_rocm_multiprocess
|
|
def test_all_to_all_bfp16(self):
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
|
|
)
|
|
group = list(range(0, self.world_size))
|
|
group_id = dist.new_group(range(self.world_size))
|
|
rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
|
|
self._test_all_to_all(
|
|
group,
|
|
group_id,
|
|
self.rank,
|
|
cuda=True,
|
|
rank_to_GPU=rank_to_GPU,
|
|
dtype=torch.float32,
|
|
qtype=DQuantType.BFP16,
|
|
)
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(
|
|
BACKEND != "nccl", "Only nccl backend supports all_to_all_single_fp16"
|
|
)
|
|
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
|
|
def test_all_to_all_single_fp16(self):
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
|
|
)
|
|
group = list(range(0, self.world_size))
|
|
group_id = dist.new_group(range(self.world_size))
|
|
rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
|
|
self._test_all_to_all_single(
|
|
group,
|
|
group_id,
|
|
self.rank,
|
|
cuda=True,
|
|
rank_to_GPU=rank_to_GPU,
|
|
dtype=torch.float32,
|
|
qtype=DQuantType.FP16,
|
|
)
|
|
|
|
@requires_nccl()
|
|
@skip_but_pass_in_sandcastle_if(
|
|
BACKEND != "nccl", "Only nccl backend supports all_to_all_single_bfp16"
|
|
)
|
|
@skip_if_lt_x_gpu(int(os.environ["WORLD_SIZE"]))
|
|
def test_all_to_all_single_bfp16(self):
|
|
store = dist.FileStore(self.file_name, self.world_size)
|
|
dist.init_process_group(
|
|
store=store, rank=self.rank, world_size=self.world_size, backend="nccl"
|
|
)
|
|
group = list(range(0, self.world_size))
|
|
group_id = dist.new_group(range(self.world_size))
|
|
rank_to_GPU = init_multigpu_helper(self.world_size, BACKEND)
|
|
self._test_all_to_all_single(
|
|
group,
|
|
group_id,
|
|
self.rank,
|
|
cuda=True,
|
|
rank_to_GPU=rank_to_GPU,
|
|
dtype=torch.float32,
|
|
qtype=DQuantType.BFP16,
|
|
)
|
|
|
|
def _test_all_gather(
|
|
self,
|
|
group,
|
|
group_id,
|
|
rank,
|
|
cuda=False,
|
|
rank_to_GPU=None,
|
|
dtype=torch.float,
|
|
qtype=None,
|
|
):
|
|
for dest in group:
|
|
tensor = _build_tensor([dest + 1, dest + 1], rank, dtype=dtype)
|
|
tensors = [
|
|
_build_tensor([dest + 1, dest + 1], -1, dtype=dtype) for i in group
|
|
]
|
|
expected_tensors = [
|
|
_build_tensor([dest + 1, dest + 1], i, dtype=dtype) for i in group
|
|
]
|
|
if cuda:
|
|
tensor = tensor.cuda(rank_to_GPU[rank][0])
|
|
tensors = [t.cuda(rank_to_GPU[rank][0]) for t in tensors]
|
|
allgather = quant.auto_quantize(dist.all_gather, qtype, quant_loss=None)
|
|
allgather(tensors, tensor, group=group_id, async_op=False)
|
|
|
|
for t1, t2 in zip(tensors, expected_tensors):
|
|
self.assertEqual(t1, t2)
|
|
|
|
def _test_all_to_all(
|
|
self,
|
|
group,
|
|
group_id,
|
|
rank,
|
|
cuda=False,
|
|
rank_to_GPU=None,
|
|
dtype=torch.float,
|
|
qtype=None,
|
|
):
|
|
if group_id is not None:
|
|
size = len(group)
|
|
in_splits = [i + 1 for i in group]
|
|
in_tensors = [
|
|
torch.ones([in_splits[i], size], dtype=dtype) * rank
|
|
for i, _ in enumerate(group)
|
|
]
|
|
out_tensors = [
|
|
torch.ones([(rank + 1), size], dtype=dtype) for _ in group
|
|
]
|
|
expected_tensors = [
|
|
torch.ones([rank + 1, size], dtype=dtype) * i for i in group
|
|
]
|
|
if cuda:
|
|
in_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in in_tensors]
|
|
expected_tensors = [
|
|
t.cuda(rank_to_GPU[rank][0]) for t in expected_tensors
|
|
]
|
|
out_tensors = [t.cuda(rank_to_GPU[rank][0]) for t in out_tensors]
|
|
quantize_alltoall = quant.auto_quantize(
|
|
dist.all_to_all, qtype, quant_loss=None
|
|
)
|
|
quantize_alltoall(out_tensors, in_tensors, group=group_id)
|
|
for t1, t2 in zip(out_tensors, expected_tensors):
|
|
self.assertEqual(t1, t2)
|
|
|
|
def _test_all_to_all_single(
|
|
self,
|
|
group,
|
|
group_id,
|
|
rank,
|
|
cuda=False,
|
|
rank_to_GPU=None,
|
|
dtype=torch.float,
|
|
qtype=DQuantType.FP16,
|
|
):
|
|
if group_id is not None:
|
|
size = len(group)
|
|
in_splits = [i + 1 for i in group]
|
|
out_splits = [rank + 1 for _ in group]
|
|
in_tensor = torch.ones([sum(in_splits), size], dtype=dtype) * rank
|
|
out_tensor = torch.ones([(rank + 1) * size, size], dtype=dtype)
|
|
expected_tensor = torch.cat(
|
|
[torch.ones([rank + 1, size], dtype=dtype) * i for i in group]
|
|
)
|
|
if cuda:
|
|
rank_to_GPU = rank_to_GPU[rank][0]
|
|
in_tensor = in_tensor.cuda(rank_to_GPU)
|
|
expected_tensor = expected_tensor.cuda(rank_to_GPU)
|
|
out_tensor = out_tensor.cuda(rank_to_GPU)
|
|
quantize_alltoall_single = quant.auto_quantize(
|
|
dist.all_to_all_single, qtype, quant_loss=None
|
|
)
|
|
quantize_alltoall_single(
|
|
out_tensor,
|
|
in_tensor,
|
|
out_splits=out_splits,
|
|
in_splits=in_splits,
|
|
group=group_id,
|
|
)
|
|
self.assertEqual(out_tensor, expected_tensor)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|