mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support multi-gpu CI for inductor-distributed (#87996)
This test by itself isn't the end goal, but it is a minimal test that exercises multi-gpu and the focus of the PR is the infra behind enabling that. I'll follow up with more tests using actual models etc. and @malfet @desertfire for awareness/feedback on the infra side Pull Request resolved: https://github.com/pytorch/pytorch/pull/87996 Approved by: https://github.com/aazzolini
This commit is contained in:
committed by
PyTorch MergeBot
parent
95fc0bcaad
commit
a51da28551
1
.github/workflows/inductor.yml
vendored
1
.github/workflows/inductor.yml
vendored
@ -29,6 +29,7 @@ jobs:
|
||||
{ config: "inductor", shard: 5, num_shards: 7, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 6, num_shards: 7, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor", shard: 7, num_shards: 7, runner: "linux.g5.4xlarge.nvidia.gpu" },
|
||||
{ config: "inductor_distributed", shard: 1, num_shards: 1, runner: "linux.g5.12xlarge.nvidia.gpu" },
|
||||
]}
|
||||
|
||||
linux-bionic-cuda11_6-py3_10-gcc7-inductor-test:
|
||||
|
@ -250,6 +250,8 @@ test_dynamo_shard() {
|
||||
}
|
||||
|
||||
test_inductor_distributed() {
|
||||
# this runs on both single-gpu and multi-gpu instance. It should be smart about skipping tests that aren't supported
|
||||
# with if required # gpus aren't available
|
||||
PYTORCH_TEST_WITH_INDUCTOR=0 PYTORCH_TEST_WITH_INDUCTOR=0 python test/run_test.py --include distributed/test_dynamo_distributed
|
||||
assert_git_not_dirty
|
||||
}
|
||||
@ -728,6 +730,10 @@ elif [[ "$TEST_CONFIG" == distributed ]]; then
|
||||
elif [[ "$TEST_CONFIG" == deploy ]]; then
|
||||
checkout_install_torchdeploy
|
||||
test_torch_deploy
|
||||
elif [[ "${TEST_CONFIG}" == *inductor_distributed* ]]; then
|
||||
install_filelock
|
||||
install_triton
|
||||
test_inductor_distributed
|
||||
elif [[ "${TEST_CONFIG}" == *dynamo* && "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
|
||||
test_without_numpy
|
||||
install_torchvision
|
||||
|
@ -2,17 +2,18 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._dynamo.test_case
|
||||
import torch.distributed as dist
|
||||
from contextlib import contextmanager
|
||||
from torch import nn
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.utils import same
|
||||
from torch._inductor.utils import has_triton
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.testing._internal.common_distributed import requires_nccl
|
||||
from torch.testing._internal.common_distributed import MultiProcessTestCase, skip_if_lt_x_gpu, requires_nccl
|
||||
import torch._dynamo.logging
|
||||
|
||||
def init_weights(m):
|
||||
if isinstance(m, nn.Linear):
|
||||
@ -31,6 +32,13 @@ class ToyModel(nn.Module):
|
||||
def forward(self, inputs):
|
||||
return self.net(inputs)
|
||||
|
||||
def get_model(device, bsz=20, in_feat=10, hidden_feat=5000, out_feat=5):
|
||||
m = ToyModel(in_feat=in_feat, hidden_feat=hidden_feat, out_feat=out_feat).to(device)
|
||||
m.apply(init_weights)
|
||||
inputs = torch.rand(bsz, in_feat).to(device)
|
||||
outputs = m(inputs)
|
||||
return m, inputs, outputs
|
||||
|
||||
|
||||
class CheckSplitsCompiler:
|
||||
def __init__(self):
|
||||
@ -40,6 +48,63 @@ class CheckSplitsCompiler:
|
||||
self.compiler_called += 1
|
||||
return gm
|
||||
|
||||
@contextmanager
|
||||
def _per_rank_init(rank, world_size):
|
||||
# To avoid multiple inheritance from _dynamo.test_case.TestCase and MultiProcessTestCase,
|
||||
# Just manually implement the most important part of the dynamo behavior to reset/clear.
|
||||
torch.cuda.set_device(rank)
|
||||
os.environ['MASTER_ADDR'] = 'localhost'
|
||||
os.environ['MASTER_PORT'] = '6789'
|
||||
dist.init_process_group("nccl", rank=rank, world_size=world_size)
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
yield
|
||||
torch._dynamo.reset()
|
||||
torch._dynamo.utils.counters.clear()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
class TestDistributedMultiProc(MultiProcessTestCase):
|
||||
def setUp(self):
|
||||
super(TestDistributedMultiProc, self).setUp()
|
||||
self._spawn_processes()
|
||||
|
||||
def tearDown(self):
|
||||
super(TestDistributedMultiProc, self).tearDown()
|
||||
try:
|
||||
os.remove(self.file_name)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
@property
|
||||
def world_size(self) -> int:
|
||||
return torch.cuda.device_count()
|
||||
|
||||
@classmethod
|
||||
def _run(cls, rank: int, test_name: str, file_name: str, parent_pipe) -> None:
|
||||
# Don't enable DDP + ReplicatedTensor, as that breaks Dynamo+DDP
|
||||
# TODO(whc) why is ReplicatedTensor defaulted=True in MultiProcessTestCase, and should we support it?
|
||||
# from torch.nn.parallel._replicated_tensor_ddp_utils import _set_ddp_with_replicated_tensor
|
||||
# _set_ddp_with_replicated_tensor(True)
|
||||
|
||||
# The rest is copypasta from MultiProcessTestCase._run
|
||||
self = cls(test_name)
|
||||
self.rank = rank
|
||||
self.file_name = file_name
|
||||
self.run_test(test_name, parent_pipe)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@patch.object(config, "optimize_ddp", False)
|
||||
def test_ddp_baseline_aot_eager_multiprocess(self):
|
||||
with _per_rank_init(self.rank, self.world_size):
|
||||
self.assertFalse(config.optimize_ddp)
|
||||
m, inputs, correct_outputs = get_model(f"cuda:{self.rank}")
|
||||
m = DDP(m, device_ids=[self.rank])
|
||||
m = torch._dynamo.optimize("aot_eager")(m)
|
||||
outputs = m(inputs)
|
||||
self.assertTrue(same(correct_outputs, outputs))
|
||||
|
||||
|
||||
@requires_nccl()
|
||||
class TestDistributed(torch._dynamo.test_case.TestCase):
|
||||
|
Reference in New Issue
Block a user