Support multi-gpu CI for inductor-distributed (#87996)

This test by itself isn't the end goal, but it is a minimal test that exercises multi-gpu and the focus of the PR is the infra behind enabling that.  I'll follow up with more tests using actual models etc.

and @malfet @desertfire for awareness/feedback on the infra side
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87996
Approved by: https://github.com/aazzolini
This commit is contained in:
Will Constable
2022-11-02 03:52:17 +00:00
committed by PyTorch MergeBot
parent 95fc0bcaad
commit a51da28551
3 changed files with 74 additions and 2 deletions

View File

@ -29,6 +29,7 @@ jobs:
{ config: "inductor", shard: 5, num_shards: 7, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 6, num_shards: 7, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor", shard: 7, num_shards: 7, runner: "linux.g5.4xlarge.nvidia.gpu" },
{ config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" },
]}
linux-bionic-cuda11_6-py3_10-gcc7-inductor-test:

View File

@ -250,6 +250,8 @@ test_dynamo_shard() {
}
test_inductor_distributed() {
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
# with if required # gpus aren't available
PYTORCH_TEST_WITH_INDUCTOR=0 PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include distributed/test_dynamo_distributed
assert_git_not_dirty
}
@ -728,6 +730,10 @@ elif [[ "$TEST_CONFIG" == distributed ]]; then
elif [[ "$TEST_CONFIG" == deploy ]]; then
checkout_install_torchdeploy
test_torch_deploy
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
install_filelock
install_triton
test_inductor_distributed
elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
test_without_numpy
install_torchvision

View File

@ -2,17 +2,18 @@
import os
import unittest
from unittest.mock import patch
import torch
import torch._dynamo
import torch._dynamo.test_case
import torch.distributed as dist
from contextlib import contextmanager
from torch import nn
from torch._dynamo import config
from torch._dynamo.utils import same
from torch._inductor.utils import has_triton
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing._internal.common_distributed import requires_nccl
from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl
import torch._dynamo.logging
def init_weights(m):
if isinstance(m, nn.Linear):
@ -31,6 +32,13 @@ class ToyModel(nn.Module):
def forward(self, inputs):
return self.net(inputs)
def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(device)
m.apply(init_weights)
inputs = torch.rand(bsz, in_feat).to(device)
outputs = m(inputs)
return m, inputs, outputs
class CheckSplitsCompiler:
def __init__(self):
@ -40,6 +48,63 @@ class CheckSplitsCompiler:
self.compiler_called += 1
return gm
@contextmanager
def _per_rank_init(rank, world_size):
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
# Just manually implement the most important part of the dynamo behavior to reset/clear.
torch.cuda.set_device(rank)
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '6789'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
yield
torch._dynamo.reset()
torch._dynamo.utils.counters.clear()
dist.destroy_process_group()
@requires_nccl()
class TestDistributedMultiProc(MultiProcessTestCase):
def setUp(self):
super(TestDistributedMultiProc, self).setUp()
self._spawn_processes()
def tearDown(self):
super(TestDistributedMultiProc, self).tearDown()
try:
os.remove(self.file_name)
except OSError:
pass
@property
def world_size(self) -> int:
return torch.cuda.device_count()
@classmethod
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
# Don't enable DDP + ReplicatedTensor, as that breaks Dynamo+DDP
# TODO(whc) why is ReplicatedTensor defaulted=True in MultiProcessTestCase, and should we support it?
# from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor
# _set_ddp_with_replicated_tensor(True)
# The rest is copypasta from MultiProcessTestCase._run
self = cls(test_name)
self.rank = rank
self.file_name = file_name
self.run_test(test_name, parent_pipe)
@skip_if_lt_x_gpu(2)
@patch.object(config, "optimize_ddp", False)
def test_ddp_baseline_aot_eager_multiprocess(self):
with _per_rank_init(self.rank, self.world_size):
self.assertFalse(config.optimize_ddp)
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
m = DDP(m, device_ids=[self.rank])
m = torch._dynamo.optimize("aot_eager")(m)
outputs = m(inputs)
self.assertTrue(same(correct_outputs, outputs))
@requires_nccl()
class TestDistributed(torch._dynamo.test_case.TestCase):